[WIP] cosmos3: add generator model scaffold (+reasoner implementation)#121
[WIP] cosmos3: add generator model scaffold (+reasoner implementation)#121merceod wants to merge 32 commits into
Conversation
Config, weight loader, dual-pathway DiT parameter structure, Cosmos3Model with prefill/image_gen graph walks, submodule stubs, registry entry, and Nano serving config.
Dual-pathway MoT attention (QK-norm, 3D interleaved mRoPE, GQA), patchify/unpatchify, timestep embedding, and the per-domain action heads. Load the transformer in bf16 from the diffusers shard index, raising on any unfilled parameter, with the timestep MLP kept in fp32. Add config/loader/ shape structural tests.
mRoPE position ids, chat-template tokenization, and a single-image t2i pipeline over the DiT, the UniPC scheduler, and the Wan VAE, with CPU unit tests and a GPU integration test.
The text-conditioning tower's K/V doesn't depend on the denoise timestep, so it only needs to run once. Wire the DiT submodule to prefill it into the paged cache, then run only the generation tower each step, re-reading that frozen K/V (conditional and unconditional prompts kept in separate cache labels for guidance). Adds prefill/denoise entry points on the transformer and a GPU test vs the fused text-to-image pipeline: bit-exact with an in-process sdpa cache, ~37 dB image PSNR through the FlashInfer paged cache.
Generalize the fused pipeline and packing from single-frame images to multi-frame video. tokenize_prompt now emits the video system prompt plus the duration and video-resolution sentences; build_static_inputs takes a has_image_condition flag so image-to-video anchors a clean frame 0 while the rest denoise. The pipeline encodes the conditioning frame through the Wan VAE and blends it with noise, matching the diffusers Cosmos3OmniPipeline. The transformer needs no changes: its forward, denoise_step and prefill_und were already shape-general, and the generation attention is non-causal in every mode (video conditioning rides on the noisy-frame indices, not a mask). The engine submodule just threads num_frames and the conditioning flag. Rename t2i_pipeline.py -> pipeline.py (Cosmos3Pipeline) now that it covers all modes. Output is bit-for-bit identical to diffusers on t2v and i2v, and the run-text-tower-once cache path stays exact across frames.
NSagan271
left a comment
There was a problem hiding this comment.
A few comments, I'll hold off until the implementation is done before making more
| return self.linear_2(self.act(self.linear_1(sample))) | ||
|
|
||
|
|
||
| class Cosmos3MLP(nn.Module): |
There was a problem hiding this comment.
(future note): will eventually want to use the ParallelGatedMLP (and eventually use parallel fused qkv for the attention if possible---it looks like it should be possible)
| ) | ||
|
|
||
|
|
||
| class Cosmos3Pipeline: |
There was a problem hiding this comment.
Is this class just used for testing (I don't see it being used outside of the tests)? If so, can it be moved to the tests/ folder?
| out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal, enable_gqa=True) | ||
| return out.transpose(1, 2).squeeze(0).flatten(-2, -1) | ||
|
|
||
| def forward( |
There was a problem hiding this comment.
Just to verify: is this forward and its downstream callers unused in production (and everything instead delegated to forward_und / forward_gen across multiple graph walks)? If so, can you make a comment to that extent? Otherwise, we should think about optimizing this branch in the future
Extend the generator to joint video+action: domain-aware action projections plus the action mRoPE band, the forward/inverse-dynamics and policy conditioning layouts, joint packing, and the cache-once engine walk. Inverse-dynamics on the av example reproduces the reference action output (MSE 5e-5); the fused and engine paths agree exactly.
A submodule with a data-dependent or one-shot forward can set disable_torch_compile; the kv-cache and stateless engines then skip compiling it. CUDA graph capture is unaffected.
Tokenize the prompt into conditional + unconditional ids (chat template + resolution sentence), thread the request's size/guidance/seed into the denoise step metadata, and return the decoded frame as PNG. The DiT and VAE nodes skip torch.compile.
…hey share the same noised tokens and differ only in text conditioning and rotary positions, so they pack into a single FlashInfer batch; a flag falls back to the sequential two-forward path.
…l requests are generating at once their guidance branches pack into a single FlashInfer plan and forward, so the per-step matmuls and attention run once for the whole batch instead of once per request. Each request keeps its own latents, timestep, positions and scheduler and stays isolated from the others.
Both guidance branches run in one captured forward per denoise step and the multistep scheduler step stays eager afterwards, which roughly halves text-to-image latency. A submodule can name a velocity-only method to capture, run the non-capturable tail in a post-replay hook, and opt out of the post-replay seq_len advance so frozen-prefix denoise loops keep re-reading the same prefix; the combined classifier-free-guidance plan now reuses a persistent FlashInfer wrapper under capture.
COSMOS3_DISABLE_CUDA_GRAPH=1 makes get_cuda_graph_configs return nothing, so the denoise loop runs eagerly. Handy as an escape hatch if graph capture misbehaves on a given driver, and to A/B the captured vs eager path on the same build.
…mage and video generation loops now stop at each request's own step count via a check_stop on the DiT node, rather than a single count fixed when the graph is built, so one graph serves image (50 steps), video (35) and any requested count up to an upper bound. The lone extra step the loop dispatches before that stop takes effect is a no-op, so it can't run the scheduler past a request's last timestep.
…alk reuses the image denoise loop and VAE decode but emits an encoded mp4 instead of a single frame; a Cosmos3 video adapter, request type, and route wire it up, and the per-request frame count and step count default by mode (image vs video) and stop the loop at each request's own length. Image-to-video is recognized but rejected for now, since its conditioning frame still needs to be VAE-encoded on the worker.
Route the conditioning frame to the worker for /v1/videos image-to-video. A new prefill_cond walk hands the DiT node the input image, which it VAE-encodes (reusing the decoder's VAE) into the clean latent frame-0 anchor that seeds the denoise loop, matching the fused pipeline's i2v latent prep; the anchor stays fixed through the loop since its predicted velocity is zero. The video adapter now resolves the request image and passes it in as an image input instead of rejecting image-to-video.
Wire the action path end to end for HTTP. The conditioning video (or image) is VAE-encoded on the worker via a conditioned prefill walk, and the first denoise iteration builds the joint video+action latents from the per-mode masks (conditioning frames/action clean, the rest noise) instead of expecting them pre-supplied. The action loop now emits the predicted action by reusing the looped action_latents edge name — a loop's terminal output is matched into the section by name, so the previous standalone name produced no output. Action prompts are chat-templated without the image/video system prompt or resolution/duration sentences, matching the references, and load_video reads the decode device from its argument like load_image. Served inverse-dynamics on the av_0 clip reproduces the reference action within tolerance.
Forward-dynamics conditions on a first frame plus a clean action chunk and predicts the resulting video, so it runs the joint video+action denoise but emits a decoded video instead of the action. Add an action_video_gen walk whose loop body is the same joint denoise (selected when action_mode is forward_dynamics) and whose terminal output routes the predicted video latents to the VAE decoder; the decoder emits video for it. Served single-chunk forward-dynamics on the AgiBotWorld example matches the reference first chunk.
Concurrent diffusion requests were serialized: the engine capped each node's max batch size to its largest captured CUDA-graph batch size, and the cosmos3 denoise step is captured only at batch size 1 (for single-request latency), so the cap forced one request per forward. Add an opt-in CudaGraphConfig.caps_eager_batch_size — when False the captured sizes are an acceleration subset, not a batch ceiling: the engine honors the submodule's max_batch_size and replays a graph only when the batch size was captured, otherwise it runs the eager batched forward. cosmos3 sets it on the image generation capture, disables speculative scheduling on the denoise loops so concurrent requests group into one batched step (like the BAGEL image loop), and extends can_batch/forward_batched to the video walk. Throughput now scales with concurrency instead of staying flat.
NSagan271
left a comment
There was a problem hiding this comment.
First pass: mixture of nitpicks, notes for future PRs / cleaning up the code, and conceptual questions.
| # keep a non-capturable step (e.g. a multistep scheduler) out of the | ||
| # graph. Runs with REAL request ids, the original ``inputs``, and the | ||
| # cloned captured outputs, so it can finish each request's step. | ||
| if hasattr(submodule, "postprocess_captured"): |
There was a problem hiding this comment.
Can just add an empty postprocess_captured to the submodule base class with an appropriate doc strong, so that model authors can easily use this paradigm.
| und_seq = layer.forward_und(und_seq, cos, sin, cache_handle) | ||
| cache_handle.advance_seq_lens() | ||
|
|
||
| def denoise_step( |
There was a problem hiding this comment.
Shouldn't this also have inputs for sound? (and same for the batched CFG version)
| return x.detach().to(torch.float32).cpu().numpy().tobytes() | ||
| raise ValueError(f"Unsupported modality for Cosmos3: {modality!r}") | ||
|
|
||
| def load_video(self, filepath: str, device: str): |
There was a problem hiding this comment.
The accessing self.device in the base implementation is almost definitely a typo; it should be accessing the passed in device argument to be consistent with the other loader functions. So, I would fix the typo in the base implementation instead of just overriding it here.
There was a problem hiding this comment.
Nit: I would move some of the functions here (the ones that are only used in submodules) to components/
| return cond, uncond | ||
|
|
||
|
|
||
| def tokenize_t2i_prompt( |
There was a problem hiding this comment.
This function appears to be unused? (I might be wrong)
| PREFILL_COND_WALK = "prefill_cond" | ||
| # Action inverse-dynamics conditions on a full video rather than a single frame, | ||
| # so it gets its own conditioned prefill that takes the video among its inputs. | ||
| PREFILL_COND_VIDEO_WALK = "prefill_cond_video" |
There was a problem hiding this comment.
Nit: these strings are reused between cosmos3_model.py and submodules.py, considering move these to a constants.py and importing them in the two locations.
| static["mse_gen_indexes"] = static["vision_mse_loss_indexes"] - static["und_len"] | ||
| return static | ||
|
|
||
| def _new_scheduler(self, num_inference_steps: int, device, flow_shift=None): |
There was a problem hiding this comment.
Not for this PR, but if this per-request schedule logic is something that will be reused across several dffusion models, we might want to create an engine-level abstraction.
| act = act[:action_chunk] | ||
| clean_action[:, :, :raw_action_dim] = act[:, :raw_action_dim] | ||
|
|
||
| self._req[fwd_info.request_id] = { |
There was a problem hiding this comment.
(for a future PR) There is a lot of stuff going into per-request submodule state that I don't really like from a design perspective. I think this may be the cleanest and most performant way to do this right now (having the masks and latents be persistent signals would increase SHM traffic, and the putting all the metadata in the step_metadata would increase the size of ZMQ packets), but it, e.g., would disable PD disaggregation for such a node. Maybe it would feel cleaner to have some engine-level support for per-request state.
There was a problem hiding this comment.
I also think that (probably in a next PR), this state can be refactored to be lighter. For instance, there are some parameters that are already present in the step metdata; the actual input_ids for cond and uncond look like they can pass through the regular ARNodeInputs tensor inputs (though, e.g., the mrope IDs do need to be persistent state), some masks are just inverses of each other, etc.
| ) | ||
| return ARNodeInputs(input_seq_len=cond["und_len"]) | ||
|
|
||
| def _encode_conditioning(self, image, height, width, num_frames, device): |
There was a problem hiding this comment.
Does it make sense to have this VAE encoder be its own graph node (and its internals subject to torch.compile, separate scheduling, etc.)?
| def _forward_image_gen(self, cm, st, latents, time_index, **kwargs) -> dict: | ||
| scheduler = st["scheduler"] | ||
| step_index = int(time_index.reshape(-1)[0].item()) | ||
| if step_index >= len(scheduler.timesteps): |
There was a problem hiding this comment.
I think the cleanest way to avoid going on step over is to have prepare_inputs return None, and then port over the skipped_rids logic from the StatelessEngine to KVCacheEngine. I believe this forward-pass-internal if statement is not compatible with batching and cuda graphs.
|
Something to consider (I haven't traced through the code thoroughly enough to see if this is possible or even helpful, but I think it fits better with our paradigms) is to add a node for the VAE processing of the engine inputs that runs in parallel (or sequential if needed) with the DiT prefill, and have the cond_latents / cond_video_latents enter into the generation loop as graph edges via the persist signal path instead of through the submodule's persistent state. For requests that don't have video / image inputs, the conductor can set the tensor info as an empty list. My sense from benchmarking BAGEL I2T is that having a separate node for the VAE (if the compute is significant) could lead to latency / throughput improvements, but I'm of course not sure if that insight generalizes to Cosmos. |
The denoise-step CUDA graph was only captured for one square resolution, so requests at the usual generation sizes ran the eager path. Capture a graph per generation tier (320x192 / 832x480 / 1280x720), overridable with COSMOS3_GEN_CAPTURE_RES. The graph runner now keeps a capture per resolution for a walk and dispatches each request to the graph matching its own shape rather than the first one declared, so several fixed-shape captures coexist. Served graph output is identical to the eager path; the win is largest where the step is launch-bound (~2.5x at 320x192) and tapers as it grows compute-bound.
When several requests finish in the same step, one request can be cleaned up (its result already returned) while a later chunk for it is still in the output queue. Decrementing the per-request counter then raised a KeyError that aborted the whole drain and dropped the other requests' chunks too. Guard the lookup the same way new_result_tensors already does and skip the late chunk.
Concurrent action requests at the same generation walk now share a single joint video+action denoise forward, the way image and video requests already do. A new batched denoise packs each request's [video | action] tokens -- one branch when guidance is off (the guidance-scale-1 inverse/forward-dynamics and base policy case), both branches with classifier-free guidance -- and the batched attention plan routes each request to its own cache pages. The per-request masks, the joint scheduler step (now factored into a shared helper), and the domain-aware action projection run per request, so one batch can mix modes and embodiments. Adds CPU and GPU tests that check the batched output reproduces the per-request path and stays isolated across requests.
| while not self.output_queue.empty(): | ||
| result: ResultChunk = self.output_queue.get() | ||
| # A request can be cleaned up (its result already returned) while a | ||
| # late chunk is still in the queue -- common when several requests |
There was a problem hiding this comment.
I'm not sure if this should be happening---I think it's correct to have the guard here (I'd change the log to warning instead of debug though), but I'd check to see if this is due to, e.g., a VAE decoder accidentally being double-triggered?
torch.compile the inner denoise compute with fullgraph=False so the FlashInfer attention stays an opaque break and only the bandwidth-bound pointwise ops fuse; the compiled kernels then bake into the per-resolution image graphs at capture, so graphs and compile stack. t2i bs=1 over HTTP drops to ~0.92/1.84/3.64s at 256/480/720p (~1.2-1.25x over graphs alone) with no image or action-golden quality change vs the fused reference (480p/50-step engine PSNR 39.3 either way). On by default; COSMOS3_DISABLE_COMPILE_DENOISE=1 falls back to the eager step. The engine-cache and action suites pin the eager step for their bit-exact mechanism checks.
The base get_submodule signature gained an autocast_dtype hint; thread it through Cosmos3 for parity. Cosmos3 already casts the meta module to bf16 before to_empty, so params land in the checkpoint dtype and the hint is a no-op here, but the engine manager now passes it by keyword.
The default pool (max_num_pages 2048 x page_size 128) pre-allocates ~38 GB of paged K/V for the 36-layer DiT regardless of the request, which OOMs larger video on an 80 GB card. One bs=1 720p x 189-frame request needs only ~692 pages across both CFG branches, so 1024 pages cover single-request video at every tier plus image batching and free ~19 GB for activations.
The diffusion generation tower recomputes all of its K/V every denoise step and only reuses the small frozen text prefix, so the paged-cache write the autoregressive path needs is wasted work here. With COSMOS3_DENSE_FA3 set, gather the prefix K/V and run one varlen FlashAttention-3 pass over [prefix | generation] per guidance branch, bypassing the paged write+read; falls back to the paged path otherwise and under CUDA-graph capture. Adds a dense-vs-paged PSNR parity check.
The serving engine cast the Wan VAE to bf16, but its 3D convolutions are several times slower in bf16 than fp32 on this cuDNN and the reference pipeline decodes in fp32; restore fp32 for the decode. Also quantize to uint8 in the decoder so only 8-bit frames cross the worker boundary instead of a 4x-larger fp32 tensor.
The conditioning encode repeat-padded the frame across the whole clip and VAE-encoded all of it, but only latent frame 0 is ever used. Encode just that frame (the Wan VAE produces it as a standalone anchor, bit-identical) and run the encode in fp32 outside autocast like the decoder, which is much faster on this cuDNN.
Gated by COSMOS3_COMPILE_VAE (default off). The Wan VAE decode is 3D-conv bound and runs once per request at request-specific shapes, so it isn't CUDA-graphed; torch.compile fuses the pointwise epilogues around the convs. Keeps the fp32, autocast-off decode. Trims a few percent off video latency (the decode is a larger slice there) and narrows the higher-resolution image gap; the first request per resolution pays a one-time trace. Adds a PSNR A/B test against the eager decode for both image and video.
Config, weight loader, dual-pathway DiT parameter structure, Cosmos3Model with prefill/image_gen graph walks, submodule stubs, registry entry, and Nano serving config.
What does this PR do?
This PR integrates cosmos3 series of models (cosmos3-nano, DROID-nano, and super 64B) toM*.
How was it tested?
So farr, we have two load-bearing CPU gates passing. (1) config_roundtrip: Cosmos3Config.from_pretrained reproduces every field (dims, mRoPE [24,20,20]/θ5e6/margin 15000, timestep_scale 0.001, two_way, use_moe, action/sound flags, ...). (2) loader_key_coverage: the backbone's state_dict keys exactly cover the checkpoint: 814 index keys − lm_head.weight = 813 = backbone params, drop-list is exactly {lm_head.weight}, zero missing, zero unexpected.
More H200 GPU tests are underway as integration progresses.
Checklist
ruff check .passes