-
Notifications
You must be signed in to change notification settings - Fork 3
[WIP] cosmos3: add generator model scaffold (+reasoner implementation) #121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
merceod
wants to merge
38
commits into
main
Choose a base branch
from
cosmos3_integration
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
55da2b3
cosmos3: add generator model scaffold
merceod ab517a9
cosmos3: implement the DiT forward and weight loading
merceod 20ab57a
cosmos3: add text-to-image packing and pipeline
merceod 70301f9
cosmos3: run the text tower once and reuse its KV across denoise steps
merceod ad0f2d4
cosmos3: extend generation to video (t2v / i2v)
merceod 2d6287e
cosmos3: robot action generation (dynamics + policy)
merceod e644b59
engine: let submodules opt out of torch.compile
merceod 0cf0bd1
cosmos3: text-to-image over the OpenAI /v1/images endpoint
merceod 419f3ed
Run both guidance branches in one batched forward per denoise step. T…
merceod da39bd1
Batch concurrent image requests through one denoise step. When severa…
merceod 5397169
Capture the image denoise step as a CUDA graph
merceod 57dd722
Add an env switch to disable the cosmos3 denoise CUDA graph
merceod 660ad1c
Run the cosmos3 denoise loop for a per-request number of steps. The i…
merceod 7a5dfb8
Serve text-to-video over /v1/videos/generations. A video generation w…
merceod acc4ad8
Serve image-to-video for cosmos3
merceod adc2b15
Serve action inverse-dynamics over /generate
merceod 8dbb393
Serve action forward-dynamics (predict video) over /generate
merceod c81480d
Batch concurrent image/video denoise steps in serving
merceod 91ee526
Capture image denoise CUDA graphs at the standard generation sizes
merceod 269cd13
Drop late result chunks for already-finished requests
merceod 74996f1
Batch concurrent action requests in one denoise step
merceod 159164f
Compile the image denoise step and fold it into the CUDA graphs
merceod bcc9da3
Add Cosmos3 serving benchmark scripts
merceod 233d1ed
Capture the image denoise step at batched sizes for concurrent requests
merceod f93f9d4
Merge remote-tracking branch 'origin/main' into cosmos3_integration
merceod b08f668
Accept autocast_dtype in Cosmos3 get_submodule
merceod dc04bc4
Right-size the Cosmos3 KV cache pool
merceod fa0d304
Use dense FlashAttention-3 for the Cosmos3 generation attention
merceod c7de256
Decode the Cosmos3 VAE in fp32 and return 8-bit frames
merceod 2f65203
Encode the image-to-video conditioning frame once, in fp32
merceod d4dd70f
Optionally torch.compile the Wan VAE decode
merceod 0a81f3a
Encode served videos at CRF 18
merceod dec2219
Add tensor parallelism to the Cosmos3 DiT
merceod 021e850
Register the Cosmos3 Super variant
merceod dbbfa2d
Align cosmos3 served encoding and prompts with the reference pipeline
merceod d265caf
Keep the cosmos3 timestep embedder in fp32 outside the engine autocast
merceod b62ab0a
Apply the reference guidance interval and flow shift to text-to-image
merceod 6dfdc42
Skip the denoise CUDA-graph for odd-latent-size resolutions
merceod File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| model: "cosmos3" | ||
| # Sequence-length hint for the scheduler. The conductor only asserts its | ||
| # presence; the real per-request capacity is the KV pool below. | ||
| max_seq_len: 8192 | ||
| # KV pool sizing. The default (max_num_pages 2048 x page_size 128) pre-allocates | ||
| # ~38 GB of paged K/V for the 36-layer DiT regardless of the workload, which | ||
| # OOMs larger video on an 80 GB card. A bs=1 720p x 189-frame request needs only | ||
| # ~692 pages across both CFG branches (images take a few dozen), so 1024 pages | ||
| # (~19 GB) cover single-request video at every tier plus image batching and free | ||
| # ~19 GB for activations. | ||
| kv_cache: | ||
| max_num_pages: 1024 | ||
| node_groups: | ||
| - node_names: ["dit"] | ||
| ranks: [0] | ||
| - node_names: ["vae_decoder"] | ||
| ranks: [0] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| model: "cosmos3" | ||
| # Sequence-length hint for the scheduler (see cosmos3_nano.yaml). | ||
| max_seq_len: 8192 | ||
| # Per-rank KV pool. Under tensor parallelism the KV heads shard across ranks, so | ||
| # each rank's pages hold half the heads — 1024 pages leave ample headroom. | ||
| kv_cache: | ||
| max_num_pages: 1024 | ||
| # The DiT runs tensor-parallel across two ranks (attention heads + MLP | ||
| # intermediate shard; the residual stream stays full and the out/down | ||
| # projections all-reduce). The VAE decoder is small and runs un-sharded on | ||
| # rank 0; the DiT's final latents are replicated, so the decoder reads them | ||
| # directly. | ||
| node_groups: | ||
| - node_names: ["dit"] | ||
| ranks: [0, 1] | ||
| tp_size: 2 | ||
| - node_names: ["vae_decoder"] | ||
| ranks: [0] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| model: "cosmos3_super" | ||
| # Sequence-length hint for the scheduler (see cosmos3_nano.yaml). | ||
| max_seq_len: 8192 | ||
| # Per-rank KV pool. Super is 64 layers (vs Nano's 36) but the KV heads (8) shard | ||
| # across the 4 TP ranks, so per-rank KV stays modest; 1024 pages is ample on the | ||
| # 143 GB H200s. | ||
| kv_cache: | ||
| max_num_pages: 1024 | ||
| # Super (64B) is unviable on one GPU (~128 GB in bf16), so the DiT runs | ||
| # tensor-parallel across 4 ranks. The VAE decoder is small and runs un-sharded | ||
| # on rank 0 (the DiT's final latents are replicated, so it reads them directly). | ||
| node_groups: | ||
| - node_names: ["dit"] | ||
| ranks: [0, 1, 2, 3] | ||
| tp_size: 4 | ||
| - node_names: ["vae_decoder"] | ||
| ranks: [0] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| """/v1/videos/generations (text-to-video and image-to-video) handler.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import base64 | ||
|
|
||
| from starlette.concurrency import run_in_threadpool | ||
|
|
||
| from mstar.api_server.openai._util import now, rid | ||
|
|
||
|
|
||
| async def create_videos(api, model_name, adapter, req): # noqa: ARG001 | ||
| args = adapter.video_to_request(req, api.upload_dir) | ||
| request_id = rid("vid") | ||
|
|
||
| api.submit_request( | ||
| text=args.text, | ||
| file_paths=args.file_paths, | ||
| input_modalities=args.input_modalities, | ||
| output_modalities=["video"], | ||
| model_kwargs=args.model_kwargs, | ||
| streaming=False, | ||
| request_id=request_id, | ||
| ) | ||
|
|
||
| chunks = await run_in_threadpool(api.collect_results, request_id) | ||
| # Each video chunk is an mp4 (H.264); return it base64-encoded, mirroring the | ||
| # image endpoint's b64_json shape. | ||
| data = [ | ||
| {"b64_json": base64.b64encode(c.data).decode("ascii"), "url": None} | ||
| for c in chunks | ||
| if c.modality == "video" | ||
| ] | ||
| return {"created": now(), "data": data} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| """Apples-to-apples t2i latency client — hits the OpenAI /v1/images/generations | ||
| endpoint that BOTH our mstar server and vLLM-Omni (`vllm serve --omni`) expose, with | ||
| an identical payload, and reports client-side wall latency (warmup + median of N). | ||
|
|
||
| Same scope on both engines (client-side end-to-end incl. HTTP + b64 PNG), same config | ||
| (tiers, steps, guidance, seed, prompt). Run once per server (different --port/--model). | ||
|
|
||
| python bench_t2i_oai.py --port 8000 --model nvidia/Cosmos3-Nano --tag vllm | ||
| python bench_t2i_oai.py --port 8100 --model cosmos3_nano --tag ours | ||
| """ | ||
| import argparse | ||
| import base64 | ||
| import json | ||
| import statistics | ||
| import time | ||
| import urllib.request | ||
|
|
||
| ap = argparse.ArgumentParser() | ||
| ap.add_argument("--port", type=int, required=True) | ||
| ap.add_argument("--model", default="nvidia/Cosmos3-Nano") | ||
| ap.add_argument("--sizes", default="320x192,832x480,1280x720") # 256p/480p/720p tiers | ||
| ap.add_argument("--steps", type=int, default=50) | ||
| ap.add_argument("--gs", type=float, default=6.0) | ||
| ap.add_argument("--seed", type=int, default=0) | ||
| ap.add_argument("--rounds", type=int, default=5) | ||
| ap.add_argument("--warmup", type=int, default=2) | ||
| ap.add_argument("--tag", default="run") | ||
| ap.add_argument("--save", default="") # optional PNG path prefix | ||
| args = ap.parse_args() | ||
|
|
||
| PROMPT = "A red cube resting on a polished wooden table, soft daylight." | ||
| NEG = "blurry, distorted, low quality" | ||
| URL = f"http://localhost:{args.port}/v1/images/generations" | ||
|
|
||
|
|
||
| def one(size): | ||
| body = json.dumps({ | ||
| "model": args.model, "prompt": PROMPT, "negative_prompt": NEG, | ||
| "size": size, "n": 1, "response_format": "b64_json", | ||
| "num_inference_steps": args.steps, "guidance_scale": args.gs, "seed": args.seed, | ||
| }).encode() | ||
| req = urllib.request.Request(URL, data=body, headers={"Content-Type": "application/json"}) | ||
| t0 = time.perf_counter() | ||
| with urllib.request.urlopen(req, timeout=1200) as r: | ||
| payload = json.load(r) | ||
| dt = time.perf_counter() - t0 | ||
| b64 = payload["data"][0]["b64_json"] | ||
| return dt, b64 | ||
|
|
||
|
|
||
| print(f"=== {args.tag} port={args.port} model={args.model} steps={args.steps} gs={args.gs} seed={args.seed} ===", flush=True) | ||
| for size in args.sizes.split(","): | ||
| try: | ||
| for _ in range(args.warmup): | ||
| one(size) | ||
| ts = [] | ||
| last_b64 = None | ||
| for _ in range(args.rounds): | ||
| dt, last_b64 = one(size) | ||
| ts.append(dt) | ||
| ts.sort() | ||
| med = statistics.median(ts) | ||
| print(f" {size:9s} median {med:.3f}s min {ts[0]:.3f} max {ts[-1]:.3f} (n={args.rounds})", flush=True) | ||
| if args.save and last_b64: | ||
| with open(f"{args.save}_{size}.png", "wb") as f: | ||
| f.write(base64.b64decode(last_b64)) | ||
| except Exception as e: # noqa: BLE001 | ||
| print(f" {size:9s} ERROR {type(e).__name__}: {str(e)[:120]}", flush=True) | ||
| print("DONE", flush=True) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this should be happening---I think it's correct to have the guard here (I'd change the log to warning instead of debug though), but I'd check to see if this is due to, e.g., a VAE decoder accidentally being double-triggered?