Skip to content

[WIP] cosmos3: add generator model scaffold (+reasoner implementation)#121

Open
merceod wants to merge 32 commits into
mainfrom
cosmos3_integration
Open

[WIP] cosmos3: add generator model scaffold (+reasoner implementation)#121
merceod wants to merge 32 commits into
mainfrom
cosmos3_integration

Conversation

@merceod

@merceod merceod commented Jun 14, 2026

Copy link
Copy Markdown
Collaborator

Config, weight loader, dual-pathway DiT parameter structure, Cosmos3Model with prefill/image_gen graph walks, submodule stubs, registry entry, and Nano serving config.

What does this PR do?

This PR integrates cosmos3 series of models (cosmos3-nano, DROID-nano, and super 64B) toM*.

How was it tested?

So farr, we have two load-bearing CPU gates passing. (1) config_roundtrip: Cosmos3Config.from_pretrained reproduces every field (dims, mRoPE [24,20,20]/θ5e6/margin 15000, timestep_scale 0.001, two_way, use_moe, action/sound flags, ...). (2) loader_key_coverage: the backbone's state_dict keys exactly cover the checkpoint: 814 index keys − lm_head.weight = 813 = backbone params, drop-list is exactly {lm_head.weight}, zero missing, zero unexpected.

More H200 GPU tests are underway as integration progresses.

Checklist

  • ruff check . passes
  • Added or updated tests / docs where relevant

Config, weight loader, dual-pathway DiT parameter structure, Cosmos3Model with prefill/image_gen graph walks, submodule stubs, registry entry, and Nano serving config.
@merceod merceod changed the title [WIP] cosmos3: add generator model scaffold [WIP] cosmos3: add generator model scaffold (+reasoner implementation) Jun 14, 2026
merceod added 2 commits June 14, 2026 04:31
Dual-pathway MoT attention (QK-norm, 3D interleaved mRoPE, GQA),
patchify/unpatchify, timestep embedding, and the per-domain action heads.
Load the transformer in bf16 from the diffusers shard index, raising on any
unfilled parameter, with the timestep MLP kept in fp32. Add config/loader/
shape structural tests.
mRoPE position ids, chat-template tokenization, and a single-image t2i
pipeline over the DiT, the UniPC scheduler, and the Wan VAE, with CPU unit
tests and a GPU integration test.
@merceod merceod requested a review from NSagan271 June 14, 2026 04:53
merceod added 2 commits June 14, 2026 06:01
The text-conditioning tower's K/V doesn't depend on the denoise timestep, so it
only needs to run once. Wire the DiT submodule to prefill it into the paged
cache, then run only the generation tower each step, re-reading that frozen K/V
(conditional and unconditional prompts kept in separate cache labels for
guidance). Adds prefill/denoise entry points on the transformer and a GPU test
vs the fused text-to-image pipeline: bit-exact with an in-process sdpa cache,
~37 dB image PSNR through the FlashInfer paged cache.
Generalize the fused pipeline and packing from single-frame images to
multi-frame video. tokenize_prompt now emits the video system prompt plus
the duration and video-resolution sentences; build_static_inputs takes a
has_image_condition flag so image-to-video anchors a clean frame 0 while the
rest denoise. The pipeline encodes the conditioning frame through the Wan VAE
and blends it with noise, matching the diffusers Cosmos3OmniPipeline.

The transformer needs no changes: its forward, denoise_step and prefill_und
were already shape-general, and the generation attention is non-causal in
every mode (video conditioning rides on the noisy-frame indices, not a mask).
The engine submodule just threads num_frames and the conditioning flag.

Rename t2i_pipeline.py -> pipeline.py (Cosmos3Pipeline) now that it covers
all modes. Output is bit-for-bit identical to diffusers on t2v and i2v, and
the run-text-tower-once cache path stays exact across frames.

@NSagan271 NSagan271 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments, I'll hold off until the implementation is done before making more

return self.linear_2(self.act(self.linear_1(sample)))


class Cosmos3MLP(nn.Module):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(future note): will eventually want to use the ParallelGatedMLP (and eventually use parallel fused qkv for the attention if possible---it looks like it should be possible)

)


class Cosmos3Pipeline:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this class just used for testing (I don't see it being used outside of the tests)? If so, can it be moved to the tests/ folder?

out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal, enable_gqa=True)
return out.transpose(1, 2).squeeze(0).flatten(-2, -1)

def forward(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to verify: is this forward and its downstream callers unused in production (and everything instead delegated to forward_und / forward_gen across multiple graph walks)? If so, can you make a comment to that extent? Otherwise, we should think about optimizing this branch in the future

merceod added 13 commits June 14, 2026 08:31
Extend the generator to joint video+action: domain-aware action
projections plus the action mRoPE band, the forward/inverse-dynamics
and policy conditioning layouts, joint packing, and the cache-once
engine walk. Inverse-dynamics on the av example reproduces the
reference action output (MSE 5e-5); the fused and engine paths agree
exactly.
A submodule with a data-dependent or one-shot forward can set
disable_torch_compile; the kv-cache and stateless engines then skip compiling it.
CUDA graph capture is unaffected.
Tokenize the prompt into conditional + unconditional ids (chat template +
resolution sentence), thread the request's size/guidance/seed into the denoise
step metadata, and return the decoded frame as PNG. The DiT and VAE nodes skip
torch.compile.
…hey share the same noised tokens and differ only in text conditioning and rotary positions, so they pack into a single FlashInfer batch; a flag falls back to the sequential two-forward path.
…l requests are generating at once their guidance branches pack into a single FlashInfer plan and forward, so the per-step matmuls and attention run once for the whole batch instead of once per request. Each request keeps its own latents, timestep, positions and scheduler and stays isolated from the others.
Both guidance branches run in one captured forward per denoise step and the multistep scheduler step stays eager afterwards, which roughly halves text-to-image latency. A submodule can name a velocity-only method to capture, run the non-capturable tail in a post-replay hook, and opt out of the post-replay seq_len advance so frozen-prefix denoise loops keep re-reading the same prefix; the combined classifier-free-guidance plan now reuses a persistent FlashInfer wrapper under capture.
COSMOS3_DISABLE_CUDA_GRAPH=1 makes get_cuda_graph_configs return nothing, so the denoise loop runs eagerly. Handy as an escape hatch if graph capture misbehaves on a given driver, and to A/B the captured vs eager path on the same build.
…mage and video generation loops now stop at each request's own step count via a check_stop on the DiT node, rather than a single count fixed when the graph is built, so one graph serves image (50 steps), video (35) and any requested count up to an upper bound. The lone extra step the loop dispatches before that stop takes effect is a no-op, so it can't run the scheduler past a request's last timestep.
…alk reuses the image denoise loop and VAE decode but emits an encoded mp4 instead of a single frame; a Cosmos3 video adapter, request type, and route wire it up, and the per-request frame count and step count default by mode (image vs video) and stop the loop at each request's own length. Image-to-video is recognized but rejected for now, since its conditioning frame still needs to be VAE-encoded on the worker.
Route the conditioning frame to the worker for /v1/videos image-to-video. A new
prefill_cond walk hands the DiT node the input image, which it VAE-encodes
(reusing the decoder's VAE) into the clean latent frame-0 anchor that seeds the
denoise loop, matching the fused pipeline's i2v latent prep; the anchor stays
fixed through the loop since its predicted velocity is zero. The video adapter
now resolves the request image and passes it in as an image input instead of
rejecting image-to-video.
Wire the action path end to end for HTTP. The conditioning video (or image) is
VAE-encoded on the worker via a conditioned prefill walk, and the first
denoise iteration builds the joint video+action latents from the per-mode masks
(conditioning frames/action clean, the rest noise) instead of expecting them
pre-supplied. The action loop now emits the predicted action by reusing the
looped action_latents edge name — a loop's terminal output is matched into the
section by name, so the previous standalone name produced no output. Action
prompts are chat-templated without the image/video system prompt or
resolution/duration sentences, matching the references, and load_video reads the
decode device from its argument like load_image. Served inverse-dynamics on the
av_0 clip reproduces the reference action within tolerance.
Forward-dynamics conditions on a first frame plus a clean action chunk and
predicts the resulting video, so it runs the joint video+action denoise but
emits a decoded video instead of the action. Add an action_video_gen walk whose
loop body is the same joint denoise (selected when action_mode is
forward_dynamics) and whose terminal output routes the predicted video latents
to the VAE decoder; the decoder emits video for it. Served single-chunk
forward-dynamics on the AgiBotWorld example matches the reference first chunk.
Concurrent diffusion requests were serialized: the engine capped each node's max
batch size to its largest captured CUDA-graph batch size, and the cosmos3 denoise
step is captured only at batch size 1 (for single-request latency), so the cap
forced one request per forward. Add an opt-in CudaGraphConfig.caps_eager_batch_size
— when False the captured sizes are an acceleration subset, not a batch ceiling:
the engine honors the submodule's max_batch_size and replays a graph only when the
batch size was captured, otherwise it runs the eager batched forward. cosmos3 sets
it on the image generation capture, disables speculative scheduling on the denoise
loops so concurrent requests group into one batched step (like the BAGEL image
loop), and extends can_batch/forward_batched to the video walk. Throughput now
scales with concurrency instead of staying flat.

@NSagan271 NSagan271 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass: mixture of nitpicks, notes for future PRs / cleaning up the code, and conceptual questions.

# keep a non-capturable step (e.g. a multistep scheduler) out of the
# graph. Runs with REAL request ids, the original ``inputs``, and the
# cloned captured outputs, so it can finish each request's step.
if hasattr(submodule, "postprocess_captured"):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can just add an empty postprocess_captured to the submodule base class with an appropriate doc strong, so that model authors can easily use this paradigm.

und_seq = layer.forward_und(und_seq, cos, sin, cache_handle)
cache_handle.advance_seq_lens()

def denoise_step(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this also have inputs for sound? (and same for the batched CFG version)

return x.detach().to(torch.float32).cpu().numpy().tobytes()
raise ValueError(f"Unsupported modality for Cosmos3: {modality!r}")

def load_video(self, filepath: str, device: str):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The accessing self.device in the base implementation is almost definitely a typo; it should be accessing the passed in device argument to be consistent with the other loader functions. So, I would fix the typo in the base implementation instead of just overriding it here.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I would move some of the functions here (the ones that are only used in submodules) to components/

return cond, uncond


def tokenize_t2i_prompt(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function appears to be unused? (I might be wrong)

PREFILL_COND_WALK = "prefill_cond"
# Action inverse-dynamics conditions on a full video rather than a single frame,
# so it gets its own conditioned prefill that takes the video among its inputs.
PREFILL_COND_VIDEO_WALK = "prefill_cond_video"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: these strings are reused between cosmos3_model.py and submodules.py, considering move these to a constants.py and importing them in the two locations.

static["mse_gen_indexes"] = static["vision_mse_loss_indexes"] - static["und_len"]
return static

def _new_scheduler(self, num_inference_steps: int, device, flow_shift=None):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but if this per-request schedule logic is something that will be reused across several dffusion models, we might want to create an engine-level abstraction.

act = act[:action_chunk]
clean_action[:, :, :raw_action_dim] = act[:, :raw_action_dim]

self._req[fwd_info.request_id] = {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(for a future PR) There is a lot of stuff going into per-request submodule state that I don't really like from a design perspective. I think this may be the cleanest and most performant way to do this right now (having the masks and latents be persistent signals would increase SHM traffic, and the putting all the metadata in the step_metadata would increase the size of ZMQ packets), but it, e.g., would disable PD disaggregation for such a node. Maybe it would feel cleaner to have some engine-level support for per-request state.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think that (probably in a next PR), this state can be refactored to be lighter. For instance, there are some parameters that are already present in the step metdata; the actual input_ids for cond and uncond look like they can pass through the regular ARNodeInputs tensor inputs (though, e.g., the mrope IDs do need to be persistent state), some masks are just inverses of each other, etc.

Comment thread mstar/model/cosmos3/submodules.py Outdated
)
return ARNodeInputs(input_seq_len=cond["und_len"])

def _encode_conditioning(self, image, height, width, num_frames, device):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to have this VAE encoder be its own graph node (and its internals subject to torch.compile, separate scheduling, etc.)?

def _forward_image_gen(self, cm, st, latents, time_index, **kwargs) -> dict:
scheduler = st["scheduler"]
step_index = int(time_index.reshape(-1)[0].item())
if step_index >= len(scheduler.timesteps):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the cleanest way to avoid going on step over is to have prepare_inputs return None, and then port over the skipped_rids logic from the StatelessEngine to KVCacheEngine. I believe this forward-pass-internal if statement is not compatible with batching and cuda graphs.

@NSagan271

Copy link
Copy Markdown
Collaborator

Something to consider (I haven't traced through the code thoroughly enough to see if this is possible or even helpful, but I think it fits better with our paradigms) is to add a node for the VAE processing of the engine inputs that runs in parallel (or sequential if needed) with the DiT prefill, and have the cond_latents / cond_video_latents enter into the generation loop as graph edges via the persist signal path instead of through the submodule's persistent state. For requests that don't have video / image inputs, the conductor can set the tensor info as an empty list.

My sense from benchmarking BAGEL I2T is that having a separate node for the VAE (if the compute is significant) could lead to latency / throughput improvements, but I'm of course not sure if that insight generalizes to Cosmos.

merceod added 3 commits June 15, 2026 01:42
The denoise-step CUDA graph was only captured for one square resolution, so
requests at the usual generation sizes ran the eager path. Capture a graph per
generation tier (320x192 / 832x480 / 1280x720), overridable with
COSMOS3_GEN_CAPTURE_RES. The graph runner now keeps a capture per resolution for
a walk and dispatches each request to the graph matching its own shape rather
than the first one declared, so several fixed-shape captures coexist. Served
graph output is identical to the eager path; the win is largest where the step
is launch-bound (~2.5x at 320x192) and tapers as it grows compute-bound.
When several requests finish in the same step, one request can be cleaned
up (its result already returned) while a later chunk for it is still in the
output queue. Decrementing the per-request counter then raised a KeyError
that aborted the whole drain and dropped the other requests' chunks too.
Guard the lookup the same way new_result_tensors already does and skip the
late chunk.
Concurrent action requests at the same generation walk now share a single
joint video+action denoise forward, the way image and video requests already
do. A new batched denoise packs each request's [video | action] tokens -- one
branch when guidance is off (the guidance-scale-1 inverse/forward-dynamics and
base policy case), both branches with classifier-free guidance -- and the
batched attention plan routes each request to its own cache pages. The
per-request masks, the joint scheduler step (now factored into a shared
helper), and the domain-aware action projection run per request, so one batch
can mix modes and embodiments. Adds CPU and GPU tests that check the batched
output reproduces the per-request path and stays isolated across requests.
while not self.output_queue.empty():
result: ResultChunk = self.output_queue.get()
# A request can be cleaned up (its result already returned) while a
# late chunk is still in the queue -- common when several requests

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this should be happening---I think it's correct to have the guard here (I'd change the log to warning instead of debug though), but I'd check to see if this is due to, e.g., a VAE decoder accidentally being double-triggered?

merceod added 3 commits June 15, 2026 07:03
torch.compile the inner denoise compute with fullgraph=False so the FlashInfer
attention stays an opaque break and only the bandwidth-bound pointwise ops fuse;
the compiled kernels then bake into the per-resolution image graphs at capture, so
graphs and compile stack. t2i bs=1 over HTTP drops to ~0.92/1.84/3.64s at
256/480/720p (~1.2-1.25x over graphs alone) with no image or action-golden quality
change vs the fused reference (480p/50-step engine PSNR 39.3 either way). On by
default; COSMOS3_DISABLE_COMPILE_DENOISE=1 falls back to the eager step. The
engine-cache and action suites pin the eager step for their bit-exact mechanism
checks.
merceod added 8 commits June 16, 2026 05:40
The base get_submodule signature gained an autocast_dtype hint; thread it
through Cosmos3 for parity. Cosmos3 already casts the meta module to bf16
before to_empty, so params land in the checkpoint dtype and the hint is a
no-op here, but the engine manager now passes it by keyword.
The default pool (max_num_pages 2048 x page_size 128) pre-allocates ~38 GB
of paged K/V for the 36-layer DiT regardless of the request, which OOMs
larger video on an 80 GB card. One bs=1 720p x 189-frame request needs only
~692 pages across both CFG branches, so 1024 pages cover single-request
video at every tier plus image batching and free ~19 GB for activations.
The diffusion generation tower recomputes all of its K/V every denoise step
and only reuses the small frozen text prefix, so the paged-cache write the
autoregressive path needs is wasted work here. With COSMOS3_DENSE_FA3 set,
gather the prefix K/V and run one varlen FlashAttention-3 pass over
[prefix | generation] per guidance branch, bypassing the paged write+read;
falls back to the paged path otherwise and under CUDA-graph capture. Adds a
dense-vs-paged PSNR parity check.
The serving engine cast the Wan VAE to bf16, but its 3D convolutions are
several times slower in bf16 than fp32 on this cuDNN and the reference
pipeline decodes in fp32; restore fp32 for the decode. Also quantize to uint8
in the decoder so only 8-bit frames cross the worker boundary instead of a
4x-larger fp32 tensor.
The conditioning encode repeat-padded the frame across the whole clip and
VAE-encoded all of it, but only latent frame 0 is ever used. Encode just that
frame (the Wan VAE produces it as a standalone anchor, bit-identical) and run
the encode in fp32 outside autocast like the decoder, which is much faster on
this cuDNN.
Gated by COSMOS3_COMPILE_VAE (default off). The Wan VAE decode is 3D-conv bound and runs once per request at request-specific shapes, so it isn't CUDA-graphed; torch.compile fuses the pointwise epilogues around the convs. Keeps the fp32, autocast-off decode. Trims a few percent off video latency (the decode is a larger slice there) and narrows the higher-resolution image gap; the first request per resolution pays a one-time trace. Adds a PSNR A/B test against the eager decode for both image and video.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants