diff --git a/configs/cosmos3_nano.yaml b/configs/cosmos3_nano.yaml new file mode 100644 index 00000000..27e000e8 --- /dev/null +++ b/configs/cosmos3_nano.yaml @@ -0,0 +1,17 @@ +model: "cosmos3" +# Sequence-length hint for the scheduler. The conductor only asserts its +# presence; the real per-request capacity is the KV pool below. +max_seq_len: 8192 +# KV pool sizing. The default (max_num_pages 2048 x page_size 128) pre-allocates +# ~38 GB of paged K/V for the 36-layer DiT regardless of the workload, which +# OOMs larger video on an 80 GB card. A bs=1 720p x 189-frame request needs only +# ~692 pages across both CFG branches (images take a few dozen), so 1024 pages +# (~19 GB) cover single-request video at every tier plus image batching and free +# ~19 GB for activations. +kv_cache: + max_num_pages: 1024 +node_groups: + - node_names: ["dit"] + ranks: [0] + - node_names: ["vae_decoder"] + ranks: [0] diff --git a/configs/cosmos3_nano_tp2.yaml b/configs/cosmos3_nano_tp2.yaml new file mode 100644 index 00000000..72757b76 --- /dev/null +++ b/configs/cosmos3_nano_tp2.yaml @@ -0,0 +1,18 @@ +model: "cosmos3" +# Sequence-length hint for the scheduler (see cosmos3_nano.yaml). +max_seq_len: 8192 +# Per-rank KV pool. Under tensor parallelism the KV heads shard across ranks, so +# each rank's pages hold half the heads — 1024 pages leave ample headroom. +kv_cache: + max_num_pages: 1024 +# The DiT runs tensor-parallel across two ranks (attention heads + MLP +# intermediate shard; the residual stream stays full and the out/down +# projections all-reduce). The VAE decoder is small and runs un-sharded on +# rank 0; the DiT's final latents are replicated, so the decoder reads them +# directly. +node_groups: + - node_names: ["dit"] + ranks: [0, 1] + tp_size: 2 + - node_names: ["vae_decoder"] + ranks: [0] diff --git a/configs/cosmos3_super_tp4.yaml b/configs/cosmos3_super_tp4.yaml new file mode 100644 index 00000000..b8e3c4f6 --- /dev/null +++ b/configs/cosmos3_super_tp4.yaml @@ -0,0 +1,17 @@ +model: "cosmos3_super" +# Sequence-length hint for the scheduler (see cosmos3_nano.yaml). +max_seq_len: 8192 +# Per-rank KV pool. Super is 64 layers (vs Nano's 36) but the KV heads (8) shard +# across the 4 TP ranks, so per-rank KV stays modest; 1024 pages is ample on the +# 143 GB H200s. +kv_cache: + max_num_pages: 1024 +# Super (64B) is unviable on one GPU (~128 GB in bf16), so the DiT runs +# tensor-parallel across 4 ranks. The VAE decoder is small and runs un-sharded +# on rank 0 (the DiT's final latents are replicated, so it reads them directly). +node_groups: + - node_names: ["dit"] + ranks: [0, 1, 2, 3] + tp_size: 4 + - node_names: ["vae_decoder"] + ranks: [0] diff --git a/mstar/api_server/data_worker.py b/mstar/api_server/data_worker.py index c7dabb7b..d15bbf36 100644 --- a/mstar/api_server/data_worker.py +++ b/mstar/api_server/data_worker.py @@ -113,6 +113,17 @@ def get_result_chunks(self)-> list[ResultChunk]: results = [] while not self.output_queue.empty(): result: ResultChunk = self.output_queue.get() + # A request can be cleaned up (its result already returned) while a + # late chunk is still in the queue -- common when several requests + # complete in the same step. Mirror new_result_tensors' guard and + # drop the straggler rather than KeyError, which would otherwise + # abort the whole drain and lose the other requests' chunks. + if result.request_id not in self.per_request_reading_tensors: + logger.debug( + "Late result chunk for cleaned-up request %s, ignoring", + result.request_id, + ) + continue self.per_request_reading_tensors[result.request_id] -= 1 logger.debug( "Data worker reading queue for request %s decreased to length %d", diff --git a/mstar/api_server/openai/adapters.py b/mstar/api_server/openai/adapters.py index 12ea03c5..14bb6fce 100644 --- a/mstar/api_server/openai/adapters.py +++ b/mstar/api_server/openai/adapters.py @@ -35,6 +35,7 @@ ChatCompletionRequest, ImageGenerationRequest, SpeechRequest, + VideoGenerationRequest, ) @@ -164,6 +165,7 @@ class OpenAIAdapter: supports_chat: bool = False # POST /v1/chat/completions supports_speech: bool = False # POST /v1/audio/speech supports_images: bool = False # POST /v1/images/generations and /v1/images/edits + supports_videos: bool = False # POST /v1/videos/generations def chat_to_request(self, req: ChatCompletionRequest, upload_dir: Path) -> SubmitArgs: # noqa: ARG002 # Output modalities vary by model: e.g. Qwen3-Omni speech output also @@ -176,6 +178,9 @@ def speech_to_request(self, req: SpeechRequest, upload_dir: Path) -> SubmitArgs: def image_to_request(self, req: ImageGenerationRequest, upload_dir: Path) -> SubmitArgs: # noqa: ARG002 raise NotImplementedError("image generation is not supported by this model") + def video_to_request(self, req: VideoGenerationRequest, upload_dir: Path) -> SubmitArgs: # noqa: ARG002 + raise NotImplementedError("video generation is not supported by this model") + def image_edit_to_request(self, prompt: str, image_path: str, extra_kwargs: dict) -> SubmitArgs: # noqa: ARG002 raise NotImplementedError("image editing is not supported by this model") @@ -297,12 +302,69 @@ def speech_to_request(self, req: SpeechRequest, upload_dir: Path) -> SubmitArgs: ) +class Cosmos3Adapter(OpenAIAdapter): + """NVIDIA Cosmos3: text-to-image and text/image-to-video generation. + + ``size`` ("WxH") maps to the generation resolution; ``seed`` and any + extra knobs (``guidance_scale``, ``num_inference_steps``, ``negative_prompt``, + and for video ``num_frames`` / ``fps``) pass through via ``extra_body``. + """ + + supports_images = True + supports_videos = True + + def image_to_request(self, req: ImageGenerationRequest, upload_dir: Path) -> SubmitArgs: # noqa: ARG002 + mk = _passthrough(req) + if getattr(req, "size", None): + mk.setdefault("size", req.size) + if getattr(req, "seed", None) is not None: + mk.setdefault("seed", req.seed) + return SubmitArgs( + text=req.prompt, + input_modalities=["text"], + output_modalities=["image"], + model_kwargs=mk, + ) + + def video_to_request(self, req: VideoGenerationRequest, upload_dir: Path) -> SubmitArgs: + mk = _passthrough(req) + if getattr(req, "size", None): + mk.setdefault("size", req.size) + if getattr(req, "seed", None) is not None: + mk.setdefault("seed", req.seed) + # num_frames / fps are first-class video fields (not in extra_body). + if getattr(req, "num_frames", None) is not None: + mk.setdefault("num_frames", req.num_frames) + if getattr(req, "fps", None) is not None: + mk.setdefault("fps", req.fps) + # Image-to-video: the conditioning frame (URL / data URI) is persisted and + # loaded by the worker, which VAE-encodes it into the clean frame-0 anchor. + image = getattr(req, "image", None) + if image: + _, path = media_io.resolve_media_ref(image, upload_dir) + return SubmitArgs( + text=req.prompt, + file_paths={"image": [path]}, + input_modalities=["image", "text"], + output_modalities=["video"], + model_kwargs=mk, + ) + return SubmitArgs( + text=req.prompt, + input_modalities=["text"], + output_modalities=["video"], + model_kwargs=mk, + ) + + # Only models with an OpenAI-standard surface are registered. Action/world-model # models (pi05, vjepa2) are deliberately absent → /v1/* 404s; use /generate. ADAPTER_REGISTRY: dict[str, OpenAIAdapter] = { "bagel": BagelAdapter(), "qwen3_omni": Qwen3OmniAdapter(), "orpheus": OrpheusAdapter(), + "cosmos3": Cosmos3Adapter(), + "cosmos3_super": Cosmos3Adapter(), } diff --git a/mstar/api_server/openai/protocol.py b/mstar/api_server/openai/protocol.py index a4d38d91..ce94a2b1 100644 --- a/mstar/api_server/openai/protocol.py +++ b/mstar/api_server/openai/protocol.py @@ -71,6 +71,28 @@ class ImageGenerationRequest(BaseModel): seed: int | None = None +class VideoGenerationRequest(BaseModel): + """``/v1/videos/generations`` (text-to-video / image-to-video). + + Not an OpenAI-standard surface; modeled on the image endpoint. ``image`` (a + URL or data URI) conditions image-to-video. Extra knobs + (``guidance_scale``, ``num_inference_steps``, ``negative_prompt`` …) flow + through via ``extra_body``. + """ + + model_config = _CFG + + prompt: str + model: str | None = None + n: int | None = 1 + size: str | None = None + response_format: str = "b64_json" + seed: int | None = None + num_frames: int | None = None + fps: float | None = None + image: str | None = None # URL or data URI for image-to-video conditioning + + class ModelCard(BaseModel): id: str object: str = "model" diff --git a/mstar/api_server/openai/router.py b/mstar/api_server/openai/router.py index dbaa1376..d1b1c080 100644 --- a/mstar/api_server/openai/router.py +++ b/mstar/api_server/openai/router.py @@ -12,7 +12,12 @@ from fastapi import APIRouter, Request from fastapi.responses import JSONResponse, StreamingResponse -from mstar.api_server.openai import serving_chat, serving_images, serving_speech +from mstar.api_server.openai import ( + serving_chat, + serving_images, + serving_speech, + serving_videos, +) from mstar.api_server.openai._util import now from mstar.api_server.openai.adapters import get_adapter from mstar.api_server.openai.protocol import ( @@ -21,6 +26,7 @@ ModelCard, ModelList, SpeechRequest, + VideoGenerationRequest, ) router = APIRouter() @@ -113,6 +119,18 @@ async def images_generations(request: ImageGenerationRequest): return JSONResponse(result) +@router.post("/v1/videos/generations") +async def videos_generations(request: VideoGenerationRequest): + api, model_name, adapter, err = _resolve("supports_videos") + if err is not None: + return err + try: + result = await serving_videos.create_videos(api, model_name, adapter, request) + except Exception as e: # noqa: BLE001 + return _error(getattr(e, "status_code", 500), str(getattr(e, "detail", e)), "server_error") + return JSONResponse(result) + + @router.post("/v1/images/edits") async def images_edits(request: Request): # Multipart (image file + prompt + passthrough fields), parsed manually so diff --git a/mstar/api_server/openai/serving_videos.py b/mstar/api_server/openai/serving_videos.py new file mode 100644 index 00000000..a1d58108 --- /dev/null +++ b/mstar/api_server/openai/serving_videos.py @@ -0,0 +1,34 @@ +"""/v1/videos/generations (text-to-video and image-to-video) handler.""" + +from __future__ import annotations + +import base64 + +from starlette.concurrency import run_in_threadpool + +from mstar.api_server.openai._util import now, rid + + +async def create_videos(api, model_name, adapter, req): # noqa: ARG001 + args = adapter.video_to_request(req, api.upload_dir) + request_id = rid("vid") + + api.submit_request( + text=args.text, + file_paths=args.file_paths, + input_modalities=args.input_modalities, + output_modalities=["video"], + model_kwargs=args.model_kwargs, + streaming=False, + request_id=request_id, + ) + + chunks = await run_in_threadpool(api.collect_results, request_id) + # Each video chunk is an mp4 (H.264); return it base64-encoded, mirroring the + # image endpoint's b64_json shape. + data = [ + {"b64_json": base64.b64encode(c.data).decode("ascii"), "url": None} + for c in chunks + if c.modality == "video" + ] + return {"created": now(), "data": data} diff --git a/mstar/benchmark/cosmos3/bench_t2i_oai.py b/mstar/benchmark/cosmos3/bench_t2i_oai.py new file mode 100644 index 00000000..7bf76da9 --- /dev/null +++ b/mstar/benchmark/cosmos3/bench_t2i_oai.py @@ -0,0 +1,69 @@ +"""Apples-to-apples t2i latency client — hits the OpenAI /v1/images/generations +endpoint that BOTH our mstar server and vLLM-Omni (`vllm serve --omni`) expose, with +an identical payload, and reports client-side wall latency (warmup + median of N). + +Same scope on both engines (client-side end-to-end incl. HTTP + b64 PNG), same config +(tiers, steps, guidance, seed, prompt). Run once per server (different --port/--model). + + python bench_t2i_oai.py --port 8000 --model nvidia/Cosmos3-Nano --tag vllm + python bench_t2i_oai.py --port 8100 --model cosmos3_nano --tag ours +""" +import argparse +import base64 +import json +import statistics +import time +import urllib.request + +ap = argparse.ArgumentParser() +ap.add_argument("--port", type=int, required=True) +ap.add_argument("--model", default="nvidia/Cosmos3-Nano") +ap.add_argument("--sizes", default="320x192,832x480,1280x720") # 256p/480p/720p tiers +ap.add_argument("--steps", type=int, default=50) +ap.add_argument("--gs", type=float, default=6.0) +ap.add_argument("--seed", type=int, default=0) +ap.add_argument("--rounds", type=int, default=5) +ap.add_argument("--warmup", type=int, default=2) +ap.add_argument("--tag", default="run") +ap.add_argument("--save", default="") # optional PNG path prefix +args = ap.parse_args() + +PROMPT = "A red cube resting on a polished wooden table, soft daylight." +NEG = "blurry, distorted, low quality" +URL = f"http://localhost:{args.port}/v1/images/generations" + + +def one(size): + body = json.dumps({ + "model": args.model, "prompt": PROMPT, "negative_prompt": NEG, + "size": size, "n": 1, "response_format": "b64_json", + "num_inference_steps": args.steps, "guidance_scale": args.gs, "seed": args.seed, + }).encode() + req = urllib.request.Request(URL, data=body, headers={"Content-Type": "application/json"}) + t0 = time.perf_counter() + with urllib.request.urlopen(req, timeout=1200) as r: + payload = json.load(r) + dt = time.perf_counter() - t0 + b64 = payload["data"][0]["b64_json"] + return dt, b64 + + +print(f"=== {args.tag} port={args.port} model={args.model} steps={args.steps} gs={args.gs} seed={args.seed} ===", flush=True) +for size in args.sizes.split(","): + try: + for _ in range(args.warmup): + one(size) + ts = [] + last_b64 = None + for _ in range(args.rounds): + dt, last_b64 = one(size) + ts.append(dt) + ts.sort() + med = statistics.median(ts) + print(f" {size:9s} median {med:.3f}s min {ts[0]:.3f} max {ts[-1]:.3f} (n={args.rounds})", flush=True) + if args.save and last_b64: + with open(f"{args.save}_{size}.png", "wb") as f: + f.write(base64.b64decode(last_b64)) + except Exception as e: # noqa: BLE001 + print(f" {size:9s} ERROR {type(e).__name__}: {str(e)[:120]}", flush=True) +print("DONE", flush=True) diff --git a/mstar/benchmark/cosmos3/bench_throughput.py b/mstar/benchmark/cosmos3/bench_throughput.py new file mode 100644 index 00000000..0f78e8ff --- /dev/null +++ b/mstar/benchmark/cosmos3/bench_throughput.py @@ -0,0 +1,105 @@ +"""Throughput under load — same-machine concurrency sweep, M* vs vLLM-Omni. + +Both engines expose OpenAI /v1/images/generations; we fire a closed-loop of `bs` +concurrent requests (ThreadPoolExecutor, exactly bs in flight) for bs*rounds total +and report sustained req/s + p50/p95/mean latency. This measures how each engine +handles concurrency: M* batches concurrent requests across its worker, while +vLLM-Omni runs one request at a time at default settings, so its req/s is flat in bs. + + python bench_throughput.py --port 8100 --model cosmos3_nano --tag ours + python bench_throughput.py --port 8000 --model nvidia/Cosmos3-Nano --tag vllm +""" +import argparse +import base64 +import json +import statistics +import time +import urllib.request +from concurrent.futures import ThreadPoolExecutor + +ap = argparse.ArgumentParser() +ap.add_argument("--port", type=int, required=True) +ap.add_argument("--model", default="nvidia/Cosmos3-Nano") +ap.add_argument("--sizes", default="320x192,832x480") # 256p, 480p (720p too slow for a sweep) +ap.add_argument("--bs", default="1,4,8") +ap.add_argument("--steps", type=int, default=50) +ap.add_argument("--gs", type=float, default=6.0) +ap.add_argument("--rounds", type=int, default=5) # measured requests per worker +ap.add_argument("--warmup", type=int, default=2) +ap.add_argument("--tag", default="run") +ap.add_argument("--out", default="") +args = ap.parse_args() + +PROMPT = "A red cube resting on a polished wooden table, soft daylight." +NEG = "blurry, distorted, low quality" +URL = f"http://127.0.0.1:{args.port}/v1/images/generations" + + +def one(size, seed): + body = json.dumps({ + "model": args.model, "prompt": PROMPT, "negative_prompt": NEG, + "size": size, "n": 1, "response_format": "b64_json", + "num_inference_steps": args.steps, "guidance_scale": args.gs, "seed": seed, + }).encode() + req = urllib.request.Request(URL, data=body, headers={"Content-Type": "application/json"}) + t0 = time.perf_counter() + try: + with urllib.request.urlopen(req, timeout=1800) as r: + payload = json.load(r) + dt = time.perf_counter() - t0 + nbytes = len(base64.b64decode(payload["data"][0]["b64_json"])) + return dt, True, nbytes, "" + except Exception as e: # noqa: BLE001 + return time.perf_counter() - t0, False, 0, f"{type(e).__name__}:{str(e)[:90]}" + + +def pct(lats, q): + if not lats: + return float("nan") + s = sorted(lats) + k = (len(s) - 1) * q / 100.0 + lo, hi = int(k), min(int(k) + 1, len(s) - 1) + return s[lo] + (s[hi] - s[lo]) * (k - lo) + + +def run_cell(size, bs): + # warm the server / graph at this size+concurrency (results discarded) + with ThreadPoolExecutor(max_workers=bs) as ex: + list(ex.map(lambda i: one(size, 900000 + i), range(max(args.warmup, bs)))) + n = bs * args.rounds + t0 = time.perf_counter() + with ThreadPoolExecutor(max_workers=bs) as ex: + res = list(ex.map(lambda i: one(size, i), range(n))) + makespan = time.perf_counter() - t0 + oks = [r for r in res if r[1]] + lats = [r[0] for r in oks] + err = next((r[3] for r in res if not r[1]), "") + return { + "size": size, "bs": bs, "n": n, "ok": len(oks), "makespan": makespan, + "thrpt": len(oks) / makespan if makespan > 0 else float("nan"), + "p50": pct(lats, 50), "p95": pct(lats, 95), + "mean": statistics.fmean(lats) if lats else float("nan"), "err": err, + } + + +print(f"=== {args.tag} port={args.port} model={args.model} steps={args.steps} gs={args.gs} ===", flush=True) +cells = [] +for size in args.sizes.split(","): + base_thrpt = None + for bs in [int(x) for x in args.bs.split(",")]: + c = run_cell(size, bs) + cells.append(c) + if bs == 1: + base_thrpt = c["thrpt"] + if c["ok"] == 0: + print(f" {size:9s} bs={bs}: ALL {c['n']} FAILED ({c['err']})", flush=True) + continue + scale = c["thrpt"] / base_thrpt if base_thrpt else float("nan") + tag = "" if c["ok"] == c["n"] else f" ({c['ok']}/{c['n']} ok)" + print(f" {size:9s} bs={bs}: thrpt {c['thrpt']:6.3f} req/s ({scale:4.2f}x bs1) " + f"p50 {c['p50']:6.2f}s p95 {c['p95']:6.2f}s mean {c['mean']:6.2f}s{tag}", flush=True) +if args.out: + with open(args.out, "w") as f: + json.dump(cells, f, indent=2) + print(f"wrote {args.out}", flush=True) +print("DONE", flush=True) diff --git a/mstar/benchmark/cosmos3/reproduce.sh b/mstar/benchmark/cosmos3/reproduce.sh new file mode 100755 index 00000000..779ee59c --- /dev/null +++ b/mstar/benchmark/cosmos3/reproduce.sh @@ -0,0 +1,75 @@ +#!/bin/bash +# Reproduce the Cosmos3-Nano serving benchmarks (M* vs vLLM-Omni): t2i / t2v / i2v +# latency and t2i throughput under concurrency. Both engines expose the OpenAI +# /v1/images/generations + /v1/videos APIs, so the client scripts in this dir hit +# both identically (same prompt / tiers / steps / guidance / seed). +# +# Measured on 1x H100 80GB, CUDA 13. Serve one engine per GPU; run them on +# separate GPUs so the bench clients can hit both back-to-back. +# +# Set for your machine before serving: +# SNAP = Cosmos3-Nano HF snapshot dir (hf download nvidia/Cosmos3-Nano) +# MSTAR = this repo checkout +# HF_TOKEN = your Hugging Face token (Cosmos3-Nano is gated) +set -eu + +# -------------------------------------------------------------------------- +# Serve M* (this repo). torch.compile + CUDA graphs are on by default. +# COSMOS3_GEN_CAPTURE_RES bakes a denoise graph per benchmarked resolution; +# COSMOS3_GEN_CAPTURE_BS additionally captures batched (concurrent) denoise +# steps, which the throughput sweep needs to scale past one request. +# usage: serve_mstar +# -------------------------------------------------------------------------- +serve_mstar() { + : "${SNAP:?set SNAP to the Cosmos3-Nano snapshot dir}" + : "${MSTAR:?set MSTAR to the repo checkout}" + local sock upload + sock=$(mktemp -d); upload=$(mktemp -d) + CUDA_VISIBLE_DEVICES="$1" PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ + COSMOS3_GEN_CAPTURE_RES=192x320,480x832,720x1280 \ + COSMOS3_GEN_CAPTURE_BS=1,4,8 \ + COSMOS3_NANO_DIR="$SNAP" PYTHONPATH="$MSTAR" \ + python "$MSTAR/mstar/api_server/entrypoint.py" \ + --config "$MSTAR/configs/cosmos3_nano.yaml" \ + --socket-path-prefix "$sock/" --upload-dir "$upload/" \ + --port "$2" --mooncake-port "$(($2 + 1000))" --tensor-comm-protocol SHM +} + +# -------------------------------------------------------------------------- +# Serve vLLM-Omni (baseline). Prebuilt cu13 wheel; same OpenAI API. +# usage: serve_vllm +# -------------------------------------------------------------------------- +serve_vllm() { + CUDA_VISIBLE_DEVICES="$1" \ + vllm serve nvidia/Cosmos3-Nano --omni --no-guardrails \ + --host 0.0.0.0 --port "$2" --init-timeout 1800 +} + +# -------------------------------------------------------------------------- +# Benchmarks. Serve each engine first (e.g. `serve_mstar 0 18300` and +# `serve_vllm 1 8200` in separate shells), then run the clients below. +# Defaults: 256p/480p/720p tiers, 50 steps (t2i), gs 6, seed 0. +# -------------------------------------------------------------------------- +here=$(dirname "$0") +run_benches() { # args: + local mp="$1" vp="$2" + # t2i latency (median of N, per tier) + python "$here/bench_t2i_oai.py" --port "$mp" --model cosmos3_nano --tag mstar + python "$here/bench_t2i_oai.py" --port "$vp" --model nvidia/Cosmos3-Nano --tag vllm + # t2v latency (189 frames, 35 steps) + python "$here/video_bench.py" --engine ours --port "$mp" + python "$here/video_bench.py" --engine vllm --port "$vp" + # i2v latency (same, plus a conditioning frame) + python "$here/video_bench.py" --engine ours --port "$mp" --image cond.jpg + python "$here/video_bench.py" --engine vllm --port "$vp" --image cond.jpg + # t2i throughput under concurrency (bs 1/4/8) + python "$here/bench_throughput.py" --port "$mp" --model cosmos3_nano --tag mstar + python "$here/bench_throughput.py" --port "$vp" --model nvidia/Cosmos3-Nano --tag vllm +} + +case "${1:-}" in + serve-mstar) shift; serve_mstar "$@";; + serve-vllm) shift; serve_vllm "$@";; + bench) shift; run_benches "$@";; + *) echo "usage: $0 {serve-mstar | serve-vllm | bench }";; +esac diff --git a/mstar/benchmark/cosmos3/video_bench.py b/mstar/benchmark/cosmos3/video_bench.py new file mode 100644 index 00000000..d980857b --- /dev/null +++ b/mstar/benchmark/cosmos3/video_bench.py @@ -0,0 +1,106 @@ +"""t2v/i2v latency — engine-aware (the video APIs differ, unlike t2i). + +ours : POST /v1/videos/generations (JSON), response data[0].b64_json = mp4. +vllm : POST /v1/videos/sync (multipart form, via curl to match the recipe), raw mp4. + +Same config on both (tiers, frames, steps, gs, seed, fps); client-side wall, median. +Video gen is slow + fairly deterministic, so few rounds. Reports MP4 byte size as a +sanity check (a real clip is large; a flat/empty one is tiny). + + python video_bench.py --engine ours --port 8100 + python video_bench.py --engine vllm --port 8000 +""" +import argparse +import base64 +import json +import subprocess +import time +import urllib.request + +ap = argparse.ArgumentParser() +ap.add_argument("--engine", choices=["ours", "vllm"], required=True) +ap.add_argument("--port", type=int, required=True) +ap.add_argument("--model", default="nvidia/Cosmos3-Nano") +ap.add_argument("--tiers", default="320x192,832x480,1280x720") +ap.add_argument("--frames", type=int, default=189) +ap.add_argument("--steps", type=int, default=35) +ap.add_argument("--gs", type=float, default=6.0) +ap.add_argument("--fps", type=int, default=24) +ap.add_argument("--seed", type=int, default=0) +ap.add_argument("--rounds", type=int, default=2) +ap.add_argument("--warmup", type=int, default=1) +ap.add_argument("--flow-shift", type=float, default=10.0) +ap.add_argument("--image", default="") # i2v: path to the conditioning frame (else t2v) +args = ap.parse_args() + +PROMPT = "A robot arm is cleaning a plate in the kitchen, smooth natural motion." +NEG = "blurry, distorted, low quality, jittery, deformed" + +# i2v conditioning frame: ours takes a base64 data-url in the JSON body; vLLM takes +# the raw file via multipart input_reference (curl reads args.image directly). +IMG_DATA_URI = None +if args.image: + with open(args.image, "rb") as _f: + IMG_DATA_URI = "data:image/jpeg;base64," + base64.b64encode(_f.read()).decode() + + +def gen_ours(size): + payload = { + "prompt": PROMPT, "negative_prompt": NEG, "size": size, "seed": args.seed, + "guidance_scale": args.gs, "num_inference_steps": args.steps, + "num_frames": args.frames, "fps": args.fps, + } + if IMG_DATA_URI: + payload["image"] = IMG_DATA_URI + body = json.dumps(payload).encode() + req = urllib.request.Request(f"http://127.0.0.1:{args.port}/v1/videos/generations", + data=body, headers={"Content-Type": "application/json"}) + t0 = time.perf_counter() + with urllib.request.urlopen(req, timeout=3600) as r: + out = json.load(r) + dt = time.perf_counter() - t0 + return dt, len(base64.b64decode(out["data"][0]["b64_json"])) + + +def gen_vllm(size): + extra = json.dumps({"use_resolution_template": False, "use_duration_template": False}) + out_mp4 = "/tmp/vbench_vllm.mp4" + cmd = [ + "curl", "-sS", "-X", "POST", f"http://127.0.0.1:{args.port}/v1/videos/sync", + "-H", "Accept: video/mp4", + "-F", f"model={args.model}", "-F", f"prompt={PROMPT}", "-F", f"negative_prompt={NEG}", + "-F", f"size={size}", "-F", f"num_frames={args.frames}", "-F", f"fps={args.fps}", + "-F", f"num_inference_steps={args.steps}", "-F", f"guidance_scale={args.gs}", + "-F", "max_sequence_length=4096", "-F", f"flow_shift={args.flow_shift}", + "-F", f"extra_params={extra}", "-F", f"seed={args.seed}", + ] + if args.image: + cmd += ["-F", f"input_reference=@{args.image};type=image/jpeg"] + cmd += ["-o", out_mp4, "-w", "%{http_code}"] + t0 = time.perf_counter() + res = subprocess.run(cmd, capture_output=True, text=True, timeout=3600) + dt = time.perf_counter() - t0 + code = res.stdout.strip()[-3:] + import os + sz = os.path.getsize(out_mp4) if os.path.exists(out_mp4) else 0 + if code != "200": + raise RuntimeError(f"http {code}, {sz}B") + return dt, sz + + +gen = gen_ours if args.engine == "ours" else gen_vllm +print(f"=== {args.engine} port={args.port} frames={args.frames} steps={args.steps} gs={args.gs} seed={args.seed} ===", flush=True) +for size in args.tiers.split(","): + try: + for _ in range(args.warmup): + gen(size) + ts, sz = [], 0 + for _ in range(args.rounds): + dt, sz = gen(size) + ts.append(dt) + ts.sort() + med = ts[len(ts) // 2] + print(f" {size:9s} median {med:.2f}s min {ts[0]:.2f} max {ts[-1]:.2f} mp4={sz // 1024}KB (n={args.rounds})", flush=True) + except Exception as e: # noqa: BLE001 + print(f" {size:9s} ERROR {type(e).__name__}: {str(e)[:140]}", flush=True) +print("DONE", flush=True) diff --git a/mstar/engine/cache_manager.py b/mstar/engine/cache_manager.py index e46a429d..32e46d77 100644 --- a/mstar/engine/cache_manager.py +++ b/mstar/engine/cache_manager.py @@ -1,3 +1,4 @@ +import os from dataclasses import dataclass import torch @@ -5,6 +6,18 @@ from mstar.engine.kv_store import KVCacheConfig, KVRequestState, PagedAllocationManager from mstar.utils.flashinfer_utils import FlashInferDecodeWrapper, FlashInferPrefillWrapper +# Run the non-causal generation attention as a dense FlashAttention-3 pass over a +# contiguous [frozen-prefix | fresh] sequence instead of the paged FlashInfer +# prefill. Diffusion recomputes every generation K/V each step (only the tiny +# text prefix is reused), so the paged path's per-step full-buffer K/V write is +# pure overhead here; a dense pass gathers the small prefix, concatenates it with +# the freshly projected K/V, and runs one varlen kernel — which is also the +# faster attention kernel at these shapes. Eager-only (the captured image path +# keeps the paged wrapper). Off unless COSMOS3_DENSE_FA3 is set; read per plan +# (once per denoise step) so it can be toggled for A/B parity checks. +def _dense_gen_attn_enabled() -> bool: + return bool(os.environ.get("COSMOS3_DENSE_FA3")) + @dataclass class _PlanState: @@ -35,6 +48,11 @@ class _PlanState: seq_lens: list[int] | None = None write_store: bool = True custom_pos_advance: list[int] | None = None + # Set when the dense generation-attention path is active for this label: the + # per-segment gather indices + varlen cu_seqlens needed to attend each + # generation segment over its contiguous frozen prefix. None on causal + # (prefill) plans, which keep the paged path. See _build_dense_gen_plan. + dense_gen: dict | None = None class WorkspaceBufferManager: @@ -422,6 +440,10 @@ def _plan_attention_impl( # reader — dropped along with their per-rid GPU construction above. ps.seq_lens = seq_lens ps.write_store = write_store + if _dense_gen_attn_enabled() and not is_causal and not self._cuda_graph_mode: + ps.dense_gen = self._build_dense_gen_plan([effective_label], seq_lens) + else: + ps.dense_gen = None def plan_rope( self, @@ -585,14 +607,43 @@ def plan_attention_batched_cfg( paged_kv_indices = torch.tensor(all_page_indices, dtype=torch.int32) paged_kv_last_page_len = torch.tensor(kv_last_page_lens, dtype=torch.int32) - wrapper = FlashInferPrefillWrapper( - workspace_buffer=self.buffer_manager.get(combined_label), - num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - page_size=page_size, - enable_nvtx=self.enable_nvtx, - ) + ps = self._plan_states.get(combined_label) + if self._cuda_graph_mode and ps is not None and ps.wrapper is not None: + # CUDA-graph mode: reuse the persistent wrapper across denoise steps. + # plan() updates its static buffers via .copy_() so the captured + # kernel picks up each step's page table without reallocating. + wrapper = ps.wrapper + elif self._cuda_graph_mode: + # First call under capture: build the persistent wrapper sized for the + # fixed batch (labels x requests) and token budget. + wrapper = FlashInferPrefillWrapper( + workspace_buffer=self.buffer_manager.get(combined_label), + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + page_size=page_size, + batch_size=len(labels) * len(self.request_ids), + max_total_tokens=sum(combined_seq_lens), + max_num_pages=cfg.max_num_pages, + device=self.device, + use_cuda_graph=True, + enable_nvtx=self.enable_nvtx, + ) + ps = _PlanState(wrapper=wrapper) + self._plan_states[combined_label] = ps + else: + # Eager mode: a fresh wrapper each call (the cache manager is rebuilt + # per forward, so there is nothing persistent to reuse). + wrapper = FlashInferPrefillWrapper( + workspace_buffer=self.buffer_manager.get(combined_label), + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + page_size=page_size, + enable_nvtx=self.enable_nvtx, + ) + ps = _PlanState(wrapper=wrapper) + self._plan_states[combined_label] = ps wrapper.plan( qo_indptr=qo_indptr, @@ -602,13 +653,12 @@ def plan_attention_batched_cfg( causal=is_causal, dtype=dtype, ) - - ps = _PlanState( - wrapper=wrapper, - seq_lens=combined_seq_lens, - write_store=write_store, - ) - self._plan_states[combined_label] = ps + ps.seq_lens = combined_seq_lens + ps.write_store = write_store + if _dense_gen_attn_enabled() and not is_causal and not self._cuda_graph_mode: + ps.dense_gen = self._build_dense_gen_plan(labels, seq_lens) + else: + ps.dense_gen = None @torch.compiler.disable def plan_rope_batched_cfg( @@ -688,6 +738,10 @@ def run_attention( label = next(iter(self.active_labels.values())) ps = self._plan_states[label] + + if ps.dense_gen is not None: + return self._run_dense_gen(q, k, v, layer_idx, ps.dense_gen).to(orig_dtype) + assert self.kv_cache is not None and ps.wrapper is not None ps.wrapper.set_kv_cache(self.kv_cache[layer_idx], k, v) @@ -700,6 +754,78 @@ def run_attention( return ps.wrapper.run(q, self.kv_cache[layer_idx]).to(orig_dtype) + def _build_dense_gen_plan(self, labels: list[str], seq_lens: list[int]) -> dict: + """Pre-compute the per-segment gather + varlen layout for the dense + generation-attention path, in the same (label, request) batch order the + generation tokens are packed in. Each segment attends its fresh + generation tokens over its frozen text prefix; the prefix lives in the + pages written at prefill, so we record the page indices to gather it from + (the same across all layers) and the cumulative-sequence-length tensors a + single varlen kernel needs. Built once per denoise step, reused by every + layer's run_attention.""" + cfg = self.kv_cache_config + page_size = cfg.page_size + segs = [] # (prefix_page_indices, prefix_len, gen_len) + cu_q = [0] + cu_k = [0] + max_q = 0 + max_k = 0 + for label in labels: + for i, rid in enumerate(self.request_ids): + state = self._get_state(rid, label) + prefix_len = state.seq_len + gen_len = seq_lens[i] + n_pages = (prefix_len + page_size - 1) // page_size + idx = torch.tensor( + state.page_indices[:n_pages], dtype=torch.long, device=self.device + ) + segs.append((idx, prefix_len, gen_len)) + cu_q.append(cu_q[-1] + gen_len) + cu_k.append(cu_k[-1] + prefix_len + gen_len) + max_q = max(max_q, gen_len) + max_k = max(max_k, prefix_len + gen_len) + return { + "segs": segs, + "cu_q": torch.tensor(cu_q, dtype=torch.int32, device=self.device), + "cu_k": torch.tensor(cu_k, dtype=torch.int32, device=self.device), + "max_q": max_q, + "max_k": max_k, + } + + @torch.compiler.disable + def _run_dense_gen( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_idx: int, dg: dict + ) -> torch.Tensor: + """Dense generation attention: per segment, gather the frozen text-prefix + K/V from the paged cache, concatenate it with this segment's fresh K/V, + and attend non-causally with one FlashAttention-3 varlen kernel. Bypasses + the paged write entirely (the generation K/V is recomputed every step, so + persisting it is wasted work).""" + from fa3_fwd_interface import flash_attn_varlen_func + + cfg = self.kv_cache_config + num_kv_heads, head_dim = cfg.num_kv_heads, cfg.head_dim + kv_layer = self.kv_cache[layer_idx] # [max_pages, 2, page_size, num_kv_heads, head_dim] + + k_parts, v_parts = [], [] + offset = 0 + for idx, prefix_len, gen_len in dg["segs"]: + sub = kv_layer[idx] # [n_pages, 2, page_size, num_kv_heads, head_dim] + k_parts.append(sub[:, 0].reshape(-1, num_kv_heads, head_dim)[:prefix_len]) + k_parts.append(k[offset:offset + gen_len]) + v_parts.append(sub[:, 1].reshape(-1, num_kv_heads, head_dim)[:prefix_len]) + v_parts.append(v[offset:offset + gen_len]) + offset += gen_len + key = torch.cat(k_parts, dim=0) + val = torch.cat(v_parts, dim=0) + if q.dtype != key.dtype: + q = q.to(key.dtype) + + out = flash_attn_varlen_func( + q, key, val, dg["cu_q"], dg["cu_k"], dg["max_q"], dg["max_k"], causal=False, + ) + return out[0] if isinstance(out, tuple) else out + @torch.compiler.disable def apply_rope( self, diff --git a/mstar/engine/cuda_graph_config.py b/mstar/engine/cuda_graph_config.py index 37f19b0d..9c034479 100644 --- a/mstar/engine/cuda_graph_config.py +++ b/mstar/engine/cuda_graph_config.py @@ -25,7 +25,29 @@ def __init__( # StatelessCudaGraphRunner picks its own default). Useful for codec-style # submodules where memory cost per size is high, or for AR walks where a # small subset is enough. - capture_batch_sizes: list[int] | None = None + capture_batch_sizes: list[int] | None = None, + # Method on the submodule to capture. Defaults to ``forward_batched`` (the + # same method the eager batched path uses). Diffusion-style walks that must + # keep a non-capturable tail (e.g. a multistep scheduler step) out of the + # graph capture a velocity-only method here and run the tail in + # ``postprocess_captured`` after replay. + capture_forward_method: str = "forward_batched", + # Whether the runner advances KV seq_lens after replay. True for + # autoregressive walks (each step appends a token). False for frozen-prefix + # denoise loops that re-read a fixed prefix and overwrite the same tail + # pages every step (advancing would grow the prefix and corrupt attention). + advance_seq_lens: bool = True, + # Whether this config's captured batch sizes also cap the engine's max + # (eager) batch size for the walk. Default True keeps the conservative + # behavior: never batch beyond a captured graph size. Set False when the + # captured sizes are only an acceleration subset and the submodule's eager + # batched path can handle larger batches — the engine then honors the + # submodule's max_batch_size and uses a graph only when the exact batch + # size was captured (gated by runner.can_run), falling back to eager + # batched execution otherwise. Needed so a denoise loop that captures a + # graph only at batch size 1 (single-request latency) can still batch + # concurrent requests instead of serializing them. + caps_eager_batch_size: bool = True, ): self.capture_graph_walk = capture_graph_walk self.replay_graph_walks = replay_graph_walks or [capture_graph_walk] @@ -33,6 +55,9 @@ def __init__( self.labels = labels or ["main"] self.compile = compile self.capture_batch_sizes = capture_batch_sizes + self.capture_forward_method = capture_forward_method + self.advance_seq_lens = advance_seq_lens + self.caps_eager_batch_size = caps_eager_batch_size @abstractmethod def get_config_type(self) -> CudaGraphConfigType: @@ -52,7 +77,10 @@ def __init__( requires_cfg: bool = False, labels: list[str] = None, compile: bool = True, - capture_batch_sizes: list[int] | None = None + capture_batch_sizes: list[int] | None = None, + capture_forward_method: str = "forward_batched", + advance_seq_lens: bool = True, + caps_eager_batch_size: bool = True, ): super().__init__( capture_graph_walk=capture_graph_walk, @@ -60,7 +88,10 @@ def __init__( requires_cfg=requires_cfg, labels=labels, compile=compile, - capture_batch_sizes=capture_batch_sizes + capture_batch_sizes=capture_batch_sizes, + capture_forward_method=capture_forward_method, + advance_seq_lens=advance_seq_lens, + caps_eager_batch_size=caps_eager_batch_size, ) self.single_request_inputs = single_request_inputs diff --git a/mstar/engine/cuda_graph_runner.py b/mstar/engine/cuda_graph_runner.py index c1bd5d93..4cd01d88 100644 --- a/mstar/engine/cuda_graph_runner.py +++ b/mstar/engine/cuda_graph_runner.py @@ -546,7 +546,11 @@ def _capture_slots( spec = prepare_slot(slot_idx) dummy_rids_to_free.append(spec.dummy_rids) - forward = submodule.forward_batched + # Usually ``forward_batched`` (the same method the eager batched + # path runs). Diffusion walks override this to a velocity-only + # method so the non-capturable scheduler tail stays out of the + # graph (run later in ``postprocess_captured``). + forward = getattr(submodule, config.capture_forward_method) if config.compile: forward = torch.compile( forward, @@ -778,23 +782,32 @@ def _get_key_for( ) -> CudaGraphKey | None: if not self.graphs: return None - config = self._config_for(graph_walk, requires_cfg) - if config is None: - return None - padded_bs = self._get_padded_batch_size(batch_size, config) - if padded_bs is None: - return None - padded_num_tokens = self._get_padded_num_tokens(num_tokens, padded_bs, config) - if padded_num_tokens is None: - return None - - key = CudaGraphKey( - graph_walk=graph_walk, - requires_cfg=requires_cfg, - bs=padded_bs, - num_tokens=padded_num_tokens, - ) - return key if key in self.graphs else None + # A walk may have several captures (e.g. one per image resolution, each a + # fixed shape with its own token count). Consider every matching config and + # pick the tightest captured (bs, num_tokens) bucket that fits this batch, + # so a request lands on the graph for its own shape rather than the first + # config declared. With a single config this is the same as before. + best: CudaGraphKey | None = None + for config in self.capture_configs: + if graph_walk not in config.replay_graph_walks or config.requires_cfg != requires_cfg: + continue + padded_bs = self._get_padded_batch_size(batch_size, config) + if padded_bs is None: + continue + padded_num_tokens = self._get_padded_num_tokens(num_tokens, padded_bs, config) + if padded_num_tokens is None: + continue + key = CudaGraphKey( + graph_walk=graph_walk, + requires_cfg=requires_cfg, + bs=padded_bs, + num_tokens=padded_num_tokens, + ) + if key in self.graphs and ( + best is None or (key.num_tokens, key.bs) < (best.num_tokens, best.bs) + ): + best = key + return best def _config_for(self, graph_walk: str, requires_cfg: bool) -> CudaGraphConfig | None: for cfg in self.capture_configs: @@ -1372,9 +1385,13 @@ def _run_basic_batched( range_push("gpu_thread.postprocess", synchronize=False) if self.enable_nvtx: range_push("cg.advance_seq_lens", synchronize=False) - for label in config_labels: - static_cm.set_active_label(label) - static_cm.advance_seq_lens() + # Frozen-prefix denoise walks re-read a fixed prefix and overwrite the + # same tail pages every step, so they opt out of the advance (it would + # grow the prefix across steps and corrupt attention). + if graph_data.config.advance_seq_lens: + for label in config_labels: + static_cm.set_active_label(label) + static_cm.advance_seq_lens() if self.enable_nvtx: range_pop(synchronize=False) @@ -1403,6 +1420,18 @@ def _run_basic_batched( if self.enable_nvtx: range_pop(synchronize=False) + # Eager tail for walks that captured only a velocity/raw forward and + # keep a non-capturable step (e.g. a multistep scheduler) out of the + # graph. Runs with REAL request ids, the original ``inputs``, and the + # cloned captured outputs, so it can finish each request's step. + if hasattr(submodule, "postprocess_captured"): + outputs = submodule.postprocess_captured( + request_ids=request_ids, + inputs=inputs, + per_request_info=per_request_info, + outputs=outputs, + ) + success = True return outputs finally: @@ -1586,9 +1615,13 @@ def _run_flashinfer_packed( range_push("gpu_thread.postprocess", synchronize=False) if self.enable_nvtx: range_push("cg.advance_seq_lens", synchronize=False) - for label in config_labels: - static_cm.set_active_label(label) - static_cm.advance_seq_lens() + # Frozen-prefix denoise walks re-read a fixed prefix and overwrite the + # same tail pages every step, so they opt out of the advance (it would + # grow the prefix across steps and corrupt attention). + if graph_data.config.advance_seq_lens: + for label in config_labels: + static_cm.set_active_label(label) + static_cm.advance_seq_lens() if self.enable_nvtx: range_pop(synchronize=False) @@ -1610,6 +1643,18 @@ def _run_flashinfer_packed( if self.enable_nvtx: range_pop(synchronize=False) + # Eager tail for walks that captured only a velocity/raw forward and + # keep a non-capturable step (e.g. a multistep scheduler) out of the + # graph. Runs with REAL request ids, the original ``inputs``, and the + # cloned captured outputs, so it can finish each request's step. + if hasattr(submodule, "postprocess_captured"): + outputs = submodule.postprocess_captured( + request_ids=request_ids, + inputs=inputs, + per_request_info=per_request_info, + outputs=outputs, + ) + success = True return outputs finally: diff --git a/mstar/engine/kv_cache_engine.py b/mstar/engine/kv_cache_engine.py index 8fce9664..80591110 100644 --- a/mstar/engine/kv_cache_engine.py +++ b/mstar/engine/kv_cache_engine.py @@ -215,6 +215,10 @@ def _compile_submodules(self) -> None: for node_name, submodule_mgmt in self.submodule_management.items(): submodule = submodule_mgmt.submodule + if getattr(submodule, "disable_torch_compile", False): + logger.info("KVCacheEngine: torch.compile disabled for %s (submodule opt-out)", node_name) + continue + try: submodule.forward = torch.compile( submodule.forward, @@ -355,7 +359,13 @@ def get_max_batch_size(self, node_name, graph_walk): configs = [ cfg for cfg in runner.capture_configs \ if graph_walk in cfg.replay_graph_walks + and getattr(cfg, "caps_eager_batch_size", True) ] + # Configs that opt out of capping (caps_eager_batch_size=False) capture a + # graph only for an acceleration subset of batch sizes; the eager batched + # path handles larger batches and the runner gates graph replay by exact + # batch size. With no capping config left for this walk, honor the + # submodule's max_batch_size instead of the captured-size ceiling. if not configs: return submod_max_bs max_cuda_graph_bs = max([ diff --git a/mstar/engine/stateless_engine.py b/mstar/engine/stateless_engine.py index 34a5977d..a4679c2f 100644 --- a/mstar/engine/stateless_engine.py +++ b/mstar/engine/stateless_engine.py @@ -515,6 +515,12 @@ def warmup(self) -> None: self._install_piecewise_runner(node_name, submodule) def _apply_torch_compile(self, node_name: str, submodule: NodeSubmodule) -> None: + if getattr(submodule, "disable_torch_compile", False): + logger.info( + "StatelessEngine[%s]: torch.compile disabled for %s (submodule opt-out)", + self.config.name, node_name, + ) + return try: if hasattr(submodule, "forward"): submodule.forward = torch.compile( diff --git a/mstar/model/cosmos3/__init__.py b/mstar/model/cosmos3/__init__.py new file mode 100644 index 00000000..102fdb00 --- /dev/null +++ b/mstar/model/cosmos3/__init__.py @@ -0,0 +1,9 @@ +"""Cosmos3 omni generator model package.""" + +from mstar.model.cosmos3.config import ( + Cosmos3Config, + Cosmos3SchedulerConfig, + Cosmos3VAEConfig, +) + +__all__ = ["Cosmos3Config", "Cosmos3SchedulerConfig", "Cosmos3VAEConfig"] diff --git a/mstar/model/cosmos3/components/__init__.py b/mstar/model/cosmos3/components/__init__.py new file mode 100644 index 00000000..f7622f31 --- /dev/null +++ b/mstar/model/cosmos3/components/__init__.py @@ -0,0 +1 @@ +"""Cosmos3 backbone components.""" diff --git a/mstar/model/cosmos3/components/transformer.py b/mstar/model/cosmos3/components/transformer.py new file mode 100644 index 00000000..e1eabff4 --- /dev/null +++ b/mstar/model/cosmos3/components/transformer.py @@ -0,0 +1,1058 @@ +"""Cosmos3 dual-pathway Mixture-of-Transformers DiT. + +Each decoder layer carries two parameter sets that run side by side: + + * UND (understanding / text-conditioning) pathway — ``to_{q,k,v,out}``, + ``norm_{q,k}``, ``mlp``, ``input_layernorm``, ``post_attention_layernorm``. + Causal self-attention over the text prefix; never attends to GEN tokens. + * GEN (generation / denoiser) pathway — ``add_{q,k,v}_proj``, ``to_add_out``, + ``norm_added_{q,k}``, ``mlp_moe_gen``, ``input_layernorm_moe_gen``, + ``post_attention_layernorm_moe_gen``. Full (non-causal) attention where + GEN queries attend to ``cat([k_und, k_gen])`` / ``cat([v_und, v_gen])``. + +The module mirrors the published diffusers checkpoint layout one-to-one, so the +flat ``layers.N.*`` safetensors keys load with no key remapping beyond dropping +the unused text ``lm_head``. + +UND and GEN run together in one fused pass every denoising step. The attention +and MLP projections are tensor-parallel: with a trivial (world-size-1) comm +group they behave exactly like plain ``nn.Linear``; with a real group the +q/k/v and gate/up projections are column-sharded along the head / intermediate +dim and the out / down projections row-shard their input and all-reduce. +""" + +from __future__ import annotations + +import math + +import torch +import torch.nn.functional as F +from diffusers.models.embeddings import Timesteps +from torch import nn + +from mstar.distributed.communication import TPCommGroup +from mstar.model.components.distributed.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) + + +class RMSNorm(nn.Module): + """Weight-only RMS normalization (no bias). + + Replicates the diffusers ``RMSNorm`` dtype ordering exactly: variance in + fp32, normalize, then round the normalized activations to the (bf16) weight + dtype *before* the weight multiply. Matching this rounding point matters for + tight bf16 parity across 36 layers' worth of norms. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + if self.weight.dtype in (torch.float16, torch.bfloat16): + hidden_states = hidden_states.to(self.weight.dtype) + return hidden_states * self.weight + return (hidden_states * self.weight).to(input_dtype) + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + half = x.shape[-1] // 2 + return torch.cat((-x[..., half:], x[..., :half]), dim=-1) + + +class Cosmos3RotaryEmbedding(nn.Module): + """3D interleaved mRoPE (``Cosmos3VLTextRotaryEmbedding``). + + ``inv_freq`` is recomputed on the fly from ``rope_theta``/``head_dim`` rather + than registered as a buffer: the model is materialized via ``meta`` + + ``to_empty``, which leaves registered buffers uninitialized. Recompute is + cheap (``head_dim/2`` values, once per forward). + """ + + def __init__(self, head_dim: int, rope_theta: float, rope_axes_dim: tuple[int, int, int]): + super().__init__() + self.head_dim = head_dim + self.rope_theta = rope_theta + self.rope_axes_dim = tuple(rope_axes_dim) + + def apply_interleaved_mrope(self, freqs: torch.Tensor) -> torch.Tensor: + """Reorganize chunked ``[TTT…HHH…WWW]`` frequencies into interleaved + ``[THTHWHTHW…TT]`` (preserves frequency continuity across the 3 grids).""" + freqs_t = freqs[0] + for dim, offset in enumerate((1, 2), start=1): # H, W + length = self.rope_axes_dim[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + def forward( + self, position_ids: torch.Tensor, device: torch.device, dtype: torch.dtype + ) -> tuple[torch.Tensor, torch.Tensor]: + inv_freq = 1.0 / ( + self.rope_theta ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32, device=device) / self.head_dim) + ) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) # [3,B,N] + inv_freq_expanded = inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1).to(device) + position_ids_expanded = position_ids[:, :, None, :].float() # [3,B,1,N] + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(2, 3) # [3,B,N,head_dim//2] + freqs = self.apply_interleaved_mrope(freqs) # [B,N,head_dim//2] + emb = torch.cat((freqs, freqs), dim=-1) # [B,N,head_dim] + return emb.cos().to(dtype=dtype), emb.sin().to(dtype=dtype) + + +class TimestepEmbedder(nn.Module): + """Two-layer MLP over sinusoidal timestep features (``linear_1``/``linear_2``). + + Matches diffusers ``TimestepEmbedding`` (act = SiLU, no cond/post-act). Kept + in fp32 at build time, like diffusers' ``_keep_in_fp32_modules``. + """ + + def __init__(self, in_channels: int, time_embed_dim: int): + super().__init__() + self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True) + self.act = nn.SiLU() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + return self.linear_2(self.act(self.linear_1(sample))) + + +class Cosmos3MLP(nn.Module): + """SwiGLU feed-forward (``gate_proj``/``up_proj``/``down_proj``, no bias). + + Tensor-parallel: ``gate_proj``/``up_proj`` are column-sharded along the + intermediate dim and ``down_proj`` row-shards its input and all-reduces. + A trivial comm group (world size 1) makes these plain linears. + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + comm_group: TPCommGroup | None = None, + ): + super().__init__() + if comm_group is None: + comm_group = TPCommGroup.trivial() + self.gate_proj = ColumnParallelLinear(comm_group, hidden_size, intermediate_size, bias=False) + self.up_proj = ColumnParallelLinear(comm_group, hidden_size, intermediate_size, bias=False) + self.down_proj = RowParallelLinear(comm_group, intermediate_size, hidden_size, bias=False) + self.act_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class Cosmos3PackedMoTAttention(nn.Module): + """Dual-pathway packed attention: separate unfused projections + QK-norm for + the understanding (causal) and generation (full) token streams. + + Mirrors diffusers ``Cosmos3AttnProcessor``: QK-norm is applied per-head + *before* RoPE; the UND stream self-attends causally, the GEN stream attends + non-causally to ``cat([und, gen])``. GQA (32 Q / 8 KV heads) is handled by + ``F.scaled_dot_product_attention(enable_gqa=True)``. + """ + + def __init__( + self, + hidden_size: int, + head_dim: int, + num_attention_heads: int, + num_key_value_heads: int, + attention_bias: bool, + rms_norm_eps: float, + comm_group: TPCommGroup | None = None, + ): + super().__init__() + if comm_group is None: + comm_group = TPCommGroup.trivial() + tp_size = comm_group.world_size + if num_attention_heads % tp_size or num_key_value_heads % tp_size: + raise ValueError( + f"TP size {tp_size} must divide both num_attention_heads " + f"({num_attention_heads}) and num_key_value_heads " + f"({num_key_value_heads})" + ) + self.head_dim = head_dim + # Per-rank head counts: TP shards the head dimension, so the q/k/v + # reshapes below operate on this rank's slice of heads. + self.num_attention_heads = num_attention_heads // tp_size + self.num_key_value_heads = num_key_value_heads // tp_size + + q_dim = num_attention_heads * head_dim + kv_dim = num_key_value_heads * head_dim + + # Understanding pathway. + self.to_q = ColumnParallelLinear(comm_group, hidden_size, q_dim, bias=attention_bias) + self.to_k = ColumnParallelLinear(comm_group, hidden_size, kv_dim, bias=attention_bias) + self.to_v = ColumnParallelLinear(comm_group, hidden_size, kv_dim, bias=attention_bias) + self.to_out = RowParallelLinear(comm_group, q_dim, hidden_size, bias=attention_bias) + self.norm_q = RMSNorm(head_dim, eps=rms_norm_eps) + self.norm_k = RMSNorm(head_dim, eps=rms_norm_eps) + + # Generation pathway. + self.add_q_proj = ColumnParallelLinear(comm_group, hidden_size, q_dim, bias=attention_bias) + self.add_k_proj = ColumnParallelLinear(comm_group, hidden_size, kv_dim, bias=attention_bias) + self.add_v_proj = ColumnParallelLinear(comm_group, hidden_size, kv_dim, bias=attention_bias) + self.to_add_out = RowParallelLinear(comm_group, q_dim, hidden_size, bias=attention_bias) + self.norm_added_q = RMSNorm(head_dim, eps=rms_norm_eps) + self.norm_added_k = RMSNorm(head_dim, eps=rms_norm_eps) + + @staticmethod + def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + # x: [N, H, D]; cos/sin: [N, D] -> [N, 1, D] for broadcast over heads. + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + return x * cos + _rotate_half(x) * sin + + def _attend(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, is_causal: bool) -> torch.Tensor: + # q: [Nq, Hq, D]; k/v: [Nk, Hkv, D] -> [Nq, Hq*D]. SDPA wants [B, H, S, D]. + q = q.unsqueeze(0).transpose(1, 2) + k = k.unsqueeze(0).transpose(1, 2) + v = v.unsqueeze(0).transpose(1, 2) + out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal, enable_gqa=True) + return out.transpose(1, 2).squeeze(0).flatten(-2, -1) + + def forward( + self, + und_seq: torch.Tensor, + gen_seq: torch.Tensor, + rotary_emb: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + H, Hkv, D = self.num_attention_heads, self.num_key_value_heads, self.head_dim + + q_und = self.to_q(und_seq).view(-1, H, D) + k_und = self.to_k(und_seq).view(-1, Hkv, D) + v_und = self.to_v(und_seq).view(-1, Hkv, D) + q_gen = self.add_q_proj(gen_seq).view(-1, H, D) + k_gen = self.add_k_proj(gen_seq).view(-1, Hkv, D) + v_gen = self.add_v_proj(gen_seq).view(-1, Hkv, D) + + q_und = self.norm_q(q_und) + k_und = self.norm_k(k_und) + q_gen = self.norm_added_q(q_gen) + k_gen = self.norm_added_k(k_gen) + + cos_und, sin_und, cos_gen, sin_gen = rotary_emb + q_und = self._apply_rope(q_und, cos_und, sin_und) + k_und = self._apply_rope(k_und, cos_und, sin_und) + q_gen = self._apply_rope(q_gen, cos_gen, sin_gen) + k_gen = self._apply_rope(k_gen, cos_gen, sin_gen) + + # UND: causal self-attention over text. + causal_out = self._attend(q_und, k_und, v_und, is_causal=True) + # GEN: full attention over [und | gen]. + all_k = torch.cat([k_und, k_gen], dim=0) + all_v = torch.cat([v_und, v_gen], dim=0) + full_out = self._attend(q_gen, all_k, all_v, is_causal=False) + + return self.to_out(causal_out), self.to_add_out(full_out) + + # ------------------------------------------------------------------ + # Cached-attention variants: the two pathways run in separate passes and + # share their K/V through a paged cache handle instead of in-pass concat. + # The understanding pass writes its K/V (causal); the generation pass reads + # that frozen K/V plus its own (non-causal) — causality is fixed by the + # handle's attention plan, not here. + # ------------------------------------------------------------------ + + def forward_und(self, und_seq: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, cache_handle) -> torch.Tensor: + H, Hkv, D = self.num_attention_heads, self.num_key_value_heads, self.head_dim + q = self.norm_q(self.to_q(und_seq).view(-1, H, D)) + k = self.norm_k(self.to_k(und_seq).view(-1, Hkv, D)) + v = self.to_v(und_seq).view(-1, Hkv, D) + q = self._apply_rope(q, cos, sin) + k = self._apply_rope(k, cos, sin) + out = cache_handle.run_attention(q=q, k=k, v=v).reshape(-1, H * D) + return self.to_out(out) + + def forward_gen(self, gen_seq: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, cache_handle) -> torch.Tensor: + H, Hkv, D = self.num_attention_heads, self.num_key_value_heads, self.head_dim + q = self.norm_added_q(self.add_q_proj(gen_seq).view(-1, H, D)) + k = self.norm_added_k(self.add_k_proj(gen_seq).view(-1, Hkv, D)) + v = self.add_v_proj(gen_seq).view(-1, Hkv, D) + q = self._apply_rope(q, cos, sin) + k = self._apply_rope(k, cos, sin) + out = cache_handle.run_attention(q=q, k=k, v=v).reshape(-1, H * D) + return self.to_add_out(out) + + +class Cosmos3MoTDecoderLayer(nn.Module): + """One dual-pathway decoder layer (UND + GEN parameter sets).""" + + def __init__( + self, + hidden_size: int, + head_dim: int, + num_attention_heads: int, + num_key_value_heads: int, + intermediate_size: int, + attention_bias: bool, + rms_norm_eps: float, + comm_group: TPCommGroup | None = None, + ): + super().__init__() + self.self_attn = Cosmos3PackedMoTAttention( + hidden_size=hidden_size, + head_dim=head_dim, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_bias=attention_bias, + rms_norm_eps=rms_norm_eps, + comm_group=comm_group, + ) + self.mlp = Cosmos3MLP(hidden_size, intermediate_size, comm_group=comm_group) + self.mlp_moe_gen = Cosmos3MLP(hidden_size, intermediate_size, comm_group=comm_group) + + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.input_layernorm_moe_gen = RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm_moe_gen = RMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + und_seq: torch.Tensor, + gen_seq: torch.Tensor, + rotary_emb: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + und_norm = self.input_layernorm(und_seq) + gen_norm = self.input_layernorm_moe_gen(gen_seq) + + und_attn_out, gen_attn_out = self.self_attn(und_norm, gen_norm, rotary_emb) + residual_und = und_seq + und_attn_out + residual_gen = gen_seq + gen_attn_out + + mlp_out_und = self.mlp(self.post_attention_layernorm(residual_und)) + mlp_out_gen = self.mlp_moe_gen(self.post_attention_layernorm_moe_gen(residual_gen)) + + return residual_und + mlp_out_und, residual_gen + mlp_out_gen + + def forward_und(self, und_seq: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, cache_handle) -> torch.Tensor: + und_norm = self.input_layernorm(und_seq) + attn_out = self.self_attn.forward_und(und_norm, cos, sin, cache_handle) + residual = und_seq + attn_out + return residual + self.mlp(self.post_attention_layernorm(residual)) + + def forward_gen(self, gen_seq: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, cache_handle) -> torch.Tensor: + gen_norm = self.input_layernorm_moe_gen(gen_seq) + attn_out = self.self_attn.forward_gen(gen_norm, cos, sin, cache_handle) + residual = gen_seq + attn_out + return residual + self.mlp_moe_gen(self.post_attention_layernorm_moe_gen(residual)) + + +class DomainAwareLinear(nn.Module): + """Per-embodiment affine map: one *full* (weight, bias) pair per action + embodiment domain, both looked up from embedding tables keyed by a domain id. + + ``fc`` holds each domain's flattened weight (shape ``[num_domains, + out*in]``, viewed as ``[in, out]`` so the map is ``x @ W`` — note the + weight is stored transposed relative to ``nn.Linear``); ``bias`` holds each + domain's ``[out]`` bias. Matches the checkpoint's + ``action_proj_{in,out}.{fc,bias}.weight`` shapes one-to-one.""" + + def __init__(self, in_features: int, out_features: int, num_domains: int): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.num_domains = num_domains + self.fc = nn.Embedding(num_domains, out_features * in_features) + self.bias = nn.Embedding(num_domains, out_features) + + def forward(self, x: torch.Tensor, domain_id: torch.Tensor) -> torch.Tensor: + domain_id = domain_id.to(device=x.device, dtype=torch.long).reshape(-1) + weight = self.fc(domain_id).view(domain_id.shape[0], self.in_features, self.out_features) + bias = self.bias(domain_id).view(domain_id.shape[0], self.out_features) + if x.ndim == 2: # [B, in] -> [B, out] + return torch.bmm(x.unsqueeze(1), weight).squeeze(1) + bias + return torch.bmm(x, weight) + bias.unsqueeze(1) # [B, T, in] -> [B, T, out] + + +class Cosmos3OmniTransformer(nn.Module): + """The full Cosmos3 generator backbone. + + ``state_dict()`` keys reproduce the published ``transformer/`` checkpoint + exactly, except the text ``lm_head`` is intentionally absent: generation + predicts flow velocity through ``proj_out`` and never decodes text logits. + """ + + def __init__(self, config, comm_group: TPCommGroup | None = None): + super().__init__() + self.config = config + h = config.hidden_size + + self.embed_tokens = nn.Embedding(config.vocab_size, h) + self.layers = nn.ModuleList( + Cosmos3MoTDecoderLayer( + hidden_size=h, + head_dim=config.head_dim, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + intermediate_size=config.intermediate_size, + attention_bias=config.attention_bias, + rms_norm_eps=config.rms_norm_eps, + comm_group=comm_group, + ) + for _ in range(config.num_hidden_layers) + ) + self.norm = RMSNorm(h, eps=config.rms_norm_eps) + self.norm_moe_gen = RMSNorm(h, eps=config.rms_norm_eps) + self.rotary_emb = Cosmos3RotaryEmbedding( + head_dim=config.head_dim, + rope_theta=config.rope_theta, + rope_axes_dim=config.rope_axes_dim, + ) + + # Vision latent in/out projections + timestep embedder. + self.proj_in = nn.Linear(config.patch_latent_dim, h, bias=True) + self.proj_out = nn.Linear(h, config.patch_latent_dim, bias=True) + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedder(in_channels=256, time_embed_dim=h) + + # Sound (AVAE-latent) heads. + if config.sound_gen: + if config.sound_dim is None: + raise ValueError("sound_dim must be set when sound_gen is True") + self.audio_proj_in = nn.Linear(config.sound_dim, h, bias=True) + self.audio_proj_out = nn.Linear(h, config.sound_dim, bias=True) + self.audio_modality_embed = nn.Parameter(torch.zeros(h)) + + # Action heads (per-embodiment domain-aware projections). + self.action_dim = config.max_action_dim + if config.action_gen: + self.action_proj_in = DomainAwareLinear( + config.max_action_dim, h, config.num_embodiment_domains + ) + self.action_proj_out = DomainAwareLinear( + h, config.max_action_dim, config.num_embodiment_domains + ) + self.action_modality_embed = nn.Parameter(torch.zeros(h)) + + # ------------------------------------------------------------------ + # Pure-tensor packing/unpacking helpers (ported from diffusers). + # ------------------------------------------------------------------ + + def _apply_timestep_embeds_to_noisy_tokens( + self, + packed_tokens: torch.Tensor, + packed_timestep_embeds: torch.Tensor, + noisy_frame_indexes: list[torch.Tensor], + token_shapes: list[tuple[int, ...]], + ) -> torch.Tensor: + start_noisy_index = 0 + flattened_noisy_frame_indexes: list[torch.Tensor] = [] + for noisy_indexes_i, token_shape_i in zip(noisy_frame_indexes, token_shapes, strict=True): + spatial_numel_i = math.prod(token_shape_i[1:]) + spatial_indexes_i = torch.arange(spatial_numel_i, device=packed_tokens.device) + frame_offsets = (noisy_indexes_i * spatial_numel_i).unsqueeze(-1) + spatial_indexes_i + start_noisy_index + flattened_noisy_frame_indexes.append(frame_offsets.flatten()) + start_noisy_index += token_shape_i[0] * spatial_numel_i + flattened = torch.cat(flattened_noisy_frame_indexes, dim=0).unsqueeze(-1).expand(-1, packed_tokens.shape[1]) + return packed_tokens.scatter_add(dim=0, index=flattened, src=packed_timestep_embeds) + + def _patchify_and_pack_latents( + self, tokens_vision: list[torch.Tensor] + ) -> tuple[torch.Tensor, list[tuple[int, int, int]]]: + p = self.config.latent_patch_size + latent_channel = self.config.latent_channel + packed_latent: list[torch.Tensor] = [] + original_latent_shapes: list[tuple[int, int, int]] = [] + for latent in tokens_vision: + latent = latent.squeeze(0) # [C, T, H, W] + _, t_actual, h_actual, w_actual = latent.shape + original_latent_shapes.append((t_actual, h_actual, w_actual)) + h_padded = ((h_actual + p - 1) // p) * p + w_padded = ((w_actual + p - 1) // p) * p + if h_padded != h_actual or w_padded != w_actual: + padded = torch.zeros( + (latent_channel, t_actual, h_padded, w_padded), device=latent.device, dtype=latent.dtype + ) + padded[:, :, :h_actual, :w_actual] = latent + latent = padded + h_patches = h_padded // p + w_patches = w_padded // p + latent = latent.reshape(latent_channel, t_actual, h_patches, p, w_patches, p) + latent = torch.einsum("cthpwq->thwpqc", latent).reshape(-1, p * p * latent_channel) + packed_latent.append(latent) + return torch.cat(packed_latent, dim=0), original_latent_shapes + + def _unpatchify_and_unpack_latents( + self, + packed_mse_preds: torch.Tensor, + token_shapes_vision: list[tuple[int, int, int]], + noisy_frame_indexes_vision: list[torch.Tensor], + original_latent_shapes: list[tuple[int, int, int]], + ) -> list[torch.Tensor]: + p = self.config.latent_patch_size + latent_channel = self.config.latent_channel + unpatchified_latents: list[torch.Tensor] = [] + start_idx = 0 + for token_shape, noisy_frame_indexes, original_shape in zip( + token_shapes_vision, noisy_frame_indexes_vision, original_latent_shapes, strict=True + ): + t_c = token_shape[0] + _, h_orig, w_orig = original_shape + h_padded = ((h_orig + p - 1) // p) * p + w_padded = ((w_orig + p - 1) // p) * p + h_patches = h_padded // p + w_patches = w_padded // p + t_n = len(noisy_frame_indexes) + output_tensor = torch.zeros( + (latent_channel, t_c, h_orig, w_orig), device=packed_mse_preds.device, dtype=packed_mse_preds.dtype + ) + num_patches = t_n * h_patches * w_patches + if num_patches > 0: + end_idx = start_idx + num_patches + latent_patches = packed_mse_preds[start_idx:end_idx] + latent_patches = latent_patches.reshape(t_n, h_patches, w_patches, p, p, latent_channel) + latent = torch.einsum("thwpqc->cthpwq", latent_patches) + latent = latent.reshape(latent_channel, t_n, h_patches * p, w_patches * p) + latent = latent[:, :, :h_orig, :w_orig] + output_tensor[:, noisy_frame_indexes] = latent + start_idx = end_idx + unpatchified_latents.append(output_tensor.unsqueeze(0)) + return unpatchified_latents + + def _pack_sound_latents( + self, tokens_sound: list[torch.Tensor], token_shapes_sound: list[tuple[int, int, int]] + ) -> torch.Tensor: + return torch.cat( + [sound[:, : shape[0]].permute(1, 0) for sound, shape in zip(tokens_sound, token_shapes_sound, strict=True)], + dim=0, + ) + + def _unpack_sound_latents( + self, + packed_preds: torch.Tensor, + token_shapes_sound: list[tuple[int, int, int]], + noisy_frame_indexes_sound: list[torch.Tensor], + ) -> list[torch.Tensor]: + sound_dim = self.config.sound_dim + unpacked: list[torch.Tensor] = [] + start_idx = 0 + for shape, noisy_idxs in zip(token_shapes_sound, noisy_frame_indexes_sound, strict=True): + T = shape[0] + output = torch.zeros((sound_dim, T), device=packed_preds.device, dtype=packed_preds.dtype) + t_n = len(noisy_idxs) + if t_n > 0: + output[:, noisy_idxs] = packed_preds[start_idx : start_idx + t_n].T + start_idx += t_n + unpacked.append(output) + return unpacked + + def _embed_action( + self, + action_latents: torch.Tensor, + action_domain_id: torch.Tensor, + action_timesteps: torch.Tensor, + action_token_shapes: list[tuple[int, int, int]], + action_noisy_frame_indexes: list[torch.Tensor], + target_dtype: torch.dtype, + ) -> torch.Tensor: + """Project action tokens ([1, T, D]) into the hidden space: domain-aware + in-projection + the action modality embedding, then scatter-add the + timestep embedding to the noisy (predicted) action tokens only. Returns + [T, hidden].""" + packed = self.action_proj_in(action_latents, action_domain_id)[0] # [T, hidden] + packed = packed + self.action_modality_embed.to(packed.dtype) + ts = action_timesteps * self.config.timestep_scale + ts_embeds = self.time_embedder(self.time_proj(ts)).to(target_dtype) + return self._apply_timestep_embeds_to_noisy_tokens( + packed_tokens=packed, + packed_timestep_embeds=ts_embeds, + noisy_frame_indexes=action_noisy_frame_indexes, + token_shapes=action_token_shapes, + ) + + def _decode_action( + self, + gen_hidden: torch.Tensor, + action_domain_id: torch.Tensor, + action_token_shapes: list[tuple[int, int, int]], + action_noisy_frame_indexes: list[torch.Tensor], + ) -> torch.Tensor: + """Domain-aware out-projection of the noisy action hidden states back to + action space, scattered into a full [1, T, D] tensor (clean tokens left + zero, matching the velocity mask the scheduler applies).""" + preds = self.action_proj_out(gen_hidden.unsqueeze(0), action_domain_id)[0] # [n_noisy, D] + t_a = action_token_shapes[0][0] + out = preds.new_zeros((t_a, self.action_dim)) + noisy = action_noisy_frame_indexes[0] + if noisy.numel() > 0: + out[noisy] = preds + return out.unsqueeze(0) # [1, T, D] + + # ------------------------------------------------------------------ + # forward: full per-step pass — encode text/vision, run layers, decode velocity. + # ------------------------------------------------------------------ + + def forward( + self, + input_ids: torch.Tensor, + text_indexes: torch.Tensor, + position_ids: torch.Tensor, + und_len: int, + sequence_length: int, + vision_tokens: list[torch.Tensor], + vision_token_shapes: list[tuple[int, int, int]], + vision_sequence_indexes: torch.Tensor, + vision_mse_loss_indexes: torch.Tensor, + vision_timesteps: torch.Tensor, + vision_noisy_frame_indexes: list[torch.Tensor], + sound_tokens: list[torch.Tensor] | None = None, + sound_token_shapes: list[tuple[int, int, int]] | None = None, + sound_sequence_indexes: torch.Tensor | None = None, + sound_mse_loss_indexes: torch.Tensor | None = None, + sound_timesteps: torch.Tensor | None = None, + sound_noisy_frame_indexes: list[torch.Tensor] | None = None, + action_tokens: torch.Tensor | None = None, + action_token_shapes: list[tuple[int, int, int]] | None = None, + action_sequence_indexes: torch.Tensor | None = None, + action_mse_loss_indexes: torch.Tensor | None = None, + action_timesteps: torch.Tensor | None = None, + action_noisy_frame_indexes: list[torch.Tensor] | None = None, + action_domain_id: torch.Tensor | None = None, + ) -> tuple: + # Returns ``(vision, sound)`` for video/sound generation (diffusers- + # compatible) or ``(vision, action, sound)`` when action tokens are given. + has_sound = sound_tokens is not None and sound_sequence_indexes is not None + has_action = action_tokens is not None and action_sequence_indexes is not None + + # Embed text into the joint hidden_states buffer at its sequence positions. + packed_text_embedding = self.embed_tokens(input_ids) + target_dtype = packed_text_embedding.dtype + hidden_states = packed_text_embedding.new_zeros(size=(sequence_length, self.config.hidden_size)) + hidden_states[text_indexes] = packed_text_embedding + + # Patchify + project vision latents, then scatter-add timestep embeds to noisy frames. + packed_tokens_vision, original_latent_shapes = self._patchify_and_pack_latents(vision_tokens) + packed_tokens_vision = self.proj_in(packed_tokens_vision) + timesteps_vision = vision_timesteps * self.config.timestep_scale + packed_timestep_embeds_vision = self.time_embedder(self.time_proj(timesteps_vision)).to(target_dtype) + packed_tokens_vision = self._apply_timestep_embeds_to_noisy_tokens( + packed_tokens=packed_tokens_vision, + packed_timestep_embeds=packed_timestep_embeds_vision, + noisy_frame_indexes=vision_noisy_frame_indexes, + token_shapes=vision_token_shapes, + ) + hidden_states[vision_sequence_indexes] = packed_tokens_vision + + # Pack + project sound latents (all sound frames noisy). + if has_sound: + packed_tokens_sound = self._pack_sound_latents(sound_tokens, sound_token_shapes).to(target_dtype) + packed_tokens_sound = self.audio_proj_in(packed_tokens_sound) + self.audio_modality_embed + timesteps_sound = sound_timesteps * self.config.timestep_scale + packed_timestep_embeds_sound = self.time_embedder(self.time_proj(timesteps_sound)).to(target_dtype) + packed_tokens_sound = self._apply_timestep_embeds_to_noisy_tokens( + packed_tokens=packed_tokens_sound, + packed_timestep_embeds=packed_timestep_embeds_sound, + noisy_frame_indexes=sound_noisy_frame_indexes, + token_shapes=sound_token_shapes, + ) + hidden_states[sound_sequence_indexes] = packed_tokens_sound + + # Project + place action tokens (after the vision block in the gen + # sequence): domain-aware in-projection + modality embed, timestep embed + # added only to noisy (predicted) action tokens. + if has_action: + packed_tokens_action = self._embed_action( + action_tokens, action_domain_id, action_timesteps, + action_token_shapes, action_noisy_frame_indexes, target_dtype, + ) + hidden_states[action_sequence_indexes] = packed_tokens_action + + # mRoPE once for the joint sequence, then slice into und/gen halves. + cos, sin = self.rotary_emb( + position_ids=position_ids.unsqueeze(0) if position_ids.ndim == 1 else position_ids.unsqueeze(1), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + cos = cos.squeeze(0) + sin = sin.squeeze(0) + + und_seq = hidden_states[:und_len] + gen_seq = hidden_states[und_len:] + rotary_emb = (cos[:und_len], sin[:und_len], cos[und_len:], sin[und_len:]) + for decoder_layer in self.layers: + und_seq, gen_seq = decoder_layer(und_seq, gen_seq, rotary_emb) + und_out = self.norm(und_seq) + gen_out = self.norm_moe_gen(gen_seq) + last_hidden_state = torch.cat([und_out, gen_out], dim=0) + + # Decode vision velocity from the joint hidden state. + preds_vision_packed = self.proj_out(last_hidden_state[vision_mse_loss_indexes]) + preds_vision = self._unpatchify_and_unpack_latents( + preds_vision_packed, + token_shapes_vision=vision_token_shapes, + noisy_frame_indexes_vision=vision_noisy_frame_indexes, + original_latent_shapes=original_latent_shapes, + ) + + preds_action: torch.Tensor | None = None + if has_action: + preds_action = self._decode_action( + last_hidden_state[action_mse_loss_indexes], + action_domain_id, action_token_shapes, action_noisy_frame_indexes, + ) + + preds_sound: list[torch.Tensor] | None = None + if has_sound: + preds_sound_packed = self.audio_proj_out(last_hidden_state[sound_mse_loss_indexes]) + preds_sound = self._unpack_sound_latents(preds_sound_packed, sound_token_shapes, sound_noisy_frame_indexes) + + # Video/sound generation keeps the diffusers ``(vision, sound)`` return so + # this module is a drop-in for the diffusers transformer; action + # generation additionally returns the predicted action band. + if has_action: + return preds_vision, preds_action, preds_sound + return preds_vision, preds_sound + + # ------------------------------------------------------------------ + # Cache-once engine path: the understanding tower runs once and writes its + # K/V; the generation tower then runs per denoising step, re-reading that + # frozen K/V. Because the text tokens never receive a timestep embedding, + # their K/V is step-independent, so caching it once is exact. ``cache_handle`` + # is a paged attention handle (set_layer_idx / run_attention / advance_seq_lens); + # the attention plan (causal vs not, which label) is configured by the caller. + # ------------------------------------------------------------------ + + def _rotary(self, position_ids: torch.Tensor, device, dtype): + """cos/sin of shape [N, head_dim] for a [3, N] block of 3D mRoPE ids.""" + cos, sin = self.rotary_emb(position_ids.unsqueeze(1), device=device, dtype=dtype) + return cos.squeeze(0), sin.squeeze(0) + + def prefill_und( + self, input_ids: torch.Tensor, position_ids: torch.Tensor, cache_handle + ) -> None: + """Run the understanding tower over the text prefix, writing per-layer K/V + to the cache under the active label and committing the prefix length. + ``position_ids`` are the text segment's 3D mRoPE ids ([3, und_len]).""" + und_seq = self.embed_tokens(input_ids) + cos, sin = self._rotary(position_ids, und_seq.device, und_seq.dtype) + for i, layer in enumerate(self.layers): + cache_handle.set_layer_idx(i) + und_seq = layer.forward_und(und_seq, cos, sin, cache_handle) + cache_handle.advance_seq_lens() + + def denoise_step( + self, + latents: torch.Tensor, + vision_timesteps: torch.Tensor, + position_ids: torch.Tensor, + vision_token_shapes: list[tuple[int, int, int]], + vision_noisy_frame_indexes: list[torch.Tensor], + vision_mse_loss_indexes: torch.Tensor, + cache_handle, + action_latents: torch.Tensor | None = None, + action_token_shapes: list[tuple[int, int, int]] | None = None, + action_noisy_frame_indexes: list[torch.Tensor] | None = None, + action_mse_gen_indexes: torch.Tensor | None = None, + action_timesteps: torch.Tensor | None = None, + action_domain_id: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """One generation-tower evaluation against the frozen understanding K/V. + + Patchifies ``latents`` ([1, C, T, H, W]), scatter-adds the timestep + embedding to the noisy tokens, runs the generation layers (each reading + the active label's cached understanding K/V plus its own freshly written + K/V), and decodes the flow velocity. ``position_ids`` are the generation + segment's 3D mRoPE ids ([3, num_gen]) — the vision band, then the action + band when present. ``vision_mse_loss_indexes`` / ``action_mse_gen_indexes`` + index into the generation token block. With action, the generation + sequence is ``[vision tokens | action tokens]`` and the call returns + ``(video_velocity, action_velocity)``.""" + has_action = action_latents is not None + packed, original_latent_shapes = self._patchify_and_pack_latents([latents]) + packed = self.proj_in(packed) + target_dtype = packed.dtype + timesteps = vision_timesteps * self.config.timestep_scale + ts_embeds = self.time_embedder(self.time_proj(timesteps)).to(target_dtype) + gen_seq = self._apply_timestep_embeds_to_noisy_tokens( + packed_tokens=packed, + packed_timestep_embeds=ts_embeds, + noisy_frame_indexes=vision_noisy_frame_indexes, + token_shapes=vision_token_shapes, + ) + if has_action: + action_seq = self._embed_action( + action_latents, action_domain_id, action_timesteps, + action_token_shapes, action_noisy_frame_indexes, target_dtype, + ) + gen_seq = torch.cat([gen_seq, action_seq], dim=0) + + cos, sin = self._rotary(position_ids, gen_seq.device, gen_seq.dtype) + for i, layer in enumerate(self.layers): + cache_handle.set_layer_idx(i) + gen_seq = layer.forward_gen(gen_seq, cos, sin, cache_handle) + gen_out = self.norm_moe_gen(gen_seq) + preds_packed = self.proj_out(gen_out[vision_mse_loss_indexes]) + preds = self._unpatchify_and_unpack_latents( + preds_packed, + token_shapes_vision=vision_token_shapes, + noisy_frame_indexes_vision=vision_noisy_frame_indexes, + original_latent_shapes=original_latent_shapes, + ) + if not has_action: + return preds[0] + action_pred = self._decode_action( + gen_out[action_mse_gen_indexes], action_domain_id, + action_token_shapes, action_noisy_frame_indexes, + ) + return preds[0], action_pred + + def denoise_step_batched_cfg( + self, + latents: torch.Tensor, + vision_timesteps: torch.Tensor, + position_ids_cond: torch.Tensor, + position_ids_uncond: torch.Tensor, + vision_token_shapes: list[tuple[int, int, int]], + vision_noisy_frame_indexes: list[torch.Tensor], + vision_mse_loss_indexes: torch.Tensor, + cache_handle, + action_latents: torch.Tensor | None = None, + action_token_shapes: list[tuple[int, int, int]] | None = None, + action_noisy_frame_indexes: list[torch.Tensor] | None = None, + action_mse_gen_indexes: torch.Tensor | None = None, + action_timesteps: torch.Tensor | None = None, + action_domain_id: torch.Tensor | None = None, + ): + """Conditional and unconditional generation in one batched pass. + + The two classifier-free-guidance branches share identical generation + tokens — same latents, same timestep, so the patchified input and its + timestep embedding are built once and repeated. They differ only in (a) + the text-conditioning K/V they attend to (held under two cache labels) + and (b) their rotary positions: the media band starts just after each + branch's text, and the two prompts have different lengths. So pack + ``[cond tokens | uncond tokens]`` into one sequence carrying per-branch + positions, and let the handle's batched plan route each branch to its + own label's pages. Returns the conditional and unconditional results in + the same form as ``denoise_step`` (a velocity, or a (video, action) + pair when action tokens are present).""" + has_action = action_latents is not None + packed, original_latent_shapes = self._patchify_and_pack_latents([latents]) + packed = self.proj_in(packed) + target_dtype = packed.dtype + timesteps = vision_timesteps * self.config.timestep_scale + ts_embeds = self.time_embedder(self.time_proj(timesteps)).to(target_dtype) + gen_seq = self._apply_timestep_embeds_to_noisy_tokens( + packed_tokens=packed, + packed_timestep_embeds=ts_embeds, + noisy_frame_indexes=vision_noisy_frame_indexes, + token_shapes=vision_token_shapes, + ) + if has_action: + action_seq = self._embed_action( + action_latents, action_domain_id, action_timesteps, + action_token_shapes, action_noisy_frame_indexes, target_dtype, + ) + gen_seq = torch.cat([gen_seq, action_seq], dim=0) + + n = gen_seq.shape[0] + gen_seq = torch.cat([gen_seq, gen_seq], dim=0) + cos_c, sin_c = self._rotary(position_ids_cond, gen_seq.device, gen_seq.dtype) + cos_u, sin_u = self._rotary(position_ids_uncond, gen_seq.device, gen_seq.dtype) + cos = torch.cat([cos_c, cos_u], dim=0) + sin = torch.cat([sin_c, sin_u], dim=0) + + for i, layer in enumerate(self.layers): + cache_handle.set_layer_idx(i) + gen_seq = layer.forward_gen(gen_seq, cos, sin, cache_handle) + gen_out = self.norm_moe_gen(gen_seq) + + def _decode(out): + preds_packed = self.proj_out(out[vision_mse_loss_indexes]) + preds = self._unpatchify_and_unpack_latents( + preds_packed, + token_shapes_vision=vision_token_shapes, + noisy_frame_indexes_vision=vision_noisy_frame_indexes, + original_latent_shapes=original_latent_shapes, + ) + if not has_action: + return preds[0] + action_pred = self._decode_action( + out[action_mse_gen_indexes], action_domain_id, + action_token_shapes, action_noisy_frame_indexes, + ) + return preds[0], action_pred + + return _decode(gen_out[:n]), _decode(gen_out[n:]) + + def denoise_step_batched(self, requests: list[dict], cache_handle): + """Denoise one step for several requests at once (image / video). + + Each request carries its own latents, timestep, rotary positions (which + differ per request, and per guidance branch) and token layout. Every + request contributes a conditional and an unconditional sequence, packed + as ``[cond r0 | cond r1 | ... | uncond r0 | uncond r1 | ...]`` to match + the order the handle's batched plan lays out its entries. The layers run + once over the whole pack; the cache routes each piece to its own request + and guidance label. Returns one ``(cond_velocity, uncond_velocity)`` pair + per request, in request order. + + Each ``requests`` entry is a dict with: ``latents``, ``vision_timesteps``, + ``position_ids_cond``, ``position_ids_uncond``, ``vision_token_shapes``, + ``vision_noisy_frame_indexes``, ``vision_mse_loss_indexes``.""" + gen_seqs, shapes, cos_cond, sin_cond, cos_uncond, sin_uncond = [], [], [], [], [], [] + for req in requests: + packed, original_latent_shapes = self._patchify_and_pack_latents([req["latents"]]) + packed = self.proj_in(packed) + ts_embeds = self.time_embedder( + self.time_proj(req["vision_timesteps"] * self.config.timestep_scale) + ).to(packed.dtype) + gen_seq = self._apply_timestep_embeds_to_noisy_tokens( + packed_tokens=packed, + packed_timestep_embeds=ts_embeds, + noisy_frame_indexes=req["vision_noisy_frame_indexes"], + token_shapes=req["vision_token_shapes"], + ) + gen_seqs.append(gen_seq) + shapes.append(original_latent_shapes) + cc, sc = self._rotary(req["position_ids_cond"], gen_seq.device, gen_seq.dtype) + cu, su = self._rotary(req["position_ids_uncond"], gen_seq.device, gen_seq.dtype) + cos_cond.append(cc) + sin_cond.append(sc) + cos_uncond.append(cu) + sin_uncond.append(su) + + # Conditional block first (all requests), then unconditional block. + all_gen = torch.cat(gen_seqs + gen_seqs, dim=0) + cos = torch.cat(cos_cond + cos_uncond, dim=0) + sin = torch.cat(sin_cond + sin_uncond, dim=0) + for i, layer in enumerate(self.layers): + cache_handle.set_layer_idx(i) + all_gen = layer.forward_gen(all_gen, cos, sin, cache_handle) + gen_out = self.norm_moe_gen(all_gen) + + sizes = [g.shape[0] for g in gen_seqs] + total = sum(sizes) + cond_out, uncond_out = gen_out[:total], gen_out[total:] + + def _decode(out, req, original_latent_shapes): + preds_packed = self.proj_out(out[req["vision_mse_loss_indexes"]]) + preds = self._unpatchify_and_unpack_latents( + preds_packed, + token_shapes_vision=req["vision_token_shapes"], + noisy_frame_indexes_vision=req["vision_noisy_frame_indexes"], + original_latent_shapes=original_latent_shapes, + ) + return preds[0] + + results, off = [], 0 + for i, req in enumerate(requests): + n = sizes[i] + cond_v = _decode(cond_out[off:off + n], req, shapes[i]) + uncond_v = _decode(uncond_out[off:off + n], req, shapes[i]) + off += n + results.append((cond_v, uncond_v)) + return results + + def denoise_step_action_batched(self, requests: list[dict], cache_handle, with_cfg: bool): + """Joint ``[video | action]`` denoise for several action requests at once. + + The action analogue of ``denoise_step_batched``. Each request carries its + own video latents, action latents, per-band timesteps, rotary positions + (per guidance branch), token layout and embodiment domain id; its + generation block is ``[vision tokens | action tokens]``. With classifier- + free guidance every request contributes a conditional and an + unconditional copy, packed ``[cond r0 | ... | cond rN | uncond r0 | ... | + uncond rN]`` to match the handle's batched plan; without guidance (the + guidance-scale-1 forward/inverse-dynamics and base policy case) each + request contributes a single sequence ``[r0 | r1 | ... | rN]``. The layers + run once over the whole pack; the cache routes each piece to its own + request and guidance label. The per-request action projection is + domain-aware, so requests from different embodiments can share the batch. + + Returns one entry per request, in request order: a tuple of branch + results, each a ``(video_velocity, action_velocity)`` pair — one branch + without guidance, ``(conditional, unconditional)`` with. + + Each ``requests`` entry is a dict with: ``latents``, ``action_latents``, + ``vision_timesteps``, ``action_timesteps``, ``position_ids_cond`` + (plus ``position_ids_uncond`` when ``with_cfg``), ``vision_token_shapes``, + ``vision_noisy_frame_indexes``, ``vision_mse_loss_indexes``, + ``action_token_shapes``, ``action_noisy_frame_indexes``, + ``action_mse_gen_indexes``, ``action_domain_id``.""" + gen_seqs, shapes, cos_cond, sin_cond, cos_uncond, sin_uncond = [], [], [], [], [], [] + for req in requests: + packed, original_latent_shapes = self._patchify_and_pack_latents([req["latents"]]) + packed = self.proj_in(packed) + target_dtype = packed.dtype + ts_embeds = self.time_embedder( + self.time_proj(req["vision_timesteps"] * self.config.timestep_scale) + ).to(target_dtype) + gen_seq = self._apply_timestep_embeds_to_noisy_tokens( + packed_tokens=packed, + packed_timestep_embeds=ts_embeds, + noisy_frame_indexes=req["vision_noisy_frame_indexes"], + token_shapes=req["vision_token_shapes"], + ) + action_seq = self._embed_action( + req["action_latents"], req["action_domain_id"], req["action_timesteps"], + req["action_token_shapes"], req["action_noisy_frame_indexes"], target_dtype, + ) + gen_seq = torch.cat([gen_seq, action_seq], dim=0) + gen_seqs.append(gen_seq) + shapes.append(original_latent_shapes) + cc, sc = self._rotary(req["position_ids_cond"], gen_seq.device, gen_seq.dtype) + cos_cond.append(cc) + sin_cond.append(sc) + if with_cfg: + cu, su = self._rotary(req["position_ids_uncond"], gen_seq.device, gen_seq.dtype) + cos_uncond.append(cu) + sin_uncond.append(su) + + if with_cfg: + all_gen = torch.cat(gen_seqs + gen_seqs, dim=0) + cos = torch.cat(cos_cond + cos_uncond, dim=0) + sin = torch.cat(sin_cond + sin_uncond, dim=0) + else: + all_gen = torch.cat(gen_seqs, dim=0) + cos = torch.cat(cos_cond, dim=0) + sin = torch.cat(sin_cond, dim=0) + + for i, layer in enumerate(self.layers): + cache_handle.set_layer_idx(i) + all_gen = layer.forward_gen(all_gen, cos, sin, cache_handle) + gen_out = self.norm_moe_gen(all_gen) + + sizes = [g.shape[0] for g in gen_seqs] + total = sum(sizes) + offsets, acc = [], 0 + for n in sizes: + offsets.append(acc) + acc += n + + def _decode(out, req, original_latent_shapes): + preds_packed = self.proj_out(out[req["vision_mse_loss_indexes"]]) + preds = self._unpatchify_and_unpack_latents( + preds_packed, + token_shapes_vision=req["vision_token_shapes"], + noisy_frame_indexes_vision=req["vision_noisy_frame_indexes"], + original_latent_shapes=original_latent_shapes, + ) + action_pred = self._decode_action( + out[req["action_mse_gen_indexes"]], req["action_domain_id"], + req["action_token_shapes"], req["action_noisy_frame_indexes"], + ) + return preds[0], action_pred + + cond_block = gen_out[:total] + uncond_block = gen_out[total:] if with_cfg else None + results = [] + for i, req in enumerate(requests): + o, n = offsets[i], sizes[i] + cond_res = _decode(cond_block[o:o + n], req, shapes[i]) + if with_cfg: + uncond_res = _decode(uncond_block[o:o + n], req, shapes[i]) + results.append((cond_res, uncond_res)) + else: + results.append((cond_res,)) + return results diff --git a/mstar/model/cosmos3/config.py b/mstar/model/cosmos3/config.py new file mode 100644 index 00000000..5a18250f --- /dev/null +++ b/mstar/model/cosmos3/config.py @@ -0,0 +1,192 @@ +"""Configuration for the Cosmos3 omni generator. + +A single ``Cosmos3Config`` describes every Cosmos3 checkpoint (Nano, Super, +Policy-DROID, and the Super task variants). The checkpoints share one +architecture; they differ only in the transformer dimensions +(``num_hidden_layers`` / ``hidden_size`` / ``num_attention_heads`` / +``intermediate_size``) and two capability flags (``sound_gen``, +``action_gen``). + +Values load from a local HF checkpoint directory laid out the diffusers way:: + + /transformer/config.json -> the DiT (dual-pathway MoT) dimensions + /vae/config.json -> AutoencoderKLWan factors + latent stats + /scheduler/scheduler_config.json -> UniPC flow scheduler settings + +Dataclass defaults mirror Cosmos3-Nano so a bare ``Cosmos3Config()`` is a +valid Nano config without any file present. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + + +def _filtered(cls: type, d: dict[str, Any]) -> dict[str, Any]: + """Keep only the dict entries that name a field on the dataclass ``cls``.""" + names = {f.name for f in cls.__dataclass_fields__.values()} + return {k: v for k, v in d.items() if k in names} + + +@dataclass +class Cosmos3VAEConfig: + """The Wan2.2-TI2V-5B VAE (``AutoencoderKLWan``) parameters we need at the + serving layer. The full VAE module loads from the ``vae/`` subfolder via + diffusers; here we only track the latent geometry and the per-channel + normalization statistics the pipeline applies to/from latent space. + """ + + z_dim: int = 48 + scale_factor_spatial: int = 16 + scale_factor_temporal: int = 4 + # Per-channel latent normalization (length == z_dim). The pipeline maps + # raw VAE latents x -> (x - mean) / std before denoising and inverts it + # before decode. + latents_mean: list[float] = field(default_factory=list) + latents_std: list[float] = field(default_factory=list) + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> "Cosmos3VAEConfig": + return cls(**_filtered(cls, d)) + + +@dataclass +class Cosmos3SchedulerConfig: + """UniPC multistep flow scheduler settings (``scheduler/scheduler_config``). + + The denoise loop drives a diffusers ``UniPCMultistepScheduler`` configured + from these fields; we do not re-implement the bh2 corrector. + """ + + scheduler_type: str = "unipc" + prediction_type: str = "flow_prediction" + predict_x0: bool = True + solver_order: int = 2 + solver_type: str = "bh2" + use_flow_sigmas: bool = True + use_karras_sigmas: bool = True + final_sigmas_type: str = "zero" + num_train_timesteps: int = 1000 + flow_shift: float = 1.0 + sigma_min: float = 0.147 + sigma_max: float = 200.0 + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> "Cosmos3SchedulerConfig": + # diffusers stores the flow shift under "flow_shift"; keep the rest by name. + return cls(**_filtered(cls, d)) + + +@dataclass +class Cosmos3Config: + """Cosmos3 generator configuration (one architecture, swappable weights).""" + + # ----- dual-pathway MoT transformer (the DiT) ----- + hidden_size: int = 4096 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + head_dim: int = 128 + intermediate_size: int = 12288 + vocab_size: int = 151936 + rms_norm_eps: float = 1e-6 + attention_bias: bool = False + max_position_embeddings: int = 262144 + + # ----- 3D interleaved mRoPE ----- + rope_theta: float = 5_000_000.0 + rope_axes_dim: tuple[int, int, int] = (24, 20, 20) # rope_scaling.mrope_section + mrope_interleaved: bool = True + unified_3d_mrope_temporal_modality_margin: int = 15000 + unified_3d_mrope_reset_spatial_ids: bool = True + base_fps: int = 24 + enable_fps_modulation: bool = True + + # ----- latent geometry / patchify ----- + latent_channel: int = 48 + latent_patch_size: int = 2 + patch_latent_dim: int = 192 # latent_patch_size**2 * latent_channel + timestep_scale: float = 0.001 + + # ----- attention / norm style ----- + joint_attn_implementation: str = "two_way" # GEN attends [UND|GEN]; UND causal, UND-only + qk_norm_for_diffusion: bool = True + qk_norm_for_text: bool = True + use_moe: bool = True # MoT two-FFN split (mlp / mlp_moe_gen), NOT sparse experts + + # ----- capability flags + modality heads ----- + action_gen: bool = True + max_action_dim: int = 64 + num_embodiment_domains: int = 32 + sound_gen: bool = True + sound_dim: int | None = 64 + sound_latent_fps: float = 25.0 + temporal_compression_factor_sound: int = 1 + video_temporal_causal: bool = False + freeze_und: bool = False + + # ----- default sampling (overridable per request / yaml) ----- + # Number of denoise model evaluations. The per-mode cookbook defaults are + # t2i 50, t2v/i2v 35, action fd/id 30, DROID policy ~4. ``num_inference_steps`` + # is the image default; ``num_inference_steps_video`` is the video default. + # A request may override either; the value is clamped to ``max_inference_steps``. + num_inference_steps: int = 50 + num_inference_steps_video: int = 35 + # Upper bound on the denoise loop's iteration count. The loop is built with + # this many iterations and each request stops early at its own step count, so + # one graph serves any per-request step count up to this cap. + max_inference_steps: int = 100 + # Default frames-per-second for video generation + mp4 playback (overridable + # per request via ``fps``). + fps: float = 24.0 + # Default frame count for a video request that doesn't specify ``num_frames`` + # (the Wan VAE downsamples time by 4, so latent frames = 1 + (n - 1) // 4). + num_frames_video: int = 17 + + # ----- sub-configs ----- + vae: Cosmos3VAEConfig = field(default_factory=Cosmos3VAEConfig) + scheduler: Cosmos3SchedulerConfig = field(default_factory=Cosmos3SchedulerConfig) + + # ----- provenance ----- + local_dir: str = "" + + @classmethod + def from_transformer_dict(cls, d: dict[str, Any]) -> "Cosmos3Config": + """Build from a diffusers ``transformer/config.json`` dict alone. + + Sub-configs are left at their defaults; use ``from_pretrained`` to also + populate VAE/scheduler from their sibling folders. + """ + kwargs = _filtered(cls, d) + rope = d.get("rope_scaling") or {} + if "mrope_section" in rope: + kwargs["rope_axes_dim"] = tuple(rope["mrope_section"]) + if "mrope_interleaved" in rope: + kwargs["mrope_interleaved"] = bool(rope["mrope_interleaved"]) + return cls(**kwargs) + + @classmethod + def from_pretrained(cls, local_dir: str | Path) -> "Cosmos3Config": + """Load from a diffusers-layout checkpoint directory.""" + root = Path(local_dir) + tcfg_path = root / "transformer" / "config.json" + if not tcfg_path.exists(): + raise FileNotFoundError(f"transformer/config.json not found under {root}") + with open(tcfg_path) as f: + cfg = cls.from_transformer_dict(json.load(f)) + cfg.local_dir = str(root) + + vae_path = root / "vae" / "config.json" + if vae_path.exists(): + with open(vae_path) as f: + cfg.vae = Cosmos3VAEConfig.from_dict(json.load(f)) + + sched_path = root / "scheduler" / "scheduler_config.json" + if sched_path.exists(): + with open(sched_path) as f: + cfg.scheduler = Cosmos3SchedulerConfig.from_dict(json.load(f)) + + return cfg diff --git a/mstar/model/cosmos3/cosmos3_model.py b/mstar/model/cosmos3/cosmos3_model.py new file mode 100644 index 00000000..993816f7 --- /dev/null +++ b/mstar/model/cosmos3/cosmos3_model.py @@ -0,0 +1,761 @@ +"""Cosmos3Model: NVIDIA Cosmos3 omni generator on the mstar engine. + +Cosmos3 is a text-conditioned diffusion model: a dual-pathway Mixture-of- +Transformers DiT denoises image/video (and optionally sound) latents, which a +Wan VAE decodes to pixels. An optional action head extends the same backbone to +robot-action generation. + +Nodes (2 for image generation): + dit (kv_cache) - dual-pathway DiT. The understanding (text) + tower prefills the conditioning K/V; the + generation tower runs the denoise loop, reading + that frozen K/V each step (it is timestep- + independent, so caching it once is exact). + vae_decoder (stateless) - Wan VAE: final latents -> pixels. + +Graph walks (image generation): + prefill - the understanding tower runs over the text prompt and writes + its per-layer K/V (causal self-attention over text). + image_gen - an N-step denoising loop. Each iteration the generation tower + attends to [frozen text K/V | current generation tokens], + predicts flow velocity, and applies one scheduler step; the + final latents go to the VAE decoder, which emits the image. +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import torch + +from mstar.communication.tensors import NameToTensorList +from mstar.conductor.request_info import ( + CurrentForwardConductorMetadata, + StreamingConnectionState, +) +from mstar.distributed.base import ShardingConfig +from mstar.engine.base import EngineType +from mstar.engine.kv_store import KVCacheConfig +from mstar.graph.base import ( + GraphEdge, + GraphNode, + GraphSection, + Loop, + Sequential, + TensorPointerInfo, +) +from mstar.graph.special_destinations import EMIT_TO_CLIENT +from mstar.model.base import ForwardPassArgs, Model +from mstar.model.cosmos3.config import Cosmos3Config +from mstar.model.cosmos3.submodules import ( + ACTION_GEN_LOOP, + ACTION_VIDEO_GEN_LOOP, + IMAGE_GEN_LOOP, + VIDEO_GEN_LOOP, + Cosmos3DiTSubmodule, + Cosmos3VAEDecoderSubmodule, +) + +logger = logging.getLogger(__name__) + +DIT_NODE = "dit" +VAE_DECODER_NODE = "vae_decoder" + + +class Cosmos3Model(Model): + """NVIDIA Cosmos3 generator implementation.""" + + PREFILL_WALK = "prefill" + PREFILL_COND_WALK = "prefill_cond" + PREFILL_COND_VIDEO_WALK = "prefill_cond_video" + IMAGE_GEN_WALK = "image_gen" + VIDEO_GEN_WALK = "video_gen" + ACTION_GEN_WALK = "action_gen" + ACTION_VIDEO_GEN_WALK = "action_video_gen" + + def __init__( + self, + model_path_hf: str, + cache_dir: str | None = None, + skip_weight_loading: bool = False, + **kwargs, + ): + self.model_path_hf = model_path_hf + self.cache_dir = cache_dir + self.skip_weight_loading = skip_weight_loading + self._yaml_config_overrides: dict = dict(kwargs) + + self._repo_dir: Path | None = None + self.config: Cosmos3Config = self._load_config() + self.tokenizer = self._load_tokenizer() + + self._submodule_cache: dict[str, torch.nn.Module | None] = {} + # The Wan VAE is shared between the DiT submodule (conditioning encode) + # and the decoder submodule, so build it once. + self._vae = None + + # ------------------------------------------------------------------ + # Config + tokenizer + # ------------------------------------------------------------------ + + def _ensure_repo(self) -> Path: + if self._repo_dir is not None: + return self._repo_dir + candidate = Path(self.model_path_hf) + if candidate.exists(): + self._repo_dir = candidate + else: + from huggingface_hub import snapshot_download + + self._repo_dir = Path( + snapshot_download(repo_id=self.model_path_hf, cache_dir=self.cache_dir) + ) + return self._repo_dir + + def _load_config(self) -> Cosmos3Config: + if self.skip_weight_loading: + cfg = Cosmos3Config() + else: + try: + cfg = Cosmos3Config.from_pretrained(self._ensure_repo()) + except Exception as exc: # noqa: BLE001 + logger.warning( + "Could not load Cosmos3 config from %s (%s); using Nano defaults.", + self.model_path_hf, exc, + ) + cfg = Cosmos3Config() + + # Overlay yaml model_kwargs last (so they win over file + defaults). + if self._yaml_config_overrides: + valid = {f.name for f in Cosmos3Config.__dataclass_fields__.values()} + for k, v in self._yaml_config_overrides.items(): + if k in valid: + setattr(cfg, k, v) + else: + logger.warning( + "Cosmos3Model: yaml model_kwargs key %r is not a Cosmos3Config " + "field; ignored.", k, + ) + return cfg + + def _load_tokenizer(self): + if self.skip_weight_loading: + return None + from transformers import AutoTokenizer + + repo = self._ensure_repo() + # The published checkpoint ships the Qwen2 text tokenizer under + # ``text_tokenizer/``; fall back to the repo root for layouts that + # keep the tokenizer files at the top level. + for sub in (repo / "text_tokenizer", repo): + try: + return AutoTokenizer.from_pretrained(str(sub), use_fast=True) + except Exception as exc: # noqa: BLE001 + logger.warning("Cosmos3 tokenizer load from %s failed (%s).", sub, exc) + logger.warning("All Cosmos3 tokenizer sources failed; proceeding without one.") + return None + + # ------------------------------------------------------------------ + # Model ABC: structure + # ------------------------------------------------------------------ + + def get_kv_cache_config(self) -> list[KVCacheConfig]: + return [ + KVCacheConfig( + num_layers=self.config.num_hidden_layers, + num_kv_heads=self.config.num_key_value_heads, + head_dim=self.config.head_dim, + max_seq_len=self.config.max_position_embeddings, + num_qo_heads=self.config.num_attention_heads, + ) + ] + + def get_node_engine_types(self) -> dict[str, EngineType]: + return { + DIT_NODE: EngineType.KV_CACHE, + VAE_DECODER_NODE: EngineType.STATELESS, + } + + def get_default_sharding_config(self) -> ShardingConfig: + # The DiT supports tensor parallelism: per layer the attention heads and + # the MLP intermediate dim shard across ranks, the residual stream stays + # full, and the row-parallel out/down projections all-reduce. Signals + # between nodes stay replicated (empty shard_dim) — the sharding is + # in-module, Megatron-style. The VAE decoder runs un-sharded on one rank. + return ShardingConfig( + groups=[], tp_enabled_nodes={DIT_NODE}, shard_dim={} + ) + + def get_graph_walk_graphs(self) -> dict[str, GraphSection]: + # prefill: the understanding tower runs over the text prompt and writes + # its conditioning K/V. No graph output — completion notifies the + # conductor, and the generation loop reads the K/V from the shared cache. + prefill = GraphNode( + name=DIT_NODE, + input_names=["text_inputs"], + outputs=[], + ) + + # prefill_cond: like prefill, but image-to-video also hands the DiT node + # the conditioning image, which it VAE-encodes into the clean anchor + # latents that seed the denoise loop (stashed on the per-request state). + prefill_cond = GraphNode( + name=DIT_NODE, + input_names=["text_inputs", "image_inputs"], + outputs=[], + ) + + # prefill_cond_video: action inverse-dynamics conditions on a whole video, + # which the DiT VAE-encodes into the clean anchor latents for the loop. + prefill_cond_video = GraphNode( + name=DIT_NODE, + input_names=["text_inputs", "video_inputs"], + outputs=[], + ) + + # image_gen: denoising loop -> VAE decode -> emit image. The loop body + # threads the latents + denoise-step index back to itself each iteration; + # on the final iteration the latents route to the decoder. max_iters is an + # upper bound — each request stops the loop at its own denoise-step count + # (Cosmos3DiTSubmodule.check_stop), so one graph serves image and video + # (and any per-request num_inference_steps) without a rebuild. + # image_gen and video_gen are the same denoise loop + VAE decode; they + # differ only in the emitted modality (one frame vs an encoded clip), so + # the request's output modality selects between them. + def _gen_walk(loop_name: str, emit_name: str, modality: str) -> Sequential: + return Sequential( + [ + Loop( + name=loop_name, + # Disable speculative (async) scheduling on the denoise + # step: with it on, the worker pre-dispatches a single + # request's next step and drains that one request's whole + # loop before others are scheduled, so concurrent requests + # never share a forward. Off, the scheduler groups all + # ready requests at this node into one batched denoise + # step (see can_batch/forward_batched). Mirrors the BAGEL + # image-gen loop nodes. + section=GraphNode( + name=DIT_NODE, + input_names=["latents", "time_index"], + outputs=[ + GraphEdge(next_node=DIT_NODE, name="latents"), + GraphEdge(next_node=DIT_NODE, name="time_index"), + ], + enable_async_scheduling=False, + ), + max_iters=self.config.max_inference_steps, + outputs=[ + GraphEdge(next_node=VAE_DECODER_NODE, name="latents"), + ], + ), + GraphNode( + name=VAE_DECODER_NODE, + input_names=["latents"], + outputs=[ + GraphEdge( + next_node=EMIT_TO_CLIENT, + name=emit_name, + output_modality=modality, + ), + ], + ), + ] + ) + + image_gen = _gen_walk(IMAGE_GEN_LOOP, "image_output", "image") + video_gen = _gen_walk(VIDEO_GEN_LOOP, "video_output", "video") + + # action_gen: like image_gen but the loop body jointly denoises the video + # and action latents (threaded as two self-edges), and the predicted + # action — not a decoded video — is what the request emits. + action_gen = Sequential( + [ + Loop( + name=ACTION_GEN_LOOP, + section=GraphNode( + name=DIT_NODE, + input_names=["latents", "action_latents", "time_index"], + outputs=[ + GraphEdge(next_node=DIT_NODE, name="latents"), + GraphEdge(next_node=DIT_NODE, name="action_latents"), + GraphEdge(next_node=DIT_NODE, name="time_index"), + ], + enable_async_scheduling=False, + ), + max_iters=self.config.max_inference_steps, + # The loop's terminal output is matched into the section by + # name (Loop.__post_init__ filters to the section's own output + # edges), so it must reuse a loop-back name: on the final + # iteration the predicted action latents go to the client + # instead of back into the loop. + outputs=[ + GraphEdge( + next_node=EMIT_TO_CLIENT, + name="action_latents", + output_modality="action", + ), + ], + ), + ] + ) + + # action_video_gen (forward dynamics): the same joint video+action denoise, + # but the action is the clean condition and the predicted video is decoded + # and emitted. The loop's terminal output reuses the "latents" loop-back + # name; on the final iteration the video latents route to the VAE decoder + # instead of back into the loop. + action_video_gen = Sequential( + [ + Loop( + name=ACTION_VIDEO_GEN_LOOP, + section=GraphNode( + name=DIT_NODE, + input_names=["latents", "action_latents", "time_index"], + outputs=[ + GraphEdge(next_node=DIT_NODE, name="latents"), + GraphEdge(next_node=DIT_NODE, name="action_latents"), + GraphEdge(next_node=DIT_NODE, name="time_index"), + ], + enable_async_scheduling=False, + ), + max_iters=self.config.max_inference_steps, + outputs=[ + GraphEdge(next_node=VAE_DECODER_NODE, name="latents"), + ], + ), + GraphNode( + name=VAE_DECODER_NODE, + input_names=["latents"], + outputs=[ + GraphEdge( + next_node=EMIT_TO_CLIENT, + name="video_output", + output_modality="video", + ), + ], + ), + ] + ) + + return { + self.PREFILL_WALK: prefill, + self.PREFILL_COND_WALK: prefill_cond, + self.PREFILL_COND_VIDEO_WALK: prefill_cond_video, + self.ACTION_VIDEO_GEN_WALK: action_video_gen, + self.IMAGE_GEN_WALK: image_gen, + self.VIDEO_GEN_WALK: video_gen, + self.ACTION_GEN_WALK: action_gen, + } + + # ------------------------------------------------------------------ + # Model ABC: I/O + # ------------------------------------------------------------------ + + def process_prompt( + self, + prompt: str | None, + input_modalities: list[str], + output_modalities: list[str], + tensors: NameToTensorList | None = None, + **kwargs, + ) -> NameToTensorList: + if prompt is None: + return {} + if self.tokenizer is None: + # Tokenizer-less fallback used by structural unit tests. + return { + "text_inputs": [ + torch.tensor(list(prompt.encode("utf-8")), dtype=torch.long) + ] + } + # Both the conditional (positive) and unconditional (negative) prompts are + # tokenized up front; the denoiser reads the second only when guidance is + # on. Image/video prompts get the chat template + resolution/duration + # sentences; action prompts are tokenized raw. + from mstar.model.cosmos3.packing import tokenize_prompt + + negative_prompt = kwargs.get("negative_prompt") + p = self._resolve_gen_params(kwargs, input_modalities, output_modalities) + # The chat system prompt and the resolution/duration metadata sentences + # are opt-in, off by default: the model sees the bare user prompt, which + # matches the reference serving pipeline (its system-prompt and + # resolution/duration templates default off too). A request may re-enable + # any of them. Action prompts never use them — they are just the + # chat-templated user text plus the end-of-text + start-of-generation + # markers (matching the NVIDIA action references). + is_action = "action" in output_modalities + allow_templates = not is_action + cond_ids, uncond_ids = tokenize_prompt( + self.tokenizer, prompt, negative_prompt, + num_frames=p["num_frames"], height=p["height"], width=p["width"], fps=p["fps"], + use_system_prompt=allow_templates and bool(kwargs.get("use_system_prompt", False)), + add_resolution_template=allow_templates and bool(kwargs.get("use_resolution_template", False)), + add_duration_template=allow_templates and bool(kwargs.get("use_duration_template", False)), + ) + return { + "text_inputs": [ + torch.tensor(cond_ids, dtype=torch.long), + torch.tensor(uncond_ids, dtype=torch.long), + ] + } + + def postprocess(self, output: torch.Tensor, modality: str) -> bytes: + if modality == "image": + import io + import os + import time + + from PIL import Image + + # The decoder emits 8-bit frames [B, C, T, H, W]; take the first one. + x = output + if x.ndim == 5: + x = x[0, :, 0] + elif x.ndim == 4: + x = x[0] + _prof = os.environ.get("COSMOS3_PROFILE") + _t0 = time.perf_counter() + arr = x.permute(1, 2, 0).cpu().numpy() # H, W, C uint8 + _t1 = time.perf_counter() + buf = io.BytesIO() + # PNG is lossless at every compression level, so the level only trades + # encode time for file size. PIL defaults to 6, which spends ~0.75 s on a + # 720p frame and dominates the serving latency. Level 0 (no deflate) is + # the fastest and matches what the OpenAI image endpoint emits at full + # quality; the decoded pixels are identical regardless. Override with + # COSMOS3_PNG_COMPRESS for A/B. + compress_level = int(os.environ.get("COSMOS3_PNG_COMPRESS", "0")) + Image.fromarray(arr).save(buf, format="PNG", compress_level=compress_level) + if _prof: + print(f"COSMOS3_PROFILE png d2h={1000 * (_t1 - _t0):.1f}ms " + f"encode={1000 * (time.perf_counter() - _t1):.1f}ms bytes={buf.tell()}", flush=True) + return buf.getvalue() + if modality == "video": + import os + import tempfile + + from torchvision.io import write_video + + # The decoder emits 8-bit frames [B, C, T, H, W]; encode all of them as + # an H.264 mp4. The frames already reflect the request fps (it modulates + # the temporal positions during generation); the container plays back + # at the model's default fps. + x = output[0] if output.ndim == 5 else output # [C, T, H, W] uint8 + _prof = os.environ.get("COSMOS3_PROFILE") + import time as _time + _vt0 = _time.perf_counter() + frames = x.permute(1, 2, 3, 0).cpu() # [T, H, W, C] uint8 + _vt1 = _time.perf_counter() + fd, path = tempfile.mkstemp(suffix=".mp4") + os.close(fd) + try: + # CRF 18 keeps the H.264 output near-visually-lossless; libx264 + # otherwise defaults to 23, which is visibly lossier. The "ultrafast" + # preset and multithreading (threads=0) target the same CRF/quality + # but encode several times faster than libx264's default "medium" + # preset, which otherwise dominates the serving latency for a + # many-frame clip. Both are overridable via COSMOS3_X264_PRESET. + write_video( + path, + frames, + fps=self.config.fps, + video_codec="libx264", + options={ + "crf": "18", + "preset": os.environ.get("COSMOS3_X264_PRESET", "ultrafast"), + "threads": "0", + }, + ) + with open(path, "rb") as f: + data = f.read() + if _prof: + print(f"COSMOS3_PROFILE mp4 d2h={1000 * (_vt1 - _vt0):.1f}ms " + f"encode={1000 * (_time.perf_counter() - _vt1):.1f}ms frames={frames.shape[0]} " + f"bytes={len(data)}", flush=True) + return data + finally: + os.remove(path) + if modality == "action": + # The predicted action latents [1, chunk, action_dim] -> [chunk, + # action_dim] float32 bytes. Columns beyond the request's + # raw_action_dim are zero padding (the client keeps the first + # raw_action_dim, the real action width for its embodiment). + x = output[0] if output.ndim == 3 else output + return x.detach().to(torch.float32).cpu().numpy().tobytes() + raise ValueError(f"Unsupported modality for Cosmos3: {modality!r}") + + def load_video(self, filepath: str, device: str): + """Decode a conditioning video to ``[T, C, H, W]`` in ``[0, 1]``. + + Overrides the base implementation, which reads ``self.device`` (this model + does not set one); the data worker passes the decode device explicitly, + exactly as ``load_image`` already receives it.""" + from dataclasses import asdict + + from torchcodec.decoders import VideoDecoder + + from mstar.model.base import TensorAndMetadata + + decoder = VideoDecoder(filepath, device=device) + video = torch.stack([frame for frame in decoder]).float() / 255.0 + return TensorAndMetadata(data=video, metadata=asdict(decoder.metadata)) + + # ------------------------------------------------------------------ + # Model ABC: forward pass orchestration + # ------------------------------------------------------------------ + + def _resolve_gen_params( + self, model_kwargs: dict | None, input_modalities: list[str], output_modalities: list[str], + ) -> dict: + """Resolve the per-request generation knobs (size, steps, guidance, …) + from request ``model_kwargs``, applying defaults. Used by both + ``process_prompt`` (for resolution-aware tokenization) and the forward- + pass metadata, so the two stay consistent.""" + mk = model_kwargs or {} + width = height = 1024 + size = mk.get("size") + if isinstance(size, str) and "x" in size.lower(): + sw, sh = size.lower().split("x", 1) + try: + width, height = int(sw), int(sh) + except ValueError: + pass + # A video request without an explicit frame count gets the video default + # (>1); image requests stay single-frame. + default_frames = ( + self.config.num_frames_video if "video" in (output_modalities or []) else 1 + ) + num_frames = int(mk.get("num_frames", default_frames)) + # The image and video cookbook step counts differ (image 50, video 35); + # default by mode and let the request override. The denoise loop runs this + # many steps and stops early (Cosmos3DiTSubmodule.check_stop), so the value + # is only bounded above by the loop's static max_iters. + default_steps = ( + self.config.num_inference_steps_video if num_frames > 1 + else self.config.num_inference_steps + ) + steps = int(mk.get("num_inference_steps", default_steps)) + steps = max(1, min(steps, self.config.max_inference_steps)) + params = { + "width": int(mk.get("width", width)), + "height": int(mk.get("height", height)), + "num_frames": num_frames, + "fps": float(mk.get("fps", self.config.fps)), + "guidance_scale": float(mk.get("guidance_scale", 6.0)), + "num_inference_steps": steps, + "has_image_condition": "image" in (input_modalities or []), + } + # Text-to-image (single frame, no visual conditioning) follows the + # reference Cosmos3 t2i recipe: classifier-free guidance only on the + # timestep interval [400, 1000] (outside it the denoise step runs the + # conditional branch alone) and flow_shift 3.0. Request kwargs override; + # video / image-conditioned paths keep their own defaults (full CFG, + # scheduler-config flow_shift). + is_t2i = num_frames == 1 and not params["has_image_condition"] + fs = mk.get("flow_shift") + if fs is None and is_t2i: + fs = 3.0 + if fs is not None: + params["flow_shift"] = float(fs) + gi = mk.get("guidance_interval") + if gi is None and is_t2i: + gi = (400.0, 1000.0) + if gi is not None: + params["guidance_interval"] = (float(gi[0]), float(gi[1])) + # Action requests carry a few extra keys straight through (``action`` is + # the clean conditioning action chunk for forward-dynamics). + for k in ("action_mode", "action_chunk_size", "raw_action_dim", "domain_id", + "action_fps", "action"): + if k in mk: + params[k] = mk[k] + return params + + def _step_metadata(self, metadata: CurrentForwardConductorMetadata) -> dict: + md = {"is_prefill": metadata.is_prefill} + md.update(metadata.kwargs) + return md + + def get_initial_forward_pass_args( + self, + partition_name: str, + input_modalities: list[str], + output_modalities: list[str], + input_signals: dict[str, list[TensorPointerInfo]], + model_kwargs: dict | None = None, + ) -> ForwardPassArgs: + params = self._resolve_gen_params(model_kwargs, input_modalities, output_modalities) + # Visual conditioning routes through a conditioned prefill that also feeds + # the DiT the input to VAE-encode: a video (action inverse-dynamics) or an + # image (image-to-video, action policy/forward-dynamics). Fall back to the + # text-only prefill if no conditioning signal actually arrived. + video_cond = "video" in input_modalities and "video_inputs" in input_signals + image_cond = params.get("has_image_condition") and "image_inputs" in input_signals + if video_cond: + walk = self.PREFILL_COND_VIDEO_WALK + elif image_cond: + walk = self.PREFILL_COND_WALK + else: + walk = self.PREFILL_WALK + full_metadata = CurrentForwardConductorMetadata( + input_modalities=input_modalities, + output_modalities=output_modalities, + graph_walk=walk, + is_prefill=True, + kwargs=params, + ) + + inputs: list[GraphEdge] = [] + if "text_inputs" in input_signals: + edge = GraphEdge(next_node=DIT_NODE, name="text_inputs") + edge.tensor_info = input_signals["text_inputs"] + inputs.append(edge) + cond_signal = "video_inputs" if video_cond else ("image_inputs" if image_cond else None) + if cond_signal: + edge = GraphEdge(next_node=DIT_NODE, name=cond_signal) + edge.tensor_info = input_signals[cond_signal] + inputs.append(edge) + + unpersist_tensors = sum([inp.tensor_info for inp in inputs], start=[]) + return ForwardPassArgs( + full_metadata=full_metadata, + inputs=inputs, + unpersist_tensors=unpersist_tensors, + step_metadata=self._step_metadata(full_metadata), + ) + + def get_partition_forward_pass_args( + self, + partition_name: str, + partition_metadata: CurrentForwardConductorMetadata, + persist_signals: dict[str, list[TensorPointerInfo]], + incoming_connections: list[StreamingConnectionState] | None = None, + ) -> ForwardPassArgs: + metadata = partition_metadata + request_done = False + inputs: list[GraphEdge] = [] + + # Forward-dynamics conditions on a clean action chunk and emits the + # predicted video; inverse-dynamics / policy emit the action. + is_fd = metadata.kwargs.get("action_mode") == "forward_dynamics" + is_action = "action" in metadata.output_modalities + is_video = "video" in metadata.output_modalities + joint_action = is_fd or is_action # walks that also thread action latents + if metadata.graph_walk in ( + self.PREFILL_WALK, self.PREFILL_COND_WALK, self.PREFILL_COND_VIDEO_WALK + ): + metadata.is_prefill = False + # Pick the denoise walk: forward-dynamics runs the joint denoise but + # decodes the predicted video; inverse-dynamics / policy emit the + # action; image and video share the loop but differ in what the VAE + # node emits. + if is_fd: + metadata.graph_walk = self.ACTION_VIDEO_GEN_WALK + elif is_action: + metadata.graph_walk = self.ACTION_GEN_WALK + elif is_video: + metadata.graph_walk = self.VIDEO_GEN_WALK + else: + metadata.graph_walk = self.IMAGE_GEN_WALK + # The first denoise iteration's initial noise + step index are + # sampled inside the DiT submodule's preprocess. Action walks also + # thread the action latents through the loop. + inputs = [ + GraphEdge(next_node=DIT_NODE, name="latents"), + GraphEdge(next_node=DIT_NODE, name="time_index"), + ] + if joint_action: + inputs.insert(1, GraphEdge(next_node=DIT_NODE, name="action_latents")) + elif metadata.graph_walk in ( + self.IMAGE_GEN_WALK, self.VIDEO_GEN_WALK, + self.ACTION_GEN_WALK, self.ACTION_VIDEO_GEN_WALK, + ): + request_done = True + + unpersist_tensors = sum([inp.tensor_info for inp in inputs], start=[]) + return ForwardPassArgs( + full_metadata=metadata, + inputs=inputs, + unpersist_tensors=unpersist_tensors, + step_metadata=self._step_metadata(metadata), + request_done=request_done, + ) + + # ------------------------------------------------------------------ + # Model ABC: submodule loading + # ------------------------------------------------------------------ + + def get_submodule( + self, node_name: str, device: str = "cpu", tp_group=None, + autocast_dtype: torch.dtype | None = None, + ) -> torch.nn.Module | None: + # autocast_dtype is accepted for interface parity (the engine manager + # passes it to every model). Cosmos3 already casts the meta module to + # bf16 before to_empty in _build_transformer, so params are allocated + # directly in the checkpoint dtype and the hint is redundant here. + if node_name in self._submodule_cache: + return self._submodule_cache[node_name] + submodule = self._create_submodule(node_name, device, tp_group) + self._submodule_cache[node_name] = submodule + if submodule is not None: + logger.info("Loaded Cosmos3 submodule for %s", node_name) + return submodule + + def _create_submodule(self, node_name: str, device: str, tp_group=None): + if node_name == DIT_NODE: + return Cosmos3DiTSubmodule( + transformer=self._build_transformer(device, tp_group=tp_group), + config=self.config, + scheduler=self._build_scheduler(), + vae=self._build_vae(device), + ) + if node_name == VAE_DECODER_NODE: + return Cosmos3VAEDecoderSubmodule( + vae=self._build_vae(device), config=self.config + ) + return None + + def _build_scheduler(self): + if self.skip_weight_loading: + return None + from diffusers import UniPCMultistepScheduler + + return UniPCMultistepScheduler.from_pretrained(str(self._ensure_repo() / "scheduler")) + + def _build_transformer(self, device: str, tp_group=None): + from mstar.model.cosmos3.components.transformer import Cosmos3OmniTransformer + from mstar.model.cosmos3.loader import load_transformer_weights + + # Build on the meta device (shapes only, no storage), pin the + # checkpoint's bf16 dtype, then materialize uninitialized tensors on the + # target device and overwrite with the checkpoint weights — the same + # path the other model packages use. bf16 matches the published + # checkpoint exactly and halves resident weight memory vs the float32 + # meta default; the engine additionally runs the forward under a bf16 + # autocast (a no-op here). + with torch.device("meta" if not self.skip_weight_loading else "cpu"): + model = Cosmos3OmniTransformer(self.config, comm_group=tp_group) + model = model.to(torch.bfloat16) + if self.skip_weight_loading: + return model.to_empty(device=device) + + model.to_empty(device=device) + load_transformer_weights(model, self._ensure_repo(), device=device) + # Keep the timestep embedder in fp32, like diffusers' + # ``_keep_in_fp32_modules=["time_embedder"]`` (the upcast is lossless from + # the bf16 checkpoint and matches diffusers' numerics). + model.time_embedder.to(torch.float32) + model.eval() + return model + + def _build_vae(self, device: str): + if self.skip_weight_loading: + return None + if self._vae is not None: + return self._vae + from diffusers import AutoencoderKLWan + + vae = AutoencoderKLWan.from_pretrained(str(self._ensure_repo() / "vae")) + self._vae = vae.to(device).eval() + return self._vae diff --git a/mstar/model/cosmos3/loader.py b/mstar/model/cosmos3/loader.py new file mode 100644 index 00000000..a602bded --- /dev/null +++ b/mstar/model/cosmos3/loader.py @@ -0,0 +1,126 @@ +"""Weight loading for the Cosmos3 generator backbone. + +The published checkpoint is the diffusers ``transformer/`` layout: flat +``layers.N.*`` keys with unfused attention projections (``to_q/to_k/to_v`` for +the understanding pathway, ``add_q_proj/add_k_proj/add_v_proj`` for the +generation pathway) and ``_moe_gen``-suffixed GEN MLP/norms. Our backbone +module mirrors that layout one-to-one, so loading needs no key remapping and +no stacked-parameter fusion — only the unused text ``lm_head`` is dropped. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import torch + +# Checkpoint keys deliberately not loaded into the generator backbone. The +# text ``lm_head`` exists in the checkpoint (the understanding tower descends +# from a text LM) but is never used: generation emits flow velocity via +# ``proj_out``, so we do not build or load it. +DROP_KEYS: frozenset[str] = frozenset({"lm_head.weight"}) + + +def cosmos3_name_remapper(name: str) -> str | None: + """Map a checkpoint key to a backbone parameter path, or ``None`` to drop. + + Identity for every key the backbone owns; ``None`` for the intentional + drop-list. Kept explicit so an unexpected checkpoint key surfaces as a + coverage failure rather than being silently ignored. + """ + if name in DROP_KEYS: + return None + return name + + +def read_transformer_weight_keys(checkpoint_dir: str | Path) -> set[str]: + """Return every tensor key declared by the ``transformer/`` shard index.""" + tdir = Path(checkpoint_dir) / "transformer" + index = tdir / "diffusion_pytorch_model.safetensors.index.json" + if index.exists(): + with open(index) as f: + return set(json.load(f)["weight_map"].keys()) + # Single-shard fallback: read tensor names from the safetensors header. + shards = list(tdir.glob("*.safetensors")) + if not shards: + raise FileNotFoundError(f"no transformer weights found under {tdir}") + from safetensors import safe_open + + keys: set[str] = set() + for shard in shards: + with safe_open(shard, framework="pt") as handle: + keys.update(handle.keys()) + return keys + + +def _transformer_shard_names(tdir: Path) -> list[str]: + """Resolve the ``transformer/`` shard filenames. + + The diffusers checkpoint indexes its shards under + ``diffusion_pytorch_model.safetensors.index.json`` (not the + ``model.safetensors`` name the generic shard iterator assumes), so the + shard list is read from that index; a single-file checkpoint is the + fallback. + """ + index = tdir / "diffusion_pytorch_model.safetensors.index.json" + if index.exists(): + with open(index) as f: + return sorted(set(json.load(f)["weight_map"].values())) + shards = sorted(p.name for p in tdir.glob("*.safetensors")) + if not shards: + raise FileNotFoundError(f"no transformer weights found under {tdir}") + return shards + + +def read_transformer_weight_shapes(checkpoint_dir: str | Path) -> dict[str, tuple[int, ...]]: + """Return ``{key: shape}`` for every ``transformer/`` tensor by reading only + the safetensors headers — no tensor data is materialized. Enables CPU-side + shape verification of the meta-built backbone against the checkpoint. + """ + from safetensors import safe_open + + tdir = Path(checkpoint_dir) / "transformer" + shapes: dict[str, tuple[int, ...]] = {} + for shard in _transformer_shard_names(tdir): + with safe_open(tdir / shard, framework="pt") as handle: + for key in handle.keys(): + shapes[key] = tuple(handle.get_slice(key).get_shape()) + return shapes + + +def load_transformer_weights( + model: torch.nn.Module, + checkpoint_dir: str | Path, + device: str = "cpu", +) -> set[str]: + """Stream the ``transformer/`` shards into ``model`` and return loaded keys. + + Mirrors the meta-device + ``load_hf_weights`` path the other model packages + use, but resolves the shard list from the diffusers ``diffusion_pytorch_model`` + index (the generic iterator only knows the ``model.safetensors`` name). No + stacked-parameter rules: the checkpoint's projections are unfused and match + the backbone parameter names directly. Raises if any backbone parameter is + left unfilled — the completeness guarantee bagel's loader also enforces. + """ + from mstar.model.loader import iter_safetensors_file, load_hf_weights + + tdir = Path(checkpoint_dir) / "transformer" + shard_names = _transformer_shard_names(tdir) + + def _weights(): + for shard in shard_names: + yield from iter_safetensors_file(tdir / shard, device=device) + + loaded = load_hf_weights(model, _weights(), name_remapper=cosmos3_name_remapper) + + expected = set(dict(model.named_parameters()).keys()) + missing = expected - loaded + if missing: + sample = sorted(missing)[:10] + more = "…" if len(missing) > 10 else "" + raise KeyError( + f"Cosmos3 transformer load left {len(missing)} parameter(s) unfilled " + f"from {tdir}: {sample}{more}" + ) + return loaded diff --git a/mstar/model/cosmos3/packing.py b/mstar/model/cosmos3/packing.py new file mode 100644 index 00000000..e1bb8d79 --- /dev/null +++ b/mstar/model/cosmos3/packing.py @@ -0,0 +1,464 @@ +"""Joint-sequence packing for Cosmos3 generation (ported from the diffusers +``Cosmos3OmniPipeline``). + +Pure, stateless primitives that turn a prompt + latent shape into the +transformer's per-step inputs: the 3D interleaved mRoPE position ids, the +text/vision segment layouts, and the chat-template tokenization. Shared by the +t2i pipeline and the engine submodule's input preprocessing. Reproduces the +diffusers pipeline's packed t2i inputs byte-for-byte. +""" + +from __future__ import annotations + +import math +from typing import Any + +import torch + +# --------------------------------------------------------------------------- +# 3D interleaved mRoPE position ids (exact ports of the pipeline helpers). +# --------------------------------------------------------------------------- + + +def get_3d_mrope_ids_text_tokens( + num_tokens: int, temporal_offset: int | float, use_float_positions: bool = False +) -> tuple[torch.Tensor, int | float]: + """Text tokens: all three axes share the same increasing ids from ``temporal_offset``.""" + if use_float_positions: + ids = torch.arange(num_tokens, dtype=torch.float32) + temporal_offset + else: + ids = torch.arange(num_tokens, dtype=torch.long) + int(temporal_offset) + mrope_ids = ids.unsqueeze(0).expand(3, -1).contiguous() # [3, num_tokens] + return mrope_ids, temporal_offset + num_tokens + + +def get_3d_mrope_ids_vae_tokens( + grid_t: int, + grid_h: int, + grid_w: int, + temporal_offset: int | float, + reset_spatial_indices: bool = True, + fps: float | None = None, + base_fps: float = 24.0, + temporal_compression_factor: int = 4, + base_temporal_compression_factor: int | None = None, + start_frame_offset: int = 0, +) -> tuple[torch.Tensor, int | float]: + """Vision/sound (VAE) tokens: (t, h, w) grid ids, with optional fps modulation + of the temporal axis (only when ``fps`` is set and ``grid_t > 1``).""" + fps_modulation_enabled = fps is not None and grid_t > 1 + effective_base_tcf = ( + base_temporal_compression_factor + if base_temporal_compression_factor is not None + else temporal_compression_factor + ) + + if fps_modulation_enabled: + tps = fps / temporal_compression_factor + base_tps = base_fps / effective_base_tcf + frame_indices = torch.arange(grid_t, dtype=torch.float32) + scaled_t = (frame_indices + start_frame_offset) / tps * base_tps + temporal_offset + t_index = scaled_t.view(-1, 1).expand(-1, grid_h * grid_w).flatten() + else: + t_index = ( + torch.arange(grid_t, dtype=torch.long).view(-1, 1).expand(-1, grid_h * grid_w).flatten() + + int(temporal_offset) + + start_frame_offset + ) + + h_index = torch.arange(grid_h, dtype=torch.long).view(1, -1, 1).expand(grid_t, -1, grid_w).flatten() + w_index = torch.arange(grid_w, dtype=torch.long).view(1, 1, -1).expand(grid_t, grid_h, -1).flatten() + + if not reset_spatial_indices: + spatial_offset = int(temporal_offset) + h_index = h_index + spatial_offset + w_index = w_index + spatial_offset + + if fps_modulation_enabled: + mrope_ids = torch.stack([t_index, h_index.to(torch.float32), w_index.to(torch.float32)], dim=0) + else: + mrope_ids = torch.stack([t_index, h_index, w_index], dim=0) + + next_temporal_offset = math.ceil(mrope_ids.max().item()) + 1 + return mrope_ids, next_temporal_offset + + +def get_3d_mrope_ids_action_tokens( + grid_t: int, + temporal_offset: int | float, + action_fps: float | None, + base_fps: float = 24.0, + base_temporal_compression_factor: int = 4, + start_frame_offset: int = 1, +) -> tuple[torch.Tensor, int | float]: + """Action tokens: a frame-rate ``(T, 1, 1)`` temporal grid sharing the media + offset with the vision band. The action stream is uncompressed in time + (``temporal_compression_factor=1``) but its rate is expressed in the same + base-fps units as the vision latents (``base_temporal_compression_factor``), + so an action chunk lines up temporally with the conditioning video.""" + return get_3d_mrope_ids_vae_tokens( + grid_t=grid_t, + grid_h=1, + grid_w=1, + temporal_offset=temporal_offset, + reset_spatial_indices=True, + fps=action_fps, + base_fps=base_fps, + temporal_compression_factor=1, + base_temporal_compression_factor=base_temporal_compression_factor, + start_frame_offset=start_frame_offset, + ) + + +# --------------------------------------------------------------------------- +# Action conditioning layout (ported from vllm-omni ``action.py``). Each mode +# fixes which latent video frames and which action tokens are clean context vs +# noisy/predicted: +# * forward_dynamics -- action is the condition (all clean); video frame 0 is +# clean, the rest are predicted. +# * inverse_dynamics -- video is the condition (all latent frames clean); +# every action token is predicted. +# * policy -- video frame 0 is clean (the rest predicted) and every +# action token is predicted. +# --------------------------------------------------------------------------- + +ACTION_MODE_FORWARD_DYNAMICS = "forward_dynamics" +ACTION_MODE_INVERSE_DYNAMICS = "inverse_dynamics" +ACTION_MODE_POLICY = "policy" +ACTION_MODES = (ACTION_MODE_FORWARD_DYNAMICS, ACTION_MODE_INVERSE_DYNAMICS, ACTION_MODE_POLICY) + + +def action_condition_frame_indexes(mode: str, action_length: int) -> list[int]: + """Clean (conditioning) action tokens for ``mode``.""" + if mode == ACTION_MODE_FORWARD_DYNAMICS: + return list(range(action_length)) + if mode in (ACTION_MODE_INVERSE_DYNAMICS, ACTION_MODE_POLICY): + return [] + raise ValueError(f"Unknown Cosmos3 action mode: {mode!r}") + + +def vision_condition_frame_indexes(mode: str, latent_frames: int) -> list[int]: + """Clean (conditioning) latent video frames for ``mode``.""" + if mode in (ACTION_MODE_FORWARD_DYNAMICS, ACTION_MODE_POLICY): + return [0] + if mode == ACTION_MODE_INVERSE_DYNAMICS: + return list(range(latent_frames)) + raise ValueError(f"Unknown Cosmos3 action mode: {mode!r}") + + +def action_start_frame_offset(action_length: int, video_length: int) -> int: + """mRoPE start-frame offset for the action band: action chunks of length + ``num_frames - 1`` start one frame in (aligned to predicted frames 1..N); + a full ``num_frames`` chunk starts at 0.""" + if action_length == video_length - 1: + return 1 + if action_length == video_length: + return 0 + raise ValueError( + "Cosmos3 action_chunk_size must equal num_frames - 1 or num_frames; " + f"got action_chunk_size={action_length}, num_frames={video_length}." + ) + + +# --------------------------------------------------------------------------- +# Prompt tokenization — ported from pipeline.tokenize_prompt. Image mode +# (num_frames == 1) and video mode differ only in the system prompt and the +# metadata sentences appended to the prompt (resolution always; duration for +# video). Both append the eos + start-of-generation special tokens. +# --------------------------------------------------------------------------- + +SYSTEM_PROMPT_IMAGE = "You are a helpful assistant who will generate images from a give prompt." +SYSTEM_PROMPT_VIDEO = "You are a helpful assistant who will generate videos from a give prompt." +IMAGE_RESOLUTION_TEMPLATE = "This image is of {height}x{width} resolution." +INVERSE_IMAGE_RESOLUTION_TEMPLATE = "This image is not of {height}x{width} resolution." +VIDEO_RESOLUTION_TEMPLATE = "This video is of {height}x{width} resolution." +INVERSE_VIDEO_RESOLUTION_TEMPLATE = "This video is not of {height}x{width} resolution." +DURATION_TEMPLATE = "The video is {duration:.1f} seconds long and is of {fps:.0f} FPS." +INVERSE_DURATION_TEMPLATE = "The video is not {duration:.1f} seconds long and is not of {fps:.0f} FPS." + + +def _append(base: str, addition: str) -> str: + base = base.rstrip(".") + return f"{base}. {addition}" if base else addition + + +def tokenize_prompt( + tokenizer, + prompt: str, + negative_prompt: str | None, + num_frames: int, + height: int, + width: int, + fps: float = 24.0, + use_system_prompt: bool = True, + add_resolution_template: bool = True, + add_duration_template: bool = True, +) -> tuple[list[int], list[int]]: + """Return ``(cond_input_ids, uncond_input_ids)`` for image/video generation. + + Mirrors the diffusers pipeline: apply the Qwen2 chat template with the + mode-specific system prompt and metadata sentences (duration for video, then + resolution), using inverse templates for the negative prompt, then append the + eos + start-of-generation (``<|vision_start|>``) special tokens. Image mode is + ``num_frames == 1``. + """ + is_image = num_frames == 1 + if negative_prompt is None: + negative_prompt = "" + eos_id = tokenizer.eos_token_id + sog_id = tokenizer.convert_tokens_to_ids("<|vision_start|>") + + resolution_template = IMAGE_RESOLUTION_TEMPLATE if is_image else VIDEO_RESOLUTION_TEMPLATE + inverse_resolution_template = ( + INVERSE_IMAGE_RESOLUTION_TEMPLATE if is_image else INVERSE_VIDEO_RESOLUTION_TEMPLATE + ) + + def _apply_templates(text: str, is_negative: bool) -> str: + if not is_image and add_duration_template: + tmpl = INVERSE_DURATION_TEMPLATE if is_negative else DURATION_TEMPLATE + text = _append(text, tmpl.format(duration=num_frames / fps, fps=fps)) + if add_resolution_template: + tmpl = inverse_resolution_template if is_negative else resolution_template + text = _append(text, tmpl.format(height=height, width=width)) + return text + + def _tokenize(text: str) -> list[int]: + conversations = [] + if use_system_prompt: + conversations.append( + {"role": "system", "content": SYSTEM_PROMPT_IMAGE if is_image else SYSTEM_PROMPT_VIDEO} + ) + conversations.append({"role": "user", "content": text}) + enc = tokenizer.apply_chat_template( + conversations, tokenize=True, add_generation_prompt=True, add_vision_id=False, return_dict=True + ) + return list(enc["input_ids"]) + [eos_id, sog_id] + + cond = _tokenize(_apply_templates(prompt, is_negative=False)) + uncond = _tokenize(_apply_templates(negative_prompt, is_negative=True)) + return cond, uncond + + +def tokenize_t2i_prompt( + tokenizer, + prompt: str, + negative_prompt: str | None, + height: int, + width: int, + use_system_prompt: bool = True, + add_resolution_template: bool = True, +) -> tuple[list[int], list[int]]: + """Image-mode convenience wrapper around :func:`tokenize_prompt`.""" + return tokenize_prompt( + tokenizer, + prompt, + negative_prompt, + num_frames=1, + height=height, + width=width, + use_system_prompt=use_system_prompt, + add_resolution_template=add_resolution_template, + ) + + +# --------------------------------------------------------------------------- +# Segment builders + full t2i static-input assembly. +# --------------------------------------------------------------------------- + + +def build_text_segment(input_ids: list[int], config, device) -> dict[str, Any]: + und_len = len(input_ids) + text_mrope_ids, next_off = get_3d_mrope_ids_text_tokens( + num_tokens=und_len, temporal_offset=0, use_float_positions=config.enable_fps_modulation + ) + return { + "input_ids": torch.tensor(input_ids, dtype=torch.long, device=device), + "text_indexes": torch.arange(und_len, dtype=torch.long, device=device), + "und_len": und_len, + "text_mrope_ids": text_mrope_ids.to(device), + "vision_start_temporal_offset": next_off + config.unified_3d_mrope_temporal_modality_margin, + } + + +def build_vision_segment( + latent_shape: tuple[int, int, int, int, int], + has_image_condition: bool, + mrope_offset: int | float, + vision_fps: float | None, + curr: int, + config, + vae_scale_factor_temporal: int, + device, + noisy_frames: list[int] | None = None, +) -> dict[str, Any]: + """``latent_shape`` is the vision latent tensor shape ``[B, C, T, H, W]``. + + ``noisy_frames`` lists the latent frames that are noisy (predicted); the rest + are clean conditioning context. When ``None`` it defaults to frame 0 clean + if ``has_image_condition`` else all frames noisy — i.e. the t2i/t2v/i2v + layouts. Action modes pass an explicit list (e.g. ``[]`` for + inverse-dynamics, where the whole video is conditioning).""" + p = config.latent_patch_size + _, _, latent_t, latent_h, latent_w = latent_shape + patch_h = math.ceil(latent_h / p) + patch_w = math.ceil(latent_w / p) + num_vision_tokens = latent_t * patch_h * patch_w + + if noisy_frames is None: + noisy_start = 1 if has_image_condition else 0 + noisy_list = list(range(noisy_start, latent_t)) + else: + noisy_list = sorted(noisy_frames) + noisy_frame_indexes = torch.tensor(noisy_list, device=device, dtype=torch.long) + + frame_token_stride = patch_h * patch_w + mse_loss_indexes: list[int] = [] + for frame_idx in noisy_list: + frame_start = curr + frame_idx * frame_token_stride + mse_loss_indexes.extend(range(frame_start, frame_start + frame_token_stride)) + + effective_fps = vision_fps if config.enable_fps_modulation else None + vision_mrope_ids, _ = get_3d_mrope_ids_vae_tokens( + grid_t=latent_t, + grid_h=patch_h, + grid_w=patch_w, + temporal_offset=mrope_offset, + reset_spatial_indices=config.unified_3d_mrope_reset_spatial_ids, + fps=effective_fps, + base_fps=float(config.base_fps), + temporal_compression_factor=vae_scale_factor_temporal, + ) + + return { + "vision_token_shapes": [(latent_t, patch_h, patch_w)], + "vision_sequence_indexes": torch.arange(curr, curr + num_vision_tokens, dtype=torch.long, device=device), + "vision_mse_loss_indexes": torch.tensor(mse_loss_indexes, dtype=torch.long, device=device), + "vision_noisy_frame_indexes": [noisy_frame_indexes], + "vision_mrope_ids": vision_mrope_ids.to(device), + "num_vision_tokens": num_vision_tokens, + "num_noisy_vision_tokens": len(noisy_list) * frame_token_stride, + } + + +def build_static_inputs( + input_ids: list[int], + latent_shape: tuple[int, int, int, int, int], + config, + vae_scale_factor_temporal: int, + fps: float, + device, + has_image_condition: bool = False, +) -> dict[str, Any]: + """Assemble the per-prompt static transformer inputs for image/video + generation. ``latent_shape`` is ``[B, C, T, H, W]`` (``T == 1`` for images; + ``T == 1 + (num_frames - 1) // temporal_factor`` for video). When + ``has_image_condition`` is set, latent frame 0 is a clean conditioning anchor + (image-to-video): it stays in the sequence but is excluded from the noisy / + predicted frames. Step-varying fields (``vision_tokens``, + ``vision_timesteps``) are spliced in per denoising step by the caller.""" + text = build_text_segment(input_ids, config, device) + vision = build_vision_segment( + latent_shape=latent_shape, + has_image_condition=has_image_condition, + mrope_offset=text["vision_start_temporal_offset"], + vision_fps=fps, + curr=text["und_len"], + config=config, + vae_scale_factor_temporal=vae_scale_factor_temporal, + device=device, + ) + position_ids = torch.cat([text["text_mrope_ids"], vision["vision_mrope_ids"]], dim=1) + return { + **text, + **vision, + "position_ids": position_ids, + "sequence_length": text["und_len"] + vision["num_vision_tokens"], + } + + +def build_t2i_static_inputs( + input_ids: list[int], + latent_shape: tuple[int, int, int, int, int], + config, + vae_scale_factor_temporal: int, + fps: float, + device, +) -> dict[str, Any]: + """Image-mode convenience wrapper around :func:`build_static_inputs`.""" + return build_static_inputs( + input_ids, latent_shape, config, vae_scale_factor_temporal, fps, device, + has_image_condition=False, + ) + + +def build_action_static_inputs( + input_ids: list[int], + video_latent_shape: tuple[int, int, int, int, int], + action_chunk_size: int, + mode: str, + config, + vae_scale_factor_temporal: int, + fps: float, + action_fps: float, + action_start_offset: int, + device, +) -> dict[str, Any]: + """Assemble the static transformer inputs for joint video+action generation. + + The generation sequence is ``[video tokens | action tokens]`` after the text + prefix. Both media bands share the post-text temporal offset (the 15000 + margin), with the action band on its own frame-rate grid. Conditioning per + ``mode`` decides which video frames and action tokens are clean context vs + noisy/predicted (see :func:`vision_condition_frame_indexes` / + :func:`action_condition_frame_indexes`).""" + text = build_text_segment(input_ids, config, device) + media_offset = text["vision_start_temporal_offset"] + _, _, latent_t, _, _ = video_latent_shape + + vision_clean = set(vision_condition_frame_indexes(mode, latent_t)) + vision_noisy = [f for f in range(latent_t) if f not in vision_clean] + vision = build_vision_segment( + latent_shape=video_latent_shape, + has_image_condition=False, + mrope_offset=media_offset, + vision_fps=fps, + curr=text["und_len"], + config=config, + vae_scale_factor_temporal=vae_scale_factor_temporal, + device=device, + noisy_frames=vision_noisy, + ) + + curr = text["und_len"] + vision["num_vision_tokens"] + action_clean = set(action_condition_frame_indexes(mode, action_chunk_size)) + action_noisy = [a for a in range(action_chunk_size) if a not in action_clean] + effective_action_fps = action_fps if config.enable_fps_modulation else None + action_mrope_ids, _ = get_3d_mrope_ids_action_tokens( + grid_t=action_chunk_size, + temporal_offset=media_offset, + action_fps=effective_action_fps, + base_fps=float(config.base_fps), + base_temporal_compression_factor=vae_scale_factor_temporal, + start_frame_offset=action_start_offset, + ) + + parts = [text["text_mrope_ids"], vision["vision_mrope_ids"], action_mrope_ids.to(device)] + pos_dtype = torch.float32 if any(p.is_floating_point() for p in parts) else torch.long + position_ids = torch.cat([p.to(pos_dtype) for p in parts], dim=1) + + return { + **text, + **vision, + "action_token_shapes": [(action_chunk_size, 1, 1)], + "action_sequence_indexes": torch.arange(curr, curr + action_chunk_size, dtype=torch.long, device=device), + "action_noisy_frame_indexes": [torch.tensor(action_noisy, dtype=torch.long, device=device)], + "action_mse_loss_indexes": torch.tensor( + [curr + a for a in action_noisy], dtype=torch.long, device=device + ), + "action_mrope_ids": action_mrope_ids.to(device), + "num_action_tokens": action_chunk_size, + "num_noisy_action_tokens": len(action_noisy), + "action_mode": mode, + "position_ids": position_ids, + "sequence_length": curr + action_chunk_size, + } diff --git a/mstar/model/cosmos3/pipeline.py b/mstar/model/cosmos3/pipeline.py new file mode 100644 index 00000000..4f869b45 --- /dev/null +++ b/mstar/model/cosmos3/pipeline.py @@ -0,0 +1,362 @@ +"""Fused generation pipeline for Cosmos3-Nano (text/image-to-image/video). + +Runs the generator in one fused forward per denoising step (text + vision +together), using mstar's DiT forward + packing and the imported diffusers UniPC +scheduler + Wan VAE. Intentionally simple (batch 1, sequential CFG); not the +served path. Produces the same image/video as the diffusers +``Cosmos3OmniPipeline`` on a fixed seed/prompt. + +``num_frames == 1`` is text-to-image; ``num_frames > 1`` is text-to-video, and +passing ``image`` anchors frame 0 to a conditioning frame (image-to-video). +""" + +from __future__ import annotations + +import torch + +from mstar.model.cosmos3.packing import ( + action_start_frame_offset, + build_action_static_inputs, + build_static_inputs, + tokenize_prompt, + vision_condition_frame_indexes, +) + +# Transformer.forward static kwargs produced by build_static_inputs. +_TF_STATIC_FIELDS = ( + "input_ids", + "text_indexes", + "position_ids", + "und_len", + "sequence_length", + "vision_token_shapes", + "vision_sequence_indexes", + "vision_mse_loss_indexes", + "vision_noisy_frame_indexes", +) + +# Additional Transformer.forward static kwargs for joint video+action generation. +_TF_ACTION_STATIC_FIELDS = ( + "action_token_shapes", + "action_sequence_indexes", + "action_mse_loss_indexes", + "action_noisy_frame_indexes", +) + + +class Cosmos3Pipeline: + """Fused t2i / t2v / i2v pipeline for Cosmos3-Nano.""" + + def __init__(self, transformer, vae, scheduler, tokenizer, config, device, dtype=torch.bfloat16): + self.transformer = transformer + self.vae = vae + self.scheduler = scheduler + self.tokenizer = tokenizer + self.config = config + self.device = device + self.dtype = dtype + + self.vae_scale_spatial = int(vae.config.scale_factor_spatial) + self.vae_scale_temporal = int(vae.config.scale_factor_temporal) + self._latents_mean = torch.tensor(vae.config.latents_mean, dtype=vae.dtype, device=device) + self._latents_inv_std = 1.0 / torch.tensor(vae.config.latents_std, dtype=vae.dtype, device=device) + + # Conditioning-frame preprocessor (PIL / numpy / tensor -> [1,3,H,W] in + # [-1,1], resized) — the same one the diffusers pipeline uses, for parity. + from diffusers.video_processor import VideoProcessor + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_spatial, resample="bilinear") + + @classmethod + def from_model(cls, model, device, dtype=torch.bfloat16): + """Build from a loaded ``Cosmos3Model`` (DiT + Wan VAE) + imported UniPC.""" + from diffusers import UniPCMultistepScheduler + + transformer = model.get_submodule("dit", device=device).transformer + vae = model._build_vae(device) + scheduler = UniPCMultistepScheduler.from_pretrained(str(model._ensure_repo() / "scheduler")) + return cls(transformer, vae, scheduler, model.tokenizer, model.config, device, dtype) + + def _encode_video(self, x: torch.Tensor) -> torch.Tensor: + """[1,3,T,H,W] in [-1,1] -> normalized latents [1,C,T_lat,H/16,W/16]. + + Takes the distribution mode (``sample_mode="argmax"``) and applies the + pipeline-side latent normalization, matching the diffusers oracle. + """ + in_dtype = x.dtype + dtype = self.vae.dtype + mean = self._latents_mean.to(device=x.device, dtype=dtype).view(1, -1, 1, 1, 1) + inv_std = self._latents_inv_std.to(device=x.device, dtype=dtype).view(1, -1, 1, 1, 1) + raw_mu = self.vae.encode(x.to(dtype)).latent_dist.mode() + return ((raw_mu - mean) * inv_std).to(in_dtype) + + def _decode(self, latents: torch.Tensor) -> torch.Tensor: + """Latents [1,C,T,H,W] -> pixels [1,3,T,H,W] in [0,1] (un-normalize + Wan VAE).""" + mean = self._latents_mean.view(1, -1, 1, 1, 1) + inv_std = self._latents_inv_std.view(1, -1, 1, 1, 1) + z = latents.to(self.vae.dtype) / inv_std + mean + decoded = self.vae.decode(z).sample # [1,3,T,H,W] in [-1,1] + return (decoded / 2 + 0.5).clamp(0, 1).to(torch.float32) + + def _prepare_latents(self, image, num_frames, height, width, generator, latents, device, dtype): + """Build the initial vision latents + whether frame 0 is a clean anchor. + + For image-to-video the conditioning frame anchors latent frame 0 (clean, + VAE-encoded) and the remaining frames start from pure noise; otherwise the + whole tensor is noise. Mirrors the diffusers ``prepare_latents`` vision path. + """ + from diffusers.utils.torch_utils import randn_tensor + + is_image = num_frames == 1 + has_image_condition = image is not None and not is_image + + conditioning_frame_2d = None + if image is not None: + conditioning_frame_2d = self.video_processor.preprocess(image, height=height, width=width).to( + device=device, dtype=dtype + ) + + if is_image: + vision_tensor = ( + conditioning_frame_2d.unsqueeze(2) + if conditioning_frame_2d is not None + else torch.zeros(1, 3, 1, height, width, dtype=dtype, device=device) + ) + else: + vision_tensor = torch.zeros(1, 3, num_frames, height, width, dtype=dtype, device=device) + if conditioning_frame_2d is not None: + vision_tensor[:, :, 0] = conditioning_frame_2d + if num_frames > 1: + vision_tensor[:, :, 1:] = conditioning_frame_2d.unsqueeze(2).expand( + -1, -1, num_frames - 1, -1, -1 + ) + + x0 = self._encode_video(vision_tensor).contiguous().float() + vision_shape = tuple(x0.shape) + + vision_condition_mask = torch.zeros((x0.shape[2], 1, 1), device=device, dtype=dtype) + if has_image_condition: + vision_condition_mask[0, 0, 0] = 1.0 + + if latents is None: + pure_noise = randn_tensor(vision_shape, generator=generator, device=device, dtype=dtype) + latents = ( + vision_condition_mask * x0.to(device=device, dtype=dtype) + + (1.0 - vision_condition_mask) * pure_noise + ) + else: + latents = latents.to(device=device, dtype=dtype) + return latents, has_image_condition + + @torch.no_grad() + def __call__( + self, + prompt: str, + negative_prompt: str = "", + image=None, + num_frames: int = 1, + height: int = 256, + width: int = 256, + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + fps: float = 24.0, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + decode: bool = True, + ): + device, dtype = self.device, self.dtype + cond_ids, uncond_ids = tokenize_prompt( + self.tokenizer, prompt, negative_prompt, num_frames=num_frames, height=height, width=width, fps=fps + ) + + latents, has_image_condition = self._prepare_latents( + image, num_frames, height, width, generator, latents, device, dtype + ) + latent_shape = tuple(latents.shape) + + cond = build_static_inputs( + cond_ids, latent_shape, self.config, self.vae_scale_temporal, fps, device, + has_image_condition=has_image_condition, + ) + uncond = build_static_inputs( + uncond_ids, latent_shape, self.config, self.vae_scale_temporal, fps, device, + has_image_condition=has_image_condition, + ) + cond_static = {k: cond[k] for k in _TF_STATIC_FIELDS} + uncond_static = {k: uncond[k] for k in _TF_STATIC_FIELDS} + num_noisy = cond["num_noisy_vision_tokens"] + + self.scheduler.set_timesteps(num_inference_steps, device=device) + for t in self.scheduler.timesteps: + vision_tokens = [latents.to(dtype)] + vision_timesteps = torch.full((num_noisy,), t.item(), device=device) + cond_v = self.transformer( + vision_tokens=vision_tokens, vision_timesteps=vision_timesteps, **cond_static + )[0][0] + if guidance_scale != 1.0: + uncond_v = self.transformer( + vision_tokens=vision_tokens, vision_timesteps=vision_timesteps, **uncond_static + )[0][0] + velocity = uncond_v + guidance_scale * (cond_v - uncond_v) + else: + velocity = cond_v + latents = self.scheduler.step( + velocity.unsqueeze(0), t, latents.unsqueeze(0), return_dict=False + )[0].squeeze(0) + + if not decode: + return latents + return self._decode(latents) + + @torch.no_grad() + def generate_action( + self, + *, + prompt: str, + mode: str, + domain_id: int, + action_chunk_size: int, + raw_action_dim: int, + video: torch.Tensor | None = None, + video_latents: torch.Tensor | None = None, + action: torch.Tensor | None = None, + num_frames: int | None = None, + height: int = 256, + width: int = 256, + fps: float = 24.0, + action_fps: float | None = None, + num_inference_steps: int = 30, + guidance_scale: float = 1.0, + flow_shift: float | None = None, + negative_prompt: str = "", + generator: torch.Generator | None = None, + cond_ids: list[int] | None = None, + uncond_ids: list[int] | None = None, + return_video: bool = False, + ): + """Joint video+action generation (forward_dynamics / inverse_dynamics / policy). + + The conditioning video is VAE-encoded to clean anchor frames per ``mode`` + (all frames for inverse-dynamics; frame 0 for forward-dynamics / policy). + Action tokens are clean conditioning for forward-dynamics, else noisy and + predicted. Returns the predicted action ``[1, action_chunk_size, + raw_action_dim]`` (and the decoded video when ``return_video``). + """ + from diffusers import UniPCMultistepScheduler + from diffusers.utils.torch_utils import randn_tensor + + device, dtype = self.device, self.dtype + action_dim = self.transformer.action_dim + if num_frames is None: + num_frames = action_chunk_size + 1 + if action_fps is None: + action_fps = fps + action_offset = action_start_frame_offset(action_chunk_size, num_frames) + + if flow_shift is not None: + scheduler = UniPCMultistepScheduler.from_config(self.scheduler.config, flow_shift=flow_shift) + else: + scheduler = UniPCMultistepScheduler.from_config(self.scheduler.config) + scheduler.set_timesteps(num_inference_steps, device=device) + + if cond_ids is None or uncond_ids is None: + cond_ids, uncond_ids = tokenize_prompt( + self.tokenizer, prompt, negative_prompt, num_frames=num_frames, + height=height, width=width, fps=fps, + ) + + # --- action latents (noise drawn before the video noise, matching the + # reference ordering so a shared seed reproduces the same sample). --- + if mode == "forward_dynamics": + if action is None: + raise ValueError("Cosmos3 forward_dynamics requires `action`.") + act = action.to(device=device, dtype=torch.float32) + if act.ndim == 3: + act = act.squeeze(0) + if act.shape[0] < action_chunk_size: + act = torch.cat([act, act[-1:].repeat(action_chunk_size - act.shape[0], 1)], dim=0) + elif act.shape[0] > action_chunk_size: + act = act[:action_chunk_size] + clean_action = torch.zeros((action_chunk_size, action_dim), dtype=torch.float32) + clean_action[:, :raw_action_dim] = act[:, :raw_action_dim] + clean_action = clean_action.to(device=device, dtype=dtype).unsqueeze(0) + action_clean_mask = torch.ones((1, action_chunk_size, 1), device=device, dtype=dtype) + else: + clean_action = torch.zeros((1, action_chunk_size, action_dim), device=device, dtype=dtype) + action_clean_mask = torch.zeros((1, action_chunk_size, 1), device=device, dtype=dtype) + a_noise = randn_tensor((1, action_chunk_size, action_dim), generator=generator, device=device, dtype=dtype) + a_noise[..., raw_action_dim:] = 0 + clean_action[..., raw_action_dim:] = 0 + action_latents = action_clean_mask * clean_action + (1.0 - action_clean_mask) * a_noise + action_velocity_mask = 1.0 - action_clean_mask + + # --- conditioning video latents (clean per mode) --- + if video_latents is None: + if video is None: + raise ValueError("Cosmos3 action generation requires `video` or `video_latents`.") + video_latents = self._encode_video(video.to(device=device, dtype=dtype)) + cond_latent = video_latents.to(device=device, dtype=dtype) + latent_shape = tuple(cond_latent.shape) + t_lat = latent_shape[2] + + vis_clean = set(vision_condition_frame_indexes(mode, t_lat)) + vmask = torch.zeros((1, 1, t_lat, 1, 1), device=device, dtype=dtype) + for f in vis_clean: + vmask[:, :, f] = 1.0 + v_noise = randn_tensor(latent_shape, generator=generator, device=device, dtype=dtype) + latents = vmask * cond_latent + (1.0 - vmask) * v_noise + velocity_mask = 1.0 - vmask # 1 where the video is predicted + + # --- static packing --- + cond = build_action_static_inputs( + cond_ids, latent_shape, action_chunk_size, mode, self.config, + self.vae_scale_temporal, fps, action_fps, action_offset, device, + ) + do_cfg = guidance_scale != 1.0 + keys = _TF_STATIC_FIELDS + _TF_ACTION_STATIC_FIELDS + cond_static = {k: cond[k] for k in keys} + uncond_static = None + if do_cfg: + uncond = build_action_static_inputs( + uncond_ids, latent_shape, action_chunk_size, mode, self.config, + self.vae_scale_temporal, fps, action_fps, action_offset, device, + ) + uncond_static = {k: uncond[k] for k in keys} + num_noisy_v = cond["num_noisy_vision_tokens"] + num_noisy_a = cond["num_noisy_action_tokens"] + domain_t = torch.tensor([domain_id], dtype=torch.long, device=device) + + for t in scheduler.timesteps: + vts = torch.full((num_noisy_v,), t.item(), device=device) + ats = torch.full((num_noisy_a,), t.item(), device=device) + step_kwargs = dict( + vision_tokens=[latents.to(dtype)], vision_timesteps=vts, + action_tokens=action_latents.to(dtype), action_timesteps=ats, action_domain_id=domain_t, + ) + v_c, a_c, _ = self.transformer(**cond_static, **step_kwargs) + if do_cfg: + v_u, a_u, _ = self.transformer(**uncond_static, **step_kwargs) + video_v = v_u[0] + guidance_scale * (v_c[0] - v_u[0]) + action_v = a_u + guidance_scale * (a_c - a_u) + else: + video_v, action_v = v_c[0], a_c + + video_v = video_v * velocity_mask + action_v = action_v * action_velocity_mask + action_v[..., raw_action_dim:] = 0 + + nv = video_v.numel() + packed = torch.cat([video_v.reshape(1, -1), action_v.reshape(1, -1)], dim=1) + packed_lat = torch.cat([latents.reshape(1, -1), action_latents.reshape(1, -1)], dim=1) + packed_next = scheduler.step(packed, t, packed_lat, return_dict=False)[0] + latents = packed_next[:, :nv].reshape(latent_shape) + action_latents = packed_next[:, nv:].reshape(1, action_chunk_size, action_dim) + + latents = velocity_mask * latents + (1.0 - velocity_mask) * cond_latent + action_latents = action_velocity_mask * action_latents + (1.0 - action_velocity_mask) * clean_action + action_latents[..., raw_action_dim:] = 0 + + action_out = action_latents[:, :, :raw_action_dim] + if return_video: + return action_out, self._decode(latents) + return action_out diff --git a/mstar/model/cosmos3/submodules.py b/mstar/model/cosmos3/submodules.py new file mode 100644 index 00000000..bbda60ec --- /dev/null +++ b/mstar/model/cosmos3/submodules.py @@ -0,0 +1,1269 @@ +"""NodeSubmodule wrappers for the Cosmos3 generator nodes. + +Two nodes: + Cosmos3DiTSubmodule -- dual-pathway DiT (KV_CACHE). Dispatches by + graph_walk between ``prefill`` (the + understanding tower runs once over the text + prompt and writes its per-layer K/V) and + ``image_gen`` (one denoising step of the + generation tower per loop iteration, attending + to the frozen understanding K/V plus the + current generation tokens, then one scheduler + step). Classifier-free guidance keeps the + conditional and unconditional prompts in two + cache labels and combines their velocities. + Cosmos3VAEDecoderSubmodule -- Wan VAE decode (STATELESS): final latents to + pixels. + +Because the text tokens never receive a timestep embedding, the understanding +K/V is denoise-step independent, so writing it once and re-reading it every step +matches running the whole transformer each step. +""" + +from __future__ import annotations + +import logging +import os + +import torch + +from mstar.conductor.request_info import CurrentForwardPassInfo +from mstar.engine.cuda_graph_config import BasicBatchedCudaGraphConfig +from mstar.model.cosmos3.packing import ( + action_start_frame_offset, + build_action_static_inputs, + build_static_inputs, + vision_condition_frame_indexes, +) +from mstar.model.submodule_base import ( + ARNodeInputs, + ARNodeSubmodule, + ModelInputsFromEngine, + NodeInputs, + NodeSubmodule, +) + +logger = logging.getLogger(__name__) + +PREFILL_WALK = "prefill" +# Image/video-conditioned generation prefills the same understanding tower, but +# also VAE-encodes the conditioning frame into a clean anchor latent (see +# Cosmos3DiTSubmodule._encode_conditioning). It is a separate walk from the +# text-only prefill because the graph node only fires once all of its declared +# inputs arrive, so the conditioning image has to be one of them. +PREFILL_COND_WALK = "prefill_cond" +# Action inverse-dynamics conditions on a full video rather than a single frame, +# so it gets its own conditioned prefill that takes the video among its inputs. +PREFILL_COND_VIDEO_WALK = "prefill_cond_video" +IMAGE_GEN_WALK = "image_gen" +VIDEO_GEN_WALK = "video_gen" +ACTION_GEN_WALK = "action_gen" +# Forward-dynamics runs the same joint video+action denoise but emits the +# predicted video (VAE-decoded) instead of the action, so it has its own walk. +ACTION_VIDEO_GEN_WALK = "action_video_gen" + +# image_gen and video_gen run the identical denoise step (the DiT loop is +# shape-general over the frame count); they differ only in the emitted output +# modality (a single image frame vs an encoded video), which the graph fixes per +# walk, so the submodule treats them the same. +GEN_WALKS = (IMAGE_GEN_WALK, VIDEO_GEN_WALK) + +# All prefill variants run the same understanding-tower prefill; the conditioned +# ones additionally VAE-encode an image (prefill_cond) or video +# (prefill_cond_video) into anchor latents. +PREFILL_WALKS = (PREFILL_WALK, PREFILL_COND_WALK, PREFILL_COND_VIDEO_WALK) + +# Names of the denoise loops in the graph walks. The loops are built with a fixed +# upper-bound iteration count and each request stops its loop early at its own +# denoise-step count (see ``check_stop``), so one graph serves any per-request +# step count. +IMAGE_GEN_LOOP = "image_gen_loop" +VIDEO_GEN_LOOP = "video_gen_loop" +ACTION_GEN_LOOP = "action_gen_loop" +ACTION_VIDEO_GEN_LOOP = "action_video_gen_loop" + +# Both action walks run the joint video+action denoise loop body; they differ +# only in what they emit (the predicted action vs the predicted video). +ACTION_WALKS = (ACTION_GEN_WALK, ACTION_VIDEO_GEN_WALK) + +# Conditional prompt K/V lives under the primary label; the unconditional +# (negative) prompt's K/V lives under a second label for classifier-free +# guidance. Both are written once at prefill and read every denoise step. +COND_LABEL = "main" +UNCOND_LABEL = "uncond" + +# Combined label for the single FlashInfer plan that runs both guidance branches +# in one forward (see cache_manager.plan_attention_batched_cfg). +CFG_BATCHED_LABEL = "_cfg_batched" + + +class Cosmos3DiTSubmodule(ARNodeSubmodule): + """Dual-pathway DiT node (understanding tower + generation denoiser).""" + + # The denoise loop is data-dependent (per-step timestep .item(), scheduler + # step, classifier-free guidance combine), so torch.compile graph-breaks and + # buys little; CUDA-graph capture of the fixed-shape step is the accelerator. + disable_torch_compile = True + + # Run the two classifier-free-guidance branches as a single batched forward + # per denoise step instead of two sequential forwards. The math is the same; + # set False to fall back to the sequential path. + batched_cfg = True + + # Cap on how many requests share one batched denoise step. Concurrent + # requests at the image-generation walk run their step in a single forward. + max_gen_batch_size = 8 + + # Image resolutions (height, width) to capture a denoise-step CUDA graph for. + # Requests at other resolutions fall back to the eager path. num_frames is + # fixed at 1 (text-to-image). The graph accelerates the single-request + # (batch size 1) denoise step, where the forward is launch-bound: the win is + # large at low resolution (~2.5x at 320x192) and shrinks as the step becomes + # compute-bound at higher resolution. Concurrent requests batch via the eager + # path regardless. The default covers the three standard generation tiers; + # override with COSMOS3_GEN_CAPTURE_RES. The served graph output is identical + # to the eager path (compare with COSMOS3_DISABLE_CUDA_GRAPH=1). + gen_capture_resolutions: tuple[tuple[int, int], ...] = ( + (192, 320), (480, 832), (720, 1280), + ) + # Batch sizes to capture per resolution. + gen_capture_batch_sizes: tuple[int, ...] = (1,) + + def __init__(self, transformer, config, scheduler=None, vae=None): + super().__init__() + self.transformer = transformer + self.config = config + # Template scheduler; a fresh instance (with its own multistep state) is + # built per request from this one's config. + self._scheduler_template = scheduler + # Wan VAE (shared with the decoder node) — used to encode the + # conditioning frame for image-to-video / action conditioning. None for + # text-only generation. + self.vae = vae + self._video_processor = None + # Per-request denoising state: packed static inputs (cond/uncond), + # scheduler, guidance scale, latent shape. + self._req: dict[str, dict] = {} + # torch.compile the pure denoise compute (the generation-layer stack + + # norms + projections). fullgraph=False leaves the FlashInfer attention an + # opaque graph break, so compile fuses the bandwidth-bound pointwise ops + # around it; the compiled kernels then bake into the per-resolution image + # CUDA graphs (capture's warmup forwards trace them before the graph + # records). disable_torch_compile stays True so the engine does not also + # compile the data-dependent submodule wrapper. On by default — frees + # ~1.2-1.3x per denoise step at the generation tiers with no change in + # image/golden quality vs the fused reference (the first request at each + # uncaptured shape pays a one-time trace). Set + # COSMOS3_DISABLE_COMPILE_DENOISE=1 for the eager step (A/B / debugging). + if not os.environ.get("COSMOS3_DISABLE_COMPILE_DENOISE"): + self.transformer.denoise_step = torch.compile( + self.transformer.denoise_step, fullgraph=False, dynamic=False, + ) + self.transformer.denoise_step_batched_cfg = torch.compile( + self.transformer.denoise_step_batched_cfg, fullgraph=False, dynamic=False, + ) + logger.info("Cosmos3 denoise compute torch.compile enabled") + + def to(self, *args, **kwargs): + # The engine casts this submodule to bf16 (worker.engine_manager), which + # also casts the timestep embedder. Diffusers keeps that module in fp32 + # (_keep_in_fp32_modules) and the reference pipeline computes the timestep + # embedding in fp32; the multi-step video denoise is sensitive to its + # precision (running it in bf16 perturbs the velocity enough to scramble + # the latents). Re-assert fp32 after any cast — paired with the + # autocast-disabled forward below so it actually runs in fp32. The upcast + # is lossless (the checkpoint weights are bf16). + super().to(*args, **kwargs) + te = getattr(self.transformer, "time_embedder", None) + if te is not None: + te.float() + return self + + def get_needed_cache_labels( + self, graph_walk: str, per_request_info: dict[str, CurrentForwardPassInfo], + ) -> list[str] | None: + return [COND_LABEL, UNCOND_LABEL] + + # ------------------------------------------------------------------ + # Static packing + scheduler helpers + # ------------------------------------------------------------------ + + def _latent_shape( + self, height: int, width: int, num_frames: int = 1 + ) -> tuple[int, int, int, int, int]: + s = self.config.vae.scale_factor_spatial + t = 1 if num_frames == 1 else 1 + (num_frames - 1) // self.config.vae.scale_factor_temporal + return (1, self.config.latent_channel, t, height // s, width // s) + + def _build_static( + self, ids: list[int], height: int, width: int, num_frames: int, + fps: float, has_image_condition: bool, device, + ) -> dict: + static = build_static_inputs( + list(ids), self._latent_shape(height, width, num_frames), self.config, + self.config.vae.scale_factor_temporal, fps, device, + has_image_condition=has_image_condition, + ) + # proj_out runs on the generation token block, so shift the joint-sequence + # mse indexes to be relative to the vision tokens. + static["mse_gen_indexes"] = static["vision_mse_loss_indexes"] - static["und_len"] + return static + + def _new_scheduler(self, num_inference_steps: int, device, flow_shift=None): + from diffusers import UniPCMultistepScheduler + + if flow_shift is not None: + scheduler = UniPCMultistepScheduler.from_config(self._scheduler_template.config, flow_shift=flow_shift) + else: + scheduler = UniPCMultistepScheduler.from_config(self._scheduler_template.config) + scheduler.set_timesteps(num_inference_steps, device=device) + return scheduler + + def _build_action_static( + self, ids: list[int], height: int, width: int, num_frames: int, action_chunk: int, + mode: str, fps: float, action_fps: float, action_offset: int, device, + ) -> dict: + static = build_action_static_inputs( + list(ids), self._latent_shape(height, width, num_frames), action_chunk, mode, + self.config, self.config.vae.scale_factor_temporal, fps, action_fps, action_offset, device, + ) + # proj_out runs on the generation token block; shift the joint-sequence + # mse indexes to be relative to the [vision | action] generation tokens. + static["mse_gen_indexes"] = static["vision_mse_loss_indexes"] - static["und_len"] + static["action_mse_gen_indexes"] = static["action_mse_loss_indexes"] - static["und_len"] + return static + + # ------------------------------------------------------------------ + # prepare_inputs + # ------------------------------------------------------------------ + + def prepare_inputs( + self, graph_walk, fwd_info, inputs, seen_token_mask=None, pos_info={}, + ) -> ARNodeInputs: + device = self.get_device() + if graph_walk in PREFILL_WALKS: + return self._prepare_prefill(fwd_info, inputs, device) + if graph_walk in GEN_WALKS: + return self._prepare_image_gen(fwd_info, inputs, device) + if graph_walk in ACTION_WALKS: + return self._prepare_action_gen(fwd_info, inputs, device) + raise ValueError(f"Unknown Cosmos3 DiT graph walk: {graph_walk!r}") + + def _prepare_prefill(self, fwd_info, inputs, device) -> ARNodeInputs: + md = fwd_info.step_metadata + height, width = int(md.get("height", 256)), int(md.get("width", 256)) + fps = float(md.get("fps", 24.0)) + gs = float(md.get("guidance_scale", 6.0)) + steps = int(md.get("num_inference_steps", self.config.num_inference_steps)) + cond_ids = inputs["text_inputs"][0].tolist() + uncond_ids = inputs["text_inputs"][1].tolist() if gs != 1.0 else None + + action_mode = md.get("action_mode") + if action_mode: + return self._prepare_action_prefill( + fwd_info, md, inputs, cond_ids, uncond_ids, height, width, fps, gs, steps, device + ) + + num_frames = int(md.get("num_frames", 1)) + # Image-to-video: latent frame 0 is a clean conditioning anchor supplied + # in the first denoise step's ``latents``; it stays in the sequence but is + # not denoised. (Text-to-image / text-to-video have no clean anchor.) + has_image_condition = bool(md.get("has_image_condition", False)) + + cond = self._build_static(cond_ids, height, width, num_frames, fps, has_image_condition, device) + uncond = None + if uncond_ids is not None: + uncond = self._build_static(uncond_ids, height, width, num_frames, fps, has_image_condition, device) + + self._req[fwd_info.request_id] = { + "cond": cond, + "uncond": uncond, + "gs": gs, + "guidance_interval": md.get("guidance_interval"), + "scheduler": self._new_scheduler(steps, device, flow_shift=md.get("flow_shift")), + "num_noisy": cond["num_noisy_vision_tokens"], + "num_vision": cond["num_vision_tokens"], + "latent_shape": self._latent_shape(height, width, num_frames), + } + # Image-to-video: encode the conditioning frame now (the understanding + # tower and the VAE encode are both prefill-time, per-request work) and + # stash its clean anchor latents for the denoise loop to inject. + if has_image_condition: + image = (inputs or {}).get("image_inputs") + if image: + self._req[fwd_info.request_id]["cond_latents"] = self._encode_conditioning( + image[0], height, width, num_frames, device, anchor_only=True + ) + return ARNodeInputs(input_seq_len=cond["und_len"]) + + def _encode_conditioning(self, image, height, width, num_frames, device, anchor_only=False): + """VAE-encode a conditioning frame into clean anchor latents. + + Mirrors the fused pipeline's image-to-video latent prep: the frame is + resized and normalized to [-1, 1], repeat-padded across the clip, and + Wan-VAE encoded with the pipeline-side latent normalization. Latent + frame 0 is the clean anchor the denoise loop keeps fixed. + + Image-to-video only consumes latent frame 0, and the Wan VAE encodes + frame 0 as a standalone causal anchor, so ``anchor_only`` skips the + repeat-pad and encodes the single frame (a bit-identical frame 0) + instead of the whole clip — at video lengths the full encode is the + bulk of the conditioning cost. The encode runs in fp32 outside autocast: + the VAE's 3D convs are far faster in fp32 (TF32) than bf16 on this cuDNN + and the reference pipeline encodes in fp32 (matching the decoder).""" + from diffusers.video_processor import VideoProcessor + + vae = self.vae + if next(vae.parameters()).dtype != torch.float32: + vae.float() + dtype = self.transformer.proj_in.weight.dtype + if self._video_processor is None: + self._video_processor = VideoProcessor( + vae_scale_factor=self.config.vae.scale_factor_spatial, resample="bilinear" + ) + # load_image gives [C, H, W] in [0, 1]; preprocess -> [1, 3, H, W] in [-1, 1]. + frame = self._video_processor.preprocess(image, height=height, width=width).to( + device=device, dtype=torch.float32 + ) + vision = frame.unsqueeze(2) + if num_frames > 1 and not anchor_only: + vision = vision.expand(-1, -1, num_frames, -1, -1) + mean = torch.tensor(vae.config.latents_mean, dtype=torch.float32, device=device).view(1, -1, 1, 1, 1) + inv_std = (1.0 / torch.tensor(vae.config.latents_std, dtype=torch.float32, device=device)).view( + 1, -1, 1, 1, 1 + ) + with torch.autocast(device_type=vision.device.type, enabled=False): + raw_mu = vae.encode(vision).latent_dist.mode() + return ((raw_mu - mean) * inv_std).to(dtype) + + def _prepare_action_prefill( + self, fwd_info, md, inputs, cond_ids, uncond_ids, height, width, fps, gs, steps, device, + ) -> ARNodeInputs: + mode = md["action_mode"] + action_chunk = int(md["action_chunk_size"]) + num_frames = int(md.get("num_frames") or action_chunk + 1) + raw_action_dim = int(md["raw_action_dim"]) + domain_id = int(md.get("domain_id", 0)) + action_fps = float(md.get("action_fps", fps)) + action_offset = action_start_frame_offset(action_chunk, num_frames) + + cond = self._build_action_static( + cond_ids, height, width, num_frames, action_chunk, mode, fps, action_fps, action_offset, device + ) + uncond = None + if uncond_ids is not None: + uncond = self._build_action_static( + uncond_ids, height, width, num_frames, action_chunk, mode, fps, action_fps, action_offset, device + ) + + latent_shape = self._latent_shape(height, width, num_frames) + t_lat = latent_shape[2] + dtype = self.transformer.proj_in.weight.dtype + action_dim = self.transformer.action_dim + vmask = torch.zeros((1, 1, t_lat, 1, 1), device=device, dtype=dtype) + for f in vision_condition_frame_indexes(mode, t_lat): + vmask[:, :, f] = 1.0 + action_clean = torch.zeros((1, action_chunk, 1), device=device, dtype=dtype) + if mode == "forward_dynamics": + action_clean[:] = 1.0 + + # Encode the visual conditioning to clean anchor latents: inverse-dynamics + # conditions on the whole video (all frames), forward-dynamics / policy on + # a single frame (frame 0). The per-mode vmask above selects which latent + # frames are kept clean from these. + cond_video = (inputs or {}).get("video_inputs") + cond_image = (inputs or {}).get("image_inputs") + if cond_video: + cond_latents = self._encode_conditioning_video(cond_video[0], height, width, num_frames, device) + elif cond_image: + cond_latents = self._encode_conditioning(cond_image[0], height, width, num_frames, device) + else: + cond_latents = torch.zeros(latent_shape, device=device, dtype=dtype) + + # Forward-dynamics conditions on a clean action chunk supplied with the + # request; the other modes predict the action (clean values are zero). + clean_action = torch.zeros((1, action_chunk, action_dim), device=device, dtype=dtype) + raw_act = md.get("action") + if mode == "forward_dynamics" and raw_act is not None: + act = torch.as_tensor(raw_act, device=device, dtype=dtype) + if act.ndim == 3: + act = act[0] + if act.shape[0] < action_chunk: + act = torch.cat([act, act[-1:].repeat(action_chunk - act.shape[0], 1)], dim=0) + elif act.shape[0] > action_chunk: + act = act[:action_chunk] + clean_action[:, :, :raw_action_dim] = act[:, :raw_action_dim] + + self._req[fwd_info.request_id] = { + "cond": cond, + "uncond": uncond, + "gs": gs, + "scheduler": self._new_scheduler(steps, device, flow_shift=md.get("flow_shift")), + "num_noisy": cond["num_noisy_vision_tokens"], + "num_noisy_action": cond["num_noisy_action_tokens"], + "num_vision": cond["num_vision_tokens"], + "num_action": cond["num_action_tokens"], + "latent_shape": latent_shape, + "action_mode": mode, + "action_chunk": action_chunk, + "action_dim": action_dim, + "raw_action_dim": raw_action_dim, + "domain_t": torch.tensor([domain_id], dtype=torch.long, device=device), + "vmask": vmask, + "velocity_mask": 1.0 - vmask, + "action_clean_mask": action_clean, + "action_velocity_mask": 1.0 - action_clean, + "cond_video_latents": cond_latents, + "clean_action": clean_action, + } + return ARNodeInputs(input_seq_len=cond["und_len"]) + + def _encode_conditioning_video(self, video, height, width, num_frames, device): + """VAE-encode a conditioning video clip into clean anchor latents. + + Used by action inverse-dynamics, which conditions on the whole observed + clip. load_video gives [T, C, H, W] in [0, 1]; each frame is resized and + normalized to [-1, 1] (matching the fused pipeline) and the clip is + Wan-VAE encoded with the pipeline-side latent normalization.""" + from diffusers.video_processor import VideoProcessor + + vae = self.vae + if next(vae.parameters()).dtype != torch.float32: + vae.float() + dtype = self.transformer.proj_in.weight.dtype + if self._video_processor is None: + self._video_processor = VideoProcessor( + vae_scale_factor=self.config.vae.scale_factor_spatial, resample="bilinear" + ) + clip = video[:num_frames] + frames = [ + self._video_processor.preprocess(clip[i], height=height, width=width).squeeze(0) + for i in range(clip.shape[0]) + ] + # fp32 outside autocast: the VAE 3D convs are much faster in fp32 (TF32) + # than bf16 on this cuDNN, and the reference pipeline encodes in fp32. + vision = torch.stack(frames, dim=1).unsqueeze(0).to(device=device, dtype=torch.float32) # [1,3,T,H,W] + mean = torch.tensor(vae.config.latents_mean, dtype=torch.float32, device=device).view(1, -1, 1, 1, 1) + inv_std = (1.0 / torch.tensor(vae.config.latents_std, dtype=torch.float32, device=device)).view( + 1, -1, 1, 1, 1 + ) + with torch.autocast(device_type=vision.device.type, enabled=False): + raw_mu = vae.encode(vision).latent_dist.mode() + return ((raw_mu - mean) * inv_std).to(dtype) + + def _prepare_image_gen(self, fwd_info, inputs, device) -> ARNodeInputs: + st = self._req[fwd_info.request_id] + if "latents" not in inputs or len(inputs["latents"]) == 0: + gen = torch.Generator(device=device).manual_seed(fwd_info.random_seed) + latents = torch.randn( + st["latent_shape"], generator=gen, device=device, dtype=self.transformer.proj_in.weight.dtype + ) + cond_latents = st.get("cond_latents") + if cond_latents is not None: + # Image-to-video: latent frame 0 is the clean conditioning anchor; + # the rest is noise. It stays clean through the loop because the + # predicted velocity is zero on conditioning frames (unpatchify + # only fills the noisy frames), matching the fused pipeline. + latents[:, :, 0] = cond_latents[:, :, 0].to(latents.dtype) + time_index = torch.zeros(1, dtype=torch.long, device=device) + else: + latents = inputs["latents"][0] + time_index = inputs["time_index"][0] + tensors = {"latents": latents, "time_index": time_index} + # The CUDA-graph capture reads the timestep and rotary positions as static + # buffers (it can't reach the per-request scheduler at replay), so + # materialize them here. The eager path ignores these and recomputes from + # per-request state. Only built in the two-branch guidance regime — the + # one the graph captures. + if st["uncond"] is not None: + # The denoise loop may dispatch one extra (discarded) step past this + # request's step count; clamp so materializing the static timestep + # buffer can't index past the schedule. + n_steps = len(st["scheduler"].timesteps) + idx = time_index.reshape(-1).clamp(max=n_steps - 1) + t = st["scheduler"].timesteps[idx].to(torch.float32) + tensors["vision_timesteps"] = t.expand(st["num_noisy"]).contiguous() + tensors["position_ids_cond"] = st["cond"]["vision_mrope_ids"] + tensors["position_ids_uncond"] = st["uncond"]["vision_mrope_ids"] + return ARNodeInputs( + input_seq_len=st["num_vision"], + tensor_inputs=tensors, + ) + + def _prepare_action_gen(self, fwd_info, inputs, device) -> ARNodeInputs: + st = self._req[fwd_info.request_id] + if "latents" not in inputs or len(inputs["latents"]) == 0: + # First iteration: build the joint [video | action] latents. Per the + # mode masks, conditioning frames/action are clean and the predicted + # ones start from noise; the clean anchors are then carried in the + # looped latents (re-injected each step). Action noise is drawn before + # the video noise to match the fused pipeline's RNG order. + from diffusers.utils.torch_utils import randn_tensor + + dtype = self.transformer.proj_in.weight.dtype + gen = torch.Generator(device=device).manual_seed(fwd_info.random_seed) + chunk, adim, raw = st["action_chunk"], st["action_dim"], st["raw_action_dim"] + a_noise = randn_tensor((1, chunk, adim), generator=gen, device=device, dtype=dtype) + a_noise[..., raw:] = 0 + action_latents = ( + st["action_clean_mask"] * st["clean_action"] + st["action_velocity_mask"] * a_noise + ) + action_latents[..., raw:] = 0 + v_noise = randn_tensor(st["latent_shape"], generator=gen, device=device, dtype=dtype) + latents = st["vmask"] * st["cond_video_latents"] + st["velocity_mask"] * v_noise + time_index = torch.zeros(1, dtype=torch.long, device=device) + else: + latents = inputs["latents"][0] + action_latents = inputs["action_latents"][0] + time_index = inputs["time_index"][0] + return ARNodeInputs( + input_seq_len=st["num_vision"] + st["num_action"], + tensor_inputs={"latents": latents, "action_latents": action_latents, "time_index": time_index}, + ) + + # ------------------------------------------------------------------ + # preprocess: plan paged attention for the labels this walk touches. + # ------------------------------------------------------------------ + + def _plan_gen(self, cm, st, num_gen: int, cfg_active: bool = True) -> None: + """Plan a denoise step's non-causal attention: one batched plan covering + both guidance branches when they run together, else a plan per label. + ``cfg_active`` False (a guidance_interval out-of-interval step, or + gs==1) plans the conditional branch alone — matching the cond-only + forward — so an interval step costs no wasted uncond/batched plan.""" + if st["uncond"] is None or not cfg_active: + cm.plan_attention(seq_lens=[num_gen], is_causal=False, label=COND_LABEL, write_store=False) + elif self.batched_cfg: + cm.plan_attention_batched_cfg( + labels=[COND_LABEL, UNCOND_LABEL], seq_lens=[num_gen], + is_causal=False, write_store=False, + ) + else: + cm.plan_attention(seq_lens=[num_gen], is_causal=False, label=COND_LABEL, write_store=False) + cm.plan_attention(seq_lens=[num_gen], is_causal=False, label=UNCOND_LABEL, write_store=False) + + def _preprocess_image_gen_captured(self, cm, inputs) -> dict: + """Plan a denoise step for the CUDA-graph path. + + Runs with synthetic request ids (no per-request state), so it derives the + token count from ``input_seq_len``. Both guidance branches are planned as + one combined attention (``plan_attention_batched_cfg``) so the captured + forward runs a single transformer pass over both — one weight load instead + of two. The static-input tensors (latents, timestep, rotary positions) are + stacked on a leading batch dim, so one captured graph spans a whole + concurrent batch (a batch of one for the single-request latency path); the + replay side copies each request's tensors into these fixed buffers. + """ + seq_lens = [inp.input_seq_len for inp in inputs] + cm.plan_attention_batched_cfg( + labels=[COND_LABEL, UNCOND_LABEL], seq_lens=seq_lens, + is_causal=False, write_store=False, + ) + return { + "latents": torch.stack([inp.tensor_inputs["latents"] for inp in inputs]), + "vision_timesteps": torch.stack([inp.tensor_inputs["vision_timesteps"] for inp in inputs]), + "position_ids_cond": torch.stack([inp.tensor_inputs["position_ids_cond"] for inp in inputs]), + "position_ids_uncond": torch.stack([inp.tensor_inputs["position_ids_uncond"] for inp in inputs]), + } + + def preprocess(self, graph_walk, engine_inputs: ModelInputsFromEngine, inputs) -> dict: + cm = engine_inputs.cache_manager + + if graph_walk == IMAGE_GEN_WALK and getattr(cm, "_cuda_graph_mode", False): + return self._preprocess_image_gen_captured(cm, inputs) + + st = self._req[engine_inputs.request_ids[0]] + + if graph_walk in PREFILL_WALKS: + cm.plan_attention(seq_lens=[st["cond"]["und_len"]], is_causal=True, label=COND_LABEL, write_store=False) + if st["uncond"] is not None: + cm.plan_attention( + seq_lens=[st["uncond"]["und_len"]], is_causal=True, label=UNCOND_LABEL, write_store=False + ) + return {} + + if graph_walk in GEN_WALKS: + rids = engine_inputs.request_ids + if len(rids) > 1: + # Cross-request batch: one batched plan over every request's two + # guidance branches, each with its own page set and token count. + cm.plan_attention_batched_cfg( + labels=[COND_LABEL, UNCOND_LABEL], + seq_lens=[self._req[r]["num_vision"] for r in rids], + is_causal=False, write_store=False, + ) + return { + "latents": {r: inp.tensor_inputs["latents"] for r, inp in zip(rids, inputs, strict=True)}, + "time_index": {r: inp.tensor_inputs["time_index"] for r, inp in zip(rids, inputs, strict=True)}, + } + ti = inputs[0].tensor_inputs["time_index"] + step_index = int(ti.reshape(-1)[0].item()) + self._plan_gen( + cm, st, st["num_vision"], cfg_active=self._cfg_active(st, step_index) + ) + return { + "latents": inputs[0].tensor_inputs["latents"], + "time_index": ti, + } + + if graph_walk in ACTION_WALKS: + rids = engine_inputs.request_ids + if len(rids) > 1: + # Cross-request batch: one batched plan over every request's joint + # [video | action] block, each with its own page set and token + # count. A single label when guidance is off (the common + # guidance-scale-1 case), both labels with classifier-free + # guidance. + sts = [self._req[r] for r in rids] + labels = ( + [COND_LABEL, UNCOND_LABEL] if sts[0]["uncond"] is not None else [COND_LABEL] + ) + cm.plan_attention_batched_cfg( + labels=labels, + seq_lens=[s["num_vision"] + s["num_action"] for s in sts], + is_causal=False, write_store=False, + ) + return { + "latents": {r: inp.tensor_inputs["latents"] for r, inp in zip(rids, inputs, strict=True)}, + "action_latents": { + r: inp.tensor_inputs["action_latents"] for r, inp in zip(rids, inputs, strict=True) + }, + "time_index": {r: inp.tensor_inputs["time_index"] for r, inp in zip(rids, inputs, strict=True)}, + } + self._plan_gen(cm, st, st["num_vision"] + st["num_action"]) + return { + "latents": inputs[0].tensor_inputs["latents"], + "action_latents": inputs[0].tensor_inputs["action_latents"], + "time_index": inputs[0].tensor_inputs["time_index"], + } + raise ValueError(f"Unknown Cosmos3 DiT graph walk: {graph_walk!r}") + + # ------------------------------------------------------------------ + # forward + # ------------------------------------------------------------------ + + # Run the prefill/denoise in the model's native bf16, NOT under the engine's + # autocast. The fused reference pipeline runs the transformer in pure bf16; + # autocast keeps normalization in fp32, which perturbs the predicted velocity + # by ~1 ULP per step. A single image step stays well within tolerance, but the + # multi-step video denoise amplifies that perturbation geometrically into a + # scrambled latent. The cache-once engine path must reproduce the reference, + # so this submodule opts out of autocast (the VAE decoder does the same). + @torch.autocast(device_type="cuda", enabled=False) + def forward(self, graph_walk, engine_inputs: ModelInputsFromEngine, **kwargs): + cm = engine_inputs.cache_manager + rid = engine_inputs.request_ids[0] + if graph_walk in PREFILL_WALKS: + return self._forward_prefill(cm, self._req[rid]) + if graph_walk in GEN_WALKS: + return self._forward_image_gen(cm, self._req[rid], **kwargs) + if graph_walk in ACTION_WALKS: + return self._forward_action_gen(cm, self._req[rid], **kwargs) + raise ValueError(f"Unknown Cosmos3 DiT graph walk: {graph_walk!r}") + + def _forward_prefill(self, cm, st) -> dict: + _prof = os.environ.get("COSMOS3_PROFILE") + if _prof: + _e0 = torch.cuda.Event(enable_timing=True); _e1 = torch.cuda.Event(enable_timing=True) + _e0.record() + cond = st["cond"] + cm.set_active_label(COND_LABEL) + self.transformer.prefill_und(cond["input_ids"], cond["text_mrope_ids"], cm) + if st["uncond"] is not None: + uncond = st["uncond"] + cm.set_active_label(UNCOND_LABEL) + self.transformer.prefill_und(uncond["input_ids"], uncond["text_mrope_ids"], cm) + if _prof: + _e1.record(); torch.cuda.synchronize() + logger.info("COSMOS3_PROFILE prefill %.1f ms", _e0.elapsed_time(_e1)) + return {} + + def _denoise(self, cm, static, latents, vision_timesteps): + return self.transformer.denoise_step( + latents, + vision_timesteps, + static["vision_mrope_ids"], + static["vision_token_shapes"], + static["vision_noisy_frame_indexes"], + static["mse_gen_indexes"], + cm, + ) + + def _cfg_active(self, st, step_index: int) -> bool: + """Whether this denoise step runs classifier-free guidance (both + branches combined). False ⇒ the conditional branch runs alone — the + guidance_scale==1 case and, for the t2i recipe, steps whose timestep + falls outside the guidance_interval [lo, hi]. ``preprocess`` and + ``_forward_image_gen`` both call this for the same step so the planned + attention (batched vs cond-only) matches the forward that runs.""" + if st["uncond"] is None: + return False + gi = st.get("guidance_interval") + if gi is None: + return True + sched = st["scheduler"] + if step_index >= len(sched.timesteps): + return False + t = float(sched.timesteps[step_index].item()) + return gi[0] <= t <= gi[1] + + def _forward_image_gen(self, cm, st, latents, time_index, **kwargs) -> dict: + scheduler = st["scheduler"] + step_index = int(time_index.reshape(-1)[0].item()) + if step_index >= len(scheduler.timesteps): + # One extra step past this request's denoise count: the loop has + # already been told to stop and this output is discarded. Pass the + # finished latents through without touching the (stateful) scheduler. + return {"latents": [latents], "time_index": [time_index]} + t = scheduler.timesteps[step_index] + vision_timesteps = torch.full((st["num_noisy"],), t.item(), device=latents.device) + + # Classifier-free guidance is applied only when an uncond branch exists + # (guidance_scale != 1) and, for the text-to-image recipe, only on the + # configured timestep interval. Outside the interval the step runs the + # conditional branch alone (cond-only velocity), matching the recipe. + cfg_active = self._cfg_active(st, step_index) + + if not cfg_active: + cm.set_active_label(COND_LABEL) + velocity = self._denoise(cm, st["cond"], latents, vision_timesteps) + elif self.batched_cfg: + cm.set_active_label(CFG_BATCHED_LABEL) + cond_v, uncond_v = self.transformer.denoise_step_batched_cfg( + latents, + vision_timesteps, + st["cond"]["vision_mrope_ids"], + st["uncond"]["vision_mrope_ids"], + st["cond"]["vision_token_shapes"], + st["cond"]["vision_noisy_frame_indexes"], + st["cond"]["mse_gen_indexes"], + cm, + ) + velocity = uncond_v + st["gs"] * (cond_v - uncond_v) + else: + cm.set_active_label(COND_LABEL) + cond_v = self._denoise(cm, st["cond"], latents, vision_timesteps) + cm.set_active_label(UNCOND_LABEL) + uncond_v = self._denoise(cm, st["uncond"], latents, vision_timesteps) + velocity = uncond_v + st["gs"] * (cond_v - uncond_v) + + new_latents = scheduler.step( + velocity.unsqueeze(0), t, latents.unsqueeze(0), return_dict=False + )[0].squeeze(0) + return {"latents": [new_latents], "time_index": [time_index + 1]} + + def _denoise_action(self, cm, static, latents, action_latents, vts, ats, domain): + und_len = static["und_len"] + return self.transformer.denoise_step( + latents, + vts, + static["position_ids"][:, und_len:], + static["vision_token_shapes"], + static["vision_noisy_frame_indexes"], + static["mse_gen_indexes"], + cm, + action_latents=action_latents, + action_token_shapes=static["action_token_shapes"], + action_noisy_frame_indexes=static["action_noisy_frame_indexes"], + action_mse_gen_indexes=static["action_mse_gen_indexes"], + action_timesteps=ats, + action_domain_id=domain, + ) + + def _action_scheduler_step(self, st, latents, action_latents, video_v, action_v, t): + """One joint [video | action] scheduler step for an action request: mask + the predicted velocities to their noisy bands, step the request's own + scheduler over the packed [video | action] state, then re-inject the clean + conditioning anchors (conditioning frames / action stay clean each step, + their masked-in values invariant). Shared by the single-request and + cross-request batched action forwards.""" + raw, chunk, adim = st["raw_action_dim"], st["action_chunk"], st["action_dim"] + video_v = video_v * st["velocity_mask"] + action_v = action_v * st["action_velocity_mask"] + action_v[..., raw:] = 0 + nv = video_v.numel() + packed = torch.cat([video_v.reshape(1, -1), action_v.reshape(1, -1)], dim=1) + packed_lat = torch.cat([latents.reshape(1, -1), action_latents.reshape(1, -1)], dim=1) + packed_next = st["scheduler"].step(packed, t, packed_lat, return_dict=False)[0] + new_latents = packed_next[:, :nv].reshape(latents.shape) + new_action = packed_next[:, nv:].reshape(1, chunk, adim) + new_latents = st["velocity_mask"] * new_latents + st["vmask"] * latents + new_action = st["action_velocity_mask"] * new_action + st["action_clean_mask"] * action_latents + new_action[..., raw:] = 0 + return new_latents, new_action + + def _forward_action_gen(self, cm, st, latents, action_latents, time_index, **kwargs) -> dict: + scheduler = st["scheduler"] + step_index = int(time_index.reshape(-1)[0].item()) + if step_index >= len(scheduler.timesteps): + # One extra step past this request's denoise count (discarded output). + return { + "latents": [latents], + "action_latents": [action_latents], + "time_index": [time_index], + } + t = scheduler.timesteps[step_index] + device = latents.device + vts = torch.full((st["num_noisy"],), t.item(), device=device) + ats = torch.full((st["num_noisy_action"],), t.item(), device=device) + domain = st["domain_t"] + + if st["uncond"] is None: + cm.set_active_label(COND_LABEL) + video_v, action_v = self._denoise_action(cm, st["cond"], latents, action_latents, vts, ats, domain) + elif self.batched_cfg: + cm.set_active_label(CFG_BATCHED_LABEL) + (video_v, action_v), (v_u, a_u) = self.transformer.denoise_step_batched_cfg( + latents, + vts, + st["cond"]["position_ids"][:, st["cond"]["und_len"]:], + st["uncond"]["position_ids"][:, st["uncond"]["und_len"]:], + st["cond"]["vision_token_shapes"], + st["cond"]["vision_noisy_frame_indexes"], + st["cond"]["mse_gen_indexes"], + cm, + action_latents=action_latents, + action_token_shapes=st["cond"]["action_token_shapes"], + action_noisy_frame_indexes=st["cond"]["action_noisy_frame_indexes"], + action_mse_gen_indexes=st["cond"]["action_mse_gen_indexes"], + action_timesteps=ats, + action_domain_id=domain, + ) + video_v = v_u + st["gs"] * (video_v - v_u) + action_v = a_u + st["gs"] * (action_v - a_u) + else: + cm.set_active_label(COND_LABEL) + video_v, action_v = self._denoise_action(cm, st["cond"], latents, action_latents, vts, ats, domain) + cm.set_active_label(UNCOND_LABEL) + v_u, a_u = self._denoise_action(cm, st["uncond"], latents, action_latents, vts, ats, domain) + video_v = v_u + st["gs"] * (video_v - v_u) + action_v = a_u + st["gs"] * (action_v - a_u) + + new_latents, new_action = self._action_scheduler_step( + st, latents, action_latents, video_v, action_v, t + ) + return { + "latents": [new_latents], + "action_latents": [new_action], + "time_index": [time_index + 1], + } + + # ------------------------------------------------------------------ + # Cross-request batching: run several requests' denoise step together. + # ------------------------------------------------------------------ + + def can_batch(self, batch, model_inputs) -> bool: + # The denoise step batches across concurrent requests at the same walk. + # The batched forward packs each request's own token shapes, so requests + # at different resolutions / frame counts (and, for action, different + # modes / embodiment domains) can share the batch. One request stays on + # the simpler single-request path. + if not self.batched_cfg or len(batch.request_ids) < 2: + return False + sts = [self._req.get(rid) for rid in batch.request_ids] + if any(st is None for st in sts): + return False + if batch.graph_walk in GEN_WALKS: + # Image/video batch only in the two-branch guidance regime, so one + # batched-CFG plan covers them. + return all(st["uncond"] is not None for st in sts) + if batch.graph_walk in ACTION_WALKS: + # Action batches when all requests share the guidance regime (all + # single-branch -- guidance-scale-1 inverse/forward-dynamics and base + # policy -- or all two-branch), so one plan covers the batch. Modes + # and embodiment domains may differ: each request's masks, scheduler + # and domain-aware action projection are applied per request. + return len({st["uncond"] is not None for st in sts}) == 1 + return False + + def max_batch_size(self, graph_walk: str): + if graph_walk in GEN_WALKS or graph_walk in ACTION_WALKS: + return self.max_gen_batch_size + return None + + # Native bf16, not the engine autocast — see the note on forward(). The + # cross-request batched denoise must match the per-request path exactly. + @torch.autocast(device_type="cuda", enabled=False) + def forward_batched( + self, graph_walk, engine_inputs: ModelInputsFromEngine, + latents, time_index, action_latents=None, **kwargs, + ): + if graph_walk in ACTION_WALKS: + return self._forward_batched_action(engine_inputs, latents, action_latents, time_index) + if graph_walk not in GEN_WALKS: + raise ValueError(f"Cosmos3 batched forward only supports generation walks, got {graph_walk!r}") + cm = engine_inputs.cache_manager + cm.set_active_label(CFG_BATCHED_LABEL) + reqs, meta = [], [] + for rid in engine_inputs.request_ids: + st = self._req[rid] + lat, ti = latents[rid], time_index[rid] + step_index = int(ti.reshape(-1)[0].item()) + n_steps = len(st["scheduler"].timesteps) + # A request may be one step past its denoise count (a discarded extra + # step) while others in the batch are still running; clamp its + # timestep so the shared forward can't index past the schedule, and + # skip its scheduler step below. + past_end = step_index >= n_steps + t = st["scheduler"].timesteps[min(step_index, n_steps - 1)] + reqs.append({ + "latents": lat, + "vision_timesteps": torch.full((st["num_noisy"],), t.item(), device=lat.device), + "position_ids_cond": st["cond"]["vision_mrope_ids"], + "position_ids_uncond": st["uncond"]["vision_mrope_ids"], + "vision_token_shapes": st["cond"]["vision_token_shapes"], + "vision_noisy_frame_indexes": st["cond"]["vision_noisy_frame_indexes"], + "vision_mse_loss_indexes": st["cond"]["mse_gen_indexes"], + }) + meta.append((rid, st, lat, ti, t, past_end)) + + results = self.transformer.denoise_step_batched(reqs, cm) + + out = {} + for (rid, st, lat, ti, t, past_end), (cond_v, uncond_v) in zip(meta, results, strict=True): + if past_end: + out[rid] = {"latents": [lat], "time_index": [ti]} + continue + velocity = uncond_v + st["gs"] * (cond_v - uncond_v) + new_latents = st["scheduler"].step( + velocity.unsqueeze(0), t, lat.unsqueeze(0), return_dict=False + )[0].squeeze(0) + out[rid] = {"latents": [new_latents], "time_index": [ti + 1]} + return out + + def _forward_batched_action(self, engine_inputs, latents, action_latents, time_index): + """Run several action requests' joint [video | action] denoise step in one + forward. Mirrors the image batched path: build each request's static gen + inputs (clamping a request that has run one step past its denoise count), + run one batched transformer pass, then per request combine the guidance + branches (when present) and apply its own joint scheduler step.""" + cm = engine_inputs.cache_manager + cm.set_active_label(CFG_BATCHED_LABEL) + rids = engine_inputs.request_ids + with_cfg = self._req[rids[0]]["uncond"] is not None + reqs, meta = [], [] + for rid in rids: + st = self._req[rid] + lat, act, ti = latents[rid], action_latents[rid], time_index[rid] + step_index = int(ti.reshape(-1)[0].item()) + n_steps = len(st["scheduler"].timesteps) + # A request may be one (discarded) step past its denoise count while + # others in the batch are still running; clamp its timestep so the + # shared forward can't index past the schedule, and skip its scheduler + # step below. + past_end = step_index >= n_steps + t = st["scheduler"].timesteps[min(step_index, n_steps - 1)] + cond = st["cond"] + und = cond["und_len"] + req = { + "latents": lat, + "action_latents": act, + "vision_timesteps": torch.full((st["num_noisy"],), t.item(), device=lat.device), + "action_timesteps": torch.full((st["num_noisy_action"],), t.item(), device=lat.device), + "position_ids_cond": cond["position_ids"][:, und:], + "vision_token_shapes": cond["vision_token_shapes"], + "vision_noisy_frame_indexes": cond["vision_noisy_frame_indexes"], + "vision_mse_loss_indexes": cond["mse_gen_indexes"], + "action_token_shapes": cond["action_token_shapes"], + "action_noisy_frame_indexes": cond["action_noisy_frame_indexes"], + "action_mse_gen_indexes": cond["action_mse_gen_indexes"], + "action_domain_id": st["domain_t"], + } + if with_cfg: + unc = st["uncond"] + req["position_ids_uncond"] = unc["position_ids"][:, unc["und_len"]:] + reqs.append(req) + meta.append((rid, st, lat, act, ti, t, past_end)) + + results = self.transformer.denoise_step_action_batched(reqs, cm, with_cfg) + + out = {} + for (rid, st, lat, act, ti, t, past_end), branches in zip(meta, results, strict=True): + if past_end: + out[rid] = {"latents": [lat], "action_latents": [act], "time_index": [ti]} + continue + if with_cfg: + (cond_video, cond_action), (uncond_video, uncond_action) = branches + video_v = uncond_video + st["gs"] * (cond_video - uncond_video) + action_v = uncond_action + st["gs"] * (cond_action - uncond_action) + else: + (video_v, action_v), = branches + new_latents, new_action = self._action_scheduler_step(st, lat, act, video_v, action_v, t) + out[rid] = { + "latents": [new_latents], + "action_latents": [new_action], + "time_index": [ti + 1], + } + return out + + # ------------------------------------------------------------------ + # CUDA-graph capture of the denoise step. Only the transformer velocity + # computation is captured; the guidance combine and the (Python, multistep) + # scheduler step run eagerly afterwards. + # ------------------------------------------------------------------ + + def get_cuda_graph_configs(self, device, tp_world_size: int = 1): + """Declare one fixed-shape capture of the image denoise step per + resolution. Requests at other resolutions, or without guidance, fall back + to the eager path. The per-resolution token layout is prompt-independent, + so bake it once here and key it by latent shape; the per-prompt rotary + positions, the latents and the timestep flow in as static-buffer inputs. + + Set ``COSMOS3_DISABLE_CUDA_GRAPH=1`` to skip capture and run the denoise + loop eagerly (escape hatch for a misbehaving driver, and an A/B switch). + Set ``COSMOS3_GEN_CAPTURE_RES`` (e.g. ``"192x320,480x832"``, height x + width) to override which resolutions are captured, and + ``COSMOS3_GEN_CAPTURE_BS`` (e.g. ``"1,4,8"``) to also capture batched + denoise steps so concurrent requests replay a padded graph instead of + falling back to the eager path.""" + if self.transformer is None or os.environ.get("COSMOS3_DISABLE_CUDA_GRAPH"): + return [] + res_env = os.environ.get("COSMOS3_GEN_CAPTURE_RES") + if res_env: + resolutions = tuple( + tuple(int(x) for x in pair.split("x")) for pair in res_env.split(",") + ) + else: + resolutions = self.gen_capture_resolutions + bs_env = os.environ.get("COSMOS3_GEN_CAPTURE_BS") + if bs_env: + capture_batch_sizes = [int(x) for x in bs_env.split(",")] + else: + capture_batch_sizes = list(self.gen_capture_batch_sizes) + dtype = self.transformer.proj_in.weight.dtype + self._capture_layout: dict[tuple, dict] = {} + configs = [] + for height, width in resolutions: + latent_shape = self._latent_shape(height, width, num_frames=1) + # patchify-2 pads an odd latent height/width (e.g. 720p: 720 // 16 = + # 45 -> pad to 46), and the captured/replayed padded layout produces + # degraded output (clean on the left, scrambled on the right). Skip + # capture for such resolutions; they fall back to the eager path, + # which is clean and ~as fast at these compute-bound tiers. + if latent_shape[3] % 2 or latent_shape[4] % 2: + logger.info( + "Cosmos3: skipping CUDA-graph capture for %dx%d " + "(odd latent dim %s -> patchify pad -> eager fallback)", + height, width, tuple(latent_shape[3:]), + ) + continue + static = self._build_static( + [0] * 8, height, width, num_frames=1, fps=24.0, + has_image_condition=False, device=device, + ) + num_vision = static["num_vision_tokens"] + num_noisy = static["num_noisy_vision_tokens"] + self._capture_layout[tuple(latent_shape)] = { + "vision_token_shapes": static["vision_token_shapes"], + "vision_noisy_frame_indexes": static["vision_noisy_frame_indexes"], + "mse_gen_indexes": static["mse_gen_indexes"], + } + single = ARNodeInputs( + input_seq_len=num_vision, + tensor_inputs={ + "latents": torch.zeros(latent_shape, device=device, dtype=dtype), + "vision_timesteps": torch.zeros(num_noisy, device=device, dtype=torch.float32), + "position_ids_cond": static["vision_mrope_ids"].clone(), + "position_ids_uncond": static["vision_mrope_ids"].clone(), + }, + ) + configs.append(BasicBatchedCudaGraphConfig( + capture_graph_walk=IMAGE_GEN_WALK, + single_request_inputs=single, + requires_cfg=False, + labels=[COND_LABEL, UNCOND_LABEL], + capture_forward_method="forward_captured", + advance_seq_lens=False, + compile=False, + capture_batch_sizes=capture_batch_sizes, + # The captured sizes (default just bs=1, for single-request + # latency; COSMOS3_GEN_CAPTURE_BS adds batched sizes) are an + # acceleration subset, not a batch ceiling: a concurrent batch at + # an uncaptured size or mixed resolution still runs the eager + # batched denoise (forward_batched), so don't let this capture cap + # max_batch_size to the captured sizes. + caps_eager_batch_size=False, + )) + return configs + + def can_use_cuda_graphs(self, batch, model_inputs) -> bool: + # Only the image denoise step is captured, only with two-branch guidance, + # and only at a resolution we captured a graph for. A batched capture is a + # single fixed resolution, so a concurrent batch must be uniform-resolution + # to share one captured (batch size, token count) bucket; mixed-resolution + # batches fall back to the eager cross-request denoise. + if batch.graph_walk != IMAGE_GEN_WALK: + return False + layout = getattr(self, "_capture_layout", None) + if not layout: + return False + shapes = set() + for rid in batch.request_ids: + st = self._req.get(rid) + if st is None or st["uncond"] is None: + return False + shape = tuple(st["latent_shape"]) + if shape not in layout: + return False + shapes.add(shape) + return len(shapes) == 1 + + def forward_captured( + self, graph_walk, engine_inputs: ModelInputsFromEngine, + latents, vision_timesteps, position_ids_cond, position_ids_uncond, **kwargs, + ) -> dict: + """Velocity-only denoise forward captured into a CUDA graph: both guidance + branches in one pass (the combined plan), no scheduler step. The token + layout is baked per resolution; the latents, timestep and rotary positions + are static-buffer inputs stacked on a leading batch dim. A single request + keeps the two-branch path; a concurrent batch runs the per-request denoise + (the same compute as the eager cross-request forward), one transformer pass + over the whole batch.""" + cm = engine_inputs.cache_manager + cm.set_active_label(CFG_BATCHED_LABEL) + layout = self._capture_layout[tuple(latents.shape[1:])] + rids = engine_inputs.request_ids + if latents.shape[0] == 1: + cond_v, uncond_v = self.transformer.denoise_step_batched_cfg( + latents[0], vision_timesteps[0], position_ids_cond[0], position_ids_uncond[0], + layout["vision_token_shapes"], layout["vision_noisy_frame_indexes"], + layout["mse_gen_indexes"], cm, + ) + return {rids[0]: {"cond_v": [cond_v], "uncond_v": [uncond_v]}} + reqs = [ + { + "latents": latents[i], + "vision_timesteps": vision_timesteps[i], + "position_ids_cond": position_ids_cond[i], + "position_ids_uncond": position_ids_uncond[i], + "vision_token_shapes": layout["vision_token_shapes"], + "vision_noisy_frame_indexes": layout["vision_noisy_frame_indexes"], + "vision_mse_loss_indexes": layout["mse_gen_indexes"], + } + for i in range(latents.shape[0]) + ] + results = self.transformer.denoise_step_batched(reqs, cm) + return { + rid: {"cond_v": [cond_v], "uncond_v": [uncond_v]} + for rid, (cond_v, uncond_v) in zip(rids, results, strict=True) + } + + def postprocess_captured(self, request_ids, inputs, per_request_info, outputs) -> dict: + """Eager tail run after graph replay: the classifier-free-guidance combine + and the (Python, multistep) scheduler step the graph can't hold. Mirrors + the tail of ``_forward_image_gen``.""" + for rid, inp in zip(request_ids, inputs, strict=True): + st = self._req[rid] + cond_v = outputs[rid]["cond_v"][0] + uncond_v = outputs[rid]["uncond_v"][0] + velocity = uncond_v + st["gs"] * (cond_v - uncond_v) + latents = inp.tensor_inputs["latents"] + time_index = inp.tensor_inputs["time_index"] + step_index = int(time_index.reshape(-1)[0].item()) + if step_index >= len(st["scheduler"].timesteps): + # Discarded extra step past this request's denoise count. + outputs[rid] = {"latents": [latents], "time_index": [time_index]} + continue + t = st["scheduler"].timesteps[step_index] + new_latents = st["scheduler"].step( + velocity.unsqueeze(0), t, latents.unsqueeze(0), return_dict=False + )[0].squeeze(0) + outputs[rid] = {"latents": [new_latents], "time_index": [time_index + 1]} + return outputs + + def check_stop(self, request_id, request_info, outputs) -> set[str]: + """Stop this request's denoise loop once it has run its own step count. + + The loop is built with a fixed upper-bound iteration count + (``config.max_inference_steps``); each request runs only as many steps as + its scheduler holds (e.g. image 50, video 35, action 30, distilled policy + ~4), which can differ between concurrent requests. Runs on the worker's + slow-postprocess path, so reading the per-request step count is fine. The + one extra step the loop dispatches before this stop takes effect is a + no-op (see the ``step_index >=`` guards in the forward methods).""" + st = self._req.get(request_id) + if st is None: + return set() + loop = { + ACTION_GEN_WALK: ACTION_GEN_LOOP, + ACTION_VIDEO_GEN_WALK: ACTION_VIDEO_GEN_LOOP, + VIDEO_GEN_WALK: VIDEO_GEN_LOOP, + }.get(request_info.graph_walk, IMAGE_GEN_LOOP) + iter_idx = request_info.dynamic_loop_iter_counts.get(loop, 0) + if iter_idx + 1 >= len(st["scheduler"].timesteps): + return {loop} + return set() + + def cleanup_request(self, request_id: str): + self._req.pop(request_id, None) + + +class Cosmos3VAEDecoderSubmodule(NodeSubmodule): + """Wan VAE decode node: final denoised latents -> pixel frames. + + Applies the pipeline-side latent normalization (the VAE itself returns raw + latents) before decoding, matching the fused t2i pipeline's decode. + """ + + # One-shot decode per request; CUDA-graph capture (not torch.compile) is the + # speedup path. + disable_torch_compile = True + + def __init__(self, vae, config): + super().__init__() + self.vae = vae + self.config = config + # The Wan VAE decode is 3D-conv bound and is not captured into a CUDA + # graph (it runs once per request at request-specific frame/resolution + # shapes). torch.compile fuses the pointwise epilogues around those convs; + # fullgraph=False lets dynamo break around the VAE's Python-level + # causal-conv feature cache, and dynamic=False gives the best per-shape + # kernels at the cost of a one-time trace per new (frames, height, width) + # — fine for the few fixed generation tiers (the first request at each + # shape pays the trace). Off by default; set COSMOS3_COMPILE_VAE=1 to + # enable (A/B against the eager decode, which is identical bar fp + # rounding). The compile wraps the same fp32, autocast-off decode below. + self._decode = vae.decode if vae is not None else None + if vae is not None and os.environ.get("COSMOS3_COMPILE_VAE"): + self._decode = torch.compile(vae.decode, fullgraph=False, dynamic=False) + logger.info("Cosmos3 VAE decode torch.compile enabled") + + def prepare_inputs(self, graph_walk, fwd_info, inputs, **kwargs) -> NodeInputs: + return NodeInputs(tensor_inputs={"latents": inputs["latents"][0]}) + + def forward(self, graph_walk, engine_inputs: ModelInputsFromEngine, latents, **kwargs): + vae = self.vae + # The Wan VAE's 3D convolutions run several times faster in fp32 (TF32 + # tensor cores) than in bf16 on this cuDNN, and the reference pipeline + # decodes in fp32. The engine casts this submodule to bf16, so restore the + # vae to fp32 once and decode outside autocast to keep the fast path. + if next(vae.parameters()).dtype != torch.float32: + vae.float() + mean = torch.tensor(vae.config.latents_mean, dtype=torch.float32, device=latents.device).view(1, -1, 1, 1, 1) + inv_std = (1.0 / torch.tensor(vae.config.latents_std, dtype=torch.float32, device=latents.device)).view( + 1, -1, 1, 1, 1 + ) + z = latents.float() / inv_std + mean + _prof = os.environ.get("COSMOS3_PROFILE") + if _prof: + _e0 = torch.cuda.Event(enable_timing=True); _e1 = torch.cuda.Event(enable_timing=True) + _e0.record() + with torch.autocast(device_type=z.device.type, enabled=False): + decoded = self._decode(z).sample # [1, 3, T, H, W] in [-1, 1] + if _prof: + _e1.record(); torch.cuda.synchronize() + logger.info("COSMOS3_PROFILE vae_decode %.1f ms out=%s", _e0.elapsed_time(_e1), tuple(decoded.shape)) + # Quantize to 8-bit here (the output is an 8-bit image/mp4 either way) so + # only the uint8 frames cross the SHM edge to the data worker, not a 4x + # larger fp32 tensor — the decoded video transfer dominates the fixed cost + # at higher resolutions. + image = (decoded / 2 + 0.5).clamp(0, 1).mul(255).to(torch.uint8) + # Route the decoded tensor to the active walk's emit edge: image_gen + # emits "image_output" (one frame); video_gen and forward-dynamics + # (action_video_gen) emit "video_output". + out_name = ( + "video_output" + if graph_walk in (VIDEO_GEN_WALK, ACTION_VIDEO_GEN_WALK) + else "image_output" + ) + return {out_name: [image]} diff --git a/mstar/model/cosmos3/tests/__init__.py b/mstar/model/cosmos3/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mstar/model/cosmos3/tests/test_action.py b/mstar/model/cosmos3/tests/test_action.py new file mode 100644 index 00000000..8a5e50b9 --- /dev/null +++ b/mstar/model/cosmos3/tests/test_action.py @@ -0,0 +1,694 @@ +"""Tests for the Cosmos3 action path (forward / inverse dynamics + policy). + +CPU-safe unit tests (tiny random config, no weights) cover: + * the action mRoPE band matches vllm-omni's ``compute_mrope_position_ids_action``; + * the per-mode conditioning layout (which video frames / action tokens are + clean context vs noisy) matches vllm-omni's ``action.py``; + * ``build_action_static_inputs`` produces the right joint ``[text|video|action]`` + sequence length, action mse indexes, and position-id width; + * the transformer ``forward`` returns ``(video, action, sound)`` with the right + shapes and the right zeros (inverse-dynamics predicts no video velocity; + forward-dynamics treats the action as clean condition); + * the engine ``denoise_step`` (generation tower over ``[video|action]`` against + the frozen understanding K/V) reproduces the fused ``forward`` bit-for-bit + with an in-process sdpa cache — the cache-once restructuring for action. + +Run: python3 test_action.py +""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F + +from mstar.model.cosmos3.components.transformer import Cosmos3OmniTransformer +from mstar.model.cosmos3.config import Cosmos3Config +from mstar.model.cosmos3.packing import ( + action_condition_frame_indexes, + build_action_static_inputs, + get_3d_mrope_ids_action_tokens, + vision_condition_frame_indexes, +) + + +# --- verbatim vllm-omni references (transformer_cosmos3.py / action.py) ------ +def _ref_mrope(grid_t, grid_h, grid_w, temporal_offset, fps, base_fps, tcf, base_tcf, start): + fps_mod = fps is not None + if fps_mod: + tps = fps / tcf + base_tps = base_fps / (base_tcf if base_tcf is not None else tcf) + fi = torch.arange(grid_t, dtype=torch.float32) + t_index = ((fi + start) / tps * base_tps + temporal_offset).view(-1, 1).expand(-1, grid_h * grid_w).flatten() + else: + t_index = ( + torch.arange(grid_t, dtype=torch.long).view(-1, 1).expand(-1, grid_h * grid_w).flatten() + + int(temporal_offset) + start + ) + h_index = torch.arange(grid_h, dtype=torch.long).view(1, -1, 1).expand(grid_t, -1, grid_w).flatten() + w_index = torch.arange(grid_w, dtype=torch.long).view(1, 1, -1).expand(grid_t, grid_h, -1).flatten() + if fps_mod: + return torch.stack([t_index, h_index.to(torch.float32), w_index.to(torch.float32)], dim=0) + return torch.stack([t_index, h_index, w_index], dim=0) + + +def _ref_action_condition_indexes(mode, action_length): + if mode == "forward_dynamics": + return list(range(action_length)) + return [] # inverse_dynamics / policy + + +def _ref_vision_condition_indexes(mode, latent_frames): + if mode in ("policy", "forward_dynamics"): + return [0] + return list(range(latent_frames)) # inverse_dynamics + + +def _cfg() -> Cosmos3Config: + return Cosmos3Config( + hidden_size=64, num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=2, + head_dim=16, intermediate_size=128, vocab_size=100, rope_axes_dim=(4, 2, 2), + latent_channel=8, latent_patch_size=2, patch_latent_dim=32, + sound_gen=False, action_gen=True, max_action_dim=12, num_embodiment_domains=6, + ) + + +class _SdpaCache: + """In-process cache-once handle (stored K/V + sdpa), the BatchedCacheManager + surface the DiT uses. Prefill stashes the understanding K/V; the denoise step + re-reads it. Also models the batched-CFG plan: under the combined label the + packed sequence is split into one block per batched label, each routed to its + own committed prefix (so a single-label batch of one request equals the plain + single-request path).""" + + def __init__(self): + self.active, self.layer = "main", 0 + self.committed, self.pending, self.is_causal = {}, {}, {} + self.batched_labels = None + + def set_active_label(self, label): + self.active = label + + def set_layer_idx(self, i): + self.layer = i + + def plan(self, is_causal): + self.is_causal[self.active] = is_causal + + # Engine-facing surface (used when the DiT submodule drives the cache). + def plan_attention(self, seq_lens=None, dtype=None, is_causal=True, write_store=True, label=None): + self.is_causal[label or self.active] = is_causal + + def plan_attention_batched_cfg(self, labels, seq_lens, is_causal=False, write_store=False, **kwargs): + self.batched_labels = list(labels) + self.is_causal["_cfg_batched"] = is_causal + + def plan_rope(self, *a, **k): + pass + + @staticmethod + def _sdpa(q, k, v, c): + o = F.scaled_dot_product_attention( + q.unsqueeze(0).transpose(1, 2), k.unsqueeze(0).transpose(1, 2), + v.unsqueeze(0).transpose(1, 2), is_causal=c, enable_gqa=True) + return o.transpose(1, 2).squeeze(0) + + def _attend_label(self, label, layer, q, k, v, causal): + key = (label, layer) + if key in self.committed: + pk, pv = self.committed[key] + return self._sdpa(q, torch.cat([pk, k], 0), torch.cat([pv, v], 0), causal) + self.pending[key] = (k, v) + return self._sdpa(q, k, v, causal) + + def run_attention(self, q, k, v, layer_idx=None): + layer = self.layer if layer_idx is None else layer_idx + if self.active == "_cfg_batched": + causal = self.is_causal["_cfg_batched"] + n = q.shape[0] // len(self.batched_labels) + outs = [] + for bi, label in enumerate(self.batched_labels): + sl = slice(bi * n, (bi + 1) * n) + outs.append(self._attend_label(label, layer, q[sl], k[sl], v[sl], causal)) + return torch.cat(outs, 0) + return self._attend_label(self.active, layer, q, k, v, self.is_causal[self.active]) + + def advance_seq_lens(self, pos_id_ns=None): + self.committed.update(self.pending) + self.pending = {} + + +_MODES = ("inverse_dynamics", "forward_dynamics", "policy") + + +def test_action_mrope_matches_reference() -> None: + for fps in (10.0, 24.0, None): + ours, _ = get_3d_mrope_ids_action_tokens( + grid_t=12, temporal_offset=100, action_fps=fps, base_fps=24.0, + base_temporal_compression_factor=4, start_frame_offset=1, + ) + ref = _ref_mrope(12, 1, 1, 100, fps, 24.0, 1, 4, 1) + assert torch.allclose(ours.float(), ref.float(), atol=0), (fps, ours[0, :4], ref[0, :4]) + + +def test_condition_indexes_match_reference() -> None: + for mode in _MODES: + assert action_condition_frame_indexes(mode, 16) == _ref_action_condition_indexes(mode, 16) + assert vision_condition_frame_indexes(mode, 5) == _ref_vision_condition_indexes(mode, 5) + + +def test_action_static_layout() -> None: + cfg = _cfg() + action_chunk, num_frames = 8, 9 + latent_t = 1 + (num_frames - 1) // cfg.vae.scale_factor_temporal # 3 + latent_shape = (1, cfg.latent_channel, latent_t, 4, 4) + ids = [1, 2, 3, 4] + tok_per_frame = (4 // cfg.latent_patch_size) ** 2 # 4 + for mode in _MODES: + s = build_action_static_inputs( + ids, latent_shape, action_chunk, mode, cfg, cfg.vae.scale_factor_temporal, + fps=10.0, action_fps=10.0, action_start_offset=1, device="cpu", + ) + assert s["sequence_length"] == len(ids) + latent_t * tok_per_frame + action_chunk + assert s["position_ids"].shape[1] == s["sequence_length"] + exp_vis_noisy = len(_ref_vision_condition_indexes(mode, latent_t)) + exp_vis_noisy = latent_t - exp_vis_noisy + assert s["num_noisy_vision_tokens"] == exp_vis_noisy * tok_per_frame + exp_act_noisy = action_chunk - len(_ref_action_condition_indexes(mode, action_chunk)) + assert s["num_noisy_action_tokens"] == exp_act_noisy + assert s["action_mse_loss_indexes"].numel() == exp_act_noisy + + +def _run_mode(model, cfg, mode, latent_shape, action_chunk, ids): + s = build_action_static_inputs( + ids, latent_shape, action_chunk, mode, cfg, cfg.vae.scale_factor_temporal, + fps=10.0, action_fps=10.0, action_start_offset=1, device="cpu", + ) + keys = ("input_ids", "text_indexes", "position_ids", "und_len", "sequence_length", + "vision_token_shapes", "vision_sequence_indexes", "vision_mse_loss_indexes", + "vision_noisy_frame_indexes", "action_token_shapes", "action_sequence_indexes", + "action_mse_loss_indexes", "action_noisy_frame_indexes") + sk = {k: s[k] for k in keys} + domain = torch.tensor([2], dtype=torch.long) + latents = torch.randn(latent_shape) + action_lat = torch.randn(1, action_chunk, cfg.max_action_dim) + vts = torch.full((s["num_noisy_vision_tokens"],), 500.0) + ats = torch.full((s["num_noisy_action_tokens"],), 500.0) + with torch.no_grad(): + pv, pa, ps = model( + vision_tokens=[latents], vision_timesteps=vts, + action_tokens=action_lat, action_timesteps=ats, action_domain_id=domain, **sk, + ) + return s, sk, latents, action_lat, domain, vts, ats, pv, pa, ps + + +def test_action_forward_shapes_and_masks() -> None: + cfg = _cfg() + torch.manual_seed(0) + model = Cosmos3OmniTransformer(cfg).eval() + action_chunk = 8 + latent_t = 1 + (9 - 1) // cfg.vae.scale_factor_temporal + latent_shape = (1, cfg.latent_channel, latent_t, 4, 4) + for mode in _MODES: + _, _, _, _, _, _, _, pv, pa, ps = _run_mode(model, cfg, mode, latent_shape, action_chunk, [1, 2, 3, 4]) + assert ps is None + assert pv[0].shape == latent_shape + assert pa.shape == (1, action_chunk, cfg.max_action_dim) + if mode == "inverse_dynamics": + assert torch.count_nonzero(pv[0]) == 0 + if mode == "forward_dynamics": + assert torch.count_nonzero(pa) == 0 + + +def test_action_denoise_step_matches_fused() -> None: + """The engine generation tower over [video|action] against the frozen + understanding K/V reproduces the fused forward bit-for-bit (sdpa cache).""" + cfg = _cfg() + torch.manual_seed(0) + model = Cosmos3OmniTransformer(cfg).eval() + action_chunk = 8 + latent_t = 1 + (9 - 1) // cfg.vae.scale_factor_temporal + latent_shape = (1, cfg.latent_channel, latent_t, 4, 4) + for mode in _MODES: + s, _, latents, action_lat, domain, vts, ats, pv, pa, _ = _run_mode( + model, cfg, mode, latent_shape, action_chunk, [1, 2, 3, 4] + ) + cache = _SdpaCache() + und_len = s["und_len"] + cache.set_active_label("main") + cache.plan(is_causal=True) + model.prefill_und(s["input_ids"], s["text_mrope_ids"], cache) + cache.plan(is_causal=False) + with torch.no_grad(): + dv, da = model.denoise_step( + latents, vts, s["position_ids"][:, und_len:], + s["vision_token_shapes"], s["vision_noisy_frame_indexes"], + s["vision_mse_loss_indexes"] - und_len, cache, + action_latents=action_lat, action_token_shapes=s["action_token_shapes"], + action_noisy_frame_indexes=s["action_noisy_frame_indexes"], + action_mse_gen_indexes=s["action_mse_loss_indexes"] - und_len, + action_timesteps=ats, action_domain_id=domain, + ) + assert (pv[0] - dv).abs().max().item() < 1e-4, mode + assert (pa - da).abs().max().item() < 1e-4, mode + + +def test_action_batched_one_matches_single() -> None: + """The cross-request action batched forward, run with a single request and no + guidance, reproduces the single-request ``denoise_step`` bit-for-bit. Checks + the batched packing / per-request decode plumbing; multi-request isolation is + the GPU-gated cross-request parity test.""" + cfg = _cfg() + torch.manual_seed(0) + model = Cosmos3OmniTransformer(cfg).eval() + action_chunk = 8 + latent_t = 1 + (9 - 1) // cfg.vae.scale_factor_temporal + latent_shape = (1, cfg.latent_channel, latent_t, 4, 4) + for mode in _MODES: + s, _, latents, action_lat, domain, vts, ats, _, _, _ = _run_mode( + model, cfg, mode, latent_shape, action_chunk, [1, 2, 3, 4] + ) + und_len = s["und_len"] + # Reference: the single-request joint denoise step. + cache = _SdpaCache() + cache.set_active_label("main") + cache.plan(is_causal=True) + model.prefill_und(s["input_ids"], s["text_mrope_ids"], cache) + cache.plan(is_causal=False) + with torch.no_grad(): + dv, da = model.denoise_step( + latents, vts, s["position_ids"][:, und_len:], + s["vision_token_shapes"], s["vision_noisy_frame_indexes"], + s["vision_mse_loss_indexes"] - und_len, cache, + action_latents=action_lat, action_token_shapes=s["action_token_shapes"], + action_noisy_frame_indexes=s["action_noisy_frame_indexes"], + action_mse_gen_indexes=s["action_mse_loss_indexes"] - und_len, + action_timesteps=ats, action_domain_id=domain, + ) + # Batched path with one request and no guidance (single-label batch). + cache2 = _SdpaCache() + cache2.set_active_label("main") + cache2.plan(is_causal=True) + model.prefill_und(s["input_ids"], s["text_mrope_ids"], cache2) + cache2.plan_attention_batched_cfg( + labels=["main"], + seq_lens=[s["num_vision_tokens"] + s["num_action_tokens"]], + is_causal=False, + ) + cache2.set_active_label("_cfg_batched") + req = { + "latents": latents, "action_latents": action_lat, + "vision_timesteps": vts, "action_timesteps": ats, + "position_ids_cond": s["position_ids"][:, und_len:], + "vision_token_shapes": s["vision_token_shapes"], + "vision_noisy_frame_indexes": s["vision_noisy_frame_indexes"], + "vision_mse_loss_indexes": s["vision_mse_loss_indexes"] - und_len, + "action_token_shapes": s["action_token_shapes"], + "action_noisy_frame_indexes": s["action_noisy_frame_indexes"], + "action_mse_gen_indexes": s["action_mse_loss_indexes"] - und_len, + "action_domain_id": domain, + } + with torch.no_grad(): + ((bv, ba),), = model.denoise_step_action_batched([req], cache2, with_cfg=False) + assert (dv - bv).abs().max().item() < 1e-5, mode + assert (da - ba).abs().max().item() < 1e-5, mode + + +# --- GPU-gated parity (needs COSMOS3_NANO_DIR + CUDA + diffusers) ------------ +import math # noqa: E402 +import os # noqa: E402 + +os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") +# The engine-vs-fused check below is a bit-exact mechanism test, so run the eager +# denoise step; the served default torch.compiles it, which perturbs the latents +# past the 1e-3 bound without moving the action golden gates (id/fd pass with +# compile on). Set COSMOS3_DISABLE_COMPILE_DENOISE= (empty) to test compiled. +os.environ.setdefault("COSMOS3_DISABLE_COMPILE_DENOISE", "1") + +_GPU: dict = {} + + +def _gpu_base(): + if "base" in _GPU: + return _GPU["base"] + snap = os.environ.get("COSMOS3_NANO_DIR") + if not snap or not torch.cuda.is_available(): + _GPU["base"] = None + return None + torch.use_deterministic_algorithms(True, warn_only=True) + from mstar.model.cosmos3.cosmos3_model import Cosmos3Model + from mstar.model.cosmos3.pipeline import Cosmos3Pipeline + + device, dtype = "cuda:0", torch.bfloat16 + model = Cosmos3Model(model_path_hf=snap) + mpipe = Cosmos3Pipeline.from_model(model, device=device, dtype=dtype) + dit = model.get_submodule("dit", device=device) + _GPU["base"] = dict(model=model, mpipe=mpipe, dit=dit, device=device, dtype=dtype, snap=snap) + return _GPU["base"] + + +def _flashinfer_action_shared(model, rids, device, dtype): + """A KV cache + paged allocator shared by several action requests, each with + only the conditional label (guidance-scale-1 action has no unconditional + branch). Mirrors the engine's persistent per-node cache.""" + from mstar.communication.tensors import LocalTransferEngine + from mstar.engine.cache_manager import WorkspaceBufferManager + from mstar.engine.kv_store import PagedAllocationManager, TransferEngineInfo + from mstar.model.cosmos3.submodules import COND_LABEL + + cfg = model.get_kv_cache_config()[0] + cfg.max_num_pages = 128 + cfg.shard(1) + kv_cache = torch.zeros( + cfg.num_layers, cfg.max_num_pages, 2, cfg.page_size, cfg.num_kv_heads, cfg.head_dim, + dtype=dtype, device=device, + ) + alloc = PagedAllocationManager(cfg, kv_cache, TransferEngineInfo("h", "h", LocalTransferEngine("h"))) + for rid in rids: + alloc.add_request(rid, [COND_LABEL]) + buf = WorkspaceBufferManager(256 * 1024 * 1024, device) + return {"kv_cache": kv_cache, "alloc": alloc, "buf": buf, "cfg": cfg, "device": device} + + +def _mk_action_cm(shared, rids): + from mstar.engine.cache_manager import BatchedCacheManager + from mstar.model.cosmos3.submodules import COND_LABEL + + return BatchedCacheManager( + request_ids=rids, active_labels_per_request={r: COND_LABEL for r in rids}, + kv_cache=shared["kv_cache"], alloc_manager=shared["alloc"], buffer_manager=shared["buf"], + kv_cache_config=shared["cfg"], device=shared["device"], auto_write_store=False, + ) + + +def test_action_engine_matches_fused() -> None: + """The cache-once engine action path reproduces the fused pipeline bit-for-bit + (sdpa), on real Nano weights — the action analogue of the video engine test.""" + base = _gpu_base() + if base is None: + print(" (skipped action engine parity: needs COSMOS3_NANO_DIR + CUDA)") + return + from diffusers.utils.torch_utils import randn_tensor + + from mstar.conductor.request_info import CurrentForwardPassInfo + from mstar.model.submodule_base import ModelInputsFromEngine + + device, dtype, mpipe, dit, model = ( + base["device"], base["dtype"], base["mpipe"], base["dit"], base["model"]) + prompt, chunk, raw, dom, fps, steps, fshift, h, w = ( + "You are an autonomous vehicle planning system.", 12, 9, 1, 10.0, 8, 10.0, 128, 128) + nf = chunk + 1 + cond_latent = torch.randn( + (1, model.config.latent_channel, 1 + (nf - 1) // 4, h // 16, w // 16), device=device, dtype=dtype) + + gen = torch.Generator(device=device).manual_seed(0) + act_fused = mpipe.generate_action( + prompt=prompt, mode="inverse_dynamics", domain_id=dom, action_chunk_size=chunk, raw_action_dim=raw, + video_latents=cond_latent, num_frames=nf, height=h, width=w, fps=fps, action_fps=fps, + num_inference_steps=steps, guidance_scale=1.0, flow_shift=fshift, generator=gen) + + gen2 = torch.Generator(device=device).manual_seed(0) + a_noise = randn_tensor((1, chunk, dit.transformer.action_dim), generator=gen2, device=device, dtype=dtype) + a_noise[..., raw:] = 0 + + from mstar.model.cosmos3.packing import tokenize_prompt + cond_ids, _ = tokenize_prompt(model.tokenizer, prompt, "", num_frames=nf, height=h, width=w, fps=fps) + rid = "ra" + md = {"height": h, "width": w, "num_frames": nf, "fps": fps, "action_fps": fps, "guidance_scale": 1.0, + "num_inference_steps": steps, "action_mode": "inverse_dynamics", "action_chunk_size": chunk, + "raw_action_dim": raw, "domain_id": dom, "flow_shift": fshift} + fwd = CurrentForwardPassInfo(request_id=rid, graph_walk="prefill", requires_cfg=False, fwd_index=0, + random_seed=0, max_tokens=0, sampling_config={}, step_metadata=md) + cm = _SdpaCache() + ei = ModelInputsFromEngine(request_ids=[rid], per_request_info={rid: fwd}, cache_manager=cm) + ni = dit.prepare_inputs("prefill", fwd, {"text_inputs": [torch.tensor(cond_ids, dtype=torch.long, device=device)]}) + dit.forward("prefill", ei, **dit.preprocess("prefill", ei, [ni])) + fwd.graph_walk = "action_gen" + latents, action_latents = cond_latent.clone(), a_noise.clone() + time_index = torch.zeros(1, dtype=torch.long, device=device) + for _ in range(steps): + ni = dit.prepare_inputs("action_gen", fwd, { + "latents": [latents], "action_latents": [action_latents], "time_index": [time_index]}) + out = dit.forward("action_gen", ei, **dit.preprocess("action_gen", ei, [ni])) + latents, action_latents, time_index = out["latents"][0], out["action_latents"][0], out["time_index"][0] + dit.cleanup_request(rid) + # The loop emits the full action latents (self-edge); trim to the raw action + # width to compare with the fused pipeline's trimmed output. + pred_action = out["action_latents"][0][:, :, :raw] + diff = (act_fused.float() - pred_action.float()).abs().max().item() + assert diff <= 1e-3, f"engine action differs from fused by {diff:.3e}" + print(f" action engine cache-once (sdpa) abs-max diff = {diff:.3e}") + + +def test_action_id_golden_gate() -> None: + """Inverse-dynamics on av_0 reproduces NVIDIA's reference action output + ([60, 9]) within MSE <= 0.05 / PSNR >= 14 (NVIDIA's own thresholds).""" + base = _gpu_base() + if base is None: + print(" (skipped action id golden gate: needs COSMOS3_NANO_DIR + CUDA)") + return + import json + + import torchvision + from PIL import Image + + from mstar.model.cosmos3.packing import tokenize_prompt + + device, dtype, mpipe, model, snap = ( + base["device"], base["dtype"], base["mpipe"], base["model"], base["snap"]) + assets = os.path.join(snap, "assets") + inp = os.path.join(assets, "example_action_id_av_0_input.mp4") + if not os.path.exists(inp): + print(" (skipped action id golden gate: av_0 input video missing)") + return + prompt, chunk, raw, dom, fps = "You are an autonomous vehicle planning system.", 60, 9, 1, 10.0 + nf = chunk + 1 + frames, _, _ = torchvision.io.read_video(inp, pts_unit="sec") + frames = frames[:nf] + h, w = int(frames.shape[1]), int(frames.shape[2]) + procs = [mpipe.video_processor.preprocess(Image.fromarray(frames[i].numpy()), height=h, width=w).squeeze(0) + for i in range(frames.shape[0])] + video = torch.stack(procs, dim=1).unsqueeze(0).to(device=device, dtype=dtype) + + cond_ids, _ = tokenize_prompt(model.tokenizer, prompt, "", num_frames=nf, height=h, width=w, fps=fps, + use_system_prompt=False, add_resolution_template=False, + add_duration_template=False) + gen = torch.Generator(device=device).manual_seed(0) + action = mpipe.generate_action( + prompt=prompt, mode="inverse_dynamics", domain_id=dom, action_chunk_size=chunk, raw_action_dim=raw, + video=video, num_frames=nf, height=h, width=w, fps=fps, action_fps=fps, + num_inference_steps=30, guidance_scale=1.0, flow_shift=10.0, generator=gen, + cond_ids=cond_ids, uncond_ids=cond_ids) + pred = action[0].float().cpu() + gold = torch.tensor(json.load(open(os.path.join(assets, "example_action_id_av_0_output.json")))["data"], + dtype=torch.float32) + mse = (pred - gold).pow(2).mean().item() + psnr = float("inf") if mse == 0 else -10 * math.log10(mse) + assert mse <= 0.05 and psnr >= 14.0, f"action id MSE {mse:.5f} / PSNR {psnr:.2f} outside gate" + print(f" action id av_0: MSE = {mse:.5f}, PSNR = {psnr:.2f} dB") + + +def test_action_fd_agibotworld_golden_gate() -> None: + """Autoregressive forward-dynamics on the AgiBotWorld 4-chunk example + reproduces NVIDIA's golden video (PSNR >= 14). Each chunk takes a [16, 29] + action chunk as the clean condition; chunk 0 conditions on the first frame, + chunks 1-3 on the previous chunk's final generated frame.""" + base = _gpu_base() + if base is None: + print(" (skipped fd agibotworld golden gate: needs COSMOS3_NANO_DIR + CUDA)") + return + import json + + import torchvision + from PIL import Image + + from mstar.model.cosmos3.packing import tokenize_prompt + + device, dtype, mpipe, model, snap = ( + base["device"], base["dtype"], base["mpipe"], base["model"], base["snap"]) + assets = os.path.join(snap, "assets") + first_png = os.path.join(assets, "example_action_fd_agibotworld_first_frame.png") + chunks_json = os.path.join(assets, "example_action_fd_agibotworld_action_chunks.json") + gold_mp4 = os.path.join(assets, "example_action_fd_agibotworld_4chunk_output.mp4") + if not (os.path.exists(first_png) and os.path.exists(chunks_json) and os.path.exists(gold_mp4)): + print(" (skipped fd agibotworld golden gate: assets missing)") + return + prompt, dom, raw, chunk = "Pickup items in the supermarket", 15, 29, 16 + nf, fps = chunk + 1, 10.0 + im = Image.open(first_png).convert("RGB") + w, h = im.size + cond_frame = mpipe.video_processor.preprocess(im, height=h, width=w).to(device=device, dtype=dtype)[0] + chunks = torch.tensor(json.load(open(chunks_json))["action_chunks"], dtype=torch.float32) + cond_ids, _ = tokenize_prompt(model.tokenizer, prompt, "", num_frames=nf, height=h, width=w, fps=fps, + use_system_prompt=False, add_resolution_template=False, + add_duration_template=False) + out = [] + for k in range(chunks.shape[0]): + cond_video = cond_frame.unsqueeze(0).unsqueeze(2).expand(-1, -1, nf, -1, -1).contiguous() + gen = torch.Generator(device=device).manual_seed(k) + _, video = mpipe.generate_action( + prompt=prompt, mode="forward_dynamics", domain_id=dom, action_chunk_size=chunk, raw_action_dim=raw, + action=chunks[k], video=cond_video, num_frames=nf, height=h, width=w, fps=fps, action_fps=fps, + num_inference_steps=30, guidance_scale=1.0, flow_shift=10.0, generator=gen, + cond_ids=cond_ids, uncond_ids=cond_ids, return_video=True) + pred = video[0, :, 1:, :, :].float() + out.append(pred.cpu()) + cond_frame = (pred[:, -1].clamp(0, 1) * 2 - 1).to(device=device, dtype=dtype) + pred_video = torch.cat(out, dim=1) + g, _, _ = torchvision.io.read_video(gold_mp4, pts_unit="sec") + gold = (g.permute(3, 0, 1, 2).float() / 255.0) + n = min(pred_video.shape[1], gold.shape[1]) + mse = (pred_video[:, :n] - gold[:, :n]).pow(2).mean().item() + psnr = float("inf") if mse == 0 else -10 * math.log10(mse) + assert psnr >= 14.0, f"fd agibotworld PSNR {psnr:.2f} < 14 (MSE {mse:.5f})" + print(f" fd agibotworld: {n} frames, PSNR = {psnr:.2f} dB") + + +@torch.no_grad() +def test_action_cross_request_batch_matches_individual() -> None: + """Several action requests denoised together in one batch reproduce each + request run alone (guidance-scale-1 inverse-dynamics, real FlashInfer cache). + Each batched action must (a) stay isolated — closer to its own bs=1 action + than to any other request's — and (b) not drift from bs=1 beyond bf16 batch- + variance. The action analogue of the image cross-request batch parity test.""" + base = _gpu_base() + if base is None: + print(" (skipped action cross-request batch parity: needs COSMOS3_NANO_DIR + CUDA)") + return + from mstar.conductor.request_info import CurrentForwardPassInfo + from mstar.model.cosmos3.packing import tokenize_prompt + from mstar.model.submodule_base import ModelInputsFromEngine + + device, dtype, dit, model = base["device"], base["dtype"], base["dit"], base["model"] + chunk, raw, dom, fps, steps, h, w = 12, 9, 1, 10.0, 6, 128, 128 + nf = chunk + 1 + prompts = [ + "You are an autonomous vehicle planning system.", + "Drive forward and keep to the center of the lane.", + "Slow down and prepare to stop at the intersection.", + ] + rids = [f"ab{i}" for i in range(len(prompts))] + seeds = [10, 20, 30] + lat_t = 1 + (nf - 1) // 4 + # A distinct conditioning clip per request so their predicted actions clearly + # differ (sharper isolation signal); the same clip is used in both runs. + cond_videos = { + rid: torch.rand((nf, 3, h, w), generator=torch.Generator().manual_seed(s), dtype=torch.float32) + for rid, s in zip(rids, seeds, strict=True) + } + + def _md(): + return {"height": h, "width": w, "num_frames": nf, "fps": fps, "action_fps": fps, + "guidance_scale": 1.0, "num_inference_steps": steps, "action_mode": "inverse_dynamics", + "action_chunk_size": chunk, "raw_action_dim": raw, "domain_id": dom, "flow_shift": 10.0} + + conds = [tokenize_prompt(model.tokenizer, p, "", num_frames=nf, height=h, width=w, fps=fps, + use_system_prompt=False, add_resolution_template=False, + add_duration_template=False)[0] for p in prompts] + + def _prefill(rid, idx, cm): + fwd = CurrentForwardPassInfo( + request_id=rid, graph_walk="prefill", requires_cfg=False, fwd_index=0, + random_seed=seeds[idx], max_tokens=0, sampling_config={}, step_metadata=_md()) + ei = ModelInputsFromEngine(request_ids=[rid], per_request_info={rid: fwd}, cache_manager=cm) + ni = dit.prepare_inputs("prefill", fwd, { + "text_inputs": [torch.tensor(conds[idx], dtype=torch.long, device=device)], + "video_inputs": [cond_videos[rid].to(device)], + }) + dit.forward("prefill", ei, **dit.preprocess("prefill", ei, [ni])) + fwd.graph_walk = "action_gen" + return fwd + + def _run_one(rid, idx): + shared = _flashinfer_action_shared(model, [rid], device, dtype) + cm = _mk_action_cm(shared, [rid]) + fwd = _prefill(rid, idx, cm) + ei = ModelInputsFromEngine(request_ids=[rid], per_request_info={rid: fwd}, cache_manager=cm) + lat = act = ti = None + for _ in range(steps): + inp = {} if lat is None else {"latents": [lat], "action_latents": [act], "time_index": [ti]} + ni = dit.prepare_inputs("action_gen", fwd, inp) + out = dit.forward("action_gen", ei, **dit.preprocess("action_gen", ei, [ni])) + lat, act, ti = out["latents"][0], out["action_latents"][0], out["time_index"][0] + dit.cleanup_request(rid) + return act[:, :, :raw].float().cpu() + + def _run_batched(): + shared = _flashinfer_action_shared(model, rids, device, dtype) + fwds = {} + for i, rid in enumerate(rids): + fwds[rid] = _prefill(rid, i, _mk_action_cm(shared, [rid])) + cmN = _mk_action_cm(shared, rids) + eiN = ModelInputsFromEngine(request_ids=rids, per_request_info=fwds, cache_manager=cmN) + lat = {r: None for r in rids} + act = {r: None for r in rids} + ti = {r: None for r in rids} + for _ in range(steps): + inputs = [] + for rid in rids: + inp = {} if lat[rid] is None else { + "latents": [lat[rid]], "action_latents": [act[rid]], "time_index": [ti[rid]]} + inputs.append(dit.prepare_inputs("action_gen", fwds[rid], inp)) + out = dit.forward_batched("action_gen", eiN, **dit.preprocess("action_gen", eiN, inputs)) + for rid in rids: + o = out[rid] + lat[rid], act[rid], ti[rid] = o["latents"][0], o["action_latents"][0], o["time_index"][0] + res = {rid: act[rid][:, :, :raw].float().cpu() for rid in rids} + for rid in rids: + dit.cleanup_request(rid) + return res + + try: + bs1 = {rid: _run_one(rid, i) for i, rid in enumerate(rids)} + bat = _run_batched() + except Exception as exc: # noqa: BLE001 + print(f" (skipped action cross-request batch parity: FlashInfer unavailable: {exc})") + return + + def _mse(a, b): + return (a - b).pow(2).mean().item() + + n = len(rids) + selfs, crosses = [], [] + for i, rid in enumerate(rids): + self_mse = _mse(bat[rid], bs1[rid]) + cross_mse = min(_mse(bat[rid], bs1[rids[j]]) for j in range(n) if j != i) + selfs.append(self_mse) + crosses.append(cross_mse) + assert self_mse < cross_mse, ( + f"request {i} not isolated: self {self_mse:.4e} vs nearest other {cross_mse:.4e}") + assert self_mse < 5e-3, f"request {i} batched action MSE {self_mse:.4e} drifts from bs=1" + print(" action cross-request batch (bs=%d): self MSE = %s | nearest-other = %s" % ( + n, ", ".join(f"{v:.2e}" for v in selfs), ", ".join(f"{v:.2e}" for v in crosses))) + import gc + gc.collect() + torch.cuda.empty_cache() + + +def _main() -> None: + fns = [ + ("action_mrope_matches_reference", test_action_mrope_matches_reference), + ("condition_indexes_match_reference", test_condition_indexes_match_reference), + ("action_static_layout", test_action_static_layout), + ("action_forward_shapes_and_masks", test_action_forward_shapes_and_masks), + ("action_denoise_step_matches_fused", test_action_denoise_step_matches_fused), + ("action_batched_one_matches_single", test_action_batched_one_matches_single), + ("action_engine_matches_fused", test_action_engine_matches_fused), + ("action_cross_request_batch_matches_individual", test_action_cross_request_batch_matches_individual), + ("action_id_golden_gate", test_action_id_golden_gate), + ("action_fd_agibotworld_golden_gate", test_action_fd_agibotworld_golden_gate), + ] + failures = [] + for name, fn in fns: + try: + fn() + print(f"PASS {name}") + except Exception as exc: # noqa: BLE001 + failures.append((name, exc)) + print(f"FAIL {name}: {exc!r}") + if failures: + raise SystemExit(1) + print("\nAll Cosmos3 action unit checks passed.") + + +if __name__ == "__main__": + _main() diff --git a/mstar/model/cosmos3/tests/test_engine_cache.py b/mstar/model/cosmos3/tests/test_engine_cache.py new file mode 100644 index 00000000..801d5ba6 --- /dev/null +++ b/mstar/model/cosmos3/tests/test_engine_cache.py @@ -0,0 +1,702 @@ +"""GPU parity for the cache-once engine path of the Cosmos3 generator. + +The understanding tower runs once and writes its per-layer K/V; the generation +tower then runs each denoise step re-reading that frozen K/V (the text tokens get +no timestep embedding, so their K/V is denoise-step independent — caching it once +is exact). This checks the ``Cosmos3DiTSubmodule`` prefill + denoise loop against +the fused ``Cosmos3Pipeline`` that runs the whole transformer every step, for both +image (single frame) and video (multi-frame, fps-modulated mRoPE) generation. + +Two GPU-gated checks per mode (need ``COSMOS3_NANO_DIR`` + CUDA; skipped otherwise): + * with an in-process sdpa cache (same attention kernel as the fused pipeline), + the cache-once output is bit-for-bit identical; + * with the engine's FlashInfer paged cache (the served path), the decoded output + matches the fused pipeline within PSNR >= 30 (FlashInfer-vs-sdpa precision). + +Run: COSMOS3_NANO_DIR= python3 test_engine_cache.py +""" + +from __future__ import annotations + +import math +import os + +os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") +# These checks validate the eager cache-once mechanism's numerical exactness, so +# run the eager denoise step. The served default torch.compiles it, which fuses +# pointwise ops and perturbs the latents past the tight bit-exact bounds below +# without changing image quality (the FlashInfer PSNR checks still pass with +# compile on — validated over HTTP). Set COSMOS3_DISABLE_COMPILE_DENOISE= (empty) +# to exercise the compiled path here instead. +os.environ.setdefault("COSMOS3_DISABLE_COMPILE_DENOISE", "1") + +import torch +import torch.nn.functional as F + +PROMPT = "A red cube resting on a polished wooden table, soft daylight." +# Parity checks here are resolution-independent; 256x256 keeps them quick. The +# CUDA-graph check below captures at whatever (H, W) it sets. NOTE: the in-process +# graph-vs-fused PSNR is a coarse smoke check — it carries a cache-setup artifact +# of this harness. The authoritative bit-exactness gate for the served graph is +# the HTTP A/B (graph-on vs COSMOS3_DISABLE_CUDA_GRAPH=1), which is byte-identical +# at every resolution. +H = W = 256 +STEPS = 12 +GS = 6.0 +SEED = 42 +VIDEO_FRAMES = 17 # latent T = 1 + (17 - 1) // 4 = 5 + + +class _SdpaCacheHandle: + """In-process reference cache with the ``BatchedCacheManager`` surface the + DiT uses, backed by stored tensors + sdpa (same kernel as the fused pipeline). + Prefill stashes each layer's understanding K/V; every denoise step re-reads it. + + Also models the batched classifier-free-guidance plan: when both guidance + branches run in one forward, ``run_attention`` receives the two branches + concatenated and routes each half to its own label's cached prefix, so the + batched result equals running the branches sequentially. + """ + + def __init__(self): + self.active = "main" + self.layer = 0 + self.committed: dict[tuple[str, int], tuple[torch.Tensor, torch.Tensor]] = {} + self.pending: dict[tuple[str, int], tuple[torch.Tensor, torch.Tensor]] = {} + self.is_causal: dict[str, bool] = {} + self.batched_labels: list[str] | None = None + + def set_active_label(self, label): + self.active = label + + def set_layer_idx(self, i): + self.layer = i + + def plan_attention(self, seq_lens=None, dtype=None, is_causal=True, write_store=True, label=None): + self.is_causal[label or self.active] = is_causal + + def plan_attention_batched_cfg(self, labels, seq_lens, is_causal=False, write_store=False, **kwargs): + self.batched_labels = list(labels) + self.is_causal["_cfg_batched"] = is_causal + + def plan_rope(self, *args, **kwargs): + pass + + @staticmethod + def _sdpa(q, k, v, is_causal): + out = F.scaled_dot_product_attention( + q.unsqueeze(0).transpose(1, 2), k.unsqueeze(0).transpose(1, 2), + v.unsqueeze(0).transpose(1, 2), is_causal=is_causal, enable_gqa=True, + ) + return out.transpose(1, 2).squeeze(0) + + def _attend_label(self, label, layer, q, k, v, causal): + key = (label, layer) + if key in self.committed: + pk, pv = self.committed[key] + return self._sdpa(q, torch.cat([pk, k], 0), torch.cat([pv, v], 0), causal) + self.pending[key] = (k, v) + return self._sdpa(q, k, v, causal) + + def run_attention(self, q, k, v, layer_idx=None): + layer = self.layer if layer_idx is None else layer_idx + if self.active == "_cfg_batched": + causal = self.is_causal["_cfg_batched"] + n = q.shape[0] // len(self.batched_labels) + outs = [] + for bi, label in enumerate(self.batched_labels): + sl = slice(bi * n, (bi + 1) * n) + outs.append(self._attend_label(label, layer, q[sl], k[sl], v[sl], causal)) + return torch.cat(outs, 0) + return self._attend_label(self.active, layer, q, k, v, self.is_causal[self.active]) + + def advance_seq_lens(self, pos_id_ns=None): + self.committed.update(self.pending) + self.pending = {} + + +def _flashinfer_cache(model, rid, device, dtype): + from mstar.communication.tensors import LocalTransferEngine + from mstar.engine.cache_manager import BatchedCacheManager, WorkspaceBufferManager + from mstar.engine.kv_store import PagedAllocationManager, TransferEngineInfo + from mstar.model.cosmos3.submodules import COND_LABEL, UNCOND_LABEL + + cfg = model.get_kv_cache_config()[0] + cfg.max_num_pages = 64 + cfg.shard(1) + kv_cache = torch.zeros( + cfg.num_layers, cfg.max_num_pages, 2, cfg.page_size, cfg.num_kv_heads, cfg.head_dim, + dtype=dtype, device=device, + ) + alloc = PagedAllocationManager(cfg, kv_cache, TransferEngineInfo("h", "h", LocalTransferEngine("h"))) + alloc.add_request(rid, [COND_LABEL, UNCOND_LABEL]) + return BatchedCacheManager( + request_ids=[rid], active_labels_per_request={rid: COND_LABEL}, kv_cache=kv_cache, + alloc_manager=alloc, buffer_manager=WorkspaceBufferManager(256 * 1024 * 1024, device), + kv_cache_config=cfg, device=device, auto_write_store=False, + ) + + +@torch.no_grad() +def _run_cache_once(model, dit, cm, init, cond_ids, uncond_ids, device, num_frames): + from mstar.conductor.request_info import CurrentForwardPassInfo + from mstar.model.submodule_base import ModelInputsFromEngine + + rid = "r0" + md = {"height": H, "width": W, "num_frames": num_frames, "fps": 24.0, + "guidance_scale": GS, "num_inference_steps": STEPS} + fwd = CurrentForwardPassInfo( + request_id=rid, graph_walk="prefill", requires_cfg=(GS != 1.0), + fwd_index=0, random_seed=SEED, max_tokens=0, sampling_config={}, step_metadata=md, + ) + ei = ModelInputsFromEngine(request_ids=[rid], per_request_info={rid: fwd}, cache_manager=cm) + text_inputs = [ + torch.tensor(cond_ids, dtype=torch.long, device=device), + torch.tensor(uncond_ids, dtype=torch.long, device=device), + ] + ni = dit.prepare_inputs("prefill", fwd, {"text_inputs": text_inputs}) + dit.forward("prefill", ei, **dit.preprocess("prefill", ei, [ni])) + + latents = init.clone() + time_index = torch.zeros(1, dtype=torch.long, device=device) + fwd.graph_walk = "image_gen" + for _ in range(STEPS): + ni = dit.prepare_inputs("image_gen", fwd, {"latents": [latents], "time_index": [time_index]}) + out = dit.forward("image_gen", ei, **dit.preprocess("image_gen", ei, [ni])) + latents, time_index = out["latents"][0], out["time_index"][0] + dit.cleanup_request(rid) + return latents + + +def _flashinfer_shared(model, rids, device, dtype): + """A KV cache + paged allocator shared by several requests, each with both + guidance labels (mirrors the engine's persistent per-node cache).""" + from mstar.communication.tensors import LocalTransferEngine + from mstar.engine.cache_manager import WorkspaceBufferManager + from mstar.engine.kv_store import PagedAllocationManager, TransferEngineInfo + from mstar.model.cosmos3.submodules import COND_LABEL, UNCOND_LABEL + + cfg = model.get_kv_cache_config()[0] + cfg.max_num_pages = 256 + cfg.shard(1) + kv_cache = torch.zeros( + cfg.num_layers, cfg.max_num_pages, 2, cfg.page_size, cfg.num_kv_heads, cfg.head_dim, + dtype=dtype, device=device, + ) + alloc = PagedAllocationManager(cfg, kv_cache, TransferEngineInfo("h", "h", LocalTransferEngine("h"))) + for rid in rids: + alloc.add_request(rid, [COND_LABEL, UNCOND_LABEL]) + buf = WorkspaceBufferManager(256 * 1024 * 1024, device) + return {"kv_cache": kv_cache, "alloc": alloc, "buf": buf, "cfg": cfg, "device": device} + + +def _mk_cm(shared, rids): + from mstar.engine.cache_manager import BatchedCacheManager + from mstar.model.cosmos3.submodules import COND_LABEL + + return BatchedCacheManager( + request_ids=rids, active_labels_per_request={r: COND_LABEL for r in rids}, + kv_cache=shared["kv_cache"], alloc_manager=shared["alloc"], buffer_manager=shared["buf"], + kv_cache_config=shared["cfg"], device=shared["device"], auto_write_store=False, + ) + + +@torch.no_grad() +def _run_batched(model, dit, shared, init, conds, unconds, device, rids): + """Prefill each request (sequential, like the engine), then run the whole + denoise loop as one batched step per iteration. Returns final latents per rid.""" + from mstar.conductor.request_info import CurrentForwardPassInfo + from mstar.model.submodule_base import ModelInputsFromEngine + + md = {"height": H, "width": W, "num_frames": 1, "fps": 24.0, + "guidance_scale": GS, "num_inference_steps": STEPS} + fwds = {} + for i, rid in enumerate(rids): + fwd = CurrentForwardPassInfo( + request_id=rid, graph_walk="prefill", requires_cfg=True, fwd_index=0, + random_seed=SEED, max_tokens=0, sampling_config={}, step_metadata=md, + ) + fwds[rid] = fwd + cm1 = _mk_cm(shared, [rid]) + ei1 = ModelInputsFromEngine(request_ids=[rid], per_request_info={rid: fwd}, cache_manager=cm1) + ti = [torch.tensor(conds[i], dtype=torch.long, device=device), + torch.tensor(unconds[i], dtype=torch.long, device=device)] + ni = dit.prepare_inputs("prefill", fwd, {"text_inputs": ti}) + dit.forward("prefill", ei1, **dit.preprocess("prefill", ei1, [ni])) + + cmN = _mk_cm(shared, rids) + eiN = ModelInputsFromEngine(request_ids=rids, per_request_info=fwds, cache_manager=cmN) + for rid in rids: + fwds[rid].graph_walk = "image_gen" + latents = {rid: init.clone() for rid in rids} + time_index = {rid: torch.zeros(1, dtype=torch.long, device=device) for rid in rids} + for _ in range(STEPS): + inputs = [ + dit.prepare_inputs("image_gen", fwds[rid], + {"latents": [latents[rid]], "time_index": [time_index[rid]]}) + for rid in rids + ] + out = dit.forward_batched("image_gen", eiN, **dit.preprocess("image_gen", eiN, inputs)) + for rid in rids: + latents[rid], time_index[rid] = out[rid]["latents"][0], out[rid]["time_index"][0] + for rid in rids: + dit.cleanup_request(rid) + return latents + + +_SETUP_CACHE: dict = {} + + +def _load(): + """Load the model / DiT / fused pipeline once (mode-independent).""" + if "base" in _SETUP_CACHE: + return _SETUP_CACHE["base"] + snap = os.environ.get("COSMOS3_NANO_DIR") + if not snap or not torch.cuda.is_available(): + _SETUP_CACHE["base"] = None + return None + torch.use_deterministic_algorithms(True, warn_only=True) + from mstar.model.cosmos3.cosmos3_model import Cosmos3Model + from mstar.model.cosmos3.pipeline import Cosmos3Pipeline + + device, dtype = "cuda:0", torch.bfloat16 + model = Cosmos3Model(model_path_hf=snap) + mpipe = Cosmos3Pipeline.from_model(model, device=device, dtype=dtype) + dit = model.get_submodule("dit", device=device) # shares mpipe's transformer + _SETUP_CACHE["base"] = dict(model=model, mpipe=mpipe, dit=dit, device=device, dtype=dtype) + return _SETUP_CACHE["base"] + + +def _scenario(num_frames): + """Per-mode context: video-aware token ids, shared initial latents, and the + fused-pipeline latents the cache-once path must reproduce.""" + key = f"frames{num_frames}" + if key in _SETUP_CACHE: + return _SETUP_CACHE[key] + base = _load() + if base is None: + _SETUP_CACHE[key] = None + return None + from mstar.model.cosmos3.packing import tokenize_prompt + + device, dtype, mpipe = base["device"], base["dtype"], base["mpipe"] + cond_ids, uncond_ids = tokenize_prompt( + base["model"].tokenizer, PROMPT, "", num_frames=num_frames, height=H, width=W + ) + lat_t = 1 if num_frames == 1 else 1 + (num_frames - 1) // mpipe.vae_scale_temporal + gen = torch.Generator(device=device).manual_seed(SEED) + init = torch.randn((1, 48, lat_t, H // 16, W // 16), generator=gen, device=device, dtype=dtype) + lat_fused = mpipe( + prompt=PROMPT, negative_prompt="", num_frames=num_frames, height=H, width=W, + num_inference_steps=STEPS, guidance_scale=GS, latents=init.clone(), decode=False, + ) + ctx = dict(cond=cond_ids, uncond=uncond_ids, init=init, lat_fused=lat_fused, num_frames=num_frames, **base) + _SETUP_CACHE[key] = ctx + return ctx + + +def _check_cache_once_exact(num_frames, tag): + ctx = _scenario(num_frames) + if ctx is None: + print(f" (skipped {tag} cache-once parity: needs COSMOS3_NANO_DIR + CUDA)") + return + dit = ctx["dit"] + prev = dit.batched_cfg + # The sequential guidance path matches the fused pipeline bit-for-bit; the + # batched path differs only in bf16 GEMM rounding (covered by the PSNR checks). + dit.batched_cfg = False + try: + lat = _run_cache_once( + ctx["model"], dit, _SdpaCacheHandle(), ctx["init"], ctx["cond"], ctx["uncond"], + ctx["device"], num_frames, + ) + finally: + dit.batched_cfg = prev + diff = (ctx["lat_fused"].float() - lat.reshape(ctx["lat_fused"].shape).float()).abs().max().item() + assert diff <= 1e-3, f"{tag} cache-once latents differ from fused by {diff:.3e} (> 1e-3)" + print(f" {tag} cache-once (sdpa) latent abs-max diff = {diff:.3e}") + + +def _check_engine_psnr(num_frames, tag): + ctx = _scenario(num_frames) + if ctx is None: + print(f" (skipped {tag} engine cache parity: needs COSMOS3_NANO_DIR + CUDA)") + return + try: + cm = _flashinfer_cache(ctx["model"], "r0", ctx["device"], ctx["dtype"]) + except Exception as exc: # noqa: BLE001 + print(f" (skipped {tag} engine cache parity: FlashInfer unavailable: {exc})") + return + lat = _run_cache_once( + ctx["model"], ctx["dit"], cm, ctx["init"], ctx["cond"], ctx["uncond"], ctx["device"], num_frames, + ) + img_fused = ctx["mpipe"]._decode(ctx["lat_fused"]).squeeze().float().cpu() + img_engine = ctx["mpipe"]._decode(lat.reshape(ctx["lat_fused"].shape)).squeeze().float().cpu() + mse = (img_fused - img_engine).pow(2).mean().item() + psnr = float("inf") if mse == 0 else -10 * math.log10(mse) + assert psnr >= 30, f"{tag} engine-path PSNR {psnr:.2f} < 30 (MSE {mse:.3e})" + print(f" {tag} engine cache path (flashinfer) PSNR = {psnr:.2f} dB") + + +@torch.no_grad() +def _check_dense_fa3(num_frames, tag): + """Dense FlashAttention-3 generation attention vs the paged FlashInfer path. + Both attend each guidance branch's generation tokens over its frozen text + prefix; they differ only in the attention kernel (FA3 over a gathered + contiguous [prefix | gen] vs FlashInfer paged) and its bf16 rounding. So the + decoded images must match closely, and the dense path must clear the same + fused-reference bar the paged path meets.""" + ctx = _scenario(num_frames) + if ctx is None: + print(f" (skipped {tag} dense-FA3 parity: needs COSMOS3_NANO_DIR + CUDA)") + return + had = os.environ.pop("COSMOS3_DENSE_FA3", None) + try: + cm = _flashinfer_cache(ctx["model"], "r0", ctx["device"], ctx["dtype"]) + lat_paged = _run_cache_once( + ctx["model"], ctx["dit"], cm, ctx["init"], ctx["cond"], ctx["uncond"], + ctx["device"], num_frames, + ) + os.environ["COSMOS3_DENSE_FA3"] = "1" + cm2 = _flashinfer_cache(ctx["model"], "r0", ctx["device"], ctx["dtype"]) + lat_dense = _run_cache_once( + ctx["model"], ctx["dit"], cm2, ctx["init"], ctx["cond"], ctx["uncond"], + ctx["device"], num_frames, + ) + except Exception as exc: # noqa: BLE001 + print(f" (skipped {tag} dense-FA3 parity: FA3/FlashInfer unavailable: {exc})") + return + finally: + if had is None: + os.environ.pop("COSMOS3_DENSE_FA3", None) + else: + os.environ["COSMOS3_DENSE_FA3"] = had + shape = ctx["lat_fused"].shape + img_fused = ctx["mpipe"]._decode(ctx["lat_fused"]).squeeze().float().cpu() + img_paged = ctx["mpipe"]._decode(lat_paged.reshape(shape)).squeeze().float().cpu() + img_dense = ctx["mpipe"]._decode(lat_dense.reshape(shape)).squeeze().float().cpu() + + def _psnr(a, b): + mse = (a - b).pow(2).mean().item() + return float("inf") if mse == 0 else -10 * math.log10(mse) + + vs_paged = _psnr(img_dense, img_paged) + vs_fused = _psnr(img_dense, img_fused) + # The dense path must match the fused reference as well as the paged engine + # path does (>= 30, the same bar), and the two engine kernels must agree to + # within their bf16 rounding (a real ordering/gather bug tanks this < 15). + assert vs_fused >= 30, f"{tag} dense-FA3 vs fused PSNR {vs_fused:.2f} < 30" + assert vs_paged >= 30, f"{tag} dense-FA3 vs paged PSNR {vs_paged:.2f} < 30" + print(f" {tag} dense-FA3 PSNR vs paged = {vs_paged:.2f} dB, vs fused = {vs_fused:.2f} dB") + + +def test_dense_fa3_image_psnr() -> None: + _check_dense_fa3(1, "t2i") + + +def test_dense_fa3_video_psnr() -> None: + _check_dense_fa3(VIDEO_FRAMES, "t2v") + + +@torch.no_grad() +def test_anchor_encode_matches_full() -> None: + """Image-to-video only consumes latent frame 0, and the Wan VAE encodes it as + a standalone causal anchor, so encoding the single conditioning frame + (anchor_only=True) must give a bit-identical frame 0 to encoding the whole + repeat-padded clip — at a fraction of the cost.""" + base = _load() + if base is None: + print(" (skipped anchor-encode parity: needs COSMOS3_NANO_DIR + CUDA)") + return + dit, device = base["dit"], base["device"] + img = torch.rand(3, H, W, device=device) # [C, H, W] in [0, 1], like load_image + anchor = dit._encode_conditioning(img, H, W, VIDEO_FRAMES, device, anchor_only=True) + full = dit._encode_conditioning(img, H, W, VIDEO_FRAMES, device, anchor_only=False) + assert anchor.shape[2] == 1, f"anchor_only must encode one latent frame, got T={anchor.shape[2]}" + diff = (anchor[:, :, 0].float() - full[:, :, 0].float()).abs().max().item() + assert diff < 1e-4, f"anchor frame-0 differs from full-clip frame-0 by {diff:.3e} (> 1e-4)" + print(f" anchor-encode 1-frame vs full-clip frame-0 abs-max diff = {diff:.3e}") + + +@torch.no_grad() +def _check_compile_vae(num_frames, tag): + """torch.compile of the Wan VAE decode (COSMOS3_COMPILE_VAE) must reproduce + the eager decode. Compile fuses the pointwise epilogues around the (fp32) 3D + convolutions without changing their math, so the decoded uint8 frames match + the eager path to fp rounding; a real fusion/ordering bug shows up as visible + banding that tanks the PSNR. Checked for both a single image frame and a + multi-frame video clip (video is the lever's main beneficiary).""" + ctx = _scenario(num_frames) + if ctx is None: + print(f" (skipped {tag} compile-VAE parity: needs COSMOS3_NANO_DIR + CUDA)") + return + from mstar.model.cosmos3.submodules import Cosmos3VAEDecoderSubmodule + + model, lat = ctx["model"], ctx["lat_fused"] + vae, config = model._build_vae(ctx["device"]), model.config + walk = "video_gen" if num_frames > 1 else "image_gen" + out_key = "video_output" if num_frames > 1 else "image_output" + had = os.environ.pop("COSMOS3_COMPILE_VAE", None) + try: + eager = Cosmos3VAEDecoderSubmodule(vae=vae, config=config) + img_eager = eager.forward(walk, None, latents=lat.clone())[out_key][0] + os.environ["COSMOS3_COMPILE_VAE"] = "1" + compiled = Cosmos3VAEDecoderSubmodule(vae=vae, config=config) + img_comp = compiled.forward(walk, None, latents=lat.clone())[out_key][0] + except Exception as exc: # noqa: BLE001 + print(f" (skipped {tag} compile-VAE parity: VAE/compile unavailable: {exc})") + return + finally: + if had is None: + os.environ.pop("COSMOS3_COMPILE_VAE", None) + else: + os.environ["COSMOS3_COMPILE_VAE"] = had + a = img_eager.float().cpu() / 255.0 + b = img_comp.float().cpu() / 255.0 + maxdiff = (a - b).abs().max().item() * 255.0 + mse = (a - b).pow(2).mean().item() + psnr = float("inf") if mse == 0 else -10 * math.log10(mse) + assert psnr >= 40, f"{tag} compile-VAE vs eager PSNR {psnr:.2f} < 40 (max uint8 diff {maxdiff:.0f})" + print(f" {tag} compile-VAE vs eager decoded PSNR = {psnr:.2f} dB (max uint8 diff {maxdiff:.0f})") + + +def test_compile_vae_matches_eager() -> None: + _check_compile_vae(1, "t2i") + + +def test_compile_vae_matches_eager_t2v() -> None: + _check_compile_vae(VIDEO_FRAMES, "t2v") + + +@torch.no_grad() +def test_batched_cfg_matches_sequential() -> None: + """Running both guidance branches in one batched forward must match running + them sequentially. The two paths differ only in bf16 GEMM rounding (a batched + matmul tiles differently), so compare the decoded images by PSNR.""" + ctx = _scenario(1) + if ctx is None: + print(" (skipped batched-CFG vs sequential: needs COSMOS3_NANO_DIR + CUDA)") + return + dit, prev, decoded = ctx["dit"], ctx["dit"].batched_cfg, {} + try: + for flag in (False, True): + dit.batched_cfg = flag + try: + cm = _flashinfer_cache(ctx["model"], "r0", ctx["device"], ctx["dtype"]) + except Exception as exc: # noqa: BLE001 + print(f" (skipped batched-CFG vs sequential: FlashInfer unavailable: {exc})") + return + lat = _run_cache_once( + ctx["model"], dit, cm, ctx["init"], ctx["cond"], ctx["uncond"], ctx["device"], 1 + ) + decoded[flag] = ctx["mpipe"]._decode(lat.reshape(ctx["lat_fused"].shape)).squeeze().float().cpu() + finally: + dit.batched_cfg = prev + mse = (decoded[False] - decoded[True]).pow(2).mean().item() + psnr = float("inf") if mse == 0 else -10 * math.log10(mse) + assert psnr >= 35, f"batched vs sequential PSNR {psnr:.2f} < 35 (MSE {mse:.3e})" + print(f" batched-CFG vs sequential decoded PSNR = {psnr:.2f} dB") + + +def test_cache_once_matches_fused_exact() -> None: + _check_cache_once_exact(1, "t2i") + + +def test_engine_cache_path_image_psnr() -> None: + _check_engine_psnr(1, "t2i") + + +def test_cache_once_matches_fused_exact_t2v() -> None: + _check_cache_once_exact(VIDEO_FRAMES, "t2v") + + +def test_engine_cache_path_video_psnr() -> None: + _check_engine_psnr(VIDEO_FRAMES, "t2v") + + +@torch.no_grad() +def test_cross_request_batch_matches_individual() -> None: + """Several requests denoised together in one batch must reproduce each + request run alone. Distinct prompts are decoded and compared to the fused + pipeline: batching must (a) keep each request isolated — its own image far + closer than any other request's — and (b) not lose quality versus the bs=1 + path (per-prompt fidelity varies with the FlashInfer kernel, so the bar is + relative to bs=1, not an absolute PSNR).""" + base = _load() + if base is None: + print(" (skipped cross-request batch parity: needs COSMOS3_NANO_DIR + CUDA)") + return + from mstar.model.cosmos3.packing import tokenize_prompt + + model, dit, mpipe = base["model"], base["dit"], base["mpipe"] + device, dtype = base["device"], base["dtype"] + prompts = [ + "A red cube resting on a polished wooden table, soft daylight.", + "A blue ceramic vase of yellow tulips beside a sunny window.", + "A small wooden sailboat on a calm turquoise sea at dawn.", + "A snowy mountain peak under a clear starry night sky.", + ] + rids = [f"r{i}" for i in range(len(prompts))] + conds, unconds = [], [] + for p in prompts: + c, u = tokenize_prompt(model.tokenizer, p, "", num_frames=1, height=H, width=W) + conds.append(c) + unconds.append(u) + gen = torch.Generator(device=device).manual_seed(SEED) + init = torch.randn((1, 48, 1, H // 16, W // 16), generator=gen, device=device, dtype=dtype) + shape = (1, 48, 1, H // 16, W // 16) + + def _dec(lat): + return mpipe._decode(lat.reshape(shape)).squeeze().float().cpu() + + def _psnr(a, b): + mse = (a - b).pow(2).mean().item() + return float("inf") if mse == 0 else -10 * math.log10(mse) + + try: + fused = [ + _dec(mpipe(prompt=p, negative_prompt="", num_frames=1, height=H, width=W, + num_inference_steps=STEPS, guidance_scale=GS, latents=init.clone(), decode=False)) + for p in prompts + ] + bs1 = [] + for i, rid in enumerate(rids): + cm = _flashinfer_cache(model, "r0", device, dtype) + bs1.append(_dec(_run_cache_once(model, dit, cm, init, conds[i], unconds[i], device, 1))) + except Exception as exc: # noqa: BLE001 + print(f" (skipped cross-request batch parity: FlashInfer unavailable: {exc})") + return + + shared = _flashinfer_shared(model, rids, device, dtype) + bat = _run_batched(model, dit, shared, init, conds, unconds, device, rids) + batched = [_dec(bat[rid]) for rid in rids] + + n = len(prompts) + for i in range(n): + match = _psnr(batched[i], fused[i]) + cross = max(_psnr(batched[i], fused[j]) for j in range(n) if j != i) + ref = _psnr(bs1[i], fused[i]) + assert match > cross + 8, f"request {i} not isolated: self {match:.2f} vs other {cross:.2f}" + assert match >= ref - 3.0, f"request {i} batched {match:.2f} degrades vs bs=1 {ref:.2f}" + print(f" cross-request batch (bs={n}) vs fused PSNR = " + + ", ".join(f"{_psnr(batched[i], fused[i]):.1f}" for i in range(n)) + + " dB (bs=1: " + ", ".join(f"{_psnr(bs1[i], fused[i]):.1f}" for i in range(n)) + ")") + # This test holds several requests' caches at once; release them so later + # GPU checks in the same process aren't starved. + del fused, bs1, batched, bat, shared + import gc + gc.collect() + torch.cuda.empty_cache() + + +@torch.no_grad() +def _run_cuda_graph_denoise(ctx): + """Capture the image denoise step and run the whole loop through the real + CudaGraphRunner (one captured forward per step covering both guidance + branches), returning the final latents.""" + from mstar.conductor.request_info import CurrentForwardPassInfo + from mstar.distributed.communication import TPCommGroup + from mstar.engine.cuda_graph_runner import CudaGraphRunner + from mstar.model.submodule_base import ModelInputsFromEngine + from mstar.utils.sampling import Sampler, SamplingConfig + + model, dit = ctx["model"], ctx["dit"] + device, dtype = ctx["device"], ctx["dtype"] + dev = torch.device(device) + # Capture at this test's (H, W) regardless of the production default. + dit.gen_capture_resolutions = ((H, W),) + rid = "cgr0" + shared = _flashinfer_shared(model, [rid], device, dtype) + md = {"height": H, "width": W, "num_frames": 1, "fps": 24.0, + "guidance_scale": GS, "num_inference_steps": STEPS} + fwd = CurrentForwardPassInfo( + request_id=rid, graph_walk="prefill", requires_cfg=False, fwd_index=0, + random_seed=SEED, max_tokens=0, sampling_config={}, step_metadata=md, + ) + cm = _mk_cm(shared, [rid]) + ei = ModelInputsFromEngine(request_ids=[rid], per_request_info={rid: fwd}, cache_manager=cm) + ti = [torch.tensor(ctx["cond"], dtype=torch.long, device=device), + torch.tensor(ctx["uncond"], dtype=torch.long, device=device)] + ni = dit.prepare_inputs("prefill", fwd, {"text_inputs": ti}) + dit.forward("prefill", ei, **dit.preprocess("prefill", ei, [ni])) + + runner = CudaGraphRunner( + submodule_name="dit", submodule=dit, kv_cache_config=shared["cfg"], + alloc_manager=shared["alloc"], sampler=Sampler(device=dev, tp_group=TPCommGroup.trivial()), + buffer_manager=shared["buf"], device=dev, autocast_dtype=dtype, + default_sampling_config=SamplingConfig(), tp_group=TPCommGroup.trivial(), + ) + runner.warmup_and_capture() + assert runner.graphs, "no CUDA graph captured for cosmos3 image_gen" + runner.register_request(rid) + + fwd.graph_walk = "image_gen" + latents = ctx["init"].clone() + time_index = torch.zeros(1, dtype=torch.long, device=device) + for _ in range(STEPS): + ni = dit.prepare_inputs("image_gen", fwd, {"latents": [latents], "time_index": [time_index]}) + out = runner.run( + graph_walk="image_gen", requires_cfg=False, request_ids=[rid], + inputs=[ni], per_request_info={rid: fwd}, submodule=dit, + ) + latents, time_index = out[rid]["latents"][0], out[rid]["time_index"][0] + dit.cleanup_request(rid) + return latents + + +@torch.no_grad() +def test_cuda_graph_matches_eager() -> None: + """The captured-graph denoise step is the served path's accelerator: both + guidance branches run in one captured forward (~2x faster than the eager + step). Each captured forward matches eager to within bf16 (the first step + differs by ~one ULP); the multistep solver amplifies that into a small latent + spread, but the decoded image is unchanged — so gate the decoded image against + the fused pipeline, the same bar the eager engine path meets.""" + ctx = _scenario(1) + if ctx is None: + print(" (skipped cuda-graph parity: needs COSMOS3_NANO_DIR + CUDA)") + return + try: + lat_graph = _run_cuda_graph_denoise(ctx) + except Exception as exc: # noqa: BLE001 + print(f" (skipped cuda-graph parity: FlashInfer/capture unavailable: {exc})") + return + img_fused = ctx["mpipe"]._decode(ctx["lat_fused"]).squeeze().float().cpu() + img_graph = ctx["mpipe"]._decode(lat_graph.reshape(ctx["lat_fused"].shape)).squeeze().float().cpu() + mse = (img_fused - img_graph).pow(2).mean().item() + psnr = float("inf") if mse == 0 else -10 * math.log10(mse) + assert psnr >= 25, f"cuda-graph denoise PSNR {psnr:.2f} < 25 (MSE {mse:.3e})" + print(f" cuda-graph denoise vs fused PSNR = {psnr:.2f} dB") + + +def _main() -> None: + failures = [] + for name, fn in [ + ("batched_cfg_matches_sequential", test_batched_cfg_matches_sequential), + ("cache_once_matches_fused_exact", test_cache_once_matches_fused_exact), + ("engine_cache_path_image_psnr", test_engine_cache_path_image_psnr), + ("cache_once_matches_fused_exact_t2v", test_cache_once_matches_fused_exact_t2v), + ("engine_cache_path_video_psnr", test_engine_cache_path_video_psnr), + ("dense_fa3_image_psnr", test_dense_fa3_image_psnr), + ("dense_fa3_video_psnr", test_dense_fa3_video_psnr), + ("anchor_encode_matches_full", test_anchor_encode_matches_full), + ("compile_vae_matches_eager", test_compile_vae_matches_eager), + ("compile_vae_matches_eager_t2v", test_compile_vae_matches_eager_t2v), + ("cuda_graph_matches_eager", test_cuda_graph_matches_eager), + ("cross_request_batch_matches_individual", test_cross_request_batch_matches_individual), + ]: + try: + fn() + print(f"PASS {name}") + except Exception as exc: # noqa: BLE001 + failures.append((name, exc)) + print(f"FAIL {name}: {exc!r}") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if failures: + raise SystemExit(1) + print("\nAll Cosmos3 engine-cache checks passed.") + + +if __name__ == "__main__": + _main() diff --git a/mstar/model/cosmos3/tests/test_loader.py b/mstar/model/cosmos3/tests/test_loader.py new file mode 100644 index 00000000..c990e2b0 --- /dev/null +++ b/mstar/model/cosmos3/tests/test_loader.py @@ -0,0 +1,170 @@ +"""CPU-only structural checks for the Cosmos3 model package. + +No GPU and no model weights are required: the config is parsed from the +checkpoint's JSON files, the backbone is built on the ``meta`` device (shapes +only, zero storage), and weight-key coverage is checked against the shard +index. Run directly (``python3 test_loader.py``) or via pytest. + +Point ``COSMOS3_NANO_DIR`` at a Cosmos3-Nano checkpoint directory (config + +tokenizer + shard index; the safetensors tensor data itself is not read). +""" + +from __future__ import annotations + +import os +from pathlib import Path + +import torch + +from mstar.model.cosmos3.components.transformer import Cosmos3OmniTransformer +from mstar.model.cosmos3.config import Cosmos3Config +from mstar.model.cosmos3.loader import ( + DROP_KEYS, + cosmos3_name_remapper, + read_transformer_weight_keys, + read_transformer_weight_shapes, +) + +NANO_DIR = Path( + os.environ.get( + "COSMOS3_NANO_DIR", + "/Users/atindrajha/Downloads/disaggregation_research/Cosmos3-Nano-hf", + ) +) + + +def test_config_roundtrip() -> None: + cfg = Cosmos3Config.from_pretrained(NANO_DIR) + + # Transformer dimensions (Nano). + assert cfg.num_hidden_layers == 36 + assert cfg.hidden_size == 4096 + assert cfg.num_attention_heads == 32 + assert cfg.num_key_value_heads == 8 + assert cfg.head_dim == 128 + assert cfg.intermediate_size == 12288 + assert cfg.vocab_size == 151936 + assert cfg.rms_norm_eps == 1e-6 + + # 3D interleaved mRoPE. + assert tuple(cfg.rope_axes_dim) == (24, 20, 20) + assert cfg.mrope_interleaved is True + assert cfg.rope_theta == 5_000_000.0 + assert cfg.unified_3d_mrope_temporal_modality_margin == 15000 + assert cfg.unified_3d_mrope_reset_spatial_ids is True + assert cfg.base_fps == 24 and cfg.enable_fps_modulation is True + + # Latent geometry / attention style. + assert cfg.latent_channel == 48 + assert cfg.latent_patch_size == 2 + assert cfg.patch_latent_dim == 192 + assert cfg.timestep_scale == 0.001 + assert cfg.joint_attn_implementation == "two_way" + assert cfg.use_moe is True + assert cfg.qk_norm_for_diffusion is True and cfg.qk_norm_for_text is True + + # Capability flags / modality heads. + assert cfg.action_gen is True and cfg.max_action_dim == 64 + assert cfg.num_embodiment_domains == 32 + assert cfg.sound_gen is True and cfg.sound_dim == 64 + + # VAE (AutoencoderKLWan) geometry + normalization stats. + assert cfg.vae.z_dim == 48 + assert cfg.vae.scale_factor_spatial == 16 + assert cfg.vae.scale_factor_temporal == 4 + assert len(cfg.vae.latents_mean) == 48 + assert len(cfg.vae.latents_std) == 48 + + # UniPC flow scheduler. + assert cfg.scheduler.scheduler_type == "unipc" + assert cfg.scheduler.prediction_type == "flow_prediction" + assert cfg.scheduler.predict_x0 is True + assert cfg.scheduler.solver_order == 2 + assert cfg.scheduler.solver_type == "bh2" + assert cfg.scheduler.use_flow_sigmas is True + assert cfg.scheduler.use_karras_sigmas is True + + +def test_loader_key_coverage() -> None: + cfg = Cosmos3Config.from_pretrained(NANO_DIR) + with torch.device("meta"): + model = Cosmos3OmniTransformer(cfg) + + model_keys = set(model.state_dict().keys()) + index_keys = read_transformer_weight_keys(NANO_DIR) + + # The only intentionally-dropped key is the unused text lm_head. + dropped = {k for k in index_keys if cosmos3_name_remapper(k) is None} + assert dropped == set(DROP_KEYS), dropped + + mapped = {cosmos3_name_remapper(k) for k in index_keys} + mapped.discard(None) + + missing = model_keys - mapped # backbone params with no checkpoint key + unexpected = mapped - model_keys # checkpoint keys with no backbone param + assert not missing, f"backbone params not covered by checkpoint: {sorted(missing)[:20]}" + assert not unexpected, f"checkpoint keys with no backbone param: {sorted(unexpected)[:20]}" + + # Sanity on the exact counts: 36 layers * 22 + 22 non-layer == 814; drop lm_head -> 813. + assert len(index_keys) == 814, len(index_keys) + assert len(model_keys) == 813, len(model_keys) + + +def test_loader_shape_coverage() -> None: + """Every backbone param's *shape* matches the checkpoint tensor it loads + from. Reads only safetensors headers (no tensor data, CPU-safe). Returns + early if the shards are LFS pointers (asset-only clone) rather than real + weights. Complements the name-only coverage check — it is what would have + caught a wrong per-domain action-projection shape before a GPU load. + """ + cfg = Cosmos3Config.from_pretrained(NANO_DIR) + with torch.device("meta"): + model = Cosmos3OmniTransformer(cfg) + + try: + ckpt_shapes = read_transformer_weight_shapes(NANO_DIR) + except Exception as exc: # noqa: BLE001 — LFS pointer / missing shards + print(f" (shape check skipped: transformer shards unreadable: {exc})") + return + + model_shapes = {k: tuple(v.shape) for k, v in model.state_dict().items()} + # The remapper is identity for backbone keys, so model key == checkpoint key. + mismatched = { + k: {"model": s, "ckpt": ckpt_shapes.get(k)} + for k, s in model_shapes.items() + if s != ckpt_shapes.get(k) + } + assert not mismatched, mismatched + + +def test_tokenizer_roundtrip() -> None: + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(str(NANO_DIR / "text_tokenizer")) + prompt = "A red cube resting on a polished wooden table, soft daylight." + ids = tok(prompt, add_special_tokens=False)["input_ids"] + assert len(ids) > 0 + assert tok.decode(ids) == prompt + + +def _main() -> None: + failures = [] + for name, fn in [ + ("config_roundtrip", test_config_roundtrip), + ("loader_key_coverage", test_loader_key_coverage), + ("loader_shape_coverage", test_loader_shape_coverage), + ("tokenizer_roundtrip", test_tokenizer_roundtrip), + ]: + try: + fn() + print(f"PASS {name}") + except Exception as exc: # noqa: BLE001 + failures.append((name, exc)) + print(f"FAIL {name}: {exc!r}") + if failures: + raise SystemExit(1) + print("\nAll Cosmos3 structural checks passed.") + + +if __name__ == "__main__": + _main() diff --git a/mstar/model/cosmos3/tests/test_serving.py b/mstar/model/cosmos3/tests/test_serving.py new file mode 100644 index 00000000..c4704380 --- /dev/null +++ b/mstar/model/cosmos3/tests/test_serving.py @@ -0,0 +1,202 @@ +"""CPU-only checks for the Cosmos3 OpenAI-serving entry points. + +Covers the request -> model wiring that the engine relies on: prompt +tokenization into a conditional + unconditional pair, generation-parameter +resolution + step-metadata threading, and the OpenAI image adapter. No GPU and +no model weights are required. The prompt-tokenization check needs a real +tokenizer, so point ``COSMOS3_NANO_DIR`` at a Cosmos3-Nano directory to run it +(it is skipped otherwise). +""" + +from __future__ import annotations + +import os +from pathlib import Path + +import pytest + +from mstar.model.cosmos3.cosmos3_model import Cosmos3Model + +NANO_DIR = Path( + os.environ.get( + "COSMOS3_NANO_DIR", + "/Users/atindrajha/Downloads/disaggregation_research/Cosmos3-Nano-hf", + ) +) + + +def test_adapter_registered_for_images() -> None: + from mstar.api_server.openai.adapters import get_adapter + + adapter = get_adapter("cosmos3") + assert adapter is not None + assert adapter.supports_images + + class _Req: + prompt = "a red cube" + size = "512x512" + seed = 7 + + def __init__(self): + self.model_extra = {"guidance_scale": 4.0} + + args = adapter.image_to_request(_Req(), upload_dir="/tmp") + assert args.text == "a red cube" + assert args.output_modalities == ["image"] + assert args.model_kwargs["size"] == "512x512" + assert args.model_kwargs["seed"] == 7 + assert args.model_kwargs["guidance_scale"] == 4.0 + + +def test_video_adapter_t2v_and_i2v(tmp_path) -> None: + from mstar.api_server.openai.adapters import get_adapter + from mstar.api_server.openai.protocol import VideoGenerationRequest + + adapter = get_adapter("cosmos3") + assert adapter is not None and adapter.supports_videos + + # text-to-video: text-only input, video output, num_frames/fps threaded. + req = VideoGenerationRequest( + prompt="a kite", size="256x256", seed=1, num_frames=17, fps=16.0, + guidance_scale=6.0, + ) + args = adapter.video_to_request(req, upload_dir=str(tmp_path)) + assert args.text == "a kite" + assert args.input_modalities == ["text"] + assert args.output_modalities == ["video"] + assert args.file_paths is None + assert args.model_kwargs["num_frames"] == 17 + assert args.model_kwargs["fps"] == 16.0 + assert args.model_kwargs["guidance_scale"] == 6.0 + + # image-to-video: the conditioning image (data URI) is persisted and routed + # in as an image input; the worker VAE-encodes it into the frame-0 anchor. + i2v = adapter.video_to_request( + VideoGenerationRequest(prompt="zoom in", image="data:image/png;base64,AAAA"), + upload_dir=str(tmp_path), + ) + assert i2v.input_modalities == ["image", "text"] + assert i2v.output_modalities == ["video"] + assert i2v.file_paths and i2v.file_paths["image"] + + +def test_gen_params_and_step_metadata() -> None: + model = Cosmos3Model(model_path_hf="unused", skip_weight_loading=True) + + # "size" parses to width/height; explicit width/height win; defaults applied. + p = model._resolve_gen_params({"size": "480x256"}, ["text"], ["image"]) + assert (p["width"], p["height"]) == (480, 256) + assert p["num_frames"] == 1 and p["has_image_condition"] is False + + # The denoise loop stops per-request (check_stop), so a per-request + # num_inference_steps is honored, clamped to [1, max_inference_steps]; + # guidance_scale is likewise per request. + p = model._resolve_gen_params( + {"num_inference_steps": 3, "guidance_scale": 2.5}, ["text"], ["image"] + ) + assert p["num_inference_steps"] == 3 + assert p["guidance_scale"] == 2.5 + # A request above the loop's upper bound is clamped; the image/video defaults + # differ by mode. + assert model._resolve_gen_params( + {"num_inference_steps": 10_000}, ["text"], ["image"] + )["num_inference_steps"] == model.config.max_inference_steps + assert model._resolve_gen_params({}, ["text"], ["image"])[ + "num_inference_steps" + ] == model.config.num_inference_steps + assert model._resolve_gen_params({"num_frames": 17}, ["text"], ["video"])[ + "num_inference_steps" + ] == model.config.num_inference_steps_video + + # i2v conditioning is inferred from the input modalities. + p = model._resolve_gen_params({}, ["image", "text"], ["image"]) + assert p["has_image_condition"] is True + + fpa = model.get_initial_forward_pass_args( + "p0", ["text"], ["image"], {"text_inputs": []}, + model_kwargs={"size": "256x256", "num_inference_steps": 7}, + ) + sm = fpa.step_metadata + assert sm["is_prefill"] is True + assert sm["height"] == 256 and sm["width"] == 256 + assert sm["num_inference_steps"] == 7 + + +def test_dynamic_loop_check_stop_and_wasted_step() -> None: + """The denoise loop stops at each request's own step count, and a step + dispatched one past that count is a no-op — so the loop's single speculative + extra iteration can't index the scheduler out of range.""" + import types + + import torch + + from mstar.model.cosmos3.submodules import ( + ACTION_GEN_LOOP, + ACTION_GEN_WALK, + Cosmos3DiTSubmodule, + IMAGE_GEN_LOOP, + IMAGE_GEN_WALK, + ) + + dit = Cosmos3DiTSubmodule(transformer=None, config=Cosmos3Model( + model_path_hf="unused", skip_weight_loading=True).config, scheduler=None) + + class _Sched: + def __init__(self, n): + self.timesteps = list(range(n)) + + n = 4 + dit._req["r"] = {"scheduler": _Sched(n), "raw_action_dim": 2} + + def info(walk, it): + return types.SimpleNamespace( + graph_walk=walk, + dynamic_loop_iter_counts={IMAGE_GEN_LOOP: it, ACTION_GEN_LOOP: it}, + ) + + # Stops only on the last real step (iter n-1), not before; routes by walk. + assert dit.check_stop("r", info(IMAGE_GEN_WALK, n - 2), {}) == set() + assert dit.check_stop("r", info(IMAGE_GEN_WALK, n - 1), {}) == {IMAGE_GEN_LOOP} + assert dit.check_stop("r", info(ACTION_GEN_WALK, n - 1), {}) == {ACTION_GEN_LOOP} + # Unknown request -> no stop. + assert dit.check_stop("missing", info(IMAGE_GEN_WALK, 0), {}) == set() + + # A forward one past the step count returns its inputs unchanged without + # touching the transformer or cache manager (both None here). + lat = torch.zeros(1, 4, 1, 2, 2) + ti = torch.tensor([n]) + out = dit._forward_image_gen(None, dit._req["r"], latents=lat, time_index=ti) + assert torch.equal(out["latents"][0], lat) and torch.equal(out["time_index"][0], ti) + + act = torch.zeros(1, 3, 5) + out = dit._forward_action_gen( + None, dit._req["r"], latents=lat, action_latents=act, time_index=ti + ) + assert torch.equal(out["latents"][0], lat) + # The action latents (the looped self-edge the loop emits on finish) pass + # through unchanged on the discarded extra step. + assert torch.equal(out["action_latents"][0], act) + + +@pytest.mark.skipif(not NANO_DIR.exists(), reason="set COSMOS3_NANO_DIR to a Cosmos3-Nano dir") +def test_process_prompt_emits_cond_and_uncond() -> None: + model = Cosmos3Model(model_path_hf=str(NANO_DIR)) + assert model.tokenizer is not None + sog = model.tokenizer.convert_tokens_to_ids("<|vision_start|>") + eos = model.tokenizer.eos_token_id + + out = model.process_prompt("a red cube on a table", ["text"], ["image"], tensors={}, size="256x256") + ti = out["text_inputs"] + assert len(ti) == 2, "t2i must emit a conditional and unconditional prompt" + cond, uncond = ti[0].tolist(), ti[1].tolist() + assert cond[-2:] == [eos, sog] + assert uncond[-2:] == [eos, sog] + assert cond != uncond + + +if __name__ == "__main__": + test_adapter_registered_for_images() + test_gen_params_and_step_metadata() + if NANO_DIR.exists(): + test_process_prompt_emits_cond_and_uncond() + print("PASS") diff --git a/mstar/model/cosmos3/tests/test_t2i.py b/mstar/model/cosmos3/tests/test_t2i.py new file mode 100644 index 00000000..969bb360 --- /dev/null +++ b/mstar/model/cosmos3/tests/test_t2i.py @@ -0,0 +1,195 @@ +"""Tests for the Cosmos3 t2i forward + packing. + +CPU-safe unit tests (tiny config) cover patchify/unpatchify, the 3D mRoPE id +helpers, the t2i packing assembly, and a full forward smoke test. An optional +GPU integration test (gated on ``COSMOS3_NANO_DIR`` + CUDA + diffusers) checks +the t2i image against the diffusers ``Cosmos3OmniPipeline``. + +Run CPU only: python3 test_t2i.py +Run with GPU: COSMOS3_NANO_DIR= python3 test_t2i.py +""" + +from __future__ import annotations + +import math +import os +from pathlib import Path + +import torch + +from mstar.model.cosmos3.components.transformer import Cosmos3OmniTransformer +from mstar.model.cosmos3.config import Cosmos3Config +from mstar.model.cosmos3.packing import ( + build_t2i_static_inputs, + get_3d_mrope_ids_text_tokens, + get_3d_mrope_ids_vae_tokens, +) + + +def _tiny_config() -> Cosmos3Config: + """A small, CPU-cheap Cosmos3 config with the same structure as Nano. + + head_dim // 2 == sum(rope_axes_dim) is required by the interleaved mRoPE; + patch_latent_dim == latent_patch_size**2 * latent_channel. + """ + return Cosmos3Config( + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + intermediate_size=128, + vocab_size=100, + rope_axes_dim=(4, 2, 2), + latent_channel=8, + latent_patch_size=2, + patch_latent_dim=32, + sound_gen=False, + action_gen=False, + ) + + +def test_patchify_unpatchify_roundtrip() -> None: + cfg = _tiny_config() + model = Cosmos3OmniTransformer(cfg) + p = cfg.latent_patch_size + x = torch.randn(1, cfg.latent_channel, 1, 4 * p, 3 * p) # [1,C,T=1,H,W], H/W divisible by p + packed, orig_shapes = model._patchify_and_pack_latents([x]) + assert packed.shape == (1 * 4 * 3, cfg.patch_latent_dim), packed.shape + assert orig_shapes == [(1, 4 * p, 3 * p)], orig_shapes + # All-noisy single frame -> unpatchify recovers x exactly. + token_shapes = [(1, 4, 3)] + recovered = model._unpatchify_and_unpack_latents( + packed, token_shapes, [torch.arange(1)], orig_shapes + )[0] + assert recovered.shape == x.shape + assert torch.allclose(recovered, x, atol=1e-6), (recovered - x).abs().max() + + +def test_mrope_ids_text() -> None: + ids, nxt = get_3d_mrope_ids_text_tokens(num_tokens=5, temporal_offset=3) + assert ids.shape == (3, 5) + assert torch.equal(ids[0], ids[1]) and torch.equal(ids[1], ids[2]) + assert ids[0].tolist() == [3, 4, 5, 6, 7] + assert nxt == 8 + + +def test_mrope_ids_vae() -> None: + # t2i: grid_t=1 -> no fps modulation; spatial reset keeps h/w as plain grids. + ids, _ = get_3d_mrope_ids_vae_tokens(grid_t=1, grid_h=2, grid_w=3, temporal_offset=10) + assert ids.shape == (3, 6) + assert ids[0].tolist() == [10] * 6 # all temporal positions == offset + assert ids[1].tolist() == [0, 0, 0, 1, 1, 1] # h grid + assert ids[2].tolist() == [0, 1, 2, 0, 1, 2] # w grid + + +def test_packing_t2i_structure() -> None: + cfg = Cosmos3Config() # Nano defaults + input_ids = list(range(7)) + latent_shape = (1, cfg.latent_channel, 1, 16, 16) + out = build_t2i_static_inputs(input_ids, latent_shape, cfg, vae_scale_factor_temporal=4, fps=24.0, device="cpu") + num_vision = 1 * 8 * 8 # patch grid 8x8 + assert out["und_len"] == 7 + assert out["sequence_length"] == 7 + num_vision + assert out["position_ids"].shape == (3, 7 + num_vision) + assert out["vision_sequence_indexes"].tolist() == list(range(7, 7 + num_vision)) + assert out["vision_token_shapes"] == [(1, 8, 8)] + # Vision temporal positions sit past the text + 15000 margin. + assert int(out["position_ids"][0, 7].item()) == 7 + cfg.unified_3d_mrope_temporal_modality_margin + + +def test_forward_smoke_cpu() -> None: + cfg = _tiny_config() + torch.manual_seed(0) + model = Cosmos3OmniTransformer(cfg).eval() + latent_shape = (1, cfg.latent_channel, 1, 4, 4) # patch grid 2x2 -> 4 vision tokens + static = build_t2i_static_inputs( + [1, 2, 3], latent_shape, cfg, vae_scale_factor_temporal=4, fps=24.0, device="cpu" + ) + fields = [ + "input_ids", "text_indexes", "position_ids", "und_len", "sequence_length", + "vision_token_shapes", "vision_sequence_indexes", "vision_mse_loss_indexes", + "vision_noisy_frame_indexes", + ] + with torch.no_grad(): + preds, sound = model( + vision_tokens=[torch.randn(latent_shape)], + vision_timesteps=torch.full((static["num_noisy_vision_tokens"],), 500.0), + **{k: static[k] for k in fields}, + ) + assert sound is None + assert preds[0].shape == latent_shape, preds[0].shape + assert torch.isfinite(preds[0]).all() + + +def test_t2i_parity_vs_diffusers() -> None: + """GPU integration: mstar DiT swapped into the diffusers pipeline yields a + bit-exact t2i image (deterministic cuBLAS). Skipped without GPU/checkpoint.""" + snap = os.environ.get("COSMOS3_NANO_DIR") + if not snap or not torch.cuda.is_available(): + print(" (skipped t2i parity: needs COSMOS3_NANO_DIR + CUDA)") + return + try: + from diffusers import AutoencoderKLWan, UniPCMultistepScheduler + from diffusers.models.transformers.transformer_cosmos3 import Cosmos3OmniTransformer as DTr + from diffusers.pipelines.cosmos.pipeline_cosmos3_omni import Cosmos3OmniPipeline + from transformers import AutoTokenizer + except Exception as exc: # noqa: BLE001 + print(f" (skipped t2i parity: diffusers/transformers unavailable: {exc})") + return + + os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") + torch.use_deterministic_algorithms(True, warn_only=True) + from mstar.model.cosmos3.cosmos3_model import Cosmos3Model + + snap_p = Path(snap) + dev, dtype = "cuda:0", torch.bfloat16 + pipe = Cosmos3OmniPipeline( + transformer=DTr.from_pretrained(snap_p, subfolder="transformer", torch_dtype=dtype), + text_tokenizer=AutoTokenizer.from_pretrained(str(snap_p / "text_tokenizer")), + vae=AutoencoderKLWan.from_pretrained(snap_p, subfolder="vae", torch_dtype=dtype), + scheduler=UniPCMultistepScheduler.from_pretrained(snap_p, subfolder="scheduler"), + sound_tokenizer=None, enable_safety_checker=False, + ).to(dev) + + def gen(): + return pipe(prompt="A red cube on a wooden table.", negative_prompt="", num_frames=1, + height=256, width=256, num_inference_steps=4, guidance_scale=6.0, + generator=torch.Generator(device=dev).manual_seed(0), + output_type="pt", enable_safety_check=False).video[0].float().cpu() + + img_d = gen() + mtr = Cosmos3Model(model_path_hf=snap).get_submodule("dit", device=dev).transformer + mtr.dtype = dtype + pipe.transformer = mtr + img_m = gen() + mse = (img_d - img_m).pow(2).mean().item() + psnr = float("inf") if mse == 0 else -10 * math.log10(mse) + assert psnr >= 30, f"t2i image PSNR {psnr:.2f} < 30 (MSE {mse:.3e})" + print(f" t2i parity PSNR={psnr:.2f} dB") + + +def _main() -> None: + failures = [] + tests = [ + ("patchify_unpatchify_roundtrip", test_patchify_unpatchify_roundtrip), + ("mrope_ids_text", test_mrope_ids_text), + ("mrope_ids_vae", test_mrope_ids_vae), + ("packing_t2i_structure", test_packing_t2i_structure), + ("forward_smoke_cpu", test_forward_smoke_cpu), + ("t2i_parity_vs_diffusers", test_t2i_parity_vs_diffusers), + ] + for name, fn in tests: + try: + fn() + print(f"PASS {name}") + except Exception as exc: # noqa: BLE001 + failures.append((name, exc)) + print(f"FAIL {name}: {exc!r}") + if failures: + raise SystemExit(1) + print("\nAll Cosmos3 t2i checks passed.") + + +if __name__ == "__main__": + _main() diff --git a/mstar/model/cosmos3/tests/test_video.py b/mstar/model/cosmos3/tests/test_video.py new file mode 100644 index 00000000..297d699e --- /dev/null +++ b/mstar/model/cosmos3/tests/test_video.py @@ -0,0 +1,259 @@ +"""Tests for the Cosmos3 t2v / i2v path (video packing + conditioning). + +CPU-safe unit tests (tiny config / stub tokenizer) cover the video prompt +templates, fps-modulated temporal mRoPE, the conditioned (image-to-video) vs +all-noisy (text-to-video) frame layout, and a multi-frame forward smoke test. An +optional GPU integration test (gated on ``COSMOS3_NANO_DIR`` + CUDA + diffusers) +checks the fused t2v / i2v output against the diffusers ``Cosmos3OmniPipeline``. + +Run CPU only: python3 test_video.py +Run with GPU: COSMOS3_NANO_DIR= python3 test_video.py +""" + +from __future__ import annotations + +import math +import os + +import torch + +from mstar.model.cosmos3.components.transformer import Cosmos3OmniTransformer +from mstar.model.cosmos3.config import Cosmos3Config +from mstar.model.cosmos3.packing import ( + build_static_inputs, + get_3d_mrope_ids_vae_tokens, + tokenize_prompt, +) + + +def _tiny_config() -> Cosmos3Config: + return Cosmos3Config( + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + intermediate_size=128, + vocab_size=100, + rope_axes_dim=(4, 2, 2), + latent_channel=8, + latent_patch_size=2, + patch_latent_dim=32, + sound_gen=False, + action_gen=False, + ) + + +class _StubTokenizer: + """Records the chat-template messages so the metadata templates can be asserted.""" + + eos_token_id = 99 + + def __init__(self): + self.seen: list[list[dict]] = [] + + def convert_tokens_to_ids(self, _tok): + return 98 # stand-in for <|vision_start|> + + def apply_chat_template(self, conversations, **_kw): + self.seen.append(conversations) + return {"input_ids": [1, 2, 3]} + + +def test_video_prompt_templates() -> None: + tok = _StubTokenizer() + cond, uncond = tokenize_prompt(tok, "a cat", "bad", num_frames=48, height=720, width=1280, fps=24.0) + # Special tokens appended (eos, start-of-generation). + assert cond[-2:] == [99, 98] and uncond[-2:] == [99, 98] + # System prompt is the video one; positive prompt carries duration + video resolution. + sys_msg = tok.seen[0][0] + assert sys_msg["role"] == "system" and "videos" in sys_msg["content"] + pos_user = tok.seen[0][1]["content"] + assert "2.0 seconds long" in pos_user and "24 FPS" in pos_user + assert "This video is of 720x1280 resolution." in pos_user + # Negative prompt uses the inverse templates. + neg_user = tok.seen[1][1]["content"] + assert "is not 2.0 seconds long" in neg_user and "This video is not of" in neg_user + + +def test_image_prompt_has_no_duration() -> None: + tok = _StubTokenizer() + tokenize_prompt(tok, "a cat", "", num_frames=1, height=256, width=256) + sys_msg, user_msg = tok.seen[0][0], tok.seen[0][1]["content"] + assert "images" in sys_msg["content"] + assert "seconds long" not in user_msg + assert "This image is of 256x256 resolution." in user_msg + + +def test_video_mrope_fps_modulation() -> None: + # grid_t > 1 with fps enables float, fps-scaled temporal positions; halving the + # fps relative to base doubles the temporal spacing. + ids12, _ = get_3d_mrope_ids_vae_tokens( + grid_t=3, grid_h=1, grid_w=1, temporal_offset=100, fps=12.0, base_fps=24.0, temporal_compression_factor=4 + ) + ids24, _ = get_3d_mrope_ids_vae_tokens( + grid_t=3, grid_h=1, grid_w=1, temporal_offset=100, fps=24.0, base_fps=24.0, temporal_compression_factor=4 + ) + assert ids12.dtype == torch.float32 + assert ids12[0].tolist() == [100.0, 102.0, 104.0] + assert ids24[0].tolist() == [100.0, 101.0, 102.0] + # A single frame disables fps modulation (image mode) -> integer positions. + ids1, _ = get_3d_mrope_ids_vae_tokens(grid_t=1, grid_h=2, grid_w=2, temporal_offset=5, fps=24.0) + assert ids1.dtype == torch.long and ids1[0].tolist() == [5, 5, 5, 5] + + +def test_video_packing_t2v_vs_i2v() -> None: + cfg = Cosmos3Config() # Nano defaults + input_ids = list(range(7)) + latent_shape = (1, cfg.latent_channel, 3, 16, 16) # T_lat=3, patch grid 8x8 + per_frame = 8 * 8 + + t2v = build_static_inputs(input_ids, latent_shape, cfg, 4, 24.0, "cpu", has_image_condition=False) + assert t2v["num_vision_tokens"] == 3 * per_frame + assert t2v["num_noisy_vision_tokens"] == 3 * per_frame # all frames noisy + assert t2v["vision_noisy_frame_indexes"][0].tolist() == [0, 1, 2] + assert t2v["position_ids"].dtype == torch.float32 # fps modulation -> float positions + # Vision temporal positions sit past the text + margin. + assert int(t2v["position_ids"][0, 7].item()) == 7 + cfg.unified_3d_mrope_temporal_modality_margin + + i2v = build_static_inputs(input_ids, latent_shape, cfg, 4, 24.0, "cpu", has_image_condition=True) + assert i2v["num_vision_tokens"] == 3 * per_frame # frame 0 stays in the sequence + assert i2v["num_noisy_vision_tokens"] == 2 * per_frame # frame 0 anchored, frames 1-2 noisy + assert i2v["vision_noisy_frame_indexes"][0].tolist() == [1, 2] + # mse indexes skip frame 0 (first noisy token is und_len + one frame stride). + assert int(i2v["vision_mse_loss_indexes"][0]) == 7 + per_frame + + +def test_video_forward_smoke_cpu() -> None: + cfg = _tiny_config() + torch.manual_seed(0) + model = Cosmos3OmniTransformer(cfg).eval() + latent_shape = (1, cfg.latent_channel, 3, 4, 4) # T_lat=3, patch grid 2x2 -> 12 vision tokens + fields = [ + "input_ids", "text_indexes", "position_ids", "und_len", "sequence_length", + "vision_token_shapes", "vision_sequence_indexes", "vision_mse_loss_indexes", + "vision_noisy_frame_indexes", + ] + for has_cond in (False, True): + static = build_static_inputs([1, 2, 3], latent_shape, cfg, 4, 24.0, "cpu", has_image_condition=has_cond) + with torch.no_grad(): + preds, sound = model( + vision_tokens=[torch.randn(latent_shape)], + vision_timesteps=torch.full((static["num_noisy_vision_tokens"],), 500.0), + **{k: static[k] for k in fields}, + ) + assert sound is None + assert preds[0].shape == latent_shape, preds[0].shape + assert torch.isfinite(preds[0]).all() + if has_cond: + # The conditioning frame is anchored: the model predicts no velocity for it. + assert torch.count_nonzero(preds[0][:, :, 0]) == 0 + + +# --------------------------------------------------------------------------- +# GPU parity (gated on COSMOS3_NANO_DIR + CUDA + diffusers). +# --------------------------------------------------------------------------- + +os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") +_GPU_CACHE: dict = {} +_V_FRAMES, _V_RES, _V_STEPS, _V_GS = 17, 256, 15, 6.0 + + +def _gpu_setup(): + if "ctx" in _GPU_CACHE: + return _GPU_CACHE["ctx"] + snap = os.environ.get("COSMOS3_NANO_DIR") + if not snap or not torch.cuda.is_available(): + _GPU_CACHE["ctx"] = None + return None + try: + from diffusers import AutoencoderKLWan, UniPCMultistepScheduler + from diffusers.models.transformers.transformer_cosmos3 import Cosmos3OmniTransformer as DTr + from diffusers.pipelines.cosmos.pipeline_cosmos3_omni import Cosmos3OmniPipeline + from transformers import AutoTokenizer + except Exception as exc: # noqa: BLE001 + print(f" (skipped video parity: diffusers/transformers unavailable: {exc})") + _GPU_CACHE["ctx"] = None + return None + torch.use_deterministic_algorithms(True, warn_only=True) + from mstar.model.cosmos3.cosmos3_model import Cosmos3Model + from mstar.model.cosmos3.pipeline import Cosmos3Pipeline + + dev, dtype = "cuda:0", torch.bfloat16 + dpipe = Cosmos3OmniPipeline( + transformer=DTr.from_pretrained(snap, subfolder="transformer", torch_dtype=dtype), + text_tokenizer=AutoTokenizer.from_pretrained(os.path.join(snap, "text_tokenizer")), + vae=AutoencoderKLWan.from_pretrained(snap, subfolder="vae", torch_dtype=dtype), + scheduler=UniPCMultistepScheduler.from_pretrained(snap, subfolder="scheduler"), + sound_tokenizer=None, enable_safety_checker=False, + ).to(dev) + mpipe = Cosmos3Pipeline.from_model(Cosmos3Model(model_path_hf=snap), device=dev, dtype=dtype) + _GPU_CACHE["ctx"] = dict(dpipe=dpipe, mpipe=mpipe, snap=snap, device=dev, dtype=dtype) + return _GPU_CACHE["ctx"] + + +def _video_parity(mode: str) -> None: + ctx = _gpu_setup() + if ctx is None: + print(f" (skipped {mode} parity: needs COSMOS3_NANO_DIR + CUDA)") + return + import json + + from PIL import Image + + dpipe, mpipe, snap = ctx["dpipe"], ctx["mpipe"], ctx["snap"] + dev, dtype = ctx["device"], ctx["dtype"] + is_i2v = mode == "i2v" + asset = os.path.join(snap, "assets", "example_i2v_prompt.json" if is_i2v else "example_t2v_prompt.json") + with open(asset) as f: + prompt = json.load(f)["temporal_caption"] + image = ( + Image.open(os.path.join(snap, "assets", "example_i2v_input.jpg")).convert("RGB") if is_i2v else None + ) + gen = torch.Generator(device=dev).manual_seed(0) + init, _ = mpipe._prepare_latents(image, _V_FRAMES, _V_RES, _V_RES, gen, None, dev, dtype) + common = dict(prompt=prompt, negative_prompt="", num_frames=_V_FRAMES, height=_V_RES, width=_V_RES, + num_inference_steps=_V_STEPS, guidance_scale=_V_GS, fps=24.0) + lat_d = dpipe(image=image, latents=init.clone(), output_type="latent", enable_safety_check=False, **common)[0] + lat_m = mpipe(image=image, latents=init.clone(), decode=False, **common) + img_d = mpipe._decode(lat_d.reshape(lat_m.shape).to(dtype)).squeeze().float().cpu() + img_m = mpipe._decode(lat_m).squeeze().float().cpu() + mse = (img_d - img_m).pow(2).mean().item() + psnr = float("inf") if mse == 0 else -10 * math.log10(mse) + assert psnr >= 30, f"{mode} video PSNR {psnr:.2f} < 30 (MSE {mse:.3e})" + print(f" {mode} parity PSNR={psnr:.2f} dB") + + +def test_t2v_parity_vs_diffusers() -> None: + _video_parity("t2v") + + +def test_i2v_parity_vs_diffusers() -> None: + _video_parity("i2v") + + +def _main() -> None: + failures = [] + tests = [ + ("video_prompt_templates", test_video_prompt_templates), + ("image_prompt_has_no_duration", test_image_prompt_has_no_duration), + ("video_mrope_fps_modulation", test_video_mrope_fps_modulation), + ("video_packing_t2v_vs_i2v", test_video_packing_t2v_vs_i2v), + ("video_forward_smoke_cpu", test_video_forward_smoke_cpu), + ("t2v_parity_vs_diffusers", test_t2v_parity_vs_diffusers), + ("i2v_parity_vs_diffusers", test_i2v_parity_vs_diffusers), + ] + for name, fn in tests: + try: + fn() + print(f"PASS {name}") + except Exception as exc: # noqa: BLE001 + failures.append((name, exc)) + print(f"FAIL {name}: {exc!r}") + if failures: + raise SystemExit(1) + print("\nAll Cosmos3 video checks passed.") + + +if __name__ == "__main__": + _main() diff --git a/mstar/model/registry.py b/mstar/model/registry.py index fab97010..b95813f8 100644 --- a/mstar/model/registry.py +++ b/mstar/model/registry.py @@ -1,5 +1,6 @@ from mstar.model.bagel.bagel_model import BagelModel from mstar.model.base import Model +from mstar.model.cosmos3.cosmos3_model import Cosmos3Model from mstar.model.orpheus.orpheus_model import OrpheusModel from mstar.model.pi05.pi05_model import Pi05Model from mstar.model.qwen3_omni.qwen3_omni_model import Qwen3OmniModel @@ -7,6 +8,8 @@ MODEL_REGISTRY: dict[str, type[Model]] = { "bagel": BagelModel, + "cosmos3": Cosmos3Model, + "cosmos3_super": Cosmos3Model, "orpheus": OrpheusModel, "pi05": Pi05Model, "qwen3_omni": Qwen3OmniModel, @@ -16,6 +19,12 @@ HF_MODELS: dict[str, dict] = { "bagel": {"model_path_hf": "ByteDance-Seed/BAGEL-7B-MoT"}, + # NVIDIA Cosmos3-Nano generator (diffusers transformer/ + Wan VAE + UniPC). + "cosmos3": {"model_path_hf": "nvidia/Cosmos3-Nano"}, + # Cosmos3-Super (64B) — same architecture + class; dims (64 layers / 5120 + # hidden / 25600 intermediate) load from the checkpoint's config.json, so it + # needs tensor parallelism (it does not fit on one GPU). + "cosmos3_super": {"model_path_hf": "nvidia/Cosmos3-Super"}, "orpheus": {"model_path_hf": "canopylabs/orpheus-3b-0.1-ft"}, # Pi0.5 PyTorch port published by lerobot — single safetensors blob # (~14 GB). mstar/model/pi05/weight_loader.py handles the lerobot->mstar diff --git a/mstar/model/submodule_base.py b/mstar/model/submodule_base.py index c781a578..dcf9ff01 100644 --- a/mstar/model/submodule_base.py +++ b/mstar/model/submodule_base.py @@ -159,6 +159,12 @@ class NodeSubmodule(torch.nn.Module): """Base class for a model's compute units: defines the prepare_inputs → preprocess → forward(_batched) contract the engines drive.""" + # Set True on a submodule whose forward does not benefit from (or is broken + # by) torch.compile — e.g. a data-dependent denoise loop, or a one-shot + # forward where the trace cost dwarfs the win. The KV-cache / stateless + # engines skip compiling such submodules (CUDA-graph capture is unaffected). + disable_torch_compile: bool = False + def get_device(self): return next(self.parameters()).device