diff --git a/benchmark/base.py b/benchmark/base.py index fc176fbd..d1c99df4 100644 --- a/benchmark/base.py +++ b/benchmark/base.py @@ -232,6 +232,78 @@ def get_supported_modalities(self): } +class MingFlashOmni(Model): + """Ming-flash-omni-2.0 (inclusionAI), the Ling-2.0 sparse-MoE omni model + (100B total / 6B active params) released 2026-02-11. + + Reachable today via the vllm-omni server using + ``vllm_omni/deploy/ming_flash_omni.yaml`` (thinker+talker) or + ``ming_flash_omni_thinker_only.yaml`` (text-only). The native ``ours`` / + ``ours_openai`` backends will work once the mstar-side port under + ``mstar/model/ming_omni_flash/`` is finished — until then, point the + benchmark at a vllm-omni instance with ``--inference-system vllm_omni``. + + Wire shape mirrors :class:`Qwen3Omni`: standard OpenAI + ``/v1/chat/completions`` with multimodal content parts. The role remap + from OpenAI's ``user``/``assistant``/``system`` to Ming's internal + ``HUMAN``/``ASSISTANT``/``SYSTEM`` happens inside the jinja chat_template + shipped in ``tokenizer_config.json`` — vllm-omni renders prompts via + ``tokenizer.apply_chat_template`` which uses that jinja, so the benchmark + sends the standard OpenAI shape unchanged. + + Caveat: Ming ALSO ships a Python-side ``BailingMM2Processor.apply_chat_template`` + (in the Ming source repo) that is strict about uppercase roles and would + AssertionError on ``user``/``assistant``. mstar's native port uses that + processor for full multimodal preprocessing (vision/audio feature + extraction) and remaps roles in ``process_prompt`` accordingly — see + ``mstar/model/ming_omni_flash/`` and its tokenizer tests. + """ + + def get_hf_url(self): + return "inclusionAI/Ming-flash-omni-2.0" + + def get_openai_system_message(self) -> Optional[dict]: + # Ming-flash-omni-2.0's cookbook uses ``sys_prompt_exp=None`` and + # ``use_cot_system_prompt=False`` by default — there's no required + # "You are Ming…"-style preamble equivalent to Qwen3-Omni's. The HF + # processor's chat_template fills in any internal system text on its + # own, and vllm-omni's serving layer goes through that template via + # ``trust_remote_code``. Sending an explicit system message here only + # risks overriding the model's own defaults, so default to None. + return None + + def get_model_kwargs(self, request_type: RequestType): + # Cap thinker output at 256 tokens for cross-system fairness — same + # rationale as Qwen3Omni: comparable runs need a fixed decode budget. + # vllm-omni's released stage default is ``max_tokens: 2048`` (see + # ``vllm_omni/deploy/ming_flash_omni.yaml`` stage 0); we lower it for + # benchmark parity. Send both ``max_tokens`` (OpenAI convention) and + # ``max_output_tokens`` (mstar's native kwarg) so the cap survives + # whichever ``--inference-system`` is in use. + # + # Force greedy on the thinker (``temperature=0.0`` at payload top-level + # in VLLMOmni.send_request) for deterministic text. The talker's + # sampling defaults live server-side in the deploy yaml + # (``stage_id: 1`` → ``temperature: 0.0`` per the released config) — + # we don't override them here. + return { + "max_tokens": 256, + "max_output_tokens": 256, + } + + def get_supported_modalities(self): + return { + RequestType.T2T, + RequestType.T2S, + RequestType.I2T, + RequestType.I2S, + RequestType.A2T, + RequestType.A2S, + RequestType.V2T, + RequestType.V2S, + } + + class Pi05(Model): """Physical Intelligence Pi0.5 VLA model. @@ -286,6 +358,7 @@ class ModelType(Enum): BAGEL = "bagel" ORPHEUS = "orpheus" QWEN3OMNI = "qwen3omni" + MING_FLASH_OMNI = "ming_flash_omni" PI05 = "pi05" VJEPA2AC = "vjepa2ac" @@ -296,6 +369,8 @@ def inst(self, **kwargs) -> Model: return Orpheus(**kwargs) if self == ModelType.QWEN3OMNI: return Qwen3Omni(**kwargs) + if self == ModelType.MING_FLASH_OMNI: + return MingFlashOmni(**kwargs) if self == ModelType.PI05: return Pi05(**kwargs) if self == ModelType.VJEPA2AC: diff --git a/benchmark/vllm_omni_instructions.md b/benchmark/vllm_omni_instructions.md index 2934c6c9..3e534544 100644 --- a/benchmark/vllm_omni_instructions.md +++ b/benchmark/vllm_omni_instructions.md @@ -21,4 +21,93 @@ CUDA_VISIBLE_DEVICES=3 vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8000 ### for qwen3-omni: ``` vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml -``` \ No newline at end of file +``` + +### for ming-flash-omni-2.0: + +The released `inclusionAI/Ming-flash-omni-2.0` ckpt (~238 GB / 42 shards) +does NOT load cleanly into vllm-omni's `MingFlashOmniForConditionalGeneration` +class as-is. Two patches are needed (one-time setup): + +1. **Replace metadata files.** vllm-omni's model class uses + `Qwen2VLImageProcessor` + `MingWhisperFeatureExtractor` (its own + registered classes), while the inclusionAI snapshot declares the + `BailingMM2*` processor variants via `auto_map` and `trust_remote_code`. + Use `Jonathan1909/Ming-flash-omni-2.0`'s `preprocessor_config.json`, + `config.json` (auto_map stripped), and `tokenizer*.json` instead. + +2. **Replace the talker weights.** vllm-omni's `MingFlashOmniTalker` expects + weights under `audio_vae.*` but the inclusionAI talker safetensors uses + `audio.*` prefix. Jonathan1909 reshipped the talker with renamed weights + (~1.5 GB). + +Building a hybrid snapshot avoids re-downloading the 200+ GB thinker weights: + +```bash +# 1. Make sure the inclusionAI thinker shards are cached +huggingface-cli download inclusionAI/Ming-flash-omni-2.0 \ + --include="model-*.safetensors" --include="model.safetensors.index.json" + +# 2. Pull only Jonathan1909's metadata + talker (no thinker weights) +huggingface-cli download Jonathan1909/Ming-flash-omni-2.0 \ + --include="*.json" --include="*.py" --include="*.txt" --include="*.mvn" \ + --include="talker/**" \ + --cache-dir /dev/shm/hf-cache # or any path with ~3 GB free + +# 3. Stitch the two together +INCL=$(huggingface-cli scan-cache | grep inclusionAI/Ming-flash-omni-2.0 \ + | awk '{print $NF}')/snapshots/$(ls ~/.cache/huggingface/hub/models--inclusionAI--Ming-flash-omni-2.0/snapshots | head -1) +JONA=/dev/shm/hf-cache/models--Jonathan1909--Ming-flash-omni-2.0/snapshots/* +HYBRID=/dev/shm/ming-hybrid +mkdir -p $HYBRID +for f in $INCL/model-*.safetensors; do ln -s "$f" "$HYBRID/$(basename $f)"; done +for f in $JONA/*; do + base=$(basename "$f") + [ -L "$HYBRID/$base" ] && rm "$HYBRID/$base" + ln -s "$f" "$HYBRID/$base" +done +``` + +Then serve and benchmark: + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve /dev/shm/ming-hybrid \ + --omni --port 8091 --host 0.0.0.0 --trust-remote-code \ + --stage-configs-path /tmp/vllm-omni/vllm_omni/model_executor/stage_configs/ming_flash_omni.yaml + +# Wait for "Application startup complete" then: +MODEL=ming_flash_omni INF_SYS=vllm_omni TASK=text_to_text \ + URL=http://0.0.0.0:8091 ./benchmark/run_benchmark.sh +``` + +NOTE: vllm-omni's `/v1/chat/completions` rejects unknown model ids, so the +client must send `"model": "/dev/shm/ming-hybrid"` (the served path), not +`"inclusionAI/Ming-flash-omni-2.0"`. Easiest is to monkey-patch +`MingFlashOmni.get_hf_url` before calling the benchmark runner: + +```python +from benchmark.base import MingFlashOmni +MingFlashOmni.get_hf_url = lambda self: "/dev/shm/ming-hybrid" +``` + +Or pass `--served-model-name inclusionAI/Ming-flash-omni-2.0` to `vllm serve` +(untested; would also work in principle). + +#### Modalities exercised on a local 4×H100 run (2026-06-06) + +| Task | Status | Notes | +|---|---|---| +| T2T (text → text) | ✅ | offline B=1: 110 tok/s, closed-loop C=32: **1060 tok/s** (full scaling sweep in [`results/ming_t2t_sweep/SUMMARY.md`](../results/ming_t2t_sweep/SUMMARY.md)) | +| I2T (image → text) | ✅ | TTFT 87 ms, ~100 tok/s on Food101 | +| A2T (audio → text) | ✅ | English transcription + Chinese audio QA both work | +| T2S (text → speech) | ✅ | RTF 0.14, 24 kHz mono PCM via harness; 44.1 kHz via direct OpenAI path | +| V2T (video → text) | ✅ | Local Ming demo mp4s; coherent descriptions (`yoga.mp4` → yoga pose narration, `cup_change.mp4` → "shell game") | +| V2S (video → speech) | ✅ | Local Ming demo mp4s; 2-3 MB WAV/clip @ 44.1 kHz | +| I2S (image → speech) | ✅ | Food101 in, ~7 s/req for ~48 s of audio | +| A2S (audio → speech) | ✅ | Ming sample wavs; 0.5-3 MB WAV/clip @ 44.1 kHz | +| T2I / I2I (image gen) | not wired | requires `ming_flash_omni_image.yaml` + a benchmark wrapper similar to BAGEL's `/v1/images/generations` path | + +The V2T/V2S/A2S runs sidestep the bench harness's `UCF101Dataset` and +`LibriSpeechDataset` (both want fresh HF-Hub downloads) by hitting +`/v1/chat/completions` directly with base64-inlined media from local files +(Ming repo's `figures/cases/*.mp4` and `data/wavs/*.wav`). \ No newline at end of file diff --git a/configs/ming_flash_omni.yaml b/configs/ming_flash_omni.yaml new file mode 100644 index 00000000..df88e507 --- /dev/null +++ b/configs/ming_flash_omni.yaml @@ -0,0 +1,36 @@ +# Ming-flash-omni-2.0 — full omni deploy (text/image/audio/video in, text + speech out). +# +# Node→rank mapping for the native mstar port +# (mstar/model/ming_omni_flash/). The model registers these nodes +# (see MingFlashOmniModel.get_node_engine_types): +# +# * Thinker (KV_CACHE, TP) — Ling-2.0 sparse MoE LLM, the +# multimodal understanding core. +# * vision_encoder (STATELESS) — Qwen3-MoE ViT + projector. +# * audio_encoder (STATELESS) — Whisper encoder + projector. +# * Talker (STATELESS) — CFM talker; the AudioVAE is wrapped +# INSIDE the Talker submodule (it is +# NOT a separate graph node). +# +# Thinker runs TP=8 across all 8 H100s here to leave room for the +# colocated talker + encoders. TP=4 also fits the thinker (~57-62 GB/rank, +# verified 2026-06-12); an earlier fp32-allocation bug made it OOM at +# ~78.5 GB and is now fixed (see get_submodule). The stateless encoders + +# the talker are small (~1.5 GB each) and colocate on rank 0. +# +# The Thinker→Talker bridge passes DETOKENIZED TEXT (re-tokenized with +# the talker's own talker/llm tokenizer), so the talker is a near- +# standalone TTS partition fed by a streaming connection — see +# MingFlashOmniModel.get_partition_topology. + +model: "ming_flash_omni" +max_seq_len: 32768 +node_groups: + # Stateless encoders + the talker colocate on rank 0. + - node_names: [vision_encoder, audio_encoder, Talker] + ranks: [0] + + # Thinker sharded across all 8 GPUs. + - node_names: [Thinker] + ranks: [0, 1, 2, 3, 4, 5, 6, 7] + tp_size: 8 diff --git a/configs/ming_flash_omni_thinker_only.yaml b/configs/ming_flash_omni_thinker_only.yaml new file mode 100644 index 00000000..539201f5 --- /dev/null +++ b/configs/ming_flash_omni_thinker_only.yaml @@ -0,0 +1,21 @@ +# Ming-flash-omni-2.0 — thinker-only deploy (text out, no talker). +# +# TP=8 across 8 H100s. Per-rank shard_inter = 1024/8 = 128; +# experts.gate_up_proj is (256, 2*128, 4096) per rank, ~33 GB across +# 31 MoE layers. With embed + lm_head + attention + dense layer 0 + +# KV cache, ~40 GB per rank fits the 80 GB H100s comfortably. +# +# TP=4 also fits (~57-62 GB/rank, verified 2026-06-12) — see +# configs/ming_flash_omni_thinker_only_tp4.yaml. An earlier fp32-allocation +# bug made TP=4 OOM at ~78.5 GB; fixed in get_submodule (cast meta model to +# bf16 before to_empty). TP=8 still leaves the most headroom. +# +# Audio / vision / talker / image-gen are step 4+; this config is for +# text-only T2T benchmarking and the first mstar-served Ming forward. + +model: "ming_flash_omni" +max_seq_len: 32768 +node_groups: + - node_names: [Thinker] + ranks: [0, 1, 2, 3, 4, 5, 6, 7] + tp_size: 8 diff --git a/configs/ming_flash_omni_thinker_only_tp4.yaml b/configs/ming_flash_omni_thinker_only_tp4.yaml new file mode 100644 index 00000000..fc90e6bd --- /dev/null +++ b/configs/ming_flash_omni_thinker_only_tp4.yaml @@ -0,0 +1,20 @@ +# Ming-flash-omni-2.0 — thinker-only deploy, TP=4 (4-GPU layout). +# +# Pinned to 4 ranks. Launch with CUDA_VISIBLE_DEVICES=4,5,6,7 so physical +# GPUs 4-7 map to logical ranks 0-3. +# +# TP=4 fits in ~57-62 GB per rank (verified 2026-06-12). Earlier notes +# claimed TP=4 OOM'd at ~78.5/80 GB; that was a load-time bug — params were +# allocated in fp32 before the bf16 cast, doubling the allocation peak. Fixed +# in MingFlashOmniModel.get_submodule (cast the meta model to bf16 BEFORE +# to_empty, so allocation happens directly in bf16). +# +# TP=4 IS dimensionally valid: 32 heads/4=8, 4 KV heads/4=1, 256 experts/4=64, +# hidden 4096/4=1024, moe_inter 1024/4=256 — all divide. + +model: "ming_flash_omni" +max_seq_len: 32768 +node_groups: + - node_names: [Thinker] + ranks: [0, 1, 2, 3] + tp_size: 4 diff --git a/mstar/model/base.py b/mstar/model/base.py index 54a7e90d..fa58d1d3 100644 --- a/mstar/model/base.py +++ b/mstar/model/base.py @@ -253,14 +253,24 @@ def get_worker_graphs(self, config_path: str) -> list[WorkerGraph]: if node_groups is None: raise KeyError("Config must define `node_groups`.") + # Nodes this deploy actually provides. A graph walk referencing a + # node absent from node_groups (e.g. the encoder / talker walks in + # a thinker-only deploy) is skipped rather than KeyError'ing during + # worker-graph division — that deploy simply can't serve the walk. + available_nodes: set[str] = set() + for group in node_groups: + available_nodes.update(group["node_names"]) + # TODO: merge identical worker graphs from different graph walks - return sum( - [ + worker_graphs: list[WorkerGraph] = [] + for graph_walk, graph in self.get_graph_walk_graphs().items(): + required = set(graph.get_nodes().keys()) + if not required <= available_nodes: + continue + worker_graphs.extend( self._get_worker_graphs_for_graph_walk(graph_walk, graph, node_groups) - for graph_walk, graph in self.get_graph_walk_graphs().items() - ], - start=[], - ) + ) + return worker_graphs def get_sharding_config(self, config_path: str) -> ShardingConfig: with open(config_path, "r") as f: diff --git a/mstar/model/ming_omni_flash/PORTING_NOTES.md b/mstar/model/ming_omni_flash/PORTING_NOTES.md new file mode 100644 index 00000000..c6631733 --- /dev/null +++ b/mstar/model/ming_omni_flash/PORTING_NOTES.md @@ -0,0 +1,1056 @@ +# Ming-flash-omni-2.0 — porting notes + +Native mstar port of `inclusionAI/Ming-flash-omni-2.0`. This directory is a +scaffold today; everything below is the punch list to make it real. + +## Status + +- `benchmark/base.py` has `MingFlashOmni` + `ModelType.MING_FLASH_OMNI`. + Benchmarking against a vllm-omni server **works today** with + `--inference-system vllm_omni` (see `benchmark/vllm_omni_instructions.md`). +- Step 1 (config port) — DONE. `mstar/model/ming_omni_flash/config.py` + loads the released ckpt; 10 tests in `test/modular/test_ming_flash_omni_config.py`. +- Step 2 (tokenizer + processor wiring) — DONE. + `MingFlashOmniModel.__init__` resolves the snapshot, stages Ming source + files (see "Ming source dependency" below), and loads + `BailingTokenizer` + `BailingMM2Processor` with graceful fallback; + 11 tests in `test/modular/test_ming_flash_omni_tokenizer.py`. +- Everything else in `MingFlashOmniModel` still raises `NotImplementedError` + — `mstar-serve --config configs/ming_flash_omni.yaml` will fail at + startup until step 3+ lands. + +## Ming source dependency (loading the tokenizer/processor) + +The released HF checkpoint `inclusionAI/Ming-flash-omni-2.0` ships +**only weights and sub-dir configs**. The tokenizer/processor Python +modules (`configuration_bailingmm2.py`, `tokenization_bailing.py`, +`processing_bailingmm2.py`, etc.) live in the source repo at +https://github.com/inclusionAI/Ming . To load the tokenizer/processor: + +```bash +# 1. Clone the source repo +git clone https://github.com/inclusionAI/Ming.git /path/to/Ming + +# 2. Install extra Python deps Ming's modules depend on +pip install opencv-python-headless openai-whisper + +# 3. Tell mstar where to find the source repo +export MING_CODE_DIR=/path/to/Ming +# (or pass ming_code_dir="/path/to/Ming" to MingFlashOmniModel) +``` + +`MingFlashOmniModel.__init__` (via `_prepare_tokenizer_dir`) symlinks +the required .py and .json files from `$MING_CODE_DIR` alongside the +snapshot's `config.json` so transformers' `trust_remote_code` machinery +can resolve them. The snapshot dir is also pushed onto `sys.path` so +the dynamic-module loader's sibling imports resolve. + +## Role-handling nuance (chat templates) + +Ming-flash-omni-2.0 ships **two** chat-template implementations with +**different role conventions**: + +- `tokenizer.apply_chat_template(messages)` — uses the **jinja template + in `tokenizer_config.json`**. Accepts standard OpenAI roles + (`user` / `assistant` / `system`) and remaps them to Ming's uppercase + `HUMAN` / `ASSISTANT` / `SYSTEM` inside the template. This is the path + vllm-omni's serving layer uses → the benchmark side works unchanged. + +- `processor.apply_chat_template(messages, sys_prompt_exp=..., use_cot_system_prompt=...)` + — uses the **Python implementation in `BailingMM2Processor`** (Ming + source repo). **Strict**: asserts `role in [HUMAN, ASSISTANT]` and + raises `AssertionError` on lowercase OpenAI roles. The native mstar + `process_prompt` (step 7) will need this path for the multimodal + preprocessing (vision feature extraction, audio padding, etc.) and + must explicitly remap roles before calling. + +## Upstream reference + +Treat the vllm-omni port as the source of truth for architecture. Files to +read (totals ~6.5 KLOC): + +| Concern | vllm-omni file | +|---|---| +| Pipeline glue | `vllm_omni/model_executor/models/ming_flash_omni/pipeline.py` (141 LOC) | +| Top-level model | `ming_flash_omni.py` (255 LOC) | +| Thinker (Ling-2.0 MoE + multimodal) | `ming_flash_omni_thinker.py` (1,164 LOC) | +| Talker (CFM + LLM) | `ming_flash_omni_talker.py` (586) + `talker_module.py` (1,145) | +| Audio VAE | `audio_vae.py` (392) | +| Audio encoder | `audio_encoder.py` (246) | +| Vision encoder | `vision_encoder.py` (125) + `projectors.py` (184) | +| Ling MoE backbone | `modeling_bailing_moe_v2.py` (892) | +| Prompt utils | `prompt_utils.py` (134) — `IMAGE_PATCH_TOKEN`, `DEFAULT_NUM_QUERY_TOKENS=256`, TTS caption template | +| Text processing | `text_processing.py` (535) | +| Speaker presets | `spk_embedding.py` (44) + `voice_presets.py` (289) | +| Config | `vllm_omni/transformers_utils/configs/ming_flash_omni.py` (420) | +| Stage input processor | `vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py` | +| ImageGen pipeline | `vllm_omni/diffusion/models/ming_flash_omni/` | +| Deploy yamls | `vllm_omni/deploy/ming_flash_omni{,_image,_thinker_only,_tts}.yaml` | + +## mstar parallels + +Mirror the structure of `mstar/model/qwen3_omni/` end-to-end. That model is +the closest analog (multimodal thinker + speech talker + vocoder), and the +graph-walk / partition / streaming patterns transfer 1:1. + +| mstar surface | Qwen3-Omni reference | Ming-flash-omni equivalent | +|---|---|---| +| Model class | `qwen3_omni_model.py` (1,529) | `ming_omni_flash_model.py` | +| Submodules | `submodules.py` (2,016) | `submodules.py` (TODO) | +| Config | `config.py` (544) | `config.py` | +| Talker | `components/talker.py` (549) + `code2wav.py` (534) | `components/talker.py` + `audio_vae.py` (TODO) | +| Thinker | `components/thinker.py` (259) | `components/thinker.py` (TODO) | +| Attention / RoPE | `components/attention.py` + `rope.py` | likely shareable; check Ling-2.0 attention shape | + +## Punch list (in order) + +1. **Config port — DONE.** `mstar/model/ming_omni_flash/config.py` + loads `config.json` + sibling subdir configs (talker / image-gen) into + a dataclass tree. Verified via 10 tests in + `test/modular/test_ming_flash_omni_config.py`. + +2. **Tokenizer + processor — DONE.** `MingFlashOmniModel.__init__` + resolves the snapshot, stages Ming source files alongside it (see + "Ming source dependency" above), and loads `BailingTokenizer` + + `BailingMM2Processor` with graceful fallback. The chat-template role + handling has two paths (see "Role-handling nuance" above); the native + `process_prompt` (step 7) will use the strict processor path and must + remap roles. Verified via 11 tests in + `test/modular/test_ming_flash_omni_tokenizer.py`. + +3. **Ling-2.0 thinker LLM port — IN PROGRESS.** + - **3a — DONE** (`components/router.py`, `rope.py`, `attention.py`): + architecture-novel pieces (MultiRouter group-limited top-k, partial + 3D `video_rope`, QK-norm attention). 12 tests in + `test/modular/test_ming_flash_omni_components.py`. + - **3b — DONE** (`components/moe.py`, `decoder_layer.py`, `model.py`): + `LingMoeBlock` (3-router text/image/audio with `torch.where` + per-token swap), `LingDecoderLayer` (hybrid dense/MoE per + `first_k_dense_replace`), full `LingMoeModel` (embed + N layers + + RMSNorm + lm_head). 9 tests in `test_ming_flash_omni_model.py`. + - **3c — DONE** (`loader.py`): weight loader that maps the released + ckpt's `model.model.*` namespace to `LingMoeModel`'s state_dict, + with per-expert gate/up/down fusion into the packed + `experts.gate_up_proj` tensor via mstar's existing + `WeightConverter` machinery. Real-ckpt smoke test loads embed + + dense layer 0 + lm_head from the released shards and runs a + forward — output is finite bf16 logits at the expected + `(T, vocab_size)` shape. 6 tests in + `test_ming_flash_omni_loader.py` (4 pure-Python + 2 CUDA+snapshot). + - **3e — DONE** (TP-aware variants): `LingAttention` uses + `QKVParallelLinear` + `RowParallelLinear` (per-rank heads + dense + row-parallel); `LingMoeBlock` shards fused experts by + `shard_inter = moe_intermediate_size / tp_size` and uses mstar's + existing `_gate_up_weight_loader` / `_down_proj_weight_loader` + for per-rank weight slicing; dense layer-0 MLP uses + `ParallelGatedMLP`; `LingMoeModel` threads `comm_group` through + every decoder layer. Weight loader refactored onto mstar's + `load_hf_weights` + 770 `StackedParamRule`s (3 per expert × + num_experts + dense MLP + synthetic QKV). The packed + `attention.query_key_value.weight` from the checkpoint is split + into synthetic `q_proj` / `k_proj` / `v_proj` keys by + `_split_packed_qkv` so `QKVParallelLinear`'s standard weight + loader handles per-rank head slicing. + + **Verified via TP=8 mstar-serve smoke** (8 H100s): server starts, + all 8 workers load 507 thinker params each (one per packed + parameter; per-rank ~40 GB), KVCacheEngine warmup_and_capture + completes, torch.compile applies, dedicated GPU threads spin up, + port 8092 listens. Per-rank model + KV cache is well under 80 GB. + TP=4 was tried first and OOMed at 78.58 GB / 80 GB; TP=8 has + plenty of headroom. + + **Known gap (resolved in 3f)**: see step 3f. + + - **3d — DONE** (cache wiring + submodule + engine integration): + `LingAttention` now uses `cache_handle.run_attention` for paged + KV-cache attention (keeps the custom partial-3D rope inline); + `BailingMoeV2ThinkerSubmodule` in `submodules.py` implements + `prepare_inputs` / `preprocess` / `forward` / `check_stop` for + the prefill + decode walks; `MingFlashOmniModel.__init__` no + longer raises NotImplementedError and all Model ABC methods + (`get_kv_cache_config`, `get_graph_walk_graphs`, `get_partitions`, + `process_prompt`, `postprocess`, `get_submodule`, etc.) are + implemented for the text-only path. 12 tests in + `test_ming_flash_omni_model.py` + the existing 30+ Ming tests + still pass. + + **Verified via `mstar-serve` smoke**: the engine instantiates the + model class, calls `get_submodule("Thinker")`, and reaches + `load_thinker_weights` — failing with OOM on a single GPU + (loaded ~75 GB before exhausting the 80 GB H100). The engine + plumbing itself works; **single-GPU OOM is the expected blocker + until step 3e brings TP-aware variants**. To actually serve the + full 100B model we need TP=4 distributing the experts + attention + across 4 H100s. + + - **3f — DONE** (graph wiring for the text-only generate loop): + two model-side bugs blocked the first end-to-end `/generate` + response on top of step 3e. + + (a) `BailingMoeV2ThinkerSubmodule` had no `postprocess` hook. + The decode loop's output edge is named `text_inputs` so the + loop feeds the previous sampled token back into the next + iteration. `submodule.forward` returns `{"logits": [...]}`; + the KV-cache engine samples into `{"new_token": [...]}`; but + the graph router needs a `text_inputs` key under that name. + Added `postprocess` that rebinds `new_token → text_inputs`, + mirroring :meth:`OrpheusLLMSubmodule.postprocess`. Without + this, every decode iteration hit `IndexError` at + `prepare_inputs` (`text_inputs` list arrived empty), which + is the same symptom the 3e notes called out. + + (b) The prefill / decode output edges used `EMPTY_DESTINATION` + + `conductor_new_token=True` rather than + `EMIT_TO_CLIENT` + `output_modality="text"`. With (a) fixed + the loop produced tokens, but the API server received + `{"outputs": {}}` because no edge routed `new_token` to the + client. Switched to Qwen3-Omni's pattern: prefill emits its + first token to the client and the decode-loop section emits + each subsequent sampled token via a parallel + `EMIT_TO_CLIENT, name="new_token", output_modality="text"` + edge alongside the `text_inputs` loopback. + + **Environment / dependency patches collected along the way** + (not Ming code, but required on this box to reach a working + forward): + + * `BailingTokenizer` doesn't load under transformers >= 5.0: + (i) accessor properties reference `self.verbose`, removed + in 5.x — set a class-level `verbose = False`; (ii) + `__init__` sets `self.add_bos_token` before + `super().__init__()` and the 5.x setter calls + `update_post_processor()` which dereferences the not-yet- + built `self._tokenizer`. Both patches live in + `_patch_bailing_tokenizer_for_transformers5` in + `ming_omni_flash_model.py`, applied once after the first + `AutoTokenizer.from_pretrained` raises an `AttributeError` + matching either signature. + + * `LingMoeBlock._dispatch_tp` always called + `mstar.utils.fused_moe.fused_experts`, which hard-requires + `sgl_kernel`. On boxes where the installed `sgl_kernel.so` + has an ABI mismatch against the running torch (the + importlib-level error doesn't propagate as a normal + `ImportError` until you actually call into the .so), this + crashes mid-forward. Added a naive fallback that calls + `dispatch_experts_fused` on each rank's expert shard then + all-reduces; math is equivalent because sum-over-TP and + sum-over-top-k commute. + + * `flashinfer-python` 0.6.6 ships a Python wrapper that + passes 10 args to the bundled `top_p_sampling_from_probs` + op while `flashinfer-jit-cache` 0.6.2 expects 8. Pin + `flashinfer-python==0.6.2` (via `pip install --no-deps`) + to match the jit-cache; the alternative would be rebuilding + the cache against 0.6.6. + + **Verified via `mstar-serve` smoke (TP=8 on 8 H100s)**: + /generate returns real model text.
+ + Note: expert layout doesn't share with Qwen3-Omni's MoE block — + `MultiRouter` (3 gates + modality masks) is Ling-specific, and + the per-expert fused weight tensor has its own shape constraints. + +4. **Vision + audio encoders.** Stateless graph nodes. Port + `vision_encoder.py` + `projectors.py` and `audio_encoder.py`. Wire into + the prefill graph walks. + + - **4a — DONE** (`components/projectors.py`, + `components/vision_encoder.py`, `components/audio_encoder.py`): + pure-port encoder + projector modules with weight-key parity + against the released ckpt's top-level prefixes + (`vision.*`, `audio.*`, `linear_proj.*`, `linear_proj_audio.*`). + + * `MingVisionProjector` / `MingAudioProjector` mirror the + `nn.Sequential` chains built inline in + `modeling_bailingmm2.py` (Linear→GELU→Linear for vision, + Conv1d→Transpose→GELU→Linear→Transpose for audio). Layer + indices match the on-disk keys (`linear_proj.{0,2}` vision, + `linear_proj_audio.{0,3}` audio). + + * `build_vision_encoder` constructs Ming's + `Qwen3MoeVisionTransformer` via dynamic import from the staged + Ming source dir (same path used by the tokenizer + processor). + Reused as-is rather than forked — no vLLM dep, ~1 GB at bf16, + runs on a single GPU. + + * `MingAudioEncoder` is a self-contained port of vllm-omni's + packed-sequence Whisper encoder (~250 LOC) — no + `openai-whisper` runtime dep, optional flash-attn varlen fast + path with a manual fallback. Param names match upstream + Whisper (`query` / `key` / `value` / `out`, + `mlp.{0,2}.{weight,bias}`) so the released ckpt's + `audio.blocks.N.*` keys load by state-dict equality. + + * 17 tests in `test/modular/test_ming_flash_omni_encoders.py`: + 12 pure-Python (projector shape / layer indices / forward / + audio encoder weight-key parity / packed-attention fallback + shape) + 1 snapshot-gated (vision encoder builds from the + real `VisionEncoderConfig`) + 1 CUDA-gated (forward smoke + under eager attention — currently skipped on this box for + missing libnvrtc-builtins, not a code bug). + + - **4b — DONE** (encoder weight loading): `loader.py` now exposes + `load_vision_encoder_weights`, `load_audio_encoder_weights`, + `load_vision_projector_weights`, `load_audio_projector_weights` + on top of a shared `_load_prefixed_state_dict` helper. None of + these are TP-aware — vision + audio encoders colocate on rank 0 + in the typical topology (see `configs/ming_flash_omni.yaml`) so + a plain prefix-strip + `load_state_dict` path suffices. The + projector loaders also prepend `proj.` to the stripped key so + the on-disk `linear_proj.{0,2}.*` / `linear_proj_audio.{0,3}.*` + keys hit the `nn.Sequential` slot by integer index. + + Verified by 4 snapshot-gated tests in + `test_ming_flash_omni_encoders.py` against the real + `/dev/shm/ming-hybrid` ckpt — all four prefixes load strictly + (no missing / unexpected). The audio encoder's + `positional_embedding` is loaded as a buffer (overrides the + sinusoidal init); the vision encoder loads all 27 blocks + + merger + deepstack_merger_list cleanly. + +5. **Thinker graph walks.** `prefill_text`, `prefill_audio`, `prefill_vision`, + `prefill_video`, `thinker_decode`. Follow Qwen3-Omni's pattern for + conditional walks based on `input_modalities`. + + - **5a — DONE** (`submodules.py`, `ming_omni_flash_model.py`): the two + encoder NodeSubmodules and their construction paths. + + * `VisionEncoderSubmodule` wraps Ming's `Qwen3MoeVisionTransformer` + + `MingVisionProjector`, mirrors + `modeling_bailingmm2.extract_image_feature` (encoder → projector + → L2 norm). `prepare_inputs` raises clearly on missing + `pixel_values` / `image_grid_thw` and promotes 1-D + `[T, H, W]` grid_thw to `(1, 3)`. + + * `AudioEncoderSubmodule` wraps `MingAudioEncoder` + + `MingAudioProjector`. Accepts either a single `(n_mels, T)` clip + or a `(B, n_mels, T)` batched tensor and optionally trims the + padded tail using `audio_seqlens`. Per-clip embeddings are + concatenated along time; L2-norm is applied when + `audio_config.norm_query_embeds` is set (true on the released + ckpt — matches `modeling_bailingmm2.extract_audio_feature`). + + * `get_node_engine_types` now registers + `vision_encoder` / `audio_encoder` as `EngineType.STATELESS` + alongside the KV-cache Thinker. Construction routes through + new `_create_vision_encoder_submodule` / + `_create_audio_encoder_submodule` helpers that build, dtype-cast, + and weight-load via the loaders from step 4b. + + * 12 tests in `test/modular/test_ming_flash_omni_submodules.py`: + 10 pure-Python (input-validation, output shape, L2 norm, + audio batched-vs-single equivalence, audio_seqlens trim, + grid_thw promotion, node-type registration, friendly error on + unknown node) + 2 snapshot-gated (full + `_create_audio_encoder_submodule` on the real ckpt — verifies + Conv1 + projector params are non-zero post-load). + + - **5b — DONE** (Thinker prefill dispatch + position helpers): + `BailingMoeV2ThinkerSubmodule.prepare_inputs` now dispatches on + `graph_walk` and emits either `input_ids` (text-only walks) or + `input_embeds` + `custom_pos_ids` (multimodal walks). `preprocess` + and `forward` route both shapes through to `LingMoeModel`'s + existing dual input_ids/input_embeds + 1D/3D position_ids + handling — no new model.py path needed. + + Three new position-id helpers live in `components/positions.py`, + each producing `(3, T)` long tensors compatible with + `LingPartialMRotaryEmbedding`'s `video_rope` branch: + + * `get_rope_index_text(seq_len, start_pos)` — three identical + sequential rows. Matches `modeling_bailing_moe_v2.get_rope_index`'s + pure-text branch (`:658-675`). + * `get_rope_index_audio` — alias to the text helper (Ming + does not special-case audio in `get_rope_index`). + * `get_rope_index_vision(grid_thw, start_pos, spatial_merge_size, + second_per_grid_t=None, tokens_per_second=2)` — per-image + 3D grid math from `:625-647`. Optional video timestamp + scaling via `second_per_grid_t * tokens_per_second`. + + The Thinker dispatch: + + * `prefill` / `prefill_text` — backward-compat text path + (unchanged from step 3f). + * `prefill_audio` — wraps `audio_embeds` with `audio_start` + / `audio_end` sentinel embeddings, builds text-like positions + for the span. + * `prefill_vision` / `prefill_video` — wraps `vision_embeds` + with `image_start`/`image_end` (or `video_start`/`video_end`), + builds grid-aware 3D positions; `eos` sentinel sits at + `global_max(vision_pos) + 1` so the next walk's text positions + can resume without collision (matches Ming source's + `llm_pos_ids_list[-1].max() + 1` accounting). + * `decode` / `thinker_decode` — single-token AR step (unchanged). + + Sentinel embeds are lazily computed per device on first use. + The model.py construction now passes `config=self.config` to the + submodule so it can read `vision.spatial_merge_size`, + `thinker_llm.tokens_per_second`, and the `*_start_token` / + `*_end_token` ids. + + Step 5b restricts to single-image / single-clip requests + (multi-image splice via `Sequential` graph wiring lands in 5c). + + 21 new tests across `test_ming_flash_omni_positions.py` (11) and + `test_ming_flash_omni_submodules.py` (10): position-id shape / + offset / abs-time math, missing-input error paths, + multi-image rejection, sentinel embed correctness for audio / + image / video walks, start_pos advancement, legacy `prefill` + walk name compat. All green. + + - **5c — DONE** (graph wiring + multimodal scheduling): + `get_graph_walk_graphs` now returns five walks instead of the + step 3f text-only `prefill` / `decode` pair: + + * `prefill_text` — bare `Thinker` node. + * `prefill_audio` — `Sequential([audio_encoder, Thinker])` + where the encoder emits `audio_embeds` into the Thinker. + * `prefill_vision` — `Sequential([vision_encoder, Thinker])`; + `image_grid_thw` routes to BOTH the encoder (for spatial + positions on the patches) AND the Thinker (for 3D MRoPE math + around the vision span). + * `prefill_video` — same shape as `prefill_vision` plus + `video_second_per_grid` routed into the Thinker. + * `thinker_decode` — AR loop, renamed from step 3f's `decode`. + + `get_partitions` lists all five walks under the single `Thinker` + partition with `initial_walk="prefill_text"`. Two new helpers + drive the scheduling: + + * `_build_thinker_prefill_schedule(input_modalities, input_signals)` + — one schedule step per modality, in `input_modalities` order; + each step is `(walk_name, {input_name: TensorPointerInfo})`. + Modalities listed without matching tensors in `input_signals` + are silently skipped (parity with qwen3_omni). + * `_get_thinker_prefill_inputs(metadata, input_signals)` — emits + one `GraphEdge` per input for the current step, routing each + to the right node (encoder vs Thinker), including the dual + `image_grid_thw` edge for vision walks. + + `get_initial_forward_pass_args` builds the schedule, picks the + first walk, and stashes the schedule + step counter on the + metadata. `get_partition_forward_pass_args` is the Thinker state + machine: advance schedule → transition to `thinker_decode` → + return `request_done=True` after the decode loop unwinds. Mirrors + `mstar/model/qwen3_omni/qwen3_omni_model.py:765+` minus the + Talker / Code2Wav partitions (which land in step 6+). + + Empty-schedule edge case (no usable modalities) short-circuits + to `request_done=True` so the conductor doesn't hang. + + 21 tests in `test/modular/test_ming_flash_omni_graph.py`: + graph-walk structure (5 walks, encoder→Thinker chaining, dual + grid_thw edge, loop feedback edge), partition listing, prefill + schedule construction for text-only / text+audio+image / video / + unknown-modality / no-inputs cases, edge routing for each walk + type, full state-machine drive across a text+audio request + (init → audio prefill → decode → done). + +6. **Talker + Audio VAE.** Port `ming_flash_omni_talker.py` + `talker_module.py` + + `audio_vae.py`. The talker is CFM-based (continuous flow matching) rather + than discrete-codec-AR like Qwen3-Omni's — the streaming topology will + differ. Re-read `mstar/streaming/topology.py` before wiring connections. + + Broken out into sub-steps because the upstream code is ~2,100 LOC + across three files (`ming_flash_omni_talker.py` 586 LOC + + `talker_module.py` 1,145 LOC + `audio_vae.py` 392 LOC): + + - **6a — DONE** (config port): replaced the step-1 raw-dict + skeleton `TalkerConfig` with typed sub-config dataclasses so the + modeling code (CFM head + DiT blocks + Aggregator + AudioVAE) + can read dims off `config.talker.*` directly. + + New dataclasses in `components/config.py` (under `TalkerConfig`): + * `TalkerLLMConfig` — Qwen2 backbone (896-dim, 24L, 14H/2KV, + sliding-window=False, RoPE θ=1e6). Distinct from + `ThinkerLLMConfig` (different vocab, no MoE, smaller dims). + `head_dim` property computes 896/14=64. + * `DiTBlockConfig` — shared shape for `flowmodel` and + `aggregator` (depth=8, hidden_size=1024, num_heads=16, + mlp_ratio=4, in_channels=64); only `dropout` differs (0 vs + 0.1 on the released ckpt). `head_dim` / `intermediate_size` + properties for convenience. + * `AudioVAEConfig` — encoder + decoder dims (latent_dim=64, + input_dim=80, hop_size=320, output_dim=882), + `sample_rate=44100`, `patch_size=4`. Encoder/decoder Qwen2 + backbones kept as raw dicts (`enc_backbone` / + `dec_backbone`) for the eventual block-builder to lift. + Discriminator + loss-weight fields retained for round-trip + fidelity but not consumed at inference. + + `TalkerConfig.from_subdir` now constructs the typed sub-configs + directly (was raw-dict assignment); `vae_sample_rate` / + `vae_patch_size` retained as `@property` accessors for backward + compat with `Model.get_output_sample_rate`. + + 8 new tests in `test_ming_flash_omni_config.py` (7 freshly + authored + 1 updated to assert the new typed shape): + - `TalkerLLMConfig` defaults / head_dim / unknown-key filter + - `DiTBlockConfig` intermediate_size / head_dim derivations + - `AudioVAEConfig` enc/dec kwarg lifting + fallback when + enc_kwargs missing latent_dim + - `TalkerConfig.from_subdir` end-to-end with synthetic tmp dirs + (round-trips all three sub-configs) + - Default-factory check that `TalkerConfig()` with no args yields + typed sub-configs + + Verified by re-running the existing snapshot-gated + `test_subdir_configs_load_when_present` against the real + `/dev/shm/ming-hybrid/talker/` tree — typed fields read + correctly (LLM hidden_size=896, VAE sample_rate=44100, + flowmodel depth=8, aggregator dropout=0.1). + + - **6b — DONE** (CFM + DiT building blocks): new + `components/talker_dit.py` ports the modeling primitives from + upstream `talker_module.py:1-402`. Module names mirror upstream + so the released ckpt's `talker/model.safetensors` keys + (`flowmodel.blocks.N.attn.to_q.weight`, + `flowmodel.blocks.N.mlp.ff.0.0.weight` etc.) will load by + state-dict equality once the loader path lands. + + Two external deps replaced with minimal in-tree ports: + * `DiTTimestepEmbedding` — sinusoidal pos-emb + Linear+SiLU+Linear + MLP, matching vllm-omni's `timestep_embedding.DiTTimestepEmbedding`. + * `RotaryEmbedding` — non-xpos 1-D RoPE matching + `x_transformers.RotaryEmbedding.forward_from_seq_len` exactly, + including the INTERLEAVED-pair `rotate_half(x1, x2) = (-x2, x1)` + layout. This is DIFFERENT from Ling-2.0 thinker's neox-cat + layout — adjacent freq pairs share the same value here, while + Ling's halves repeat across the split. Required so the released + weights line up with the same RoPE shape they were trained + against. + + The CFM module wraps the DiT and integrates an ODE/SDE step grid + from `get_epss_timesteps` with classifier-free guidance. + Sway-sampling-coef remap is honored (`-1.0` default packs more + steps near `t=0`). The released ckpt's `steps=10` schedule is + the predefined `[0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32] / 32`. + + Skipped from `talker_module.py`: `CFMGraphExecutor`, + `CFMGraphExecutorPool` (vllm-specific batching), `Aggregator` + (lands in 6c), the resampling / silence-trim / `build_tts_input` + / `MingAudioGenerator` orchestration utilities (lands in 6e + where the Talker submodule wires the streaming graph). + + New factory `build_talker_cfm(talker_config, llm_cond_dim=None, + dtype=..., device=...)` constructs DiT + CFM directly from a + `TalkerConfig` so 6e's `_create_talker_submodule` will be a + one-liner. `llm_cond_dim` defaults to `talker.llm.hidden_size` + (896 on the released ckpt). + + 28 tests in `test_ming_flash_omni_talker_dit.py`: + - RotaryEmbedding layout: rotate-half pair negation, + freqs.shape `(1, T, dim)`, adjacent-pair-shared-frequency + invariant, partial-rotary apply preserves passed-through tail. + - DiTTimestepEmbedding: shape, dtype-stability, even-dim guard. + - RMSNorm normalises to unit-rms per row. + - FeedForward layer indices align with the ckpt's + `ff.0.0` / `ff.0.1` / `ff.2` keys. + - Attention: `to_q/to_k/to_v/to_out.0` param names, qk_norm + branches, rope on/off shape preservation, unknown-qk_norm + rejection. + - DiTBlock + FinalLayer + CondEmbedder round-trip. + - DiT.forward output `(B, 1 + his + patch, out_channels)` for + no-spk and `(B, 2 + his + patch, ...)` with spk; CFG forward + returns trailing `x.shape[1]` rows. + - CFM.sample shape preservation + length / sde_rnd validation, + sway=None branch. + - `build_talker_cfm` from real `TalkerConfig` defaults yields + the expected DiT dims (1024 hidden, 8 layers, 16 heads, + cond_embedder input = 896) + `llm_cond_dim` override. + + - **6c — DONE** (Aggregator + Qwen2 backbone + heads): + + `_Attention` / `_DiTBlock` grew a `mask` parameter to match + upstream API exactly. For the CFM path the caller passes + `mask=None`, so behaviour is unchanged; the Aggregator's mask + branch is now exercised. Mask semantics mirror upstream's + `talker_module.Attention.forward`: + * `attn_mask_enabled=True` builds an SDPA `attn_mask` from the + (B, T) key-padding mask so padded keys are excluded from + softmax. + * Regardless of `attn_mask_enabled`, the masked-out output rows + are zeroed via `masked_fill(~mask, 0)` — matches upstream's + unconditional zeroing branch. + + `Aggregator` (port of `talker_module.Aggregator:702-744`): same + DiTBlock stack as the CFM head, but the input embedder is + `nn.Linear` (audio-latent → hidden) plus a learnable [CLS]-style + `word_embedder` (`nn.Embedding(1, hidden_size)`) prepended to the + sequence. The output is the `[CLS]` row only, projected through + `final_layer` to `llm_input_dim` so the condition feedback loops + back into the talker LLM's embedding space. + + `build_aggregator(talker_config, llm_input_dim=None, ...)` and + `build_talker_cfm(...)` both honor `attn_mask_enabled` from the + respective DiTBlockConfig (False on the released ckpt). + + **Talker LLM backbone** — `build_talker_llm(talker_llm_config, + attn_implementation="sdpa", ...)` constructs a stock + `transformers.Qwen2Model` from `TalkerLLMConfig`. No custom modeling + path: the talker LLM colocates on a single rank in the typical + topology and the ckpt's `talker/model.safetensors` keys are + plain `model.*` Qwen2 keys, so reusing HF keeps the surface small + and inherits HF's KV-cache + attention impl. Matches what the + upstream `MingFlashOmniTalkerForConditionalGeneration.__init__` + does (line 116: `self.model = Qwen2Model(llm_config)`). + + **Talker heads** — `build_talker_heads(talker_config, + spk_embed_dim=192, ...)` returns a dict of two `nn.Linear` heads: + * `stop_head` — `Linear(hidden_size, 2, bias=True)`: binary + end-of-audio classifier consumed during the generation loop. + * `spk_head` — `Linear(192, hidden_size, bias=True)`: projects + a CAMPPlus speaker embedding into the LLM hidden space; the + projected embedding is prepended to the prompt as a voice- + condition token. + + 13 new tests appended to `test_ming_flash_omni_talker_dit.py`: + - Attention mask output-zeroing (unconditional), SDPA attn_mask + branch (attn_mask_enabled=True), no-mask no-zeroing regression + guard. + - Aggregator: `[CLS]` row only output `(B, 1, llm_input_dim)`, + single-row `word_embedder`, mask propagation through DiT + blocks, shape stability across varying T, `build_aggregator` + from real TalkerConfig + `llm_input_dim` override. + - `build_talker_llm`: returns `transformers.Qwen2Model` with + correct dims; tiny-input forward returns hidden states. + - `build_talker_heads`: stop_head (h→2) + spk_head (192→h) with + biases; `spk_embed_dim` override. + + Total talker_dit tests: 41 (28 from 6b + 13 from 6c). Full + Ming step-1..7 + 6a/6b/6c suite: **162 pass / 9 skipped / 0 fail + / 1 deselected** (deselected is pre-existing cuDNN-broken + attention forward, unrelated). + + - **6d — DONE** (AudioVAE): new `components/audio_vae.py` ports + `vllm_omni/.../audio_vae.py` (~392 LOC). Module tree mirrors + upstream so the released ckpt's `talker/vae/model.safetensors` + keys load 1:1 by state-dict equality once the loader path + lands (6f). + + Building blocks: + * `_ISTFT` — sliding-window OLA inverse-STFT. Two padding + modes: `"center"` wraps `torch.istft` directly; `"same"` is + the hand-rolled `F.fold` reconstruction with optional + streaming buffers (carries the trailing `win_length - hop` + samples + window envelope across chunks). + * `_ISTFTHead` — Linear → STFT mag (exp+clip) / phase → `_ISTFT`. + * `_StreamingLinearUpsample` — chunked linear upsampler with + 1-step lookahead so chunked output matches single-shot output + at chunk boundaries. + * `_Encoder` — waveform → latent params. `get_frames` windows + the waveform with stride `hop_size`, `fc1` projects to hidden, + Qwen2 backbone runs, then optional `aggregator` (4-layer + Qwen2 + `cls_embed`) summarises each patch. + * `_Decoder` — latent → waveform. `fc1` to hidden, optional + `_StreamingLinearUpsample`, Qwen2 backbone with sliding-window + bridge for streaming KV cache, `_ISTFTHead` to audio. + * `AudioVAE` — wraps encoder+decoder, exposes `encode_latent` + (with an inline `_oobleck_sample()` so we don't depend on the + broken-on-this-box `diffusers` package) and `decode`. + + **Defaults fixed**: `AudioVAEConfig.encoder_input_dim` / + `encoder_hop_size` were previously 80 / 320 (placeholder from + step 6a); updated to 882 / 882 to match the released ckpt + (`enc_kwargs: {hop_size: 882, input_dim: 882, latent_dim: 64}`). + The existing 6a tests still pass since they explicitly pass + overrides through `from_dict`. + + `build_audio_vae(audio_vae_config, dtype, device, attn_implementation=None)`: + auto-picks `"sdpa"` on CPU and FA2 when available on CUDA; + caller can pin explicitly. Mirrors vllm-omni's runtime choice + for the talker LLM (`llm_config._attn_implementation = "sdpa"`). + + 18 tests in `test_ming_flash_omni_audio_vae.py` covering: + - Oobleck sampler shape + mean-collapse-on-small-scale. + - ISTFT padding-mode validation + center / same forward paths. + - StreamingLinearUpsample: single-shot path, deferred-first-chunk + path, **chunked-vs-single-shot equivalence** (the key + correctness property — proves boundary lookahead is wired + correctly so chunked streaming doesn't introduce artefacts). + - ISTFTHead output shape (audio + x_pred). + - Encoder: `get_frames` padding arithmetic, forward without + patching, forward with patching (aggregator path collapses + to per-patch latents). + - Decoder: non-streaming reconstruct shape, patching path + routes through the upsampler. + - AudioVAE: construction + encode_latent shape (incl. per-clip + frame counts) + decode end-to-end. + - **Snapshot-gated parity**: built `AudioVAE.state_dict()` keys + contain all representative entries present in + `talker/vae/model.safetensors` (fc1/fc2/fc3/norm/cls_embed, + encoder.encoder, encoder.aggregator, decoder.fc1, + decoder.head.out, decoder.head.istft.window, decoder.decoder) + and vice versa — proves the eventual loader will be a clean + prefix-strip + load_state_dict. + + - **6e — IN PROGRESS** (Talker submodule + graph walks): split into + 6e-1 (orchestration helper) + 6e-2 (mstar graph wiring). + + - **6e-1 — DONE** (`components/talker_generator.py`): port of + upstream `MingAudioGenerator` (talker_module.py:854-1146) plus + the streaming-decode utilities `silence_holder` / + `trim_trailing_silence`. Stateless-per-request `TalkerGenerator` + binds Qwen2 LLM + CFM + Aggregator + stop_head + AudioVAE and + exposes: + * `generate_latents(inputs_embeds, ...)` — the AR loop: + repeated (`llm_step` → `cfm_sample_step` → stop check). Each + step emits one `(B, patch_size, latent_dim)` latent; the + Aggregator output becomes the next step's `inputs_embeds`; + the stop_head softmax gates early termination after + `min_new_token` steps. + * `cfm_sample_step` — one CFM substep-integration + Aggregator + + stop classification. + * `llm_step` — single Qwen2 forward with `StaticCache` + `cache_position` bookkeeping on step > 0. + * `decode_to_waveform(latents, stream_decode=True)` — one-shot + or chunked AudioVAE decode; the streaming path threads + `silence_holder` + a sliding `decode_pad` window across chunks. + * `duration_capped_steps` — the text-length → max-steps prosody + heuristic. + * `_init_his_lat` / `_update_his_lat` — history-latent sliding + window (right-aligns a voice-prompt latent when supplied). + + Skipped from upstream: `CFMGraphExecutorPool` / `CFMGraphExecutor` + (vllm CUDA-graph batching — mstar's engine handles capture); + `build_tts_input` / `_looks_like_music_prompt` (→ step 8). + + 24 tests in `test_ming_flash_omni_talker_generator.py`: + trim_trailing_silence (empty / short-clip / silent-tail trim / + weird-shape passthrough), silence_holder (cache init, sub-frame + buffering until last_chunk), generator construction (with / + without VAE), his-lat zeros + right-align + window update + + unsupported-shape guard, cfm_sample_step output shapes + + stop-softmax-sums-to-1, llm_step step-0 path, generate_latents + per-step collection + max_steps cap, duration_capped_steps + heuristic, decode_to_waveform one-shot / streaming / empty / + no-VAE-raises, instance trim_trailing_silence. + + - **6e-2 — DONE** (TalkerSubmodule + construction + node + registration): the talker is a STATELESS node, not an AR / + streaming-codec node. Ming's thinker→talker bridge passes + DETOKENIZED TEXT (the talker re-encodes with its own + `talker/llm` tokenizer — see vllm-omni `pipeline.py`'s + `thinker2talker`), and the CFM step count is stop_head- + determined rather than a conductor decode loop. So the whole + per-request generation (LLM prefill + CFM AR decode + AudioVAE + decode) runs inside one `TalkerSubmodule.forward` call. + + * `TalkerSubmodule` (`submodules.py`): `prepare_inputs` embeds + `talker_text_inputs` token ids via the talker LLM's + `embed_tokens`; `forward` runs `generate_latents` → + `decode_to_waveform` → `trim_trailing_silence` and returns + `{"audio_chunk": [waveform]}` (`(1, 1, num_samples)` at the + VAE sample rate). `get_stateless_flavor` returns + `"audio_codec"` (no autocast / no torch.compile — the CFM + ODE loop + ISTFT are numerically sensitive). + + * `get_node_engine_types` registers `Talker` as + `EngineType.STATELESS` when the snapshot ships a `talker/` + subdir; thinker-only configs omit it. + + * `_create_talker_submodule` builds the full stack + (`build_talker_llm` + `build_talker_cfm` + `build_aggregator` + + `build_talker_heads` + `build_audio_vae`), loads every + subtree via the step-6f loaders, wraps in a + `TalkerGenerator` → `TalkerSubmodule`. + + 12 tests across `test_ming_flash_omni_talker_submodule.py` (9) + + an updated `test_get_submodule_rejects_unknown_node`: + stateless flavor, prepare_inputs embed (1-D + 2-D ids) + + missing-input guard, forward returns finite audio_chunk, + node-type registration (with / without talker config), + `_create_talker_submodule` no-talker guard, plus a + snapshot-gated end-to-end that builds the full talker from + real weights and generates a finite waveform. + + - **6e-3 — DONE** (graph walks + Thinker→Talker bridge): the + talker is now a second partition wired off the Thinker, gated + entirely on `config.talker is not None` (thinker-only configs + are byte-for-byte unchanged from step 5c). + + Graph + partition additions (all in `ming_omni_flash_model.py`): + * `get_graph_walk_graphs` adds a `talker` walk — a single + `Talker` node consuming `thinker_tokens`, emitting one + `audio_chunk` `EMIT_TO_CLIENT` edge. The `thinker_decode` + loop gains a `StreamingGraphEdge(name="thinker_tokens", + target_partition="Talker")` so each decoded token streams to + the talker. + * `get_partition_topology` declares the Thinker→Talker + `Connection` with a `FixedChunkPolicy(chunk_size=1, + continue_after_done=True)` — the talker needs the FULL text + before it generates, so the policy keeps the consumer alive + past the Thinker's text EOS. + * `get_partitions` adds the `Talker` partition + (`producer_partitions=["Thinker"]`, `initial_walk=None`). + * `get_output_sample_rate("audio")` returns the talker VAE + sample rate (44.1 kHz). + * `get_initial_forward_pass_args` / `get_partition_forward_pass_args` + dispatch a Talker branch: `_get_talker_forward` waits for + `producer_done`, then fires the single `talker` walk once and + reports `request_done` on the next invocation. + + Thinker→Talker text bridge: Ming passes DETOKENIZED TEXT, not + hidden states. `thinker_text_to_talker_inputs` decodes the + thinker output ids with the thinker tokenizer and re-encodes + with the talker's own `talker/llm` tokenizer (loaded lazily + + cached via `_get_talker_tokenizer`). `_create_talker_submodule` + injects this as the `TalkerSubmodule.text_bridge`, and + `TalkerSubmodule.prepare_inputs` accepts either pre-bridged + `talker_text_inputs` or raw `thinker_tokens` (running the + bridge in the latter case). + + 18 tests in `test_ming_flash_omni_talker_graph.py`: thinker-only + path unchanged (no talker walk / partition / streaming edge), + talker-enabled graph structure (walk, audio edge, streaming + edge to Talker), partition + topology + chunk-policy + continue-after-done, node-type registration, audio sample rate, + Talker state machine (waits for producer_done, fires once, + then done; audio-output gating), and the text bridge + (decode→re-encode round-trip + missing-tokenizer guard). + Updated two pre-existing tests that asserted Talker was an + unknown node/partition. + + **Step 6 complete** — audio-out `/generate` is now wireable + end-to-end at the model layer (live bring-up still blocked by the + TP=4 OOM on the 4-GPU dev box; needs TP=8 thinker + talker on a + spare rank). + + - **6f — DONE** (weight loaders): `loader.py` exposes five new + entry points on top of the step-4b `_load_prefixed_state_dict` + helper. The helper grew two args: `subdir` (relative to + `local_dir` — lets us point at `talker/` or `talker/vae/` instead + of the snapshot root) and `allow_unexpected` (set of post-rename + keys allowed to appear in the ckpt without a target module slot). + + Five loaders: + * `load_talker_llm_weights` — strips `model.` from + `talker/model.safetensors` for a `transformers.Qwen2Model`. + * `load_talker_cfm_weights` — strips `cfm.` for a `CFM(DiT)`. + Allows the ckpt's `model.rotary_embed.inv_freq` (we register + it as `persistent=False` and recompute locally — deterministic + from head_dim + rope_theta). + * `load_talker_aggregator_weights` — strips `aggregator.` for + an `Aggregator`. Same `rotary_embed.inv_freq` allow. + * `load_talker_heads_weights` — loads `stop_head.*` + + `spk_head.*` into the dict produced by `build_talker_heads`. + * `load_talker_audio_vae_weights` — empty-prefix load from + `talker/vae/model.safetensors` (the ckpt's `encoder.*` / + `decoder.*` are top-level siblings with no shared prefix — + no strip needed). + + 7 snapshot-gated tests in `test_ming_flash_omni_talker_loader.py` + verify strict load against `/dev/shm/ming-hybrid/talker/`: + - Talker LLM: representative key parity + non-zero embed table + after load. + - CFM: `model.x_embedder.weight` / `model.blocks.0.attn.to_q.weight` + / `model.blocks.0.mlp.ff.0.0.weight` / `model.final_layer.linear.weight`. + - Aggregator: `x_embedder` / `word_embedder` / `blocks.0.attn.to_q` + / `final_layer.linear`. + - Heads: `stop_head` + `spk_head` weights both load; non-zero + post-load; missing-key guard fires before disk I/O. + - AudioVAE: full encoder + decoder + aggregator + ISTFT window + keys loaded; CPU end-to-end decode on a real-weights latent + produces a finite waveform (catches catastrophic + dtype/layout misloads that key-name parity alone wouldn't + surface). + + Full Ming step-1..7 + 6a/6b/6c/6d/6f suite: 187 pass / 9 skipped + / 0 fail / 1 deselected. + +7. **Process_prompt — DONE.** `MingFlashOmniModel.process_prompt` now + produces the full `NameToTensorList` consumed by step 5c's prefill + scheduler. Strategy mirrors `qwen3_omni`'s `process_prompt`: apply + the chat template to TEXT-ONLY messages (so the tokenizer doesn't + insert placeholder tokens we'd later have to strip), then run the + image / video / audio sub-processors separately for each modality. + The Ming chat template path uses `tokenizer.apply_chat_template` + (jinja, accepts OpenAI roles `user`/`assistant`/`system`) rather + than `processor.apply_chat_template` (Python implementation in + `BailingMM2Processor`, asserts on lowercase OpenAI roles — see + "Role-handling nuance" above). + + Input convention (`tensors: NameToTensorList`): + * `image_inputs` — list of CHW float [0,1] tensors per image. + Internal `_image_to_processor_input` converts to HWC uint8 to + avoid the upstream's double-rescale near-zero bug + (`qwen3_omni_model.py:1033-1038` documents the same gotcha). + Single-channel inputs auto-broadcast to 3 channels. + * `audio_inputs` — list of either raw 1-D float tensors (sample + rate inferred from processor default 16 kHz) or + `(waveform, sample_rate)` tuples. + * `video_inputs` — list of (T, C, H, W) float tensors. Per-frame + `second_per_grid` defaults to 1.0; override via + `kwargs["input_metadata"]["video"][i]["second_per_grid"]`. + + Output keys consumed by `_build_thinker_prefill_schedule`: + * `text_inputs` — list of 1-D long tensors (one per text turn). + * `pixel_values`, `image_grid_thw` — one entry per image. + * `pixel_values_videos`, `video_grid_thw`, + `video_second_per_grid` — one entry per video clip. + * `audio_features` (n_mels, T) + `audio_seqlens` (length-1 long) + — one entry per audio clip. Note: upstream returns audio_feats + as (B, T, n_mels); we transpose to (n_mels, T) per clip so + `AudioEncoderSubmodule.prepare_inputs` can splice without a + reshape. + + 17 tests in `test/modular/test_ming_flash_omni_process_prompt.py`: + text-only happy path, no-prompt audio-only path, image conversion + correctness (CHW float [0,1] → HWC uint8, grayscale broadcast, + uint8 pass-through), per-modality dispatch, missing-processor + error paths, multi-image / mixed-modality combinations, video + metadata override, snapshot-gated text+image E2E with the real + `BailingMM2Processor`. 16 green + 1 env-skip on this box. + + Image-gen-specific `*256` block (the + query-token expansion for the imagegen DiT path) is deferred to + step 9 (ImageGen partition), since today's prefill schedule only + covers text-out generation. + +8. **TTS caption template — DONE.** `components/prompt_utils.py` ports + vllm-omni's `prompt_utils.py` wholesale (self-contained, no torch / + model deps): + * `create_instruction(user_input)` + `BASE_CAPTION_TEMPLATE` — the + JSON caption builder for the `ming_flash_omni_tts` talker-only + deploy. Merges only known keys (序号 / 说话人 / 方言 / 风格 / 语速 + / 基频 / 音量 / 情感 / BGM / IP) into a deep-copied template; + `ensure_ascii=False` keeps the Chinese field names readable. + * `maybe_expand_image_gen_prompt` + `IMAGE_PATCH_TOKEN` + + `DEFAULT_NUM_QUERY_TOKENS=256` — the + `*256` query-token expansion the + ImageGen path (step 9) needs; landed here so the constants live + in one place. + 10 tests in `test_ming_flash_omni_prompt_utils.py`: query-token + expansion (default 256, custom count, no-op on already-expanded / + empty / non-string), caption build (defaults, known-key merge, + unknown-key ignore, no template mutation across calls, unescaped + unicode, shallow BGM merge). + + **8b — DONE (image-gen prompt wiring):** `process_prompt` now calls + `maybe_expand_image_gen_prompt` when `output_modalities` contains + `"image"` AND the deploy ships an `ImageGenConfig` (thinker-only + deploys leave the prompt untouched). The expansion count comes from + `config.image_gen.num_query_tokens` (= sum of img_gen_scales², 256 by + default), so it tracks the checkpoint rather than the hard-coded + constant. This is the thinker-side half of the step-9 handoff; the DiT + that consumes the query embeddings is still 9b. 5 tests in + `test_ming_flash_omni_process_prompt.py`: block appended on image + output, count tracks img_gen_scales, no expansion on text output, + no-op without ImageGenConfig, no double-expansion. + +9. **ImageGen partition.** Separate from the omni pipeline; lives under + vllm-omni's diffusion tree (`diffusion/models/ming_flash_omni/`, + ~1,315 LOC). Wire as a fourth partition with its own graph walk. + Needs `FlowEngine`-style integration. Multi-commit step. + + - **9a — DONE** (config port): `ImageGenConfig` fleshed out with + typed sub-config dataclasses parsed from the imagegen subdir tree: + * `ZImageDiTConfig` (transformer/config.json) — the diffusion DiT + (dim=3840, 30 layers + 2 refiner, 16-channel latents, 3D axial + RoPE via axes_dims=(32,48,48) / axes_lens=(1536,512,512)). + * `ImageVAEConfig` (vae/config.json) — AutoencoderKL, 16-channel + latent, scaling_factor=0.3611 / shift_factor=0.1159. + * `ImageGenSchedulerConfig` (scheduler/) — + FlowMatchEulerDiscreteScheduler (shift=3.0). + * `ByT5MapperConfig` (byt5/byt5.json) — ByT5-small glyph encoder + + T5EncoderBlockByT5Mapper (4 layers → sdxl_channels=2560) for the + text-rendering pathway. + * `connector` — Qwen2 LLM (1536-dim, 28L) kept as a raw dict + (built via the shared Qwen2 path at construction time). + `from_subdirs` reads each subdir into the typed fields; the + `mlp/config.json` knobs (img_gen_scales, diffusion_c_input_dim, + use_identity_mlp, dit_type) stay at the top level. + 13 tests (7 new pure-Python + 6 existing, incl. updated + snapshot-gated assertions on dit.dim=3840 / vae.latent_channels=16 + / scheduler.shift=3.0 / byt5.sdxl_channels=2560 / connector qwen2). + + - **9b — DONE** (modeling + pipeline + wiring): the full image-gen + stack is ported into `components/` as native pure-torch (+ stock + transformers) modules, decoupled from vllm-omni / vllm TP / diffusers + internals: + * `t5_block_mapper.py` + `byte5_encoder.py` — ByT5 glyph mapper + + encoder. Built on **stock HF `T5Block`** (unfused q/k/v/o, + wi_0/wi_1/wo) so Ming's `byt5_mapper.pt` loads with a plain `copy_`, + no stacked-param remap. 11 mapper tests + 1 snapshot-gated encoder. + * `zimage_transformer.py` — ZImage DiT (`ZImageTransformer2DModel`) + + Ming's ref-latent subclass (`MingZImageTransformer2DModel`). Drops + vllm's TP linears / `CachedTransformer` / fused `Attention` / + `RotaryEmbedding` for plain `nn.Linear` + + `F.scaled_dot_product_attention`. Unfused param names + (`attention.to_q/to_k/to_v`, `feed_forward.w1/w3`) → direct load. + The interleaved (is_neox_style=False) RoPE, GLIDE/DiT + `timestep_embedding`, and FP32 RMSNorm match the vllm-omni reference + (RoPE parity verified maxdiff=0.0). 14 tests on a tiny config. + One intentional divergence: vllm-omni computes-but-does-not-apply the + attention pad mask; this port applies it (identical for the bsz-1 + multiple-of-32 t2i path, correct when caption padding is nonzero). + * `condition_encoder.py` — Qwen2-connector condition path (proj_in → + bidirectional Qwen2 → proj_out → L2-normalize×1000). transformers + only (no diffusers). 7 tests with a stub connector. + * `imagegen_pipeline.py` — flow-matching + CFG denoise loop + (`MingImageDenoiser`, `combine_cfg`, `calculate_shift`) **decoupled + from diffusers** (DiT/scheduler/VAE injected), so the guidance math / + sign convention / scheduler stepping are unit-tested with stubs. + diffusers + transformers loading lives behind the lazy + `MingImagePipeline.from_checkpoint` classmethod (diffusers is broken + on this box — confirmed — so eager import is avoided). 11 tests. + * Wiring (`submodules.py` + `ming_omni_flash_model.py`): + `ImageGenSubmodule` (STATELESS consumer) + an `imagegen` graph walk + + `ImageGen` partition + `Thinker→ImageGen` streaming connection + (`continue_after_done`) + `_create_imagegen_submodule` factory, all + guarded on `config.image_gen`. Mirrors the Talker consumer pattern. + + **Producer↔consumer handoff — DONE.** Both ends of the Thinker→ImageGen + stream are wired: + * Producer: the thinker prefill node carries a `thinker_hidden_states` + `StreamingGraphEdge` (added when `config.image_gen` is set), and + `BailingMoeV2ThinkerSubmodule.forward` detects the `` token + in the prefill ids, runs `LingMoeModel.forward(return_hidden_states=True)`, + slices the patch positions via `extract_image_gen_hidden_states`, and + publishes them under `thinker_hidden_states`. No metadata plumbing — + the gate is the patch token's presence in the tokenized prompt. + * Consumer: `get_initial_forward_pass_args` / `get_partition_forward_pass_args` + gained an `ImageGen` branch + `_get_imagegen_forward` state machine + (fires the `imagegen` walk once the producer is done, then request_done), + mirroring `_get_talker_forward`. + ~30 graph/partition/submodule/producer tests (incl. patch-token-gated + emit, consumer state machine fire-once-then-done). + + **Remaining live-bringup gap (not code):** end-to-end image output still + needs live multi-GPU serve (TP=8) + a working diffusers (broken on this + box). The full modeling + graph + producer/consumer wiring are complete + and unit-validated; only the live run is left. + +10. **Configs — DONE.** `configs/ming_flash_omni.yaml` rewritten to the + real registered node names: `vision_encoder` + `audio_encoder` + + `Talker` colocated on rank 0, `Thinker` TP=8 across all 8 GPUs. + Dropped the stale placeholders (`AudioVAE` is wrapped inside the + Talker submodule, not a separate node; TP=4 → TP=8 to match the + verified OOM finding). Node names cross-checked against + `get_node_engine_types` (a yaml-vs-registered assertion passes). + `configs/ming_flash_omni_thinker_only.yaml` unchanged (already + correct). An image-gen variant lands with step 9. + +11. **Benchmark `OursOpenAI` parity.** Once `mstar-serve` boots the model, + extend `benchmark/request.py:OursOpenAI` to route Ming TTS through the + correct endpoint (likely `/v1/chat/completions` with `modalities=["audio"]`, + matching the Qwen3-Omni path — `MingFlashOmni` declares no Orpheus-style + speech-only fallback). + +12. **Tests.** Add `test/modular/test_ming_flash_omni_*.py` covering config + load, submodule weight load on a tiny shard, and a smoke graph walk on + a single GPU. Mirror `test/modular/test_qwen3_omni_*.py` if present. + +## Things to verify against the released checkpoint (not in vllm-omni) + +- Exact `max_position_embeddings` and `rope_theta` for thinker vs talker + (read from `config.json`, not the deploy yaml). +- Whether `default_sampling_params.repetition_penalty=1.05` from the deploy + yaml is a serving default or a hard requirement — affects + `benchmark/base.py:MingFlashOmni.get_model_kwargs`. +- The output sample rate for the talker (Qwen3-Omni is 24 kHz; check + `audio_vae.py` for Ming's). Override + `Model.get_output_sample_rate` if it differs. diff --git a/mstar/model/ming_omni_flash/__init__.py b/mstar/model/ming_omni_flash/__init__.py new file mode 100644 index 00000000..c6152997 --- /dev/null +++ b/mstar/model/ming_omni_flash/__init__.py @@ -0,0 +1,21 @@ +from mstar.model.ming_omni_flash.components.model import ( + LingMoeModel as LingMoeModel, +) +from mstar.model.ming_omni_flash.loader import ( + load_audio_encoder_weights as load_audio_encoder_weights, +) +from mstar.model.ming_omni_flash.loader import ( + load_audio_projector_weights as load_audio_projector_weights, +) +from mstar.model.ming_omni_flash.loader import ( + load_thinker_weights as load_thinker_weights, +) +from mstar.model.ming_omni_flash.loader import ( + load_vision_encoder_weights as load_vision_encoder_weights, +) +from mstar.model.ming_omni_flash.loader import ( + load_vision_projector_weights as load_vision_projector_weights, +) +from mstar.model.ming_omni_flash.ming_omni_flash_model import ( + MingFlashOmniModel as MingFlashOmniModel, +) diff --git a/mstar/model/ming_omni_flash/components/__init__.py b/mstar/model/ming_omni_flash/components/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mstar/model/ming_omni_flash/components/attention.py b/mstar/model/ming_omni_flash/components/attention.py new file mode 100644 index 00000000..dbb7cac7 --- /dev/null +++ b/mstar/model/ming_omni_flash/components/attention.py @@ -0,0 +1,171 @@ +"""Ling-2.0 attention (TP-aware, packed-tokens, cache-handle-aware). + +Uses mstar's :class:`QKVParallelLinear` + :class:`RowParallelLinear` for +TP-sharded projections. Per-rank head counts come from the QKV proj — +when ``tp_size > 1``, attention runs on this rank's slice of heads and +the output `dense` projection all-reduces across ranks. + +The architecture-specific bits (per-head QK-norm, partial 3D +``video_rope`` rotation) stay inline — they only operate on this rank's +heads, no cross-rank comm. + +Reference: mstar's :class:`ParallelAttention` +(`mstar/model/components/distributed/attention.py`) + +Qwen3-Omni's :class:`Qwen3OmniAttention` +(`mstar/model/qwen3_omni/components/attention.py`). +""" + +from __future__ import annotations + +import torch +from torch import nn + +from mstar.distributed.communication import TPCommGroup +from mstar.engine.cache_manager import BatchedCacheManager +from mstar.model.components.distributed.linear import ( + QKVParallelLinear, + RowParallelLinear, +) +from mstar.model.components.norm import RMSNorm +from mstar.model.ming_omni_flash.components.rope import LingPartialMRotaryEmbedding + + +class LingAttention(nn.Module): + """Ling-2.0 attention layer (TP-aware). + + Constructor takes TOTAL head counts; per-rank counts are derived from + ``qkv_proj.num_heads`` / ``qkv_proj.num_kv_heads`` after construction + (computed by :class:`QKVParallelLinear` based on ``comm_group.world_size``). + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + rms_norm_eps: float, + rotary: LingPartialMRotaryEmbedding, + use_qkv_bias: bool = False, + use_bias: bool = False, + comm_group: TPCommGroup | None = None, + ) -> None: + super().__init__() + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads={num_heads} must be divisible by " + f"num_kv_heads={num_kv_heads} for GQA" + ) + if rotary.head_dim != head_dim: + raise ValueError( + f"rotary.head_dim={rotary.head_dim} must equal head_dim={head_dim}" + ) + if comm_group is None: + comm_group = TPCommGroup.trivial() + self.comm_group = comm_group + + self.hidden_size = hidden_size + self.head_dim = head_dim + self.total_num_heads = num_heads + self.total_num_kv_heads = num_kv_heads + + # Packed QKV projection — TP-sharded along the heads axis. + # Q rows: total_num_heads * head_dim; K rows: total_num_kv_heads * + # head_dim; V rows: same. Stored ordered [Q, K, V] along dim 0 — + # same packing the released ckpt uses for ``query_key_value.weight``, + # so the manual q/k/v split in loader.py copies into the right + # slots automatically. + self.qkv_proj = QKVParallelLinear( + comm_group=comm_group, + hidden_size=hidden_size, + head_size=head_dim, + total_num_heads=num_heads, + total_num_kv_heads=num_kv_heads, + bias=use_qkv_bias, + ) + # Per-rank head counts; everything downstream uses these. + self.num_heads = self.qkv_proj.num_heads + self.num_kv_heads = self.qkv_proj.num_kv_heads + self.kv_groups = self.num_heads // self.num_kv_heads + self.q_size = self.num_heads * head_dim + self.kv_size = self.num_kv_heads * head_dim + self.scaling = head_dim ** -0.5 + + # Output projection — input dim is sharded (per-rank q_size), + # output dim is full hidden_size; row-parallel runs all-reduce + # across ranks. + self.dense = RowParallelLinear( + comm_group=comm_group, + input_size=num_heads * head_dim, # full pre-shard input + output_size=hidden_size, + bias=use_bias, + input_is_parallel=True, + reduce_results=True, + ) + + # Per-head normalisation on q and k before rope. Operates on the + # head_dim axis, so identical math at each rank's local heads. + self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps) + + self.rotary = rotary + + def forward( + self, + hidden_states: torch.Tensor, + cache_handle: BatchedCacheManager, + position_ids: torch.Tensor, + ) -> torch.Tensor: + """Engine-facing forward (packed tokens, cache-aware, TP-aware). + + Args: + hidden_states: ``(num_tokens, hidden_size)``. NOT pre-sharded + — QKVParallelLinear takes the full hidden dim as input. + cache_handle: see step 3d. + position_ids: see step 3d. + + Returns: + ``(num_tokens, hidden_size)`` — full hidden dim after the + row-parallel dense all-reduces across ranks. + """ + num_tokens = hidden_states.shape[0] + + # qkv_proj returns this rank's slice along the heads axis: + # (num_tokens, num_heads * head_dim + 2 * num_kv_heads * head_dim). + qkv = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(num_tokens, self.num_heads, self.head_dim) + k = k.view(num_tokens, self.num_kv_heads, self.head_dim) + v = v.view(num_tokens, self.num_kv_heads, self.head_dim) + + # QK-norm: per-head RMSNorm on the head_dim axis. Each rank + # operates on its own slice of heads — no comm. + q = self.q_norm(q.reshape(-1, self.head_dim)).view( + num_tokens, self.num_heads, self.head_dim + ) + k = self.k_norm(k.reshape(-1, self.head_dim)).view( + num_tokens, self.num_kv_heads, self.head_dim + ) + + # Partial 3D rope on this rank's heads (rope cos/sin are + # head_dim-shaped, identical at every rank). + q = q.transpose(0, 1) + k = k.transpose(0, 1) + q, k = self.rotary(q, k, position_ids) + q = q.transpose(0, 1).contiguous() + k = k.transpose(0, 1).contiguous() + + # Cache attention on per-rank heads. mstar's BatchedCacheManager + # is per-worker, so its KV cache config already accounts for the + # per-rank head counts (worker derives this from ShardingConfig). + attn_output = cache_handle.run_attention(q=q, k=k, v=v) + attn_output = attn_output.reshape(num_tokens, self.q_size) + # dense is row-parallel: it consumes the per-rank slice along the + # input dim and all-reduces the (full hidden_size) output. + return self.dense(attn_output) + + @staticmethod + def head_norm_check(q_after_norm: torch.Tensor) -> float: + """Diagnostic: returns max abs deviation of per-head RMS from 1.""" + norms = q_after_norm.float().pow(2).mean(dim=-1).sqrt() + return (norms - 1.0).abs().max().item() diff --git a/mstar/model/ming_omni_flash/components/audio_encoder.py b/mstar/model/ming_omni_flash/components/audio_encoder.py new file mode 100644 index 00000000..37acefd3 --- /dev/null +++ b/mstar/model/ming_omni_flash/components/audio_encoder.py @@ -0,0 +1,343 @@ +"""Whisper-style audio encoder for Ming-flash-omni-2.0. + +Self-contained port of vllm-omni's +``vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py`` (247 +LOC) — itself a re-implementation of the OpenAI Whisper encoder that +supports packed variable-length inputs (the Ming source's +``modeling_whisper_encoder.py`` uses padded batches and depends on +``openai-whisper``; we avoid that runtime dep entirely). + +Weight-key parity with the upstream Whisper encoder: + - ``conv1.{weight,bias}`` (kernel=3, stride=1, pad=1) + - ``conv2.{weight,bias}`` (kernel=3, stride=2, pad=1) + - ``positional_embedding`` buffer (sinusoidal, not loaded) + - ``blocks.{N}.attn.{query,key,value,out}.{weight,bias}`` + - ``blocks.{N}.attn_ln.{weight,bias}`` + - ``blocks.{N}.mlp.{0,2}.{weight,bias}`` (Linear, GELU, Linear) + - ``blocks.{N}.mlp_ln.{weight,bias}`` + - ``ln_post.{weight,bias}`` + +The released Ming checkpoint stores these under the top-level prefix +``audio.*`` (see ``model.safetensors.index.json``); the loader strips +that prefix before applying state_dict here. +""" + +from __future__ import annotations + +import logging +import operator +from itertools import accumulate + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Whisper primitives (auto-dtype-casting layers + sinusoidal embedding) +# --------------------------------------------------------------------------- + + +def _sinusoids(length: int, channels: int, max_timescale: int = 10000) -> torch.Tensor: + """Sinusoidal positional embedding from Whisper. + + Args: + length: positions. + channels: must be even. + max_timescale: matches OpenAI Whisper's default (10_000). + """ + if channels % 2 != 0: + raise ValueError(f"channels must be even, got {channels}") + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + +class _AutoCastConv1d(nn.Conv1d): + """Conv1d that casts its weight/bias to the input dtype on every forward. + + Lets the encoder keep bf16 weights while taking fp32 mel inputs + without an explicit ``.to(bf16)`` at the call site (Whisper does + this too). + """ + + def _conv_forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None) -> torch.Tensor: + return super()._conv_forward( + x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype), + ) + + +class _AutoCastLinear(nn.Linear): + """Linear with the same auto-cast trick.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.linear( + x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype), + ) + + +# --------------------------------------------------------------------------- +# Multi-head attention (packed sequence with optional FA2 fast path) +# --------------------------------------------------------------------------- + + +def _try_import_flash_attn(): + """Return flash_attn_varlen_func if importable, else None. + + Wrapped so test boxes without flash-attn keep green via the manual + PyTorch fallback. Audio encoder forward shape is identical either way. + """ + try: + from flash_attn import flash_attn_varlen_func # type: ignore + return flash_attn_varlen_func + except ImportError: + return None + + +_FLASH_ATTN_VARLEN = _try_import_flash_attn() + + +class _PackedMultiHeadAttention(nn.Module): + """Whisper-style MHA with variable-length packed sequences. + + Param naming matches OpenAI Whisper (``query`` / ``key`` / ``value`` / + ``out`` — not ``q_proj`` / ``k_proj`` / etc.) so the checkpoint keys + load directly. + """ + + def __init__(self, n_state: int, n_head: int, use_flash_attn: bool = True) -> None: + super().__init__() + if n_state % n_head != 0: + raise ValueError(f"n_state={n_state} not divisible by n_head={n_head}") + self.n_head = n_head + self.query = _AutoCastLinear(n_state, n_state) + self.key = _AutoCastLinear(n_state, n_state, bias=False) + self.value = _AutoCastLinear(n_state, n_state) + self.out = _AutoCastLinear(n_state, n_state) + + if use_flash_attn and _FLASH_ATTN_VARLEN is None: + logger.warning("flash-attn not available — falling back to manual attention.") + self.use_flash_attn = use_flash_attn and _FLASH_ATTN_VARLEN is not None + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + """Packed-sequence attention. + + Args: + x: (total_tokens, n_state) packed tensor. + cu_seqlens: (num_seqs + 1,) cumulative seq lengths, + e.g. [0, len1, len1+len2, ...]. int32. + """ + q = self.query(x) + k = self.key(x) + v = self.value(x) + + n_tokens, n_state = q.shape + head_dim = n_state // self.n_head + q = q.view(n_tokens, self.n_head, head_dim) + k = k.view(n_tokens, self.n_head, head_dim) + v = v.view(n_tokens, self.n_head, head_dim) + + if self.use_flash_attn and q.dtype in (torch.float16, torch.bfloat16): + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = _FLASH_ATTN_VARLEN( + q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + ) + else: + attn_output = self._manual_packed_attention(q, k, v, cu_seqlens) + + attn_output = attn_output.contiguous().view(n_tokens, n_state) + return self.out(attn_output) + + @staticmethod + def _manual_packed_attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor, + ) -> torch.Tensor: + """Pad-attention-unpack fallback for the packed format.""" + _, n_head, head_dim = q.shape + scale = head_dim ** -0.5 + + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + batch = len(seqlens) + max_len = max(seqlens) + + # Pad each sequence to max_len so we can run a single batched matmul. + q_pad = torch.zeros(batch, max_len, n_head, head_dim, dtype=q.dtype, device=q.device) + k_pad = torch.zeros_like(q_pad) + v_pad = torch.zeros_like(q_pad) + for i, ln in enumerate(seqlens): + start = int(cu_seqlens[i].item()) + end = int(cu_seqlens[i + 1].item()) + q_pad[i, :ln] = q[start:end] + k_pad[i, :ln] = k[start:end] + v_pad[i, :ln] = v[start:end] + + # (B, H, T, D) + q_pad = q_pad.transpose(1, 2) + k_pad = k_pad.transpose(1, 2) + v_pad = v_pad.transpose(1, 2) + + # Mask padding columns out of softmax. + padding_mask = ( + torch.arange(max_len, device=q.device)[None, :] + >= torch.tensor(seqlens, device=q.device)[:, None] + ) + attn_mask = torch.zeros(batch, 1, 1, max_len, dtype=q.dtype, device=q.device) + attn_mask = attn_mask.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2), -torch.finfo(q.dtype).max, + ) + + scores = torch.matmul(q_pad, k_pad.transpose(-2, -1)) * scale + attn_mask + weights = F.softmax(scores, dim=-1) + context = torch.matmul(weights, v_pad) # (B, H, T, D) + context = context.transpose(1, 2).contiguous() # (B, T, H, D) + + # Unpack back to packed. + return torch.cat([context[i, :ln] for i, ln in enumerate(seqlens)], dim=0) + + +# --------------------------------------------------------------------------- +# Residual block (Whisper attn + FFN) +# --------------------------------------------------------------------------- + + +class _ResidualAttentionBlock(nn.Module): + """Whisper-style attn + FFN residual block (param names match upstream).""" + + def __init__(self, n_state: int, n_head: int, use_flash_attn: bool = True) -> None: + super().__init__() + self.attn = _PackedMultiHeadAttention(n_state, n_head, use_flash_attn=use_flash_attn) + self.attn_ln = nn.LayerNorm(n_state) + + n_mlp = n_state * 4 + # Sequential layout (Linear, GELU, Linear) so checkpoint keys + # blocks.{N}.mlp.0.* / .2.* hit the right module by integer index. + self.mlp = nn.Sequential( + _AutoCastLinear(n_state, n_mlp), + nn.GELU(), + _AutoCastLinear(n_mlp, n_state), + ) + self.mlp_ln = nn.LayerNorm(n_state) + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.attn_ln(x), cu_seqlens=cu_seqlens) + x = x + self.mlp(self.mlp_ln(x)) + return x + + +# --------------------------------------------------------------------------- +# Encoder — public API +# --------------------------------------------------------------------------- + + +class MingAudioEncoder(nn.Module): + """Whisper audio encoder with packed-sequence support. + + Loadable from the released Ming-flash-omni-2.0 checkpoint's + ``audio.*`` weight subtree (caller strips the prefix). Defaults + match the released ckpt's ``audio_config.whisper_encoder_config``. + + Note the deviation from the openai-whisper original: the + ``positional_embedding`` is a *buffer* with a fixed sinusoidal + table sized to ``n_ctx`` (15000 on the released ckpt — enough for + ~150 s of audio at the post-conv frame rate). The Ming source's + ``modeling_whisper_encoder.py`` notes the same change — they drop + the trainable parameter so they can shrink the sequence length + below the original 30 s pad. + """ + + def __init__( + self, + n_mels: int = 128, + n_ctx: int = 15000, + n_state: int = 1280, + n_head: int = 20, + n_layer: int = 32, + use_flash_attn: bool = True, + ) -> None: + super().__init__() + self.n_layer = n_layer + self.n_mels = n_mels + self.use_flash_attn = use_flash_attn + self.audio_emb_dim = n_state + + self.conv1 = _AutoCastConv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = _AutoCastConv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) + # Buffer (not Parameter) — checkpoint doesn't ship this; we + # recompute it. Keeps load_state_dict happy with the snapshot. + self.register_buffer("positional_embedding", _sinusoids(n_ctx, n_state)) + self.blocks = nn.ModuleList( + [_ResidualAttentionBlock(n_state, n_head, use_flash_attn=use_flash_attn) for _ in range(n_layer)] + ) + self.ln_post = nn.LayerNorm(n_state) + + def forward(self, x_list: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + """Run the encoder on a list of variable-length mel spectrograms. + + Args: + x_list: list of (n_mels, T_i) mel features per audio clip. + + Returns: + (packed, cu_seqlens): + - packed: (total_T', n_state) all clips concatenated + along time. + - cu_seqlens: (len(x_list) + 1,) int32 cumulative encoded + lengths suitable for re-segmenting / feeding + into the projector. + """ + target_dtype = self.conv1.weight.dtype + + encoded = [] + encoded_lens: list[int] = [] + for mel in x_list: + mel = mel.to(target_dtype) + x = mel.unsqueeze(0) # (1, n_mels, T) + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.squeeze(0).transpose(0, 1) # (T', n_state) + + seq_len = x.shape[0] + x = (x + self.positional_embedding[:seq_len, :]).to(x.dtype) + encoded.append(x) + encoded_lens.append(seq_len) + + packed = torch.cat(encoded, dim=0) # (sum T', n_state) + cu_seqlens = torch.tensor( + list(accumulate(encoded_lens, func=operator.add, initial=0)), + device=packed.device, dtype=torch.int32, + ) + for block in self.blocks: + packed = block(packed, cu_seqlens=cu_seqlens) + packed = self.ln_post(packed) + return packed, cu_seqlens + + +def build_audio_encoder( + audio_config, + dtype: torch.dtype = torch.bfloat16, + device: str | torch.device = "cpu", + use_flash_attn: bool = True, +) -> MingAudioEncoder: + """Construct :class:`MingAudioEncoder` from an ``AudioEncoderConfig``. + + Matches ``build_vision_encoder``'s factory shape so the model class + treats both modalities symmetrically when wiring submodules. + """ + whisper_cfg = audio_config.whisper_encoder_config + encoder = MingAudioEncoder( + n_mels=int(whisper_cfg["n_mels"]), + n_ctx=int(whisper_cfg["n_ctx"]), + n_state=int(whisper_cfg["n_state"]), + n_head=int(whisper_cfg["n_head"]), + n_layer=int(whisper_cfg["n_layer"]), + use_flash_attn=use_flash_attn, + ) + encoder = encoder.to(dtype=dtype, device=device) + encoder.eval() + return encoder + + +__all__ = ["MingAudioEncoder", "build_audio_encoder"] diff --git a/mstar/model/ming_omni_flash/components/audio_vae.py b/mstar/model/ming_omni_flash/components/audio_vae.py new file mode 100644 index 00000000..4eaadc90 --- /dev/null +++ b/mstar/model/ming_omni_flash/components/audio_vae.py @@ -0,0 +1,726 @@ +"""AudioVAE for Ming-flash-omni-2.0 (step 6d). + +Self-contained port of vllm-omni's +``vllm_omni/model_executor/models/ming_flash_omni/audio_vae.py`` (392 LOC). +The released ckpt ships the VAE under ``talker/vae/model.safetensors`` +with the top-level prefixes ``encoder.*`` and ``decoder.*``; we mirror +the upstream module tree so the eventual loader is a plain prefix-strip ++ load_state_dict. + +Topology (released ckpt): + + AudioVAE + .encoder (Encoder) # waveform → latent + .encoder (Qwen2Model, sliding-window=64) # main backbone + .aggregator (Qwen2Model, 4 layers) # patch-summarisation + .fc1 (Linear 882 → 896) + .fc2 (Linear 896 → 896) + .fc3 (Linear 896 → 128) # latent_dim*2 (mean+scale) + .norm (LayerNorm 896) + .cls_embed (Parameter (1, 1, 896)) + .decoder (Decoder) # latent → waveform + .decoder (Qwen2Model, sliding-window=64) + .fc1 (Linear 64 → 896) + .head (ISTFTHead) + .out (Linear 896 → 3530 = n_fft + 2) + .istft (ISTFT, n_fft=3528, hop=882, win=3528) + .upsampling (StreamingLinearUpsample) # only when patch_size != -1 + +Two simplifications vs vllm-omni: + + * `encode_latent` uses an inline `_oobleck_sample()` instead of + `diffusers.OobleckDiagonalGaussianDistribution` — same math + (mean/scale split, softplus on scale, reparameterised sample) but + no diffusers dep. The full diffusers class also exposes + `kl_divergence` / `mode` for training; we only need `sample` at + inference, so the minimal helper is enough. + + * `Decoder.low_level_reconstruct`'s streaming KV-cache fill path uses + HF `Cache` instances; the upstream's `past_key_values` tuple + fallback isn't needed on transformers >= 4.43. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +import torch +import torch.nn.functional as F + +if TYPE_CHECKING: + from transformers import Qwen2Config +from torch import nn + +logger = logging.getLogger(__name__) + + +# =========================================================================== +# Inline Oobleck-style Gaussian sampler (replaces diffusers dep) +# =========================================================================== + + +def _oobleck_sample(parameters: torch.Tensor) -> torch.Tensor: + """Sample from a diagonal Gaussian parameterised by ``[mean, scale]``. + + Matches the inference-time behaviour of + ``diffusers.models.autoencoders.autoencoder_oobleck.OobleckDiagonalGaussianDistribution.sample``: + + mean, scale_raw = parameters.chunk(2, dim=1) + scale = softplus(scale_raw) + 1e-4 + sample = mean + scale * eps + + Args: + parameters: ``(B, 2 * latent_dim, T)`` tensor — first half is + the mean, second half is the raw scale. + + Returns: + ``(B, latent_dim, T)`` sample. + """ + mean, scale_raw = parameters.chunk(2, dim=1) + scale = F.softplus(scale_raw) + 1e-4 + eps = torch.randn_like(mean) + return mean + scale * eps + + +# =========================================================================== +# ISTFT — inverse-STFT reconstruction with optional streaming buffers +# =========================================================================== + + +class _ISTFT(nn.Module): + """Sliding-window OLA inverse STFT used by ISTFTHead. + + Two padding modes: + + * ``"center"`` — wraps ``torch.istft`` directly. + * ``"same"`` — hand-rolled F.fold reconstruction so we can + manage chunk boundaries via ``audio_buffer`` / ``window_buffer`` + (essential for the streaming decode path). + + The streaming variant preserves the trailing ``win_length - hop_length`` + samples of audio + window envelope across chunks so adjacent chunks + sum-of-window-envelope-normalise correctly when concatenated. + """ + + def __init__( + self, + n_fft: int, + hop_length: int, + win_length: int, + padding: str = "same", + ) -> None: + super().__init__() + if padding not in ("center", "same"): + raise ValueError(f"Padding must be 'center' or 'same'; got {padding!r}.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.buffer_len = win_length - hop_length + self.register_buffer("window", torch.hann_window(win_length)) + + # ------------------------------------------------------------------ + # Per-chunk buffer plumbing + # ------------------------------------------------------------------ + + def _buffer_process( + self, + x: torch.Tensor, + buffer: torch.Tensor | None, + pad: int, + last_chunk: bool, + streaming: bool, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Apply OLA buffering for the ``same`` padding mode. + + Non-streaming: trim ``pad`` samples off both ends. + Streaming: add the previous chunk's tail into the current head; + retain the new tail unless this is the last chunk (in which case + trim ``pad`` off the end). + """ + if streaming: + if buffer is None: + x = x[:, pad:] + else: + x = x.clone() + x[:, : self.buffer_len] = x[:, : self.buffer_len] + buffer + buffer = x[:, -self.buffer_len :] + if not last_chunk: + x = x[:, : -self.buffer_len] + else: + x = x[:, :-pad] + else: + x = x[:, pad:-pad] + return x, buffer + + def forward( + self, + spec: torch.Tensor, + audio_buffer: torch.Tensor | None = None, + window_buffer: torch.Tensor | None = None, + streaming: bool = False, + last_chunk: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """Inverse-STFT reconstruction. + + Args: + spec: ``(B, n_fft//2 + 1, T)`` complex STFT magnitudes. + + Returns: + Tuple of ``(waveform, audio_buffer, window_buffer)``. + Buffers are None when ``streaming=False`` and the centre + padding mode is in use. + """ + if self.padding == "center": + y = torch.istft( + spec, self.n_fft, self.hop_length, self.win_length, self.window, + center=True, + ) + return y, None, None + + # same-padding path + pad = (self.win_length - self.hop_length) // 2 + B, N, T = spec.shape + + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + output_size = (T - 1) * self.hop_length + self.win_length + y = F.fold( + ifft, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + )[:, 0, 0, :] + + y, audio_buffer = self._buffer_process( + y, audio_buffer, pad, last_chunk=last_chunk, streaming=streaming, + ) + + # Compute the per-position sum-of-window-squared so OLA averages + # correctly. Same fold over a (1, T, win_length) tile of the + # squared window. + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = ( + F.fold( + window_sq, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + ) + .squeeze(0) + .squeeze(0) + ) + window_envelope, window_buffer = self._buffer_process( + window_envelope, window_buffer, pad, + last_chunk=last_chunk, streaming=streaming, + ) + window_envelope = window_envelope.squeeze() + + if not (window_envelope > 1e-11).all(): + raise RuntimeError( + "ISTFT window envelope has near-zero positions; " + "check hop_length / win_length / window choice." + ) + y = y / window_envelope + + return y, audio_buffer, window_buffer + + +# =========================================================================== +# ISTFTHead — Linear → STFT magnitude/phase → ISTFT → waveform +# =========================================================================== + + +class _ISTFTHead(nn.Module): + """Projects DiT hidden states to STFT mag+phase then runs an ISTFT. + + Output Linear emits ``n_fft + 2`` channels; the first half is the + log-magnitude (exp'd + clipped to 1e2) and the second half is the + phase. Reassembled as a complex spectrogram for the ISTFT. + """ + + def __init__( + self, + dim: int, + n_fft: int, + hop_length: int, + padding: str = "same", + ) -> None: + super().__init__() + self.out = nn.Linear(dim, n_fft + 2) + self.istft = _ISTFT( + n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding, + ) + + def forward( + self, + x: torch.Tensor, + audio_buffer: torch.Tensor | None = None, + window_buffer: torch.Tensor | None = None, + streaming: bool = False, + last_chunk: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """Returns ``(audio, x_pred, audio_buffer, window_buffer)``. + + ``audio`` is ``(B, 1, T_samples)``; ``x_pred`` is the raw + (B, n_fft+2, T_frames) projection (useful for adversarial / + spec-disc training paths; harmless at inference). + """ + x_pred = self.out(x).transpose(1, 2) + mag, phase = x_pred.chunk(2, dim=1) + mag = torch.exp(mag).clip(max=1e2) + spec = mag * (torch.cos(phase) + 1j * torch.sin(phase)) + audio, audio_buffer, window_buffer = self.istft( + spec, audio_buffer=audio_buffer, window_buffer=window_buffer, + streaming=streaming, last_chunk=last_chunk, + ) + return audio.unsqueeze(1), x_pred, audio_buffer, window_buffer + + +# =========================================================================== +# StreamingLinearUpsample — chunked linear upsample for patched latents +# =========================================================================== + + +class _StreamingLinearUpsample(nn.Module): + """Linear upsampling that produces consistent output across chunks. + + Non-streaming: ``upsampler(x)`` directly. + Streaming: defer emit until we have a 1-step lookahead so the + upsample boundary matches the non-chunked result. Internal ``state`` + dict tracks: ``prev_chunk``, ``history_last`` (the last frame of the + PREVIOUS prev_chunk, kept so the upsample window has left context), + ``is_first``. + """ + + def __init__(self, scale_factor: int = 4) -> None: + super().__init__() + self.scale_factor = scale_factor + self.upsampler = nn.Upsample( + scale_factor=scale_factor, mode="linear", align_corners=False, + ) + + def forward( + self, + x: torch.Tensor | None, + state: dict[str, Any] | None = None, + is_last: bool = False, + ) -> tuple[torch.Tensor | None, dict[str, Any] | None]: + if state is None: + state = {"prev_chunk": None, "history_last": None, "is_first": True} + + if x is None and not is_last: + return None, state + + # Single-chunk fast path: first AND last. + if state["is_first"] and is_last: + out = self.upsampler(x.transpose(1, 2)).transpose(1, 2) + return out, None + + output_chunks: list[torch.Tensor] = [] + + if state["is_first"]: + state["prev_chunk"] = x + state["is_first"] = False + if not is_last: + return None, state + + # Emit the deferred prev_chunk now that we have a right lookahead. + if state["prev_chunk"] is not None: + p = state["prev_chunk"].transpose(1, 2) + if state["history_last"] is None: + lookahead = x[:, :1, :].transpose(1, 2) + inp = torch.cat([p, lookahead], dim=2) + up = self.upsampler(inp) + out_prev = up[:, :, : p.size(2) * self.scale_factor] + else: + lookahead = x[:, :1, :].transpose(1, 2) + inp = torch.cat([state["history_last"], p, lookahead], dim=2) + up = self.upsampler(inp) + start = self.scale_factor + end = start + p.size(2) * self.scale_factor + out_prev = up[:, :, start:end] + output_chunks.append(out_prev.transpose(1, 2)) + state["history_last"] = p[:, :, -1:] + state["prev_chunk"] = x + + if is_last: + p = state["prev_chunk"].transpose(1, 2) + inp = torch.cat([state["history_last"], p], dim=2) + up = self.upsampler(inp) + out_last = up[:, :, self.scale_factor :] + output_chunks.append(out_last.transpose(1, 2)) + state = None + + final = torch.cat(output_chunks, dim=1) if output_chunks else None + return final, state + + +# =========================================================================== +# Encoder / Decoder (Qwen2-backed) +# =========================================================================== + + +def _build_vae_qwen2_config(backbone: dict, attn_implementation: str) -> "Qwen2Config": + """Build a Qwen2Config from the VAE backbone dict, stripping fields HF doesn't accept.""" + from transformers import Qwen2Config + # Drop fields that Qwen2Config doesn't accept as kwargs (HF would + # store them as custom attrs, but cleaner to drop). `is_causal` is + # the only field upstream adds that HF's Qwen2 ignores. + accepted = { + k: v for k, v in backbone.items() + if k not in ("is_causal", "transformers_version", "torch_dtype", + "_attn_implementation", "_attn_implementation_autoset", + "attn_implementation", "model_type", "architectures") + } + cfg = Qwen2Config(**accepted, attn_implementation=attn_implementation) + return cfg + + +def _resolve_attn_implementation() -> str: + """Prefer FA2 when available; else sdpa.""" + try: + from transformers.utils import is_flash_attn_2_available + return "flash_attention_2" if is_flash_attn_2_available() else "sdpa" + except Exception: + return "sdpa" + + +class _Decoder(nn.Module): + """Latent → waveform via Qwen2 backbone + ISTFTHead. + + Module-tree mirrors upstream so the released ckpt's + ``decoder.decoder.layers.N.*`` (Qwen2Model), ``decoder.fc1``, + ``decoder.head.out``, ``decoder.head.istft.window`` keys all + land via plain state-dict equality. + """ + + def __init__( + self, + decoder_args: dict, + output_dim: int = 882, + latent_dim: int = 64, + patch_size: int = -1, + attn_implementation: str | None = None, + ) -> None: + super().__init__() + from transformers import Qwen2Model + if attn_implementation is None: + attn_implementation = _resolve_attn_implementation() + cfg = _build_vae_qwen2_config(decoder_args, attn_implementation=attn_implementation) + logger.info("AudioVAE Decoder: using attn_implementation=%r", cfg._attn_implementation) + + self.decoder = Qwen2Model(cfg) + self.output_dim = output_dim + self.latent_dim = latent_dim + self.hop_length = output_dim + self.fc1 = nn.Linear(latent_dim, cfg.hidden_size) + self.head = _ISTFTHead( + dim=cfg.hidden_size, + n_fft=self.hop_length * 4, + hop_length=self.hop_length, + padding="same", + ) + self.patch_size = patch_size + if self.patch_size != -1: + self.upsampling = _StreamingLinearUpsample(scale_factor=patch_size) + + def low_level_reconstruct( + self, + x: torch.Tensor, + past_key_values=None, + use_cache: bool = False, + stream_state: tuple[Any, Any, Any] = (None, None, None), + last_chunk: bool = False, + ): + """Reconstruct ``(B, 1, T_samples)`` waveform from latent ``(B, T, latent_dim)``. + + Non-streaming path runs the full upsample + backbone + head. + Streaming path threads ``stream_state = (upsample_state, + audio_buffer, window_buffer)`` and the Qwen2 backbone's + ``past_key_values`` across chunks; bridges the sliding-window + boundary with the partial-fill trick from upstream when the + first chunk would exceed ``sliding_window``. + """ + upsample_state, audio_buffer, window_buffer = stream_state + bsz, device, dtype = x.size(0), x.device, x.dtype + x = self.fc1(x) + if self.patch_size != -1: + if use_cache: + x, upsample_state = self.upsampling( + x, state=upsample_state, is_last=last_chunk, + ) + if x is None: + stream_state = (upsample_state, audio_buffer, window_buffer) + return torch.empty(bsz, 1, 0, device=device, dtype=dtype), stream_state, past_key_values + else: + x = self.upsampling.upsampler(x.transpose(1, 2)).transpose(1, 2) + + hidden_states_list: list[torch.Tensor] = [] + + # Sliding-window bridge: when the cache is empty and this chunk + # would push past `sliding_window`, fill the cache with the + # first (sw_size - 1) tokens first so the second pass benefits + # from the cached prefix. + if use_cache and getattr(self.decoder.config, "sliding_window", None) is not None: + sw_size = self.decoder.config.sliding_window + target_len = sw_size - 1 + past_len = _get_past_len(past_key_values) + curr_len = x.shape[1] + if past_len < target_len and (past_len + curr_len) >= sw_size: + fill_len = target_len - past_len + x_fill = x[:, :fill_len, :] + outputs = self.decoder( + inputs_embeds=x_fill, past_key_values=past_key_values, use_cache=True, + ) + hidden_states_list.append(outputs.last_hidden_state) + past_key_values = outputs.past_key_values + x = x[:, fill_len:, :] + + outputs = self.decoder( + inputs_embeds=x, past_key_values=past_key_values, use_cache=use_cache, + ) + hidden_states_list.append(outputs.last_hidden_state) + past_key_values = outputs.past_key_values + + full_hidden = ( + torch.cat(hidden_states_list, dim=1) + if len(hidden_states_list) > 1 + else hidden_states_list[0] + ) + x_out, _x_pred, audio_buffer, window_buffer = self.head( + full_hidden, + streaming=use_cache, + audio_buffer=audio_buffer, + window_buffer=window_buffer, + last_chunk=last_chunk, + ) + stream_state = (upsample_state, audio_buffer, window_buffer) + return x_out, stream_state, past_key_values + + +def _get_past_len(past_key_values) -> int: + """Recover past-seq-len across the various HF cache shapes.""" + if past_key_values is None: + return 0 + if hasattr(past_key_values, "get_seq_length"): + return int(past_key_values.get_seq_length()) + if isinstance(past_key_values, tuple) and len(past_key_values) > 0: + return int(past_key_values[0][0].shape[-2]) + return 0 + + +class _Encoder(nn.Module): + """Waveform → latent via Qwen2 backbone + optional patch aggregator. + + With ``patch_size != -1`` the encoder runs a second short Qwen2 + backbone (4 layers) over each patch concatenated with a learnable + [CLS] embedding and outputs the [CLS] row only — same shape as + the Aggregator (`components/talker_dit.Aggregator`) but inside the + VAE encoder rather than at the talker output. + """ + + def __init__( + self, + encoder_args: dict, + input_dim: int = 882, + hop_size: int = 882, + latent_dim: int = 64, + patch_size: int = -1, + attn_implementation: str | None = None, + ) -> None: + super().__init__() + from transformers import Qwen2Model + if attn_implementation is None: + attn_implementation = _resolve_attn_implementation() + cfg = _build_vae_qwen2_config(encoder_args, attn_implementation=attn_implementation) + logger.info("AudioVAE Encoder: using attn_implementation=%r", cfg._attn_implementation) + + self.encoder = Qwen2Model(cfg) + self.input_dim = input_dim + self.hop_size = hop_size + self.latent_dim = latent_dim + + self.fc1 = nn.Linear(input_dim, cfg.hidden_size, bias=False) + self.fc2 = nn.Linear(cfg.hidden_size, cfg.hidden_size) + self.fc3 = nn.Linear(cfg.hidden_size, latent_dim * 2) + self.norm = nn.LayerNorm(cfg.hidden_size) + self.patch_size = patch_size + if patch_size != -1: + # Aggregator is a 4-layer Qwen2 backbone (upstream + # explicitly overrides num_hidden_layers to 4). + agg_cfg = _build_vae_qwen2_config( + {**encoder_args, "num_hidden_layers": 4}, + attn_implementation=attn_implementation, + ) + self.aggregator = Qwen2Model(agg_cfg) + # Learnable CLS embedding prepended to each patch. + self.cls_embed = nn.Parameter(torch.empty(1, 1, cfg.hidden_size)) + # Match upstream's normal_(0, 0.02) init so eager-init + # weights match if the loader is bypassed in tests. + nn.init.normal_(self.cls_embed, mean=0.0, std=0.02) + + # ------------------------------------------------------------------ + # Waveform → frames windowed slicing + # ------------------------------------------------------------------ + + def get_frames(self, x: torch.Tensor) -> torch.Tensor: + """Slide a ``(input_dim,)`` window over the waveform with stride hop_size. + + Pads the right edge so the final window doesn't overshoot. + Returns ``(B, num_frames, input_dim)``. + """ + num_frames_total = (x.size(-1) + self.hop_size - 1) // self.hop_size + expected_len = (num_frames_total - 1) * self.hop_size + self.input_dim + padding_needed = expected_len - x.size(-1) + waveform = F.pad(x, (0, padding_needed), value=0.0) + frames = waveform.unfold(dimension=-1, size=self.input_dim, step=self.hop_size) + return frames + + def pad_patch_insert_cls(self, x: torch.Tensor) -> torch.Tensor: + """Group frames into patches of ``patch_size`` and append a CLS row to each.""" + bsz, num_frame, dim = x.size() + r = num_frame % self.patch_size + pad_num = self.patch_size - r if r else 0 + x = F.pad(x, (0, 0, 0, pad_num), value=0.0) + x = x.reshape(-1, self.patch_size, dim) + cls = self.cls_embed.expand(x.size(0), -1, -1) + x = torch.cat((x, cls), dim=1) + x = x.reshape(bsz, -1, dim) + return x + + def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Returns ``(latent_params, waveform.unsqueeze(1))``. + + ``latent_params`` is ``(B, T_latents, latent_dim*2)`` — the + first half is the Gaussian mean and the second half is the + raw scale; pass through `_oobleck_sample` to draw a latent. + """ + x = self.get_frames(waveform) + x = self.fc1(x) + x = self.fc2(x) + h = self.encoder(inputs_embeds=x).last_hidden_state + + if self.patch_size != -1: + h = self.pad_patch_insert_cls(h) + h = self.aggregator(inputs_embeds=h).last_hidden_state + bsz, _, dim = h.size() + h = h.reshape(-1, self.patch_size + 1, dim) + h = h[:, -1:, :].reshape(bsz, -1, dim) + + h = self.fc3(h) + return h, waveform.unsqueeze(1) + + +# =========================================================================== +# AudioVAE — wraps Encoder + Decoder +# =========================================================================== + + +class AudioVAE(nn.Module): + """Top-level Audio VAE. + + Plain nn.Module (not PreTrainedModel) so we don't inherit HF + config machinery — the dataclass `AudioVAEConfig` carries the dims + and the loader handles weights directly. + """ + + def __init__( + self, + audio_vae_config, + attn_implementation: str | None = None, + ) -> None: + super().__init__() + self.config = audio_vae_config + self.encoder = _Encoder( + encoder_args=audio_vae_config.enc_backbone, + input_dim=audio_vae_config.encoder_input_dim, + hop_size=audio_vae_config.encoder_hop_size, + latent_dim=audio_vae_config.latent_dim, + patch_size=audio_vae_config.patch_size, + attn_implementation=attn_implementation, + ) + self.decoder = _Decoder( + decoder_args=audio_vae_config.dec_backbone, + output_dim=audio_vae_config.decoder_output_dim, + latent_dim=audio_vae_config.latent_dim, + patch_size=audio_vae_config.patch_size, + attn_implementation=attn_implementation, + ) + + @property + def sample_rate(self) -> int: + return self.config.sample_rate + + def encode_latent( + self, + waveform: torch.Tensor, + waveform_length: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Run the encoder and sample a latent (Gaussian re-parameterised). + + Returns ``(latent, frame_num)``. ``latent`` is + ``(B, latent_dim, T_latents)``; ``frame_num`` is the per-clip + latent count after patching. + """ + frame_num = torch.ceil( + waveform_length / self.config.encoder_input_dim, + ).to(torch.int32) + if self.config.patch_size != -1: + frame_num = torch.ceil(frame_num / self.config.patch_size) + h, _y = self.encoder(waveform) + # encoder.fc3 emits (B, T, latent_dim*2) — transpose to channels-second + # for `_oobleck_sample` (chunks on dim=1). + h = h.transpose(1, 2) + latent = _oobleck_sample(h) + latent = latent.transpose(1, 2) + return latent, frame_num + + def decode( + self, + latent: torch.Tensor, + past_key_values=None, + use_cache: bool = False, + stream_state: tuple[Any, Any, Any] = (None, None, None), + last_chunk: bool = False, + ): + """Decode latent → waveform; threads the streaming state for chunked TTS.""" + waveform, stream_state, past_key_values = self.decoder.low_level_reconstruct( + latent, + past_key_values=past_key_values, + use_cache=use_cache, + stream_state=stream_state, + last_chunk=last_chunk, + ) + return waveform, stream_state, past_key_values + + +def build_audio_vae( + audio_vae_config, + dtype: torch.dtype = torch.bfloat16, + device: str | torch.device = "cpu", + attn_implementation: str | None = None, +) -> AudioVAE: + """Construct an `AudioVAE` from `AudioVAEConfig`. + + ``attn_implementation`` defaults to ``"sdpa"`` on CPU and FA2 when + flash-attn is importable AND the target device is CUDA. Caller can + pin to ``"eager"`` for debugging or ``"sdpa"`` to mirror what + vllm-omni's talker actually uses at runtime (it forces sdpa on the + talker LLM regardless of FA2 availability). + """ + if attn_implementation is None: + device_str = str(device) + if device_str == "cpu" or device_str.startswith("cpu"): + attn_implementation = "sdpa" + else: + attn_implementation = _resolve_attn_implementation() + vae = AudioVAE(audio_vae_config, attn_implementation=attn_implementation) + vae = vae.to(dtype=dtype, device=device) + vae.eval() + return vae + + +__all__ = ["AudioVAE", "build_audio_vae"] diff --git a/mstar/model/ming_omni_flash/components/byte5_encoder.py b/mstar/model/ming_omni_flash/components/byte5_encoder.py new file mode 100644 index 00000000..7787216b --- /dev/null +++ b/mstar/model/ming_omni_flash/components/byte5_encoder.py @@ -0,0 +1,224 @@ +"""ByT5 glyph/text encoder for Ming-flash-omni-2.0 image generation. + +Native mstar port of vllm-omni's ``byte5_encoder.py``. Bundles the byt5 +tokenizer + HF T5 encoder + :class:`T5EncoderBlockByT5Mapper`. The released +checkpoint's byt5 weights were trained with per-language font/color special +tokens, so we replicate that vocabulary extension before loading — otherwise +``byt5_model.pt`` shape-mismatches at the embedding table. + +Typical forward: a list of prompt strings (optionally carrying +```` / ```` markers) → ``[B, byt5_max_length, +diffusion_c_input_dim]`` features, padded positions zeroed so the downstream +``torch.cat`` onto cap_feats injects no garbage. +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from types import SimpleNamespace + +import torch +from torch import nn + +from mstar.model.ming_omni_flash.components.t5_block_mapper import ( + T5EncoderBlockByT5Mapper, +) + +logger = logging.getLogger(__name__) + + +def _add_multilingual_special_tokens( + tokenizer, + text_encoder: nn.Module, + font_ann_path: Path, + color_ann_path: Path, + add_font: bool, + add_color: bool, + add_align: bool = False, +) -> None: + """Extend the byt5 vocab with per-language font + color markers. + + Mirrors ``add_special_token_multilingual`` in Ming's bizgen utils. The token + set must match what the checkpoint was trained with, otherwise the resized + embedding table won't line up with the shipped weights. + """ + idx_font_dict = json.loads(Path(font_ann_path).read_text()) + idx_color_dict = json.loads(Path(color_ann_path).read_text()) + + font_tokens: list[str] = [] + for font_code in idx_font_dict: + prefix = font_code[:3] + if prefix in ("cn-", "en-", "jp-", "kr-"): + font_tokens.append(f"<{prefix}font-{idx_font_dict[font_code]}>") + else: + font_tokens.append(f"") + color_tokens = [f"" for i in range(len(idx_color_dict))] + align_tokens = [f"" for i in range(3)] + + extra: list[str] = [] + if add_color: + extra += color_tokens + if add_font: + extra += font_tokens + if add_align: + extra += align_tokens + tokenizer.add_tokens(extra, special_tokens=True) + text_encoder.resize_token_embeddings(len(tokenizer)) + + +class MingByT5Encoder(nn.Module): + """Bundles byt5 tokenizer + T5 encoder + :class:`T5EncoderBlockByT5Mapper`. + + Build via :meth:`from_checkpoint` when the checkpoint ships byt5 weights; + otherwise callers can skip this and the pipeline falls back to no-byt5 + conditioning. + """ + + def __init__( + self, + tokenizer, + text_encoder: nn.Module, + mapper: T5EncoderBlockByT5Mapper, + max_length: int, + ) -> None: + super().__init__() + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.mapper = mapper + self.max_length = max_length + + @classmethod + def from_checkpoint( + cls, + byte5_dir: Path, + *, + device: torch.device, + dtype: torch.dtype, + ) -> MingByT5Encoder: + """Load tokenizer + encoder + mapper from the checkpoint's ``byt5`` dir. + + Wrapped in ``torch.random.fork_rng`` so any ``nn.init`` inside + ``from_pretrained`` / vocab-resize cannot advance the default generator + — otherwise the diffusion pipeline's seeded noise becomes + order-dependent across requests (same-seed replays would diverge). + """ + cuda_devs = list(range(torch.cuda.device_count())) if torch.cuda.is_available() else [] + with torch.random.fork_rng(devices=cuda_devs, enabled=True): + return cls._from_checkpoint_impl(byte5_dir, device=device, dtype=dtype) + + @classmethod + def _from_checkpoint_impl( + cls, + byte5_dir: Path, + *, + device: torch.device, + dtype: torch.dtype, + ) -> MingByT5Encoder: + from transformers import AutoTokenizer, T5ForConditionalGeneration + + byte5_dir = Path(byte5_dir) + # Ming checkpoint uses ``byt5`` (no 'e') in filenames and JSON keys; + # the ``byte5_`` variable spelling below is kept for readability. + cfg_raw = json.loads((byte5_dir / "byt5.json").read_text()) + cfg = SimpleNamespace(**cfg_raw) + byte5_config = cfg.byt5_config + mapper_config = cfg.byt5_mapper_config + max_length = int(cfg.byt5_max_length) + + # ---- Tokenizer + T5 encoder (base). + ckpt_key = byte5_config.get("byt5_ckpt_path") + byte5_ckpt_path = byte5_dir / ckpt_key.lstrip("./") + tokenizer = AutoTokenizer.from_pretrained(byte5_ckpt_path, local_files_only=True) + text_encoder = T5ForConditionalGeneration.from_pretrained( + byte5_ckpt_path, local_files_only=True + ).get_encoder() + + # ---- Extend vocab with font/color markers so the shipped weights load. + if byte5_config.get("special_token"): + if not byte5_config.get("multilingual", True): + raise NotImplementedError( + "Non-multilingual byt5 vocab extension is not ported; " + "the released Ming checkpoint uses multilingual=True." + ) + _add_multilingual_special_tokens( + tokenizer, + text_encoder, + font_ann_path=byte5_dir / byte5_config["font_ann_path"].lstrip("./"), + color_ann_path=byte5_dir / byte5_config["color_ann_path"].lstrip("./"), + add_font=bool(byte5_config.get("font_special_token")), + add_color=bool(byte5_config.get("color_special_token")), + ) + + # ---- Load byt5 text-encoder weights. base.pt wraps the backbone in a + # trainable-module container (module.text_tower.encoder.*); byt5_model.pt + # carries the top-level encoder state. Follow Ming's two-step load. + base_state = torch.load(byte5_dir / "byt5_model" / "base.pt", map_location="cpu", weights_only=False) + prefix = "module.text_tower.encoder." + base_filtered = { + name[len(prefix):]: state + for name, state in base_state["state_dict"].items() + if name.startswith(prefix) + } + text_encoder.load_state_dict(base_filtered, strict=True) + del base_state, base_filtered + + encoder_state = torch.load(byte5_dir / "byt5_model" / "byt5_model.pt", map_location="cpu", weights_only=False) + text_encoder.load_state_dict(encoder_state) + del encoder_state + + text_encoder.to(device=device, dtype=dtype).eval() + + # ---- Mapper (stock HF T5Block layout ⇒ direct state_dict load). + mapper = T5EncoderBlockByT5Mapper( + byte5_config=text_encoder.config, + num_layers=int(mapper_config["num_layers"]), + sdxl_channels=int(mapper_config["sdxl_channels"]), + ) + mapper_state = torch.load(byte5_dir / "byt5_mapper" / "byt5_mapper.pt", map_location="cpu", weights_only=False) + mapper.load_weights(mapper_state.items()) + del mapper_state + mapper.to(device=device, dtype=dtype).eval() + + logger.info( + "[MingByT5Encoder] ready: d_model=%d mapper_layers=%d sdxl_channels=%d max_length=%d vocab=%d", + text_encoder.config.d_model, + mapper_config["num_layers"], + mapper_config["sdxl_channels"], + max_length, + len(tokenizer), + ) + return cls(tokenizer, text_encoder, mapper, max_length) + + @torch.inference_mode() + def forward(self, texts: list[str]) -> torch.Tensor: + """Tokenize → T5 encode → mapper; zeroes padded positions. + + Returns ``[B, max_length, sdxl_channels]``. + """ + device = next(self.text_encoder.parameters()).device + dtype = next(self.text_encoder.parameters()).dtype + + tokens = self.tokenizer( + texts, + padding="max_length", + max_length=self.max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = tokens.input_ids.to(device) + attention_mask = tokens.attention_mask.to(device) + + encoder_out = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask.float(), + ) + hidden_states = encoder_out[0] + feats = self.mapper(hidden_states, attention_mask) + feats = feats * attention_mask.unsqueeze(-1).to(dtype=feats.dtype) + return feats.to(dtype=dtype) + + +__all__ = ["MingByT5Encoder"] diff --git a/mstar/model/ming_omni_flash/components/condition_encoder.py b/mstar/model/ming_omni_flash/components/condition_encoder.py new file mode 100644 index 00000000..ac60969a --- /dev/null +++ b/mstar/model/ming_omni_flash/components/condition_encoder.py @@ -0,0 +1,243 @@ +"""Ming-flash-omni-2.0 condition encoder for image generation (step 9b). + +Native mstar port of vllm-omni's ``condition_encoder.py``. Encodes the thinker +hidden states (sliced at the learnable ```` query-token positions) +into the DiT's ``cap_feats`` conditioning: + + thinker hidden states [B, N, 4096] + │ proj_in (Linear, bias) -> [B, N, 1536] + │ Qwen2 connector (bidirectional, non-causal) + │ proj_out (Linear, bias) -> [B, N, 2560] + │ F.normalize(dim=-1) × 1000 (text_encoder_norm) + ▼ + cap_feats consumed by ZImageTransformer2DModel + +Only transformers is required (the connector is a small Qwen2 backbone loaded +via ``Qwen2ForCausalLM.from_pretrained``); there is no diffusers dependency, so +the forward path is unit-testable with a stub connector. +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path + +import torch +import torch.nn.functional as F +from torch import nn + +logger = logging.getLogger(__name__) + + +class MingConditionEncoder(nn.Module): + """Qwen2 connector + proj_in/out + L2-normalize×1000 → DiT condition embeds. + + The connector runs bidirectionally (``is_causal=False``) since it encodes a + fixed block of query-token hidden states rather than decoding + autoregressively. ``proj_in`` / ``proj_out`` / connector are populated by + :meth:`load_from_checkpoint`; before that the module is cheap to construct + (Identity projections), which keeps dummy-init and unit tests light. + + Args: + image_gen_config: an ``ImageGenConfig`` (mstar) exposing + ``connector_subfolder`` / ``mlp_subfolder`` / + ``diffusion_c_input_dim`` / ``text_encoder_norm`` / + ``use_identity_mlp``. + thinker_hidden_size: hidden size of the thinker (BailingMoeV2); 4096 on + the released checkpoint. + device / dtype: optional placement applied after loading. + """ + + def __init__( + self, + image_gen_config, + *, + thinker_hidden_size: int = 4096, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__() + self.config = image_gen_config + self.thinker_hidden_size = thinker_hidden_size + self._target_device = torch.device(device) if device is not None else None + self._target_dtype = dtype + + self.connector: nn.Module | None = None + self.connector_hidden_size: int | None = None + self.proj_in: nn.Module = nn.Identity() + self.proj_out: nn.Module = nn.Identity() + self.norm: nn.Module = nn.Identity() + + # ------------------------------------------------------------------ + # Weight loading + # ------------------------------------------------------------------ + + def load_from_checkpoint(self, model_path: str | Path) -> None: + """Load the Qwen2 connector + proj_in/proj_out weights from disk.""" + from transformers import AutoConfig, Qwen2ForCausalLM + + model_path = Path(model_path) + connector_path = model_path / self.config.connector_subfolder + logger.info("[MingConditionEncoder] loading connector from %s", connector_path) + + connector_cfg = AutoConfig.from_pretrained(connector_path, trust_remote_code=True, local_files_only=True) + connector_cfg.is_decoder = False + self.connector_hidden_size = int(connector_cfg.hidden_size) + + connector = Qwen2ForCausalLM.from_pretrained( + connector_path, + config=connector_cfg, + torch_dtype=self._target_dtype, + local_files_only=True, + ) + # Force bidirectional attention defensively — some transformers versions + # read ``self_attn.is_causal`` in forward. + for module in connector.modules(): + if hasattr(module, "is_causal"): + module.is_causal = False + + self.connector = getattr(connector, "model", connector) # base encoder, no LM head + + self.proj_in = nn.Linear(self.thinker_hidden_size, self.connector_hidden_size, bias=True) + # text_encoder_norm = L2 normalize on the final cap_feats (NOT an + # intermediate RMSNorm); applied explicitly in forward(). Keep + # self.norm as Identity. + self.norm = nn.Identity() + self.proj_out = nn.Linear(self.connector_hidden_size, self.config.diffusion_c_input_dim, bias=True) + + mlp_path = model_path / self.config.mlp_subfolder + mlp_cfg_path = mlp_path / "config.json" + if mlp_cfg_path.exists() and not json.loads(mlp_cfg_path.read_text()).get("use_identity_mlp", False): + raise NotImplementedError(f"{mlp_cfg_path} has use_identity_mlp=False; ToClipMLP path not implemented.") + self._load_optional_mlp_weights(mlp_path) + + if self._target_device is not None: + self.to(self._target_device) + if self._target_dtype is not None: + self.to(dtype=self._target_dtype) + + def _load_optional_mlp_weights(self, mlp_path: Path) -> None: + """Copy proj_in / proj_out (+ optional norm) weights from ``mlp/``. + + Expected keys (inclusionAI/Ming-flash-omni-2.0): ``proj_in.{weight,bias}`` + [1536,4096]/[1536], ``proj_out.{weight,bias}`` [2560,1536]/[2560], and + ``query_tokens_dict.16x16`` [256,4096] which is consumed on the thinker + side (skipped here). Missing proj weights are logged as errors — the + conditioning is meaningless without them. + """ + if not mlp_path.exists(): + logger.warning("[MingConditionEncoder] mlp/ missing at %s — proj/norm stay random-init", mlp_path) + return + + from safetensors.torch import load_file + + candidates = sorted(mlp_path.glob("*.safetensors")) or sorted(mlp_path.glob("*.bin")) + if not candidates: + logger.warning("[MingConditionEncoder] no weight files under %s", mlp_path) + return + + state: dict[str, torch.Tensor] = {} + for p in candidates: + if p.suffix == ".safetensors": + state.update(load_file(str(p))) + else: + state.update(torch.load(str(p), map_location="cpu")) + + handled: set[str] = set() + + def _copy(dst: torch.Tensor, src_key: str) -> bool: + src = state.get(src_key) + if src is None: + logger.error("[MingConditionEncoder] mlp/ missing key %r", src_key) + return False + if tuple(src.shape) != tuple(dst.shape): + logger.error( + "[MingConditionEncoder] mlp/%s shape mismatch: ckpt=%s module=%s", + src_key, + tuple(src.shape), + tuple(dst.shape), + ) + return False + with torch.no_grad(): + dst.copy_(src.to(dtype=dst.dtype, device=dst.device)) + handled.add(src_key) + return True + + ok = all( + [ + _copy(self.proj_in.weight, "proj_in.weight"), + _copy(self.proj_in.bias, "proj_in.bias"), + _copy(self.proj_out.weight, "proj_out.weight"), + _copy(self.proj_out.bias, "proj_out.bias"), + ] + ) + if not ok: + logger.error("[MingConditionEncoder] proj_in/proj_out NOT fully loaded; conditioning will be garbage.") + + if "norm.weight" in state and hasattr(self.norm, "weight"): + _copy(self.norm.weight, "norm.weight") + + for k in state: + if k.startswith("query_tokens_dict"): + handled.add(k) # thinker-side; not loaded here + + leftover = set(state.keys()) - handled + if leftover: + logger.warning("[MingConditionEncoder] mlp/ unhandled keys: %s", sorted(leftover)) + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward( + self, + thinker_hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Encode ``[B, N, thinker_hidden_size]`` → ``[B, N, diffusion_c_input_dim]``.""" + if self.connector is None: + raise RuntimeError("MingConditionEncoder.load_from_checkpoint() must run before forward().") + if thinker_hidden_states.dim() != 3: + raise ValueError(f"expected [B, N, H], got shape {tuple(thinker_hidden_states.shape)}") + + b, n, _ = thinker_hidden_states.shape + x = self.proj_in(thinker_hidden_states) + + # Ming passes a 4D all-ones mask [B, 1, N, N] to force full bidirectional + # self-attention over the query positions. + if attention_mask is None: + attention_mask = torch.ones((b, 1, n, n), dtype=x.dtype, device=x.device) + elif attention_mask.dim() == 2: + attention_mask = attention_mask.to(x.dtype)[:, None, None, :].expand(b, 1, n, n) + + out = self.connector( + inputs_embeds=x, + attention_mask=attention_mask, + use_cache=False, + output_hidden_states=True, + return_dict=True, + ) + hidden = out.hidden_states[-1] + cap_feats = self.proj_out(hidden) + + cap_feats = F.normalize(cap_feats, dim=-1) + if self.config.text_encoder_norm: + cap_feats = cap_feats * 1000.0 + return cap_feats + + @torch.no_grad() + def zero_negative(self, cap_feats: torch.Tensor) -> torch.Tensor: + """Zero tensor shaped like ``cap_feats`` for CFG negatives.""" + return torch.zeros_like(cap_feats) + + def extra_repr(self) -> str: + return ( + f"thinker_hidden_size={self.thinker_hidden_size}, " + f"connector_hidden_size={self.connector_hidden_size}, " + f"diffusion_c_input_dim={self.config.diffusion_c_input_dim}, " + f"text_encoder_norm={self.config.text_encoder_norm}" + ) + + +__all__ = ["MingConditionEncoder"] diff --git a/mstar/model/ming_omni_flash/components/decoder_layer.py b/mstar/model/ming_omni_flash/components/decoder_layer.py new file mode 100644 index 00000000..511e9730 --- /dev/null +++ b/mstar/model/ming_omni_flash/components/decoder_layer.py @@ -0,0 +1,111 @@ +"""Ling-2.0 decoder layer (TP-aware, hybrid dense / MoE).""" + +from __future__ import annotations + +import torch +from torch import nn + +from mstar.distributed.communication import TPCommGroup +from mstar.engine.cache_manager import BatchedCacheManager +from mstar.model.components.distributed.mlp import ParallelGatedMLP +from mstar.model.components.norm import RMSNorm +from mstar.model.ming_omni_flash.components.attention import LingAttention +from mstar.model.ming_omni_flash.components.moe import LingMoeBlock +from mstar.model.ming_omni_flash.components.rope import ( + LingPartialMRotaryEmbedding, +) + + +class LingDecoderLayer(nn.Module): + """One Ling-2.0 decoder layer; layer_idx decides dense-vs-MoE FFN. + + All sub-modules receive ``comm_group``; defaults to single-rank + trivial when not set. Dense layer-0 MLP uses :class:`ParallelGatedMLP` + so its `down_proj` all-reduces across ranks. + """ + + def __init__( + self, + layer_idx: int, + first_k_dense_replace: int, + hidden_size: int, + intermediate_size: int, + moe_intermediate_size: int, + num_attention_heads: int, + num_kv_heads: int, + head_dim: int, + rms_norm_eps: float, + num_experts: int, + num_experts_per_tok: int, + num_shared_experts: int, + n_group: int, + topk_group: int, + routed_scaling_factor: float, + rotary: LingPartialMRotaryEmbedding, + use_qkv_bias: bool = False, + use_bias: bool = False, + comm_group: TPCommGroup | None = None, + ) -> None: + super().__init__() + if comm_group is None: + comm_group = TPCommGroup.trivial() + self.layer_idx = layer_idx + self.is_moe = layer_idx >= first_k_dense_replace + + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + + self.self_attn = LingAttention( + hidden_size=hidden_size, + num_heads=num_attention_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + rotary=rotary, + use_qkv_bias=use_qkv_bias, + use_bias=use_bias, + comm_group=comm_group, + ) + + if self.is_moe: + self.mlp: nn.Module = LingMoeBlock( + hidden_size=hidden_size, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + moe_intermediate_size=moe_intermediate_size, + num_shared_experts=num_shared_experts, + n_group=n_group, + topk_group=topk_group, + routed_scaling_factor=routed_scaling_factor, + comm_group=comm_group, + ) + else: + # Dense layer-0 MLP — ParallelGatedMLP so its column-parallel + # gate/up + row-parallel down handle TP sharding internally. + self.mlp = ParallelGatedMLP( + comm_group=comm_group, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + bias=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + cache_handle: BatchedCacheManager, + position_ids: torch.Tensor, + image_mask: torch.Tensor | None = None, + audio_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + residual = hidden_states + h = self.input_layernorm(hidden_states) + h = self.self_attn(h, cache_handle, position_ids) + h = residual + h + + residual = h + h = self.post_attention_layernorm(h) + if self.is_moe: + h = self.mlp(h, image_mask=image_mask, audio_mask=audio_mask) + else: + h = self.mlp(h) + return residual + h diff --git a/mstar/model/ming_omni_flash/components/imagegen_pipeline.py b/mstar/model/ming_omni_flash/components/imagegen_pipeline.py new file mode 100644 index 00000000..2503ac84 --- /dev/null +++ b/mstar/model/ming_omni_flash/components/imagegen_pipeline.py @@ -0,0 +1,375 @@ +"""Ming-flash-omni-2.0 imagegen diffusion pipeline (step 9b). + +Native mstar port of vllm-omni's ``pipeline_ming_imagegen.py`` + +``z_image/pipeline_z_image.py`` denoise loop. The upstream pipeline subclasses +``ZImagePipeline`` (diffusers-/vllm_omni-coupled) and reads cross-stage tensors +off a global forward-context. This port: + + * keeps the **denoise loop pure** (``MingImageDenoiser.denoise``) — it takes + the DiT, scheduler, latents and prompt embeds as plain arguments, so the + flow-matching + classifier-free-guidance math is unit-testable with stubs + and has no diffusers dependency; + * pushes diffusers/transformers loading behind + :meth:`MingImagePipeline.from_checkpoint` (lazy import) so the module + imports cleanly even where diffusers is unavailable. + +Flow-matching denoise (Z-Image convention): + - latents start as Gaussian noise; timesteps come from + FlowMatchEulerDiscreteScheduler with dynamic shifting (``mu`` from + :func:`calculate_shift`); + - per step the DiT predicts velocity; CFG combines pos/neg; the prediction is + negated before ``scheduler.step`` (Z-Image sign convention); + - final latents are un-shifted/un-scaled and VAE-decoded to ``[-1, 1]``. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass + +import torch + +logger = logging.getLogger(__name__) + + +def calculate_shift( + image_seq_len: int, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +) -> float: + """Dynamic-shift ``mu`` for FlowMatchEulerDiscreteScheduler (Z-Image).""" + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + return image_seq_len * m + b + + +@dataclass +class MingImageGenSamplingParams: + """Resolved sampling knobs for one image-gen request.""" + + height: int = 1024 + width: int = 1024 + num_inference_steps: int = 50 + guidance_scale: float = 2.0 + seed: int | None = None + cfg_truncation: float = 1.0 + cfg_normalization: float = 0.0 + + +def combine_cfg( + pos: torch.Tensor, + neg: torch.Tensor, + guidance_scale: float, + cfg_normalization: float = 0.0, +) -> torch.Tensor: + """Classifier-free-guidance combination with optional renormalization. + + ``pred = pos + scale * (pos - neg)``; when ``cfg_normalization > 0`` the + result is rescaled so its norm does not exceed ``cfg_normalization`` × the + positive prediction's norm (Z-Image's renorm trick). Operates in fp32. + """ + pos = pos.float() + neg = neg.float() + pred = pos + guidance_scale * (pos - neg) + if cfg_normalization and float(cfg_normalization) > 0.0: + ori = torch.linalg.vector_norm(pos) + new = torch.linalg.vector_norm(pred) + max_new = ori * float(cfg_normalization) + scale = torch.where( + new > max_new, + (max_new / new.clamp(min=1e-12)).to(pred.dtype), + pred.new_tensor(1.0), + ) + pred = pred * scale + return pred + + +class MingImageDenoiser: + """Pure flow-matching + CFG denoise loop (no diffusers coupling). + + Holds references to the DiT transformer and a diffusers-style scheduler + (anything exposing ``.step(model_output, t, sample) -> (prev_sample, ...)`` + and ``.timesteps``). The loop math mirrors ZImagePipeline.forward steps 6. + """ + + def __init__(self, transformer, scheduler, dtype: torch.dtype = torch.float32) -> None: + self.transformer = transformer + self.scheduler = scheduler + self.dtype = dtype + + def denoise( + self, + latents: torch.Tensor, + timesteps, + prompt_embeds: list[torch.Tensor], + negative_prompt_embeds: list[torch.Tensor] | None, + guidance_scale: float, + cfg_truncation: float = 1.0, + cfg_normalization: float = 0.0, + ) -> torch.Tensor: + """Run the denoising loop and return the final ``[B, C, H, W]`` latents. + + Args: + latents: initial noise ``[B, C, H, W]`` (fp32). + timesteps: iterable of scheduler timesteps (1-D tensor). + prompt_embeds / negative_prompt_embeds: list[Tensor] one per item. + guidance_scale: CFG scale; ``> 0`` enables CFG (needs negatives). + cfg_truncation: disable CFG once normalized time exceeds this. + cfg_normalization: optional CFG renorm factor (0 = off). + """ + actual_batch = latents.shape[0] + do_cfg = guidance_scale > 0 and negative_prompt_embeds is not None + + ts = timesteps if isinstance(timesteps, torch.Tensor) else torch.as_tensor(timesteps) + norm_ts = ((1000 - ts.float()) / 1000).tolist() + + for i, t in enumerate(timesteps): + if isinstance(t, torch.Tensor): + timestep = t.expand(latents.shape[0]) + else: + timestep = torch.tensor([t] * latents.shape[0]) + timestep = (1000 - timestep) / 1000 + t_norm = norm_ts[i] + + current_scale = guidance_scale + if do_cfg and cfg_truncation is not None and float(cfg_truncation) <= 1 and t_norm > cfg_truncation: + current_scale = 0.0 + apply_cfg = do_cfg and current_scale > 0 + + latents_typed = latents.to(self.dtype) + if apply_cfg: + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + embeds_input = prompt_embeds + negative_prompt_embeds + timestep_input = timestep.repeat(2) + else: + latent_model_input = latents_typed + embeds_input = prompt_embeds + timestep_input = timestep + + # DiT expects a list of [C, F, H, W] (frame axis inserted at dim 2). + latent_model_input = latent_model_input.unsqueeze(2) + model_out = self.transformer( + list(latent_model_input.unbind(dim=0)), + timestep_input, + embeds_input, + )[0] + + if apply_cfg: + pos_out = model_out[:actual_batch] + neg_out = model_out[actual_batch:] + noise_pred = torch.stack( + [ + combine_cfg(pos_out[j], neg_out[j], current_scale, cfg_normalization) + for j in range(actual_batch) + ], + dim=0, + ) + else: + noise_pred = torch.stack([o.float() for o in model_out], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred # Z-Image sign convention + + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + + return latents + + +class MingImagePipeline: + """Text-to-image / img2img pipeline for Ming-flash-omni-2.0. + + Construct via :meth:`from_checkpoint` (loads VAE / scheduler / DiT / + condition encoder / optional ByT5 — diffusers + transformers required) or + inject components directly (used by tests). The conditioning path is Ming's + own (Qwen2 connector), so there is no Z-Image text encoder / tokenizer. + """ + + def __init__( + self, + *, + transformer, + scheduler, + vae, + condition_encoder, + image_gen_config, + byte5=None, + device: torch.device | str = "cpu", + dtype: torch.dtype = torch.bfloat16, + ) -> None: + self.transformer = transformer + self.scheduler = scheduler + self.vae = vae + self.condition_encoder = condition_encoder + self.image_gen_config = image_gen_config + self.byte5 = byte5 + self.device = torch.device(device) + self.dtype = dtype + self.denoiser = MingImageDenoiser(transformer, scheduler, dtype=dtype) + self.vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) if vae is not None else 8 + + @classmethod + def from_checkpoint(cls, model_path, image_gen_config, *, device="cuda", dtype=torch.bfloat16): + """Load all components from the checkpoint (lazy diffusers import). + + Kept separate from ``__init__`` so the module imports without diffusers; + only this path needs it. + """ + from pathlib import Path + + from diffusers import AutoencoderKL + from diffusers.schedulers import FlowMatchEulerDiscreteScheduler + + from mstar.model.ming_omni_flash.components.byte5_encoder import MingByT5Encoder + from mstar.model.ming_omni_flash.components.condition_encoder import MingConditionEncoder + from mstar.model.ming_omni_flash.components.zimage_transformer import MingZImageTransformer2DModel + + model_path = Path(model_path) + cfg = image_gen_config + + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model_path, subfolder=cfg.scheduler_subfolder, local_files_only=True + ) + scheduler.config["use_dynamic_shifting"] = True + + vae = AutoencoderKL.from_pretrained( + model_path, subfolder=cfg.vae_subfolder, local_files_only=True, torch_dtype=dtype + ).to(device).eval() + + transformer = MingZImageTransformer2DModel( + all_patch_size=tuple(cfg.dit.all_patch_size), + all_f_patch_size=tuple(cfg.dit.all_f_patch_size), + dim=cfg.dit.dim, + n_layers=cfg.dit.n_layers, + n_refiner_layers=cfg.dit.n_refiner_layers, + n_heads=cfg.dit.n_heads, + n_kv_heads=cfg.dit.n_kv_heads, + in_channels=cfg.dit.in_channels, + norm_eps=cfg.dit.norm_eps, + rope_theta=cfg.dit.rope_theta, + t_scale=cfg.dit.t_scale, + axes_dims=tuple(cfg.dit.axes_dims), + axes_lens=tuple(cfg.dit.axes_lens), + cap_feat_dim=cfg.diffusion_c_input_dim, + ).to(device, dtype=dtype).eval() + + condition_encoder = MingConditionEncoder( + cfg, thinker_hidden_size=4096, device=device, dtype=dtype + ) + condition_encoder.load_from_checkpoint(model_path) + + byte5_dir = model_path / "byt5" + byte5 = None + if (byte5_dir / "byt5.json").exists(): + byte5 = MingByT5Encoder.from_checkpoint(byte5_dir, device=torch.device(device), dtype=dtype) + + return cls( + transformer=transformer, + scheduler=scheduler, + vae=vae, + condition_encoder=condition_encoder, + image_gen_config=cfg, + byte5=byte5, + device=device, + dtype=dtype, + ) + + def prepare_latents(self, batch_size, height, width, generator=None) -> torch.Tensor: + """Gaussian init latents ``[B, C, H/vae, W/vae]`` (fp32).""" + c = self.transformer.in_channels + vae_scale = self.vae_scale_factor * 2 + shape = (batch_size, c, height // vae_scale, width // vae_scale) + return torch.randn(shape, generator=generator, device=self.device, dtype=torch.float32) + + def build_cap_feats( + self, + thinker_hidden_states: torch.Tensor, + negative_hidden: torch.Tensor | None = None, + byte5_texts: list[str] | None = None, + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Run the condition encoder (+ optional ByT5) → (pos, neg) embed lists. + + Negatives default to zeros (Ming's CFG convention) unless explicit + negative thinker states are supplied. ByT5 glyph features are appended + along the sequence dim; the negative side gets zeros for that span so + CFG doesn't push away from rendered text. + """ + if thinker_hidden_states.dim() == 2: + thinker_hidden_states = thinker_hidden_states.unsqueeze(0) + cap_feats = self.condition_encoder(thinker_hidden_states) + + negative_cap_feats = None + if negative_hidden is not None: + if negative_hidden.dim() == 2: + negative_hidden = negative_hidden.unsqueeze(0) + negative_cap_feats = self.condition_encoder(negative_hidden) + + if byte5_texts and self.byte5 is not None: + byte5_feats = self.byte5(byte5_texts).to(cap_feats.dtype) + cap_feats = torch.cat((cap_feats, byte5_feats), dim=1) + if negative_cap_feats is not None: + negative_cap_feats = torch.cat((negative_cap_feats, torch.zeros_like(byte5_feats)), dim=1) + + prompt_embeds = [cap_feats[i] for i in range(cap_feats.shape[0])] + if negative_cap_feats is not None: + negative_prompt_embeds = [negative_cap_feats[i] for i in range(negative_cap_feats.shape[0])] + else: + negative_prompt_embeds = [self.condition_encoder.zero_negative(e) for e in prompt_embeds] + return prompt_embeds, negative_prompt_embeds + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + """Un-shift/un-scale then VAE-decode to a ``[B, 3, H, W]`` image in [-1,1].""" + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + return self.vae.decode(latents, return_dict=False)[0] + + @torch.inference_mode() + def generate( + self, + thinker_hidden_states: torch.Tensor, + params: MingImageGenSamplingParams, + *, + negative_hidden: torch.Tensor | None = None, + byte5_texts: list[str] | None = None, + ) -> torch.Tensor: + """End-to-end text-to-image: condition → denoise → VAE decode.""" + generator = None + if params.seed is not None: + generator = torch.Generator(device=self.device).manual_seed(int(params.seed)) + + prompt_embeds, negative_prompt_embeds = self.build_cap_feats( + thinker_hidden_states, negative_hidden, byte5_texts + ) + latents = self.prepare_latents(len(prompt_embeds), params.height, params.width, generator) + + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + self.scheduler.set_timesteps(params.num_inference_steps, device=self.device, mu=mu) + + latents = self.denoiser.denoise( + latents, + self.scheduler.timesteps, + prompt_embeds, + negative_prompt_embeds, + guidance_scale=params.guidance_scale, + cfg_truncation=params.cfg_truncation, + cfg_normalization=params.cfg_normalization, + ) + return self.decode_latents(latents) + + +__all__ = [ + "MingImageDenoiser", + "MingImageGenSamplingParams", + "MingImagePipeline", + "calculate_shift", + "combine_cfg", +] diff --git a/mstar/model/ming_omni_flash/components/model.py b/mstar/model/ming_omni_flash/components/model.py new file mode 100644 index 00000000..86b7bf69 --- /dev/null +++ b/mstar/model/ming_omni_flash/components/model.py @@ -0,0 +1,215 @@ +"""Ling-2.0 thinker LLM (full forward, no KV cache yet). + +Composes :class:`LingDecoderLayer` × N with a shared rope, vocab +embedding, final RMSNorm, and an untied lm_head. The shape downstream +mstar code will eventually wrap is one of these :class:`LingMoeModel` +instances behind a :class:`NodeSubmodule` (step 3c). + +Reference structure: vllm-omni's :class:`BailingMoeV2Model` + +:class:`BailingMoeV2ForCausalLM` +``/tmp/vllm-omni/.../modeling_bailing_moe_v2.py:662-895``. +""" + +from __future__ import annotations + +import torch +from torch import nn + +from mstar.distributed.communication import TPCommGroup +from mstar.model.components.norm import RMSNorm +from mstar.model.ming_omni_flash.components.decoder_layer import ( + LingDecoderLayer, +) +from mstar.model.ming_omni_flash.components.rope import ( + LingPartialMRotaryEmbedding, +) + + +class LingMoeModel(nn.Module): + """Full Ling-2.0 thinker forward (embed + layers + lm_head). + + All shape-relevant config flattens into the constructor so callers + don't need a :class:`MingFlashOmniModelConfig` instance — useful for + small-dim unit tests. The eventual mstar submodule (step 3c) builds + one of these from the real config. + + Args (all required, but small-dim test configs only need plausible + values; nothing here is hard-coded to Ming-specific dims): + vocab_size: e.g. 157184 on released ckpt. + hidden_size: e.g. 4096. + intermediate_size: dense layer-0 MLP intermediate; e.g. 9216. + moe_intermediate_size: per-expert intermediate; e.g. 1024. + num_hidden_layers: e.g. 32. + num_attention_heads, num_kv_heads, head_dim: e.g. 32 / 4 / 128. + rms_norm_eps: 1e-6. + rope_theta: 2_400_000. + max_position_embeddings: 32768. + partial_rotary_factor: 0.5. + mrope_section: [8, 12, 12]. + num_experts: 256. + num_experts_per_tok: 8. + num_shared_experts: 1. + n_group: 8. + topk_group: 4. + routed_scaling_factor: 2.5. + first_k_dense_replace: 1. + tie_word_embeddings: False on released ckpt — lm_head is a + separate matrix from embed_tokens. + """ + + def __init__( + self, + *, + vocab_size: int, + hidden_size: int, + intermediate_size: int, + moe_intermediate_size: int, + num_hidden_layers: int, + num_attention_heads: int, + num_kv_heads: int, + head_dim: int, + rms_norm_eps: float, + rope_theta: float, + max_position_embeddings: int, + partial_rotary_factor: float, + mrope_section: list[int], + num_experts: int, + num_experts_per_tok: int, + num_shared_experts: int, + n_group: int, + topk_group: int, + routed_scaling_factor: float, + first_k_dense_replace: int, + tie_word_embeddings: bool = False, + use_qkv_bias: bool = False, + use_bias: bool = False, + comm_group: TPCommGroup | None = None, + ) -> None: + super().__init__() + if comm_group is None: + comm_group = TPCommGroup.trivial() + self.comm_group = comm_group + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + + # embed_tokens + lm_head stay replicated. At hidden_size=4096 + # they're 1.3 GB each — cheap compared to the layers. + self.embed_tokens = nn.Embedding(vocab_size, hidden_size) + + # Single rotary instance shared across every layer — inv_freq is + # config-only, no per-layer state. + rotary = LingPartialMRotaryEmbedding( + head_dim=head_dim, + partial_rotary_factor=partial_rotary_factor, + mrope_section=mrope_section, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + ) + + self.layers = nn.ModuleList([ + LingDecoderLayer( + layer_idx=i, + first_k_dense_replace=first_k_dense_replace, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + moe_intermediate_size=moe_intermediate_size, + num_attention_heads=num_attention_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + num_shared_experts=num_shared_experts, + n_group=n_group, + topk_group=topk_group, + routed_scaling_factor=routed_scaling_factor, + rotary=rotary, + use_qkv_bias=use_qkv_bias, + use_bias=use_bias, + comm_group=comm_group, + ) + for i in range(num_hidden_layers) + ]) + + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) + self.tie_word_embeddings = tie_word_embeddings + if tie_word_embeddings: + self.lm_head.weight = self.embed_tokens.weight + + def forward( + self, + cache_handle, + input_ids: torch.Tensor | None = None, + input_embeds: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + image_mask: torch.Tensor | None = None, + audio_mask: torch.Tensor | None = None, + return_hidden_states: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """Run the full thinker forward. + + Args: + cache_handle: :class:`BatchedCacheManager` from the engine + (or a unit-test mock with ``set_layer_idx`` + + ``run_attention``). Required — the attention layer + writes K/V to its paged cache and runs FlashInfer + attention against it. + input_ids: ``(T,)`` token ids — if provided, ``embed_tokens`` + turns them into embeddings. + input_embeds: ``(T, hidden_size)`` precomputed embeddings — + used directly (multimodal callers pass this with vision / + audio embeddings already spliced in). + position_ids: ``(T,)`` for 1D rope, or ``(3, T)`` for 3D + video_rope. Defaults to ``torch.arange(T)`` if None. + image_mask, audio_mask: per-token modality masks for + :class:`LingMoeBlock`. ``None`` ⇒ all text routing. + + return_hidden_states: when True, also return the post-norm + hidden states ``(T, hidden_size)`` as a second tuple element. + The image-gen path (step 9b) needs these at the + ```` query-token positions to condition the DiT — + ``lm_head`` logits are irrelevant there. + + Returns: + ``(T, vocab_size)`` logits by default. The caller (the submodule) + slices the last position for next-token sampling. When + ``return_hidden_states`` is True, returns + ``(logits, hidden_states)`` where ``hidden_states`` is the + post-norm ``(T, hidden_size)`` tensor. + """ + if (input_ids is None) == (input_embeds is None): + raise ValueError( + "Exactly one of input_ids / input_embeds must be provided" + ) + + if input_embeds is None: + assert input_ids is not None + h = self.embed_tokens(input_ids) + else: + h = input_embeds + + if h.dim() != 2: + raise ValueError( + f"LingMoeModel expects packed (T, hidden) input; got " + f"shape {tuple(h.shape)}." + ) + + T = h.shape[0] + if position_ids is None: + position_ids = torch.arange(T, device=h.device) + + for layer_idx, layer in enumerate(self.layers): + cache_handle.set_layer_idx(layer_idx) + h = layer( + h, cache_handle, position_ids, + image_mask=image_mask, + audio_mask=audio_mask, + ) + + h = self.norm(h) + logits = self.lm_head(h) + if return_hidden_states: + return logits, h + return logits diff --git a/mstar/model/ming_omni_flash/components/moe.py b/mstar/model/ming_omni_flash/components/moe.py new file mode 100644 index 00000000..9d7b5b4c --- /dev/null +++ b/mstar/model/ming_omni_flash/components/moe.py @@ -0,0 +1,303 @@ +"""Ling-2.0 MoE block (TP-aware ``MultiRouter`` flavour). + +Same 3-router text/image/audio gate selection as step 3b, now with +per-rank expert sharding when ``comm_group.world_size > 1``: + + * Fused expert tensors hold ``(E, 2*shard_inter, hidden)`` and + ``(E, hidden, shard_inter)`` per rank, where + ``shard_inter = moe_intermediate_size // tp_size``. + * Mminf's ``_gate_up_weight_loader`` / ``_down_proj_weight_loader`` + handle per-rank slicing during checkpoint load — these get + attached to the params via the ``_attach_weight_loaders`` dance + that survives ``.to_empty`` / ``.to(...)``. + * Shared expert is a ``ParallelGatedMLP`` so its ``down_proj`` + all-reduces internally. + * Forward TP path mirrors :class:`ParallelSparseMoeBlock._dispatch_tp`: + `fused_experts(..., reduce_results=False)` → ``all_reduce`` → + ``moe_sum_reduce_triton``. + +Routers (``LingMoeRouter``) stay replicated across ranks — gates must +make identical decisions so every rank dispatches tokens to the same +experts. + +Reference: vllm-omni's ``BailingMoeV2SparseMoeBlock`` (lines 304-433) ++ mstar's :class:`ParallelSparseMoeBlock` +(`mstar/model/components/moe.py:318-414`). +""" + +from __future__ import annotations + +from functools import partial + +import torch +from torch import nn + +from mstar.distributed.communication import TPCommGroup +from mstar.distributed.utils import divide +from mstar.model.components.distributed.mlp import ParallelGatedMLP +from mstar.model.components.mlp import GatedMLP +from mstar.model.components.moe import ( + _dispatch, + _down_proj_weight_loader, + _gate_up_weight_loader, + dispatch_experts_fused, +) +from mstar.model.ming_omni_flash.components.router import LingMoeRouter + + +def _normalize_modality_mask( + mask: torch.Tensor | None, num_tokens: int, name: str, +) -> torch.Tensor | None: + """Reshape a modality mask to ``(num_tokens, 1)`` bool, or pass through None.""" + if mask is None: + return None + if mask.dim() == 1: + if mask.shape[0] != num_tokens: + raise ValueError( + f"{name} length {mask.shape[0]} != num_tokens={num_tokens}" + ) + return mask.reshape(num_tokens, 1).bool() + if mask.dim() == 2: + if mask.numel() != num_tokens: + raise ValueError( + f"{name} shape {tuple(mask.shape)} has {mask.numel()} elements; " + f"expected num_tokens={num_tokens}" + ) + return mask.reshape(num_tokens, 1).bool() + if mask.dim() == 3: + if mask.shape[-1] != 1 or mask.numel() != num_tokens: + raise ValueError( + f"{name} shape {tuple(mask.shape)} not compatible with " + f"num_tokens={num_tokens}" + ) + return mask.reshape(num_tokens, 1).bool() + raise ValueError( + f"{name} must be 1D, 2D, or 3D; got shape {tuple(mask.shape)}" + ) + + +class LingMoeBlock(nn.Module): + """Ling-2.0 MoE FFN with text/image/audio gate selection per token. + + Constructor takes the FULL ``moe_intermediate_size``; the per-rank + ``shard_inter`` is computed from ``comm_group.world_size``. + + Args: + hidden_size: model hidden dim. + num_experts: total routed experts. + num_experts_per_tok: top-k experts per token. + moe_intermediate_size: per-expert intermediate dim (FULL — + sharding handled internally). + num_shared_experts: number of shared experts (1 on the released + ckpt). The shared expert is a ``ParallelGatedMLP`` of width + ``moe_intermediate_size * num_shared_experts``. + n_group, topk_group, routed_scaling_factor: passed to the + :class:`LingMoeRouter`s. + comm_group: TP comm group; defaults to single-rank trivial. + """ + + def __init__( + self, + hidden_size: int, + num_experts: int, + num_experts_per_tok: int, + moe_intermediate_size: int, + num_shared_experts: int, + n_group: int, + topk_group: int, + routed_scaling_factor: float = 1.0, + comm_group: TPCommGroup | None = None, + ) -> None: + super().__init__() + if comm_group is None: + comm_group = TPCommGroup.trivial() + self.comm_group = comm_group + tp_size = comm_group.world_size + tp_rank = comm_group.rank + + self.hidden_size = hidden_size + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.moe_intermediate_size = moe_intermediate_size + + router_kwargs = dict( + hidden_size=hidden_size, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + n_group=n_group, + topk_group=topk_group, + routed_scaling_factor=routed_scaling_factor, + ) + # Routers — replicated. All ranks must agree on which experts a + # given token routes to, so gate weights are loaded identically + # per rank (default weight_loader, no shard_id). + self.gate = LingMoeRouter(**router_kwargs) + self.image_gate = LingMoeRouter(**router_kwargs) + self.audio_gate = LingMoeRouter(**router_kwargs) + + # Fused expert tensors with per-rank intermediate shard. + shard_inter = divide(moe_intermediate_size, tp_size) + self.experts = nn.Module() + self.experts.gate_up_proj = nn.Parameter( + torch.empty(num_experts, 2 * shard_inter, hidden_size) + ) + self.experts.down_proj = nn.Parameter( + torch.empty(num_experts, hidden_size, shard_inter) + ) + + # Shared expert: ParallelGatedMLP. Its down_proj all-reduces, so + # the shared output already lives on the full hidden state at + # every rank. + if num_shared_experts <= 0: + raise ValueError( + "LingMoeBlock requires num_shared_experts >= 1; released " + "Ming-flash-omni-2.0 has 1." + ) + self.shared_expert = ParallelGatedMLP( + comm_group=comm_group, + hidden_size=hidden_size, + intermediate_size=moe_intermediate_size * num_shared_experts, + bias=False, + ) + + self._attach_weight_loaders(tp_rank, tp_size, moe_intermediate_size) + + # ------------------------------------------------------------------ + # Weight loader plumbing — mirrors ParallelSparseMoeBlock + # ------------------------------------------------------------------ + + def _attach_weight_loaders( + self, tp_rank: int, tp_size: int, full_inter: int, + ) -> None: + """Attach mstar's per-rank fused-expert weight loaders. + + The loaders accept shard ids ``"gate:N"``, ``"up:N"``, ``"down:N"`` + and slice along the intermediate dim per rank, then write into + the right expert slot. ``load_hf_weights`` dispatches based on + the ``StackedParamRule.shard_id`` we configure in the loader. + """ + self.experts.gate_up_proj.weight_loader = partial( + _gate_up_weight_loader, tp_rank, tp_size, full_inter, + ) + self.experts.down_proj.weight_loader = partial( + _down_proj_weight_loader, tp_rank, tp_size, full_inter, + ) + + def _apply(self, fn, recurse=True): + """Re-attach loaders after any ``to_empty`` / ``.to(...)`` since + those operations re-allocate Parameters and drop attached + attributes on the old objects.""" + result = super()._apply(fn, recurse=recurse) + self._attach_weight_loaders( + self.comm_group.rank, + self.comm_group.world_size, + self.moe_intermediate_size, + ) + return result + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward( + self, + hidden_states: torch.Tensor, + image_mask: torch.Tensor | None = None, + audio_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Route + dispatch + add shared expert output. + + TP=1 path uses the direct ``_dispatch`` helper (mstar's + triton-fused or naive loop depending on availability). TP>1 + path uses the unreduced fused_experts call + manual all-reduce + + sum-reduce — mirrors :class:`ParallelSparseMoeBlock._dispatch_tp`. + """ + input_shape = hidden_states.shape + flat = hidden_states.view(-1, hidden_states.shape[-1]).contiguous() + num_tokens = flat.shape[0] + + # Text-gate baseline routing (always computed). + _, topk_weight, topk_idx = self.gate(flat) + + image_mask = _normalize_modality_mask(image_mask, num_tokens, "image_mask") + audio_mask = _normalize_modality_mask(audio_mask, num_tokens, "audio_mask") + + if image_mask is not None: + _, img_w, img_idx = self.image_gate(flat) + topk_idx = torch.where(image_mask, img_idx, topk_idx) + topk_weight = torch.where(image_mask, img_w, topk_weight) + if audio_mask is not None: + _, aud_w, aud_idx = self.audio_gate(flat) + topk_idx = torch.where(audio_mask, aud_idx, topk_idx) + topk_weight = torch.where(audio_mask, aud_w, topk_weight) + + if self.comm_group.world_size == 1: + routed = _dispatch( + flat, + self.experts.gate_up_proj, + self.experts.down_proj, + self.num_experts, + topk_idx, + topk_weight, + ) + else: + routed = self._dispatch_tp(flat, topk_weight, topk_idx) + + shared = self.shared_expert(flat) + # Upstream sums routed + shared without an additional gate + # (BailingMoeV2SparseMoeBlock.forward:429). The + # routed_scaling_factor is baked into topk_weight via the router. + return (routed + shared).view(input_shape) + + def _dispatch_tp( + self, + flat: torch.Tensor, + routing_weights: torch.Tensor, + selected_experts: torch.Tensor, + ) -> torch.Tensor: + """TP>1 expert dispatch. + + Identical to :func:`ParallelSparseMoeBlock._dispatch_tp` — runs + fused_experts WITHOUT the final per-token reduce, all-reduces + the per-rank partial results across TP ranks, then sum-reduces + across top-k. Result is the full-precision routed output at + every rank. + + Falls back to the naive per-expert loop in + :func:`dispatch_experts_fused` when ``sgl_kernel`` isn't loadable + (e.g. ABI-mismatched against the installed torch). The naive path + already returns ``(tokens, hidden)`` summed across top-k, so we + all-reduce that directly — math is equivalent because sum-over-TP + and sum-over-top-k commute. + """ + from mstar.utils.fused_moe.align import has_sgl_kernel + + if has_sgl_kernel(): + from mstar.utils.fused_moe import fused_experts, moe_sum_reduce_triton + + cache3 = fused_experts( + flat, + self.experts.gate_up_proj, + self.experts.down_proj, + routing_weights, + selected_experts, + reduce_results=False, + ) + self.comm_group.all_reduce(cache3) + output = torch.empty_like(flat) + moe_sum_reduce_triton(cache3, output, routed_scaling_factor=1.0) + return output + + partial = dispatch_experts_fused( + flat, + self.experts.gate_up_proj, + self.experts.down_proj, + self.experts.gate_up_proj.shape[0], + selected_experts, + routing_weights, + ) + self.comm_group.all_reduce(partial) + return partial + + +__all__ = ["LingMoeBlock", "GatedMLP"] # GatedMLP re-export for back-compat diff --git a/mstar/model/ming_omni_flash/components/positions.py b/mstar/model/ming_omni_flash/components/positions.py new file mode 100644 index 00000000..5f14d221 --- /dev/null +++ b/mstar/model/ming_omni_flash/components/positions.py @@ -0,0 +1,209 @@ +"""3D MRoPE position-id helpers for Ming-flash-omni-2.0. + +Ming-flash-omni-2.0 uses partial 3D MRoPE +(`mrope_section=[8, 12, 12]`, `partial_rotary_factor=0.5`) in the +``video_rope`` layout. The cos/sin remap lives in +:class:`mstar.model.ming_omni_flash.components.rope.LingPartialMRotaryEmbedding`; +this module produces the *position-id* tensors that feed into it. + +Three helpers cover the modality-specific position layouts used by the +Thinker prefill walks: + + * :func:`get_rope_index_text` — pure-text span (sentinels included). + * :func:`get_rope_index_audio` — audio embeddings (treated as text + positions per ``modeling_bailing_moe_v2.get_rope_index``, which + only special-cases ``image_*`` / ``video_*`` tokens). + * :func:`get_rope_index_vision` — image (or video) embeddings with + grid-aware T/H/W position ids per + ``modeling_bailing_moe_v2.get_rope_index:592-647``. + +All three return ``(3, seq_len)`` tensors with rows ``[t, h, w]``; +the rope module's ``video_rope`` remap will pick out H/W on even/odd +spatial slots and T on the temporal tail (see +``LingPartialMRotaryEmbedding._cos_sin_3d_video_rope`` for the layout). +""" + +from __future__ import annotations + +import torch + + +def get_rope_index_text( + seq_len: int, + start_pos: int | float, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.long, +) -> torch.Tensor: + """3D MRoPE positions for a pure-text span. + + All three (T, H, W) components share the same sequential positions + ``[start_pos, start_pos+1, ..., start_pos+seq_len-1]``. This matches + the pure-text branch of ``modeling_bailing_moe_v2.get_rope_index`` + (`./modeling_bailing_moe_v2.py:658-675`). + + Args: + seq_len: number of tokens in this span. + start_pos: position offset for the first token. + device: target device. + dtype: integer dtype for the position ids (rope module + casts to float internally; long matches the upstream). + + Returns: + ``(3, seq_len)`` tensor. + """ + positions = torch.arange(seq_len, dtype=dtype, device=device) + int(start_pos) + return positions.unsqueeze(0).expand(3, -1).contiguous() + + +def get_rope_index_audio( + num_audio_tokens: int, + start_pos: int | float, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.long, +) -> torch.Tensor: + """3D MRoPE positions for an audio span. + + Ming's `get_rope_index` does NOT special-case audio: audio tokens + advance through the same per-token position counter as text. Each + audio token contributes one position; T/H/W all match. Audio + semantics live in the audio encoder + projector (which already + down-sample to one embedding per LLM-time-step). + + Args: + num_audio_tokens: number of audio embeddings (after the + projector's conv1d down-sample). + start_pos: position offset for the first audio embedding. + device: target device. + dtype: integer dtype for position ids. + + Returns: + ``(3, num_audio_tokens)`` tensor, identical rows. + """ + return get_rope_index_text(num_audio_tokens, start_pos, device=device, dtype=dtype) + + +def get_rope_index_vision( + grid_thw: torch.Tensor, + start_pos: int | float, + spatial_merge_size: int, + device: torch.device | str | None = None, + second_per_grid_t: float | None = None, + tokens_per_second: int = 2, + dtype: torch.dtype = torch.long, +) -> torch.Tensor: + """3D MRoPE positions for a vision span (single image or video). + + Mirrors `modeling_bailing_moe_v2.get_rope_index:625-647` for one + image: + + * Temporal: ``arange(grid_t)`` expanded across ``H*W``, optionally + scaled by ``second_per_grid_t * tokens_per_second`` + for absolute video timestamps. + * Height: ``arange(llm_grid_h)`` expanded across ``T * W``. + * Width: ``arange(llm_grid_w)`` expanded across ``T * H``. + + ``llm_grid_h = grid_h // spatial_merge_size`` (same for W). All + three components are offset by ``start_pos`` so the span fits into + the global position-id counter the caller is tracking. + + Multi-image / video frames concatenate across images by calling + this helper per image and stitching the results — see + :func:`stitch_vision_positions` (or the dispatch in + `BailingMoeV2ThinkerSubmodule.prepare_inputs`). + + Args: + grid_thw: ``(3,)`` long tensor of (T, H, W) grid sizes. + start_pos: position offset for this image's first token. + spatial_merge_size: from `VisionEncoderConfig.spatial_merge_size` + (= 2 on the released ckpt). + device: target device. + second_per_grid_t: when set, multiply the temporal component by + ``second_per_grid_t * tokens_per_second`` for absolute video + timestamps. None ⇒ raw frame index. Image inputs always pass + None; video inputs pass the per-clip frame interval. + tokens_per_second: temporal-resolution multiplier + (= 2 on the released ckpt; mirrors ``config.tokens_per_second``). + dtype: integer dtype for position ids. + + Returns: + ``(3, grid_t * (H/m) * (W/m))`` tensor of T/H/W positions + offset by ``start_pos``. + """ + if grid_thw.dim() != 1 or grid_thw.numel() != 3: + raise ValueError( + f"grid_thw must be a 1-D tensor of length 3 (T, H, W); " + f"got shape {tuple(grid_thw.shape)}" + ) + grid_t = int(grid_thw[0].item()) + grid_h = int(grid_thw[1].item()) + grid_w = int(grid_thw[2].item()) + if grid_h % spatial_merge_size != 0 or grid_w % spatial_merge_size != 0: + raise ValueError( + f"grid_h={grid_h} / grid_w={grid_w} not divisible by " + f"spatial_merge_size={spatial_merge_size}." + ) + llm_grid_h = grid_h // spatial_merge_size + llm_grid_w = grid_w // spatial_merge_size + + # Temporal: arange(grid_t), expanded across H*W, optionally absolute time. + range_t = torch.arange(grid_t, dtype=dtype, device=device).view(-1, 1) + expanded_t = range_t.expand(-1, llm_grid_h * llm_grid_w) + if second_per_grid_t is not None: + # Float math then back to int (matches modeling_bailing_moe_v2 path). + t_index = ( + expanded_t.float() * float(second_per_grid_t) * float(tokens_per_second) + ).to(dtype).flatten() + else: + t_index = expanded_t.flatten() + + h_index = ( + torch.arange(llm_grid_h, dtype=dtype, device=device) + .view(1, -1, 1) + .expand(grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w, dtype=dtype, device=device) + .view(1, 1, -1) + .expand(grid_t, llm_grid_h, -1) + .flatten() + ) + return torch.stack([t_index, h_index, w_index], dim=0) + int(start_pos) + + +def vision_span_max_position( + grid_thw: torch.Tensor, + start_pos: int | float, + spatial_merge_size: int, + second_per_grid_t: float | None = None, + tokens_per_second: int = 2, +) -> int: + """Compute one past the largest position id this vision span produces. + + Useful for advancing the global ``start_pos`` counter past a vision + span when the next walk needs to know where text positions resume + (mirrors ``modeling_bailing_moe_v2.get_rope_index``'s + ``llm_pos_ids_list[-1].max() + 1`` accounting at the end of an + image span). + """ + grid_t = int(grid_thw[0].item()) + grid_h = int(grid_thw[1].item()) + grid_w = int(grid_thw[2].item()) + llm_grid_h = grid_h // spatial_merge_size + llm_grid_w = grid_w // spatial_merge_size + + if second_per_grid_t is not None: + max_t = int((grid_t - 1) * float(second_per_grid_t) * float(tokens_per_second)) + else: + max_t = grid_t - 1 + max_h = llm_grid_h - 1 + max_w = llm_grid_w - 1 + return int(start_pos) + max(max_t, max_h, max_w) + 1 + + +__all__ = [ + "get_rope_index_text", + "get_rope_index_audio", + "get_rope_index_vision", + "vision_span_max_position", +] diff --git a/mstar/model/ming_omni_flash/components/projectors.py b/mstar/model/ming_omni_flash/components/projectors.py new file mode 100644 index 00000000..6a02323f --- /dev/null +++ b/mstar/model/ming_omni_flash/components/projectors.py @@ -0,0 +1,165 @@ +"""Vision + audio projectors for Ming-flash-omni-2.0. + +Ports the two ``nn.Sequential`` blocks built inline in +``modeling_bailingmm2.py:BailingMM2NativeForConditionalGeneration.__init__`` +(lines 66-88 of the Ming source repo) into standalone modules that mstar +can load weights into directly. The released checkpoint stores the +weights under the top-level prefixes ``linear_proj.*`` (vision) and +``linear_proj_audio.*`` (audio): + + * Vision (mlp_depth=2): + linear_proj.0.{weight,bias} -> Linear(vision_out_hidden, llm_hidden) + [GELU at index 1, no params] + linear_proj.2.{weight,bias} -> Linear(llm_hidden, llm_hidden) + + * Audio (mlp_depth=2): + linear_proj_audio.0.{weight,bias} -> Conv1d(audio_d_model, llm_hidden, ds_kernel_size, ds_stride) + [Transpose at index 1, GELU at index 2, no params] + linear_proj_audio.3.{weight,bias} -> Linear(llm_hidden, llm_hidden) + [Transpose at index 4, no params] + +We mirror the upstream layer ordering exactly so the +``linear_proj.*`` / ``linear_proj_audio.*`` keys from the checkpoint land +on the right ``nn.Module`` slot via plain index-based lookup. +""" + +from __future__ import annotations + +import torch +from torch import nn + + +class _Transpose(nn.Module): + """Used inside ``nn.Sequential`` chains (modeling_utils.py:Transpose).""" + + def __init__(self, dim0: int, dim1: int) -> None: + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.transpose(self.dim0, self.dim1) + + +class MingVisionProjector(nn.Module): + """MLP projector: vision encoder output -> LLM hidden space. + + Args: + vision_dim: ``VisionEncoderConfig.out_hidden_size`` (4096 on the + released ckpt — the vision encoder already projects internally + via its ``PatchMerger``). + llm_dim: ``ThinkerLLMConfig.hidden_size`` (4096). + mlp_depth: ``MingFlashOmniModelConfig.mlp_depth`` (2 on the + released ckpt). depth=1 yields a single Linear; depth=N adds + (N-1) GELU+Linear pairs after it. + """ + + def __init__(self, vision_dim: int, llm_dim: int, mlp_depth: int = 2) -> None: + super().__init__() + if mlp_depth < 1: + raise ValueError(f"mlp_depth must be >= 1, got {mlp_depth}") + layers: list[nn.Module] = [nn.Linear(vision_dim, llm_dim)] + for _ in range(1, mlp_depth): + layers.append(nn.GELU()) + layers.append(nn.Linear(llm_dim, llm_dim)) + # Expose as ``proj`` (not raw ``nn.Sequential``) so subclassing / + # surgery has a stable name. Weight loading walks ``proj..*``. + self.proj = nn.Sequential(*layers) + + def forward(self, vision_embeds: torch.Tensor) -> torch.Tensor: + """Project vision tokens. + + Args: + vision_embeds: (N_tokens, vision_dim) or (B, N_tokens, vision_dim). + + Returns: + Same shape with the last dim replaced by ``llm_dim``. + """ + return self.proj(vision_embeds) + + +class MingAudioProjector(nn.Module): + """Conv1d-downsample + MLP projector: Whisper encoder -> LLM hidden space. + + Layer ordering matches ``modeling_bailingmm2.py`` exactly so the + released ckpt's ``linear_proj_audio.0`` / ``.3`` keys hit the Conv1d + and Linear by integer index. + + Args: + audio_dim: ``AudioEncoderConfig.d_model`` (= whisper n_state, + 1280 on the released ckpt). + llm_dim: ``ThinkerLLMConfig.hidden_size``. + ds_kernel_size: temporal kernel for the down-sample conv (3 on + the released ckpt). + ds_stride: temporal stride (2 on the released ckpt). + mlp_depth: ``MingFlashOmniModelConfig.mlp_depth`` (2 on the + released ckpt; depth=N adds (N-1) GELU+Linear pairs + after the conv). + """ + + def __init__( + self, + audio_dim: int, + llm_dim: int, + ds_kernel_size: int = 3, + ds_stride: int = 2, + mlp_depth: int = 2, + ) -> None: + super().__init__() + if mlp_depth < 1: + raise ValueError(f"mlp_depth must be >= 1, got {mlp_depth}") + self.ds_kernel_size = ds_kernel_size + self.ds_stride = ds_stride + self.audio_dim = audio_dim + self.llm_dim = llm_dim + + layers: list[nn.Module] = [ + nn.Conv1d( + audio_dim, + llm_dim, + kernel_size=ds_kernel_size, + stride=ds_stride, + padding=ds_kernel_size // 2, + ), + # Conv1d output is (B, llm_dim, T'); MLP wants (B, T', llm_dim). + _Transpose(-1, -2), + ] + for _ in range(1, mlp_depth): + layers.append(nn.GELU()) + layers.append(nn.Linear(llm_dim, llm_dim)) + # Trailing transpose flips back to (B, llm_dim, T') — that's the + # shape upstream callers expect after the projector. + layers.append(_Transpose(-1, -2)) + self.proj = nn.Sequential(*layers) + + def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: + """Project a packed (B, T, audio_dim) tensor. + + Args: + audio_embeds: (B, T, audio_dim) Whisper encoder output, channels-last. + + Returns: + (B, llm_dim, T') tensor, where + ``T' = (T - ds_kernel_size + 2*(ds_kernel_size//2)) // ds_stride + 1``. + """ + # Conv1d expects (B, C, T) — flip first. + x = audio_embeds.transpose(-1, -2) + return self.proj(x) + + def compute_output_length(self, input_length: torch.Tensor) -> torch.Tensor: + """Output sequence length after Whisper conv stems + this projector. + + Mirrors :func:`projectors.AudioProjector.compute_output_length` from + vllm-omni: the Whisper encoder has two fixed Conv1d stems (kernel=3, + stride=2 then stride=1 -> see ``whisper_encoder``); we then apply + ``Conv1d(ds_kernel_size, ds_stride)``. The Whisper stem formula + ``(L - 3 + 2) // 2 + 1`` applies once, then the projector conv. + """ + # Whisper encoder stem (conv1: kernel=3, pad=1, stride=2) + length = (input_length - 3 + 2 * 1) // 2 + 1 + # Projector conv (kernel=ds_kernel_size, pad=ds_kernel_size//2, stride=ds_stride) + length = (length - self.ds_kernel_size + 2 * (self.ds_kernel_size // 2)) // self.ds_stride + 1 + return length + + +__all__ = ["MingVisionProjector", "MingAudioProjector"] diff --git a/mstar/model/ming_omni_flash/components/prompt_utils.py b/mstar/model/ming_omni_flash/components/prompt_utils.py new file mode 100644 index 00000000..f54bfd0d --- /dev/null +++ b/mstar/model/ming_omni_flash/components/prompt_utils.py @@ -0,0 +1,130 @@ +"""Ming-flash-omni-2.0 prompt utilities (step 8). + +Port of vllm-omni's ``ming_flash_omni/prompt_utils.py``. Two unrelated +helper families share the file because both are tightly coupled to +Ming-specific prompt conventions: + +1. **Image-gen query-token expansion** — string-level helpers that mark + the ``*N`` block the thinker substitutes + with learnable image-gen query embeddings during forward. Used by the + ImageGen path (step 9); included here so the constants live in one + place. + +2. **TTS / talker caption builder** — the JSON caption template + merge + helper for the standalone ``ming_flash_omni_tts`` talker-only deploy. + Lets the talker accept the same JSON caption shape vllm-omni speaks + (speaker / dialect / style / emotion / BGM controls). +""" + +from __future__ import annotations + +import copy +import json +from typing import Any + +# ============================================================ +# Image-gen query-token block (thinker stage — used by step 9) +# ============================================================ + +_IMAGE_OPEN_TOKEN = "" +_IMAGE_CLOSE_TOKEN = "" +IMAGE_PATCH_TOKEN = "" + +# Matches ``ImageGenConfig(img_gen_scales=[16])`` → 16*16 = 256 on the +# released inclusionAI/Ming-flash-omni-2.0 checkpoint. +DEFAULT_NUM_QUERY_TOKENS = 256 + + +def maybe_expand_image_gen_prompt( + prompt: str, + num_query_tokens: int = DEFAULT_NUM_QUERY_TOKENS, +) -> str: + """Append the ``*N`` suffix for text-to-image. + + The thinker expects image-generation requests to end with an N-wide + block of ```` tokens (wrapped in ```` / ````) + whose positions get substituted with learnable query embeddings during + forward. + + No-op (returns the input unchanged) when ``prompt`` is not a non-empty + string, or already contains an ```` block (avoids double + expansion). + + Args: + prompt: raw user prompt text. + num_query_tokens: total query tokens to emit (default 256). + """ + if not isinstance(prompt, str) or not prompt: + return prompt + if IMAGE_PATCH_TOKEN in prompt: + return prompt + suffix = _IMAGE_OPEN_TOKEN + (IMAGE_PATCH_TOKEN * num_query_tokens) + _IMAGE_CLOSE_TOKEN + return prompt + suffix + + +# ============================================================ +# TTS / talker caption builder (talker-only deploy) +# ============================================================ + +DEFAULT_PROMPT = "Please generate speech based on the following description.\n" + +# Base caption schema the standalone talker understands. Keys are the +# Ming-native Chinese field names (序号 = index, 说话人 = speaker, +# 方言 = dialect, 风格 = style, 语速 = speed, 基频 = pitch, 音量 = volume, +# 情感 = emotion, BGM = background music block, IP = persona). +BASE_CAPTION_TEMPLATE: dict[str, Any] = { + "audio_sequence": [ + { + "序号": 1, + "说话人": "speaker_1", + "方言": None, + "风格": None, + "语速": None, + "基频": None, + "音量": None, + "情感": None, + "BGM": { + "Genre": None, + "Mood": None, + "Instrument": None, + "Theme": None, + "ENV": None, + "SNR": None, + }, + "IP": None, + } + ] +} + + +def create_instruction(user_input: dict[str, Any]) -> str: + """Return a JSON caption string for ``audio_sequence[0]``. + + Only keys already present on the base template are merged in; unknown + keys are silently ignored so the output schema stays stable (the + talker's prompt parser keys off the exact field set). + + Args: + user_input: partial caption controls, e.g. + ``{"说话人": "speaker_2", "情感": "happy"}``. + + Returns: + A UTF-8 JSON string (``ensure_ascii=False`` to keep the Chinese + field names readable, matching upstream). + """ + caption = copy.deepcopy(BASE_CAPTION_TEMPLATE) + item = caption["audio_sequence"][0] + for key, value in user_input.items(): + if key in item: + item[key] = value + return json.dumps(caption, ensure_ascii=False) + + +__all__ = [ + "IMAGE_PATCH_TOKEN", + "DEFAULT_NUM_QUERY_TOKENS", + "maybe_expand_image_gen_prompt", + "DEFAULT_PROMPT", + "BASE_CAPTION_TEMPLATE", + "create_instruction", +] diff --git a/mstar/model/ming_omni_flash/components/rope.py b/mstar/model/ming_omni_flash/components/rope.py new file mode 100644 index 00000000..64d9c11e --- /dev/null +++ b/mstar/model/ming_omni_flash/components/rope.py @@ -0,0 +1,265 @@ +"""Ling-2.0 partial 3D rotary embeddings (``video_rope`` flavor). + +Ling-2.0's attention uses **partial rotary** (only the first +``head_dim * partial_rotary_factor`` dims of each head are rotated; the rest +pass through unchanged) with **3D MRoPE positions** (time / height / width +each get their own position id) in the ``video_rope`` cos/sin layout. + +The cos/sin layout is the unusual bit. Standard MRoPE places contiguous +frequency sections per axis: + + [ T T ... T H H ... H W W ... W ] (sizes mrope_section = [Nt, Nh, Nw]) + +Ling's ``video_rope`` interleaves H and W element-wise in the spatial +section and puts T at the end: + + [ H W H W ... H W T T ... T ] (sizes hw_size = Nh + Nw, Nt at tail) + +For pure-text positions (1D position_ids, no T/H/W split) the rotation +degenerates to the standard 1D rotary on the first ``rotary_dim`` dims. + +References +---------- +* Ming upstream ``apply_3d_rotary_pos_emb`` + ``/tmp/ming_repo/modeling_bailing_moe_v2.py:226-313`` (video_rope branch + is the ``elif rope_type == "video_rope"`` block). +* vllm-omni ``MingVideoRopeMRotaryEmbedding._remap_video_rope`` + ``/tmp/vllm-omni/vllm_omni/model_executor/models/ming_flash_omni/modeling_bailing_moe_v2.py:79-110`` + — same remap as ours; we port the math without depending on vllm. +""" + +from __future__ import annotations + +import torch +from torch import nn + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Standard neox-style rotary half-rotation: ``[-x2, x1]``.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _build_inv_freq(rotary_dim: int, theta: float) -> torch.Tensor: + """Standard rotary inverse-frequency table: ``theta ** (-2i / rotary_dim)`` for i in [0, rotary_dim/2).""" + return 1.0 / ( + theta ** (torch.arange(0, rotary_dim, 2, dtype=torch.float32) / rotary_dim) + ) + + +class LingPartialMRotaryEmbedding(nn.Module): + """Partial rotary + ``video_rope`` 3D MRoPE. + + Args: + head_dim: full head dim of the attention layer. + partial_rotary_factor: fraction of head_dim that's actually rotated + (the rest is concatenated pass-through). The model uses 0.5; + head_dim=128 → rotary_dim=64. + mrope_section: per-axis cos/sin section sizes. Released ckpt: + ``[8, 12, 12]``. The first is Nt (time), the rest are Nh + (height) and Nw (width); Nh+Nw must equal rotary_dim/2 − Nt + (i.e. the section sums to rotary_dim/2 — see config invariant). + rope_theta: rotary base frequency. Released ckpt: ``2_400_000``. + max_position_embeddings: max sequence length; precomputed cache size. + + The forward expects ``position_ids`` of shape ``(3, num_tokens)`` for + 3D positions or ``(num_tokens,)`` for plain 1D rope (degenerates to + standard rotary). + """ + + def __init__( + self, + head_dim: int, + partial_rotary_factor: float, + mrope_section: list[int], + rope_theta: float, + max_position_embeddings: int, + ) -> None: + super().__init__() + self.head_dim = head_dim + self.rotary_dim = int(head_dim * partial_rotary_factor) + if self.rotary_dim % 2 != 0: + raise ValueError( + f"rotary_dim must be even (got {self.rotary_dim}); check " + f"partial_rotary_factor." + ) + self.mrope_section = list(mrope_section) + if sum(self.mrope_section) != self.rotary_dim // 2: + raise ValueError( + f"sum(mrope_section)={sum(self.mrope_section)} must equal " + f"rotary_dim//2={self.rotary_dim // 2}" + ) + if len(self.mrope_section) != 3: + raise ValueError( + f"mrope_section must be length-3 [Nt, Nh, Nw]; got {self.mrope_section}" + ) + self.hw_size = self.mrope_section[1] + self.mrope_section[2] + + self.rope_theta = float(rope_theta) + self.max_position_embeddings = int(max_position_embeddings) + + # Cache inv_freq once; cos/sin tables are computed on first forward + # (lazy so we don't pay for max_position_embeddings * rotary_dim + # storage on CPU for tests). + self.register_buffer( + "inv_freq", + _build_inv_freq(self.rotary_dim, self.rope_theta), + persistent=False, + ) + + # ------------------------------------------------------------------ + # cos / sin cache + # ------------------------------------------------------------------ + + def _compute_cos_sin( + self, position_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute cos/sin for ``position_ids``. + + ``position_ids`` is ``(num_tokens,)`` or ``(3, num_tokens)``. + Returns ``cos, sin`` of shape ``(num_tokens, rotary_dim)`` in the + video_rope layout (H/W interleaved spatial + T tail). + """ + if position_ids.dim() == 1: + return self._cos_sin_1d(position_ids) + if position_ids.dim() != 2 or position_ids.shape[0] != 3: + raise ValueError( + f"position_ids must be (num_tokens,) or (3, num_tokens); " + f"got shape {tuple(position_ids.shape)}" + ) + return self._cos_sin_3d_video_rope(position_ids) + + def _cos_sin_1d( + self, position_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Standard 1D rotary cos/sin — used for pure-text positions.""" + # (num_tokens, rotary_dim/2) + freqs = position_ids.float().unsqueeze(-1) * self.inv_freq.unsqueeze(0) + # (num_tokens, rotary_dim) — neox style: cat freqs with themselves + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos(), emb.sin() + + def _cos_sin_3d_video_rope( + self, position_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """3D positions → video_rope layout. + + position_ids: ``(3, num_tokens)`` — row 0 = time, row 1 = height, + row 2 = width. + + Steps: + 1. Compute per-axis freqs: ``(3, num_tokens, rotary_dim/2)``. + 2. Form (cos, sin) of shape ``(3, num_tokens, rotary_dim)`` neox-style. + 3. Remap each rotary_dim/2 frequency-pair index ``i`` into: + - i < hw_size → H if i even, W if i odd + - i ≥ hw_size → T + Pairs ``(cos[i], cos[i + rotary_dim/2])`` correspond to the + same frequency, so the same row assignment applies to both + halves. + """ + # (3, num_tokens, rotary_dim/2) + freqs = position_ids.float().unsqueeze(-1) * self.inv_freq.view(1, 1, -1) + # (3, num_tokens, rotary_dim) — neox cat + cos_3d = torch.cat((freqs, freqs), dim=-1).cos() + sin_3d = torch.cat((freqs, freqs), dim=-1).sin() + return self._remap_video_rope(cos_3d, sin_3d) + + def _remap_video_rope( + self, cos_3d: torch.Tensor, sin_3d: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Remap per-axis cos/sin into the video_rope 2D layout. + + cos_3d, sin_3d: ``(3, num_tokens, rotary_dim)``. + Returns: ``(num_tokens, rotary_dim)``. + + Mirror of vllm-omni's ``_remap_video_rope`` with one difference: + we operate on the *full* rotary_dim tables (not the half-tables + chunked from the cos_sin cache), because we never built a cache — + we computed freqs in 1:1 correspondence with positions in the + forward path. The H/W alternation rule still picks the correct + index because each half of the neox-cat repeats the same + frequency. + """ + # Both halves of the rotary_dim (the first and second halves + # contain the same frequencies after the neox cat) get the same + # axis-assignment. So a single index i in [0, rotary_dim/2) picks + # a frequency-pair that should come from one axis. + half = self.rotary_dim // 2 + + result_cos = torch.empty_like(cos_3d[0]) + result_sin = torch.empty_like(sin_3d[0]) + + # Spatial half: H on even indices, W on odd indices, capped at hw_size. + # Then mirror to the second half (which holds the same freqs). + for offset in (0, half): + # H rows go on even positions [0, 2, 4, ...] up to hw_size + result_cos[:, offset : offset + self.hw_size : 2] = cos_3d[ + 1, :, offset : offset + self.hw_size : 2 + ] + result_cos[:, offset + 1 : offset + self.hw_size : 2] = cos_3d[ + 2, :, offset + 1 : offset + self.hw_size : 2 + ] + result_sin[:, offset : offset + self.hw_size : 2] = sin_3d[ + 1, :, offset : offset + self.hw_size : 2 + ] + result_sin[:, offset + 1 : offset + self.hw_size : 2] = sin_3d[ + 2, :, offset + 1 : offset + self.hw_size : 2 + ] + # Temporal tail + result_cos[:, offset + self.hw_size : offset + half] = cos_3d[ + 0, :, offset + self.hw_size : offset + half + ] + result_sin[:, offset + self.hw_size : offset + half] = sin_3d[ + 0, :, offset + self.hw_size : offset + half + ] + return result_cos, result_sin + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward( + self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Rotate the first ``rotary_dim`` dims of q and k in-place. + + Args: + q, k: ``(..., num_tokens, head_dim)`` (typical layout from + ParallelAttention is ``(num_tokens, num_heads, head_dim)``). + Only the last dim and the per-token axis matter. + position_ids: ``(num_tokens,)`` for 1D rope or + ``(3, num_tokens)`` for video_rope. + + Returns: + ``(q, k)`` with rotation applied to the rotary half. + """ + if q.shape[-1] != self.head_dim or k.shape[-1] != self.head_dim: + raise ValueError( + f"q/k last dim {q.shape[-1]}/{k.shape[-1]} != " + f"head_dim {self.head_dim}" + ) + + cos, sin = self._compute_cos_sin(position_ids) + # Broadcast cos/sin across the leading axes of q (typically a + # heads axis comes BEFORE the token axis: q is (..., heads, T, + # head_dim)). cos starts as (T, rotary_dim); we need to insert + # ones at every leading dim of q so the broadcast aligns + # (T at the second-to-last position, rotary_dim at the last). + while cos.dim() < q.dim(): + cos = cos.unsqueeze(0) + sin = sin.unsqueeze(0) + + q_rot, q_pass = q[..., : self.rotary_dim], q[..., self.rotary_dim :] + k_rot, k_pass = k[..., : self.rotary_dim], k[..., self.rotary_dim :] + cos_q = cos.to(q.dtype) + sin_q = sin.to(q.dtype) + cos_k = cos.to(k.dtype) + sin_k = sin.to(k.dtype) + + q_rot = (q_rot * cos_q) + (_rotate_half(q_rot) * sin_q) + k_rot = (k_rot * cos_k) + (_rotate_half(k_rot) * sin_k) + return ( + torch.cat([q_rot, q_pass], dim=-1), + torch.cat([k_rot, k_pass], dim=-1), + ) diff --git a/mstar/model/ming_omni_flash/components/router.py b/mstar/model/ming_omni_flash/components/router.py new file mode 100644 index 00000000..ae6dff1f --- /dev/null +++ b/mstar/model/ming_omni_flash/components/router.py @@ -0,0 +1,159 @@ +"""Ling-2.0 MoE router with grouped expert selection. + +Ling-2.0 (BailingMoeV2) uses ``router_type: "MultiRouter"``, which differs from +mstar's standard :class:`mstar.model.components.moe.TopKRouter` in four ways: + + * **Sigmoid** activation on the gate logits, not softmax. + * A learned per-expert bias added to the routing scores before top-k — + not gradient-trained on this checkpoint (stored as ``requires_grad=False``). + * **Group-limited top-k**: the ``num_experts`` are partitioned into + ``n_group`` groups; tokens may only route to experts within the + ``topk_group`` highest-scoring groups (group score = sum of top-2 + expert scores in that group). This caps cross-group all-to-all + bandwidth at the cost of expressiveness. + * Weights are renormalised to sum to 1 across the chosen top-k and then + multiplied by ``routed_scaling_factor``. + +Returns the same 3-tuple as :class:`TopKRouter` (``logits, weights, indices``) +so it can drop into mstar's existing :class:`SparseMoeBlockWithSharedExpert` +and the fused-Triton dispatch path. + +Reference: vllm-omni's ``BailingMoeV2Gate`` +``/tmp/vllm-omni/vllm_omni/model_executor/models/ming_flash_omni/modeling_bailing_moe_v2.py:211-279`` +and Ming upstream ``modeling_bailing_moe_v2.py:696-765``. +""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch import nn + + +class LingMoeRouter(nn.Module): + """Ling-2.0 ``MultiRouter`` (group-limited top-k with sigmoid + bias). + + Args: + hidden_size: input hidden dimension. + num_experts: total routed experts. Must divide evenly by ``n_group``. + num_experts_per_tok: top-k experts selected per token. + n_group: expert groups; the experts are split contiguously by + ``num_experts // n_group``. + topk_group: how many groups a single token may route into. + routed_scaling_factor: post-renormalisation scale applied to the + top-k weights (matches upstream ``routed_scaling_factor``). + + The gate ``nn.Linear`` weight is **replicated** across TP ranks in the + parallel build (router decisions must be identical across ranks); for + this step-3a unit-test scope we just expose a plain ``nn.Linear``. + """ + + def __init__( + self, + hidden_size: int, + num_experts: int, + num_experts_per_tok: int, + n_group: int, + topk_group: int, + routed_scaling_factor: float = 1.0, + ) -> None: + super().__init__() + if num_experts % n_group != 0: + raise ValueError( + f"num_experts={num_experts} must be divisible by n_group={n_group}" + ) + if topk_group > n_group: + raise ValueError( + f"topk_group={topk_group} cannot exceed n_group={n_group}" + ) + self.hidden_size = hidden_size + self.num_experts = num_experts + self.top_k = num_experts_per_tok + self.n_group = n_group + self.topk_group = topk_group + self.experts_per_group = num_experts // n_group + self.routed_scaling_factor = routed_scaling_factor + + # Gate projection — replicated (no bias). + self.gate = nn.Linear(hidden_size, num_experts, bias=False) + + # Expert bias — not gradient-trained, but stored as a parameter so + # state_dict loaders see it. + self.expert_bias = nn.Parameter( + torch.zeros(num_experts), requires_grad=False, + ) + + def _group_limited_topk(self, scores: torch.Tensor) -> torch.Tensor: + """Pick the top-k experts under the ``topk_group``-best-groups constraint. + + Args: + scores: ``(num_tokens, num_experts)``. Already sigmoid + bias. + + Returns: + ``(num_tokens, top_k)`` int64 expert indices. + + Per-group score = sum of that group's top-2 expert scores. The + ``topk_group`` groups with the highest per-group scores are kept; + the rest are masked out before the final top-k. + """ + num_tokens = scores.size(0) + # (N, n_group, experts_per_group) + grouped = scores.view(num_tokens, self.n_group, self.experts_per_group) + # Per-group score: sum of top-2 expert scores in that group. + # Matches upstream exactly (``.topk(2, dim=-1)[0].sum(dim=-1)``). + group_scores = grouped.topk(2, dim=-1)[0].sum(dim=-1) + # Pick the topk_group best groups. + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1.0) + # Broadcast group mask back across experts_per_group. + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_tokens, self.n_group, self.experts_per_group) + .reshape(num_tokens, -1) + ) + # Mask un-selected groups' experts to -inf so they can't be picked. + masked = scores.masked_fill(~score_mask.bool(), float("-inf")) + return torch.topk(masked, k=self.top_k, dim=-1, sorted=False)[1] + + def forward( + self, hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Route tokens to experts. + + Args: + hidden_states: ``(..., hidden_size)``. Flattened internally. + + Returns: + Three tensors matching :class:`TopKRouter`'s shape: + - ``router_logits``: ``(N, num_experts)`` raw gate logits + (pre-sigmoid). Kept as float32 for stability and parity + with ``TopKRouter``. + - ``routing_weights``: ``(N, top_k)`` normalised + scaled + weights for the chosen experts. + - ``selected_experts``: ``(N, top_k)`` int64 expert indices. + """ + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) + # Linear is rank-replicated; the float() cast matches upstream's + # ``logits = logits.float()`` for numeric stability. + logits = F.linear(hidden_states, self.gate.weight).float() + # Per-expert sigmoid (NOT softmax). Bias is added AFTER sigmoid + # in the routing path; the gathered weights below pull from the + # un-biased sigmoid scores. + sigmoid_scores = torch.sigmoid(logits) + scored_for_routing = sigmoid_scores + self.expert_bias + + selected_experts = self._group_limited_topk(scored_for_routing) + # Gather the un-biased sigmoid score for the chosen experts. + chosen_scores = torch.gather( + sigmoid_scores, dim=1, index=selected_experts, + ).to(logits.dtype) + if self.top_k > 1: + chosen_scores = chosen_scores / ( + chosen_scores.sum(dim=-1, keepdim=True) + 1e-20 + ) + routing_weights = chosen_scores * self.routed_scaling_factor + + return logits, routing_weights, selected_experts diff --git a/mstar/model/ming_omni_flash/components/t5_block_mapper.py b/mstar/model/ming_omni_flash/components/t5_block_mapper.py new file mode 100644 index 00000000..25373e37 --- /dev/null +++ b/mstar/model/ming_omni_flash/components/t5_block_mapper.py @@ -0,0 +1,129 @@ +"""T5EncoderBlockByT5Mapper — Ming's per-block T5 stack mapping byt5 features +onto the DiT condition space. + +Native mstar port of vllm-omni's ``t5_block_mapper.py``. The upstream version +builds on vllm-omni's TP-fused ``T5Block`` (fused ``qkv_proj`` / ``wi``) and +therefore needs a stacked-weight remap at load time. We instead build on +HuggingFace's stock ``T5Block``, whose submodule layout (``SelfAttention.q/k/v/o`` ++ ``DenseReluDense.wi_0/wi_1/wo``) is byte-for-byte what Ming's +``byt5_mapper.pt`` ships — so the checkpoint loads with a plain +``load_state_dict`` (no fused mapping). This keeps the port pure-torch + stock +transformers, consistent with the rest of the mstar modeling tree. + +The mapper stacks ``num_layers`` encoder blocks on the byt5 features, RMSNorms, +then projects ``d_model -> sdxl_channels`` (Ming's ``diffusion_c_input_dim``). +""" + +from __future__ import annotations + +from collections.abc import Iterable + +import torch +from torch import nn +from transformers.models.t5.modeling_t5 import T5Block, T5LayerNorm + + +class T5EncoderBlockByT5Mapper(nn.Module): + """Stacks ``num_layers`` HF T5 encoder blocks on top of byt5 features and + projects them to ``sdxl_channels``. + + Args: + byte5_config: an HF ``T5Config`` (``text_encoder.config`` from the + loaded byt5 backbone). Supplies ``d_model`` / ``num_heads`` / + ``layer_norm_epsilon`` / relative-attention knobs. + num_layers: number of T5 encoder blocks to stack (0 ⇒ norm + project + only). Only the first block carries the relative-attention bias; + the rest reuse the position_bias it emits (standard T5 weight + sharing). + sdxl_channels: output projection width. ``None`` ⇒ no projection + (returns ``d_model``-wide features after the first RMSNorm). + """ + + def __init__(self, byte5_config, num_layers: int, sdxl_channels: int | None = None) -> None: + super().__init__() + if num_layers > 0: + self.blocks = nn.ModuleList( + [ + T5Block(byte5_config, has_relative_attention_bias=(i == 0)) + for i in range(num_layers) + ] + ) + else: + self.blocks = None + self.layer_norm = T5LayerNorm(byte5_config.d_model, eps=byte5_config.layer_norm_epsilon) + if sdxl_channels is not None: + self.channel_mapper = nn.Linear(byte5_config.d_model, sdxl_channels) + self.final_layer_norm = T5LayerNorm(sdxl_channels, eps=byte5_config.layer_norm_epsilon) + else: + self.channel_mapper = None + self.final_layer_norm = None + + @staticmethod + def get_extended_attention_mask(attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + """Turn a {0,1} pad mask into an additive (-inf on pad) attention bias. + + Mirrors the upstream helper: accepts a 2-D ``[B, S]`` or pre-broadcast + 3-D ``[B, S, S]`` mask and returns ``[B, 1, *, S]`` with ``0`` on keep + positions and ``finfo.min`` on pad positions, ready to add to the + attention logits inside ``T5Block``. + """ + if attention_mask.dim() == 3: + extended = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + extended = attention_mask[:, None, None, :] + else: + raise ValueError(f"Unexpected attention_mask shape {tuple(attention_mask.shape)}") + extended = extended.to(dtype=dtype) + return (1.0 - extended) * torch.finfo(dtype).min + + def forward(self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + extended_mask = self.get_extended_attention_mask(attention_mask, dtype=inputs_embeds.dtype) + + hidden_states = inputs_embeds + position_bias = None + + if self.blocks is not None: + for block in self.blocks: + # HF T5Block returns (hidden_states, position_bias) with + # use_cache=False; the first block computes position_bias from + # its relative-attention table and later blocks reuse it. + hidden_states, position_bias = block( + hidden_states, + attention_mask=extended_mask, + position_bias=position_bias, + use_cache=False, + ) + + hidden_states = self.layer_norm(hidden_states) + if self.channel_mapper is not None: + hidden_states = self.channel_mapper(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load Ming's HF-format ``byt5_mapper.pt`` directly. + + Because we build on stock HF ``T5Block`` (unfused q/k/v/o, wi_0/wi_1/wo) + the source and target names already match — no stacked-param remap like + the vllm-omni port needs. Names present in the checkpoint but absent + from the module (or vice versa) are skipped and reported via the return + value, so callers can assert full coverage. + """ + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if name not in params_dict: + continue + param = params_dict[name] + if param.shape != loaded_weight.shape: + raise ValueError( + f"Shape mismatch loading byt5 mapper weight {name}: " + f"param {tuple(param.shape)} vs checkpoint {tuple(loaded_weight.shape)}" + ) + with torch.no_grad(): + param.copy_(loaded_weight) + loaded_params.add(name) + return loaded_params + + +__all__ = ["T5EncoderBlockByT5Mapper"] diff --git a/mstar/model/ming_omni_flash/components/talker_dit.py b/mstar/model/ming_omni_flash/components/talker_dit.py new file mode 100644 index 00000000..2c7fa316 --- /dev/null +++ b/mstar/model/ming_omni_flash/components/talker_dit.py @@ -0,0 +1,852 @@ +"""CFM + DiT building blocks for the Ming-flash-omni-2.0 Talker (step 6b). + +Ports the modeling primitives from vllm-omni's +``ming_flash_omni/talker_module.py`` (lines 1–402: DiT modules + CFM) +into mstar. Skips the vllm-only CFMGraphExecutor / Pool plumbing — +mstar has its own batching surface. + +Upstream module layout (mirror the names so the loader can map +``talker/model.safetensors`` keys 1:1): + + flowmodel.x_embedder, .c_embedder, .t_embedder, .blocks.{N}.norm1, + .blocks.{N}.attn.to_{q,k,v}, .blocks.{N}.attn.to_out.{0,1}, ..., .final_layer + +Two external deps replaced with in-tree minimal ports to keep the +runtime dep surface small: + + * ``DiTTimestepEmbedding`` — SinusPositionEmbedding + 2-layer MLP. + Mirrors vllm-omni's ``timestep_embedding.DiTTimestepEmbedding``. + * ``RotaryEmbedding`` — non-xpos 1-D RoPE matching x_transformers' + ``RotaryEmbedding.forward_from_seq_len`` exactly so the same + apply pattern works. We port both classes (without the xpos + branch — the talker config doesn't enable it). +""" + +from __future__ import annotations + +import math + +import torch +import torch.nn.functional as F +from torch import nn + +# =========================================================================== +# Sinusoidal timestep embedding (port of vllm-omni's DiTTimestepEmbedding) +# =========================================================================== + + +class _SinusPositionEmbedding(nn.Module): + """Sinusoidal embedding for scalar timesteps (DDPM / DiT convention). + + Mirrors vllm-omni's ``SinusPositionEmbedding`` exactly: + ``scale * x * exp(-log(10000) * k / (half_dim - 1))`` for + ``k in [0, half_dim)``, then concat(sin, cos). + """ + + def __init__(self, dim: int) -> None: + super().__init__() + if dim % 2 != 0: + raise ValueError(f"freq_embed_dim must be even, got {dim}") + self.dim = dim + + def forward(self, x: torch.Tensor, scale: float = 1000.0) -> torch.Tensor: + device = x.device + half = self.dim // 2 + # log-spaced inverse frequencies + emb = math.log(10000.0) / (half - 1) + emb = torch.exp(torch.arange(half, device=device).float() * -emb) + emb = scale * x.unsqueeze(1).float() * emb.unsqueeze(0) + out = torch.cat((emb.sin(), emb.cos()), dim=-1) + return out.to(x.dtype) + + +class DiTTimestepEmbedding(nn.Module): + """SinusPosEmb → Linear → SiLU → Linear. Output is ``hidden_size``-dim.""" + + def __init__(self, dim: int, freq_embed_dim: int = 256) -> None: + super().__init__() + self.time_embed = _SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential( + nn.Linear(freq_embed_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim), + ) + + def forward(self, timestep: torch.Tensor) -> torch.Tensor: + h = self.time_embed(timestep) + h = h.to(timestep.dtype) + return self.time_mlp(h) + + +# =========================================================================== +# RoPE — non-xpos 1-D variant (port of x_transformers.RotaryEmbedding) +# =========================================================================== +# +# x_transformers uses an INTERLEAVED pair layout: freqs are stacked as +# ``(f, f)`` per dim and then flattened, and ``rotate_half`` permutes +# adjacent pairs as ``(x1, x2) -> (-x2, x1)`` rather than the neox-cat +# split-by-halves convention used by Ling-2.0's thinker. +# We must mirror this layout exactly because the released ckpt's +# weights were trained against it. + + +def _rotate_half_interleaved(x: torch.Tensor) -> torch.Tensor: + """Pair-wise rotation: ``(..., d, 2) -> stack(-x2, x1)`` then flatten.""" + x = x.unflatten(-1, (-1, 2)) + x1, x2 = x.unbind(dim=-1) + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def _apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + """Standard partial-rotary apply with the interleaved pair layout. + + Args: + t: ``(B, H, T, head_dim)`` queries or keys. + freqs: ``(1, T, head_dim)`` rotary frequency table. + """ + rot_dim = freqs.shape[-1] + seq_len = t.shape[-2] + freqs = freqs[:, -seq_len:, :] + # Broadcast (1, T, D) to match (B, H, T, D) along the heads axis. + if t.ndim == 4 and freqs.ndim == 3: + freqs = freqs.unsqueeze(1) # (1, 1, T, D) + + rotated = t[..., :rot_dim] + passed = t[..., rot_dim:] + orig_dtype = rotated.dtype + cos = freqs.cos().to(orig_dtype) + sin = freqs.sin().to(orig_dtype) + rotated = (rotated * cos) + (_rotate_half_interleaved(rotated) * sin) + out = torch.cat([rotated, passed], dim=-1) + return out + + +class RotaryEmbedding(nn.Module): + """Non-xpos 1-D rotary embeddings matching x_transformers' interleaved layout. + + ``forward_from_seq_len(T)`` returns ``(freqs, xpos_scale=None)`` where + freqs is ``(1, T, dim)``. The DiT only ever uses ``xpos_scale=None`` + (released ckpt's ``use_xpos`` is implicitly False). + """ + + def __init__(self, dim: int, base: float = 10000.0) -> None: + super().__init__() + self.dim = dim + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward_from_seq_len( + self, seq_len: int, + ) -> tuple[torch.Tensor, None]: + t = torch.arange(seq_len, device=self.inv_freq.device) + # einsum('b i, j -> b i j') with t unsqueezed to (1, T) and + # inv_freq as (D//2,). Result: (1, T, D//2). + freqs = torch.einsum( + "i,j->ij", t.type_as(self.inv_freq), self.inv_freq, + ).unsqueeze(0) # (1, T, D//2) + # Stack pair-wise then flatten so each adjacent (f, f) pair lines + # up with ``rotate_half_interleaved``'s (-x2, x1) layout. + freqs = torch.stack((freqs, freqs), dim=-1).flatten(-2) # (1, T, D) + return freqs, None + + +# =========================================================================== +# DiT building blocks (RMSNorm, FeedForward, Attention, DiTBlock, FinalLayer, +# CondEmbedder) +# =========================================================================== + + +class _RMSNorm(nn.Module): + """Plain RMSNorm with a learnable scale (mirrors upstream).""" + + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.weight.dtype in (torch.float16, torch.bfloat16): + x = x.to(self.weight.dtype) + return F.rms_norm( + x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps, + ) + + +class _FeedForward(nn.Module): + """Linear → GELU → Dropout → Linear (port of upstream FeedForward). + + Layer indices in the released ckpt: ``ff.0.0`` (first Linear), + ``ff.0.1`` (GELU, no params), ``ff.1`` (Dropout, no params), + ``ff.2`` (second Linear). Match by integer index. + """ + + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: float = 4, + dropout: float = 0.0, + approximate: str = "none", + ) -> None: + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU(approximate=approximate), + ) + self.ff = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.ff(x) + + +class _Attention(nn.Module): + """Single-block attention with optional QK-norm, RoPE, and key-padding mask. + + Param names — `to_q`, `to_k`, `to_v`, `to_out.0`, (`to_out.1` is a + Dropout, no params) — mirror upstream exactly so the talker ckpt's + ``blocks.N.attn.to_q.weight`` etc. load by state_dict equality. + `q_norm` / `k_norm` are present only when ``qk_norm="rms_norm"`` + (released ckpt sets qk_norm=None, so both are None and absent from + state_dict). + + Mask handling matches upstream (`talker_module.Attention.forward`): + * ``mask`` is a ``(B, T)`` boolean key-padding mask — True for + valid positions, False for padding. + * When ``attn_mask_enabled=True``: build an SDPA attention mask + from ``mask`` so padded keys are excluded from softmax. + * Regardless of `attn_mask_enabled`: zero out output rows at + masked-out positions before returning (matches upstream's + unconditional ``x.masked_fill(~mask, 0.0)``). + + The released flowmodel + aggregator configs set + ``attn_mask_enabled=False`` so the SDPA mask branch is a no-op on + the live model; we still preserve the parameter for parity. + """ + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + qk_norm: str | None = None, + attn_mask_enabled: bool = True, + ) -> None: + super().__init__() + self.dim = dim + self.heads = heads + self.dim_head = dim_head + self.inner_dim = dim_head * heads + self.dropout = dropout + self.attn_mask_enabled = attn_mask_enabled + + self.to_q = nn.Linear(dim, self.inner_dim) + self.to_k = nn.Linear(dim, self.inner_dim) + self.to_v = nn.Linear(dim, self.inner_dim) + if qk_norm is None: + self.q_norm = None + self.k_norm = None + elif qk_norm == "rms_norm": + self.q_norm = _RMSNorm(dim_head) + self.k_norm = _RMSNorm(dim_head) + else: + raise ValueError(f"Unimplemented qk_norm: {qk_norm!r}") + + # ``to_out`` is a ModuleList of [Linear, Dropout] (matches + # upstream so ckpt keys ``to_out.0.weight`` etc. land). + self.to_out = nn.ModuleList([ + nn.Linear(self.inner_dim, dim), + nn.Dropout(dropout), + ]) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor | None = None, + rope: tuple[torch.Tensor, torch.Tensor | None] | None = None, + ) -> torch.Tensor: + B = x.shape[0] + q = self.to_q(x).view(B, -1, self.heads, self.dim_head).transpose(1, 2) + k = self.to_k(x).view(B, -1, self.heads, self.dim_head).transpose(1, 2) + v = self.to_v(x).view(B, -1, self.heads, self.dim_head).transpose(1, 2) + if self.q_norm is not None: + q = self.q_norm(q) + if self.k_norm is not None: + k = self.k_norm(k) + + if rope is not None: + freqs, _xpos_scale = rope # xpos_scale always None on this path + q = _apply_rotary_pos_emb(q, freqs) + k = _apply_rotary_pos_emb(k, freqs) + + # SDPA mask. Upstream builds a (B', H, T, T) bool mask from a + # (B, T) key-padding mask and uses additive masking via SDPA's + # attn_mask kwarg. We replicate the same shape so float weights + # see identical attention patterns. + attn_mask = None + if self.attn_mask_enabled and mask is not None: + # mask shape: (B, T). Expand to (B, H, Tq, Tk). + attn_mask = mask[:, None, None, :].expand(B, self.heads, q.shape[-2], k.shape[-2]) + + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False, + ) + out = out.transpose(1, 2).reshape(B, -1, self.inner_dim) + out = self.to_out[0](out) + out = self.to_out[1](out) + + if mask is not None: + # Unconditional output-zeroing at masked positions (matches + # upstream's ``x.masked_fill(~mask, 0.0)``, executed even + # when attn_mask_enabled is False). + out = out.masked_fill(~mask[:, :, None], 0.0) + return out + + +class _DiTBlock(nn.Module): + """Pre-norm attention + pre-norm FFN with residuals (upstream DiTBlock). + + Forward signature matches upstream `(x, mask, rope)` so the + Aggregator can pass a key-padding mask through to the attention. + For the CFM DiT path the caller passes mask=None. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + dropout: float = 0.1, + qk_norm: str | None = None, + attn_mask_enabled: bool = True, + ) -> None: + super().__init__() + self.norm1 = _RMSNorm(hidden_size) + self.attn = _Attention( + dim=hidden_size, + heads=num_heads, + dim_head=hidden_size // num_heads, + dropout=dropout, + qk_norm=qk_norm, + attn_mask_enabled=attn_mask_enabled, + ) + self.norm2 = _RMSNorm(hidden_size) + self.mlp = _FeedForward( + dim=hidden_size, mult=mlp_ratio, dropout=dropout, approximate="tanh", + ) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor | None, + rope: tuple[torch.Tensor, torch.Tensor | None] | None, + ) -> torch.Tensor: + x = x + self.attn(self.norm1(x), mask=mask, rope=rope) + x = x + self.mlp(self.norm2(x)) + return x + + +class _FinalLayer(nn.Module): + """RMSNorm → Linear; projects DiT hidden states back to ``out_channels``.""" + + def __init__(self, hidden_size: int, out_channels: int) -> None: + super().__init__() + self.norm_final = _RMSNorm(hidden_size) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.norm_final(x)) + + +class _CondEmbedder(nn.Module): + """Projects LLM hidden states (cond) into the DiT hidden space.""" + + def __init__(self, input_feature_size: int, hidden_size: int) -> None: + super().__init__() + self.cond_embedder = nn.Linear(input_feature_size, hidden_size) + + def forward(self, llm_cond: torch.Tensor) -> torch.Tensor: + return self.cond_embedder(llm_cond) + + +# =========================================================================== +# DiT (assembles N DiTBlocks + embedders + final layer) +# =========================================================================== + + +class DiT(nn.Module): + """Diffusion-transformer for audio-latent generation (port of upstream DiT). + + Forward signature mirrors upstream so the calling code in + ``CFM.sample`` (and `forward_with_cfg`) works unchanged. The + optional ``spk_embedder`` is omitted on the released ckpt (the + flowmodel config has no ``spk_dim``). + """ + + def __init__( + self, + in_channels: int = 64, + hidden_size: int = 1024, + depth: int = 8, + num_heads: int = 16, + mlp_ratio: float = 4.0, + llm_cond_dim: int = 896, + dropout: float = 0.0, + qk_norm: str | None = None, + spk_dim: int | None = None, + attn_mask_enabled: bool = False, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.num_heads = num_heads + self.hidden_size = hidden_size + + self.t_embedder = DiTTimestepEmbedding(hidden_size) + self.x_embedder = nn.Linear(in_channels, hidden_size) + self.c_embedder = _CondEmbedder(llm_cond_dim, hidden_size) + self.spk_embedder = ( + nn.Linear(spk_dim, hidden_size) if spk_dim is not None else None + ) + + self.rotary_embed = RotaryEmbedding(hidden_size // num_heads) + self.blocks = nn.ModuleList([ + _DiTBlock( + hidden_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + dropout=dropout, + qk_norm=qk_norm, + attn_mask_enabled=attn_mask_enabled, + ) + for _ in range(depth) + ]) + self.final_layer = _FinalLayer(hidden_size, self.out_channels) + + def forward( + self, + x: torch.Tensor, # (B, patch_size, in_channels) + t: torch.Tensor, # (B,) or scalar + c: torch.Tensor, # (B, 1, llm_cond_dim) + latent_history: torch.Tensor, # (B, his_patch_size, in_channels) + spk_emb: torch.Tensor | None = None, + ) -> torch.Tensor: + """Returns hidden states of shape ``(B, prefix + T, out_channels)``. + + ``prefix`` is 1 (t+c) plus 1 if spk_embedder is set; the caller + is expected to take the last ``T`` rows where ``T`` is the + sum of ``latent_history`` and ``x`` lengths. + """ + x = torch.cat([latent_history, x], dim=1) + x = self.x_embedder(x) + t_h = self.t_embedder(t).unsqueeze(1) + c_h = self.c_embedder(c) + y = t_h + c_h + if spk_emb is None: + if self.spk_embedder is not None: + raise AssertionError( + "DiT was built with spk_embedder but spk_emb was None at forward." + ) + x = torch.cat([y, x], dim=1) + else: + assert self.spk_embedder is not None, "spk_emb provided but spk_embedder=None" + x = torch.cat([self.spk_embedder(spk_emb), y, x], dim=1) + + rope = self.rotary_embed.forward_from_seq_len(x.shape[1]) + for block in self.blocks: + # DiT path: mask=None (CFM only uses RoPE; the Aggregator is + # what actually exercises the mask branch). + x = block(x, None, rope) + return self.final_layer(x) + + def forward_with_cfg( + self, + x: torch.Tensor, + t: torch.Tensor, + c: torch.Tensor, + latent_history: torch.Tensor, + spk_emb: torch.Tensor | None = None, + ) -> torch.Tensor: + """Classifier-free guidance: double the batch and pass null cond. + + Returns only the last ``x.shape[1]`` rows (the denoised x). + """ + x_cat = torch.cat([x, x], dim=0) + lh_cat = torch.cat([latent_history, latent_history], dim=0) + null_c = torch.zeros_like(c) + c_cat = torch.cat([c, null_c], dim=0) + if t.ndim == 0: + t = t.repeat(x_cat.shape[0]) + spk_cat = None if spk_emb is None else torch.cat([spk_emb, spk_emb], dim=0) + out = self.forward(x_cat, t, c_cat, lh_cat, spk_cat) + return out[:, -x.shape[1]:, :] + + +# =========================================================================== +# CFM (Conditional Flow Matching sampler) +# =========================================================================== + + +def get_epss_timesteps( + n: int, device: torch.device | str, dtype: torch.dtype, +) -> torch.Tensor: + """EPSS schedule (port of upstream ``get_epss_timesteps``). + + Returns ``n + 1`` integration timesteps in [0, 1]. Predefined + fixed-step schedules (5, 6, 7, 10, 12, 16) match the upstream's + empirically-tuned packing of more steps near t=0 where prediction + error is highest; other ``n`` values fall back to linspace. + """ + dt = 1 / 32 + predefined = { + 5: [0, 2, 4, 8, 16, 32], + 6: [0, 2, 4, 6, 8, 16, 32], + 7: [0, 2, 4, 6, 8, 16, 24, 32], + 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], + 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32], + 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32], + } + schedule = predefined.get(n) + if not schedule: + return torch.linspace(0, 1, n + 1, device=device, dtype=dtype) + return dt * torch.tensor(schedule, device=device, dtype=dtype) + + +class CFM(nn.Module): + """Conditional Flow Matching sampler over a wrapped DiT. + + Single ``sample`` entry point — given an LLM condition and a noise + latent, integrate the velocity field for ``steps`` substeps with + classifier-free guidance. + """ + + def __init__( + self, + model: nn.Module, + steps: int = 10, + sway_sampling_coef: float | None = -1.0, + ) -> None: + super().__init__() + self.model = model + self.steps = steps + self.sway_sampling_coef = sway_sampling_coef + + @torch.no_grad() + def sample( + self, + llm_cond: torch.Tensor, # (B, 1, llm_cond_dim) + lat_cond: torch.Tensor, # (B, his_patch_size, latent_dim) + y0: torch.Tensor, # (B, patch_size, latent_dim) — initial noise + t: torch.Tensor, # (steps + 1,) — from get_epss_timesteps + sde_args: torch.Tensor, # (3,) — [cfg_strength, sigma, temperature] + sde_rnd: torch.Tensor, # (steps, B, patch_size, latent_dim) + ) -> torch.Tensor: + """Returns the denoised latent ``(B, patch_size, latent_dim)``.""" + if t.shape[0] != self.steps + 1: + raise ValueError( + f"CFM.sample: expected t of length steps+1 = {self.steps + 1}, got {t.shape[0]}" + ) + if sde_rnd.shape[0] != self.steps: + raise ValueError( + f"CFM.sample: expected sde_rnd[0] = {self.steps}, got {sde_rnd.shape[0]}" + ) + + def velocity(fn_t: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + pred_cfg = self.model.forward_with_cfg(x, fn_t, llm_cond, lat_cond, None) + pred, null_pred = torch.chunk(pred_cfg, 2, dim=0) + # Standard CFG composition. + return pred + (pred - null_pred) * sde_args[0] + + if self.sway_sampling_coef is not None: + t = t + self.sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + + for step in range(self.steps): + dt = t[step + 1] - t[step] + y0 = y0 + velocity(t[step], y0) * dt + # SDE noise term: sigma * sqrt(temperature) * sqrt(|dt|) * eps + y0 = y0 + sde_args[1] * (sde_args[2] ** 0.5) * (dt.abs() ** 0.5) * sde_rnd[step] + return y0 + + +# =========================================================================== +# Factory: build a DiT + CFM from TalkerConfig +# =========================================================================== + + +def build_talker_cfm( + talker_config, + llm_cond_dim: int | None = None, + dtype: torch.dtype = torch.bfloat16, + device: str | torch.device = "cpu", +) -> CFM: + """Construct DiT + CFM from a :class:`TalkerConfig`. + + The released ckpt's flowmodel block carries + ``in_channels=64, hidden_size=1024, depth=8, num_heads=16, mlp_ratio=4`` + with no spk_dim. ``llm_cond_dim`` defaults to the talker LLM hidden + size (896) when not specified. + """ + flow = talker_config.flowmodel + if llm_cond_dim is None: + llm_cond_dim = talker_config.llm.hidden_size + dit = DiT( + in_channels=flow.in_channels, + hidden_size=flow.hidden_size, + depth=flow.depth, + num_heads=flow.num_heads, + mlp_ratio=flow.mlp_ratio, + llm_cond_dim=llm_cond_dim, + dropout=flow.dropout, + qk_norm=flow.qk_norm, + attn_mask_enabled=flow.attn_mask_enabled, + ) + cfm = CFM(model=dit, steps=talker_config.steps) + cfm = cfm.to(dtype=dtype, device=device) + cfm.eval() + return cfm + + +# =========================================================================== +# Aggregator (DiT-shaped, maps audio latents back to LLM cond space) +# =========================================================================== + + +class Aggregator(nn.Module): + """Maps generated audio-latent patches back to LLM embedding space. + + Port of upstream `talker_module.Aggregator` (lines 702-744). Same + DiTBlock stack as the CFM head but the input embedder is `nn.Linear` + (audio-latent → hidden) plus a learnable [CLS]-style `word_embedder` + prepended to the sequence; the output is the `[CLS]` row only, + projected to `llm_input_dim` so it can re-enter the talker LLM's + embedding space (closing the conditional-history loop). + + The released aggregator block matches the flowmodel shape + (`depth=8, hidden_size=1024, num_heads=16, mlp_ratio=4, in_channels=64`) + except `dropout=0.1` and an `attn_mask_enabled=False` default that + still preserves the output-masking branch. + """ + + def __init__( + self, + in_channels: int = 64, + hidden_size: int = 1024, + depth: int = 8, + num_heads: int = 16, + mlp_ratio: float = 4.0, + llm_input_dim: int = 896, + dropout: float = 0.1, + qk_norm: str | None = None, + attn_mask_enabled: bool = False, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.num_heads = num_heads + self.hidden_size = hidden_size + + # Learnable [CLS] token (single-row embedding table — exactly as + # upstream uses ``nn.Embedding(1, hidden_size)`` indexed at 0). + self.word_embedder = nn.Embedding(1, hidden_size) + self.x_embedder = nn.Linear(in_channels, hidden_size) + + self.rotary_embed = RotaryEmbedding(hidden_size // num_heads) + self.blocks = nn.ModuleList([ + _DiTBlock( + hidden_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + dropout=dropout, + qk_norm=qk_norm, + attn_mask_enabled=attn_mask_enabled, + ) + for _ in range(depth) + ]) + self.final_layer = _FinalLayer(hidden_size, llm_input_dim) + + def forward( + self, + x: torch.Tensor, # (B, T, in_channels) audio latents + mask: torch.Tensor | None = None, # (B, T) key-padding mask, True = valid + ) -> torch.Tensor: + """Returns the [CLS] row only: ``(B, 1, llm_input_dim)``. + + Mirrors upstream `Aggregator.forward`: prepend a single learnable + [CLS] token, prepend a True-cell to the mask, run all DiT blocks, + project to ``llm_input_dim`` via `final_layer`, return the + leading row. + """ + B = x.shape[0] + h = self.x_embedder(x) + cls_ids = torch.zeros((B, 1), dtype=torch.long, device=h.device) + cls_embed = self.word_embedder(cls_ids) + h = torch.cat([cls_embed, h], dim=1) + + rope = self.rotary_embed.forward_from_seq_len(h.shape[1]) + if mask is not None: + # Prepend a True column so the [CLS] row is never masked. + mask_pad = mask[:, :1].clone().detach() + mask = torch.cat([mask_pad, mask], dim=-1) + + for block in self.blocks: + h = block(h, mask, rope) + h = self.final_layer(h) + return h[:, :1, :] + + +def build_aggregator( + talker_config, + llm_input_dim: int | None = None, + dtype: torch.dtype = torch.bfloat16, + device: str | torch.device = "cpu", +) -> Aggregator: + """Construct an :class:`Aggregator` from a :class:`TalkerConfig`. + + The released ckpt's aggregator block carries + ``in_channels=64, hidden_size=1024, depth=8, num_heads=16, + mlp_ratio=4, dropout=0.1``. ``llm_input_dim`` defaults to + ``talker_config.llm.hidden_size`` (896). + """ + agg = talker_config.aggregator + if llm_input_dim is None: + llm_input_dim = talker_config.llm.hidden_size + module = Aggregator( + in_channels=agg.in_channels, + hidden_size=agg.hidden_size, + depth=agg.depth, + num_heads=agg.num_heads, + mlp_ratio=agg.mlp_ratio, + llm_input_dim=llm_input_dim, + dropout=agg.dropout, + qk_norm=agg.qk_norm, + attn_mask_enabled=agg.attn_mask_enabled, + ) + module = module.to(dtype=dtype, device=device) + module.eval() + return module + + +# =========================================================================== +# Talker LLM backbone (Qwen2) +# =========================================================================== + + +def build_talker_llm( + talker_llm_config, + attn_implementation: str = "sdpa", + dtype: torch.dtype = torch.bfloat16, + device: str | torch.device = "cpu", +): + """Construct a HF `Qwen2Model` from our `TalkerLLMConfig`. + + The talker's LLM is a stock Qwen2 model — no custom modules, no + TP needed in the typical topology (the talker colocates on a + single rank). Reusing `transformers.Qwen2Model` keeps the surface + small and inherits HF's KV-cache + attention impl. The ckpt's + weight keys under `talker/model.safetensors` start with `model.` + and follow the standard Qwen2 layout, so the eventual loader + will be a simple prefix strip. + + Args: + talker_llm_config: `TalkerLLMConfig` instance. + attn_implementation: passed through to Qwen2Config so the + model can use FA2 / SDPA. The upstream vllm-omni talker + uses ``"sdpa"`` (the ckpt's Qwen2 has + `_attn_implementation: flash_attention_2` baked into its + config dict but the vllm-omni runtime forcibly overrides + to sdpa to play nicely with vLLM's attention machinery + — we follow the same default). + dtype: cast the model to this dtype after construction. + device: device to materialise the model on. + + Returns: + A `transformers.models.qwen2.modeling_qwen2.Qwen2Model` + instance with all parameters allocated (weights are still + random; the loader populates them later). + """ + try: + from transformers import Qwen2Config, Qwen2Model + except ImportError as e: + raise ImportError( + "build_talker_llm requires transformers >= 4.43 (Qwen2 support). " + f"Original error: {e}" + ) from e + + llm_cfg = Qwen2Config( + vocab_size=talker_llm_config.vocab_size, + hidden_size=talker_llm_config.hidden_size, + intermediate_size=talker_llm_config.intermediate_size, + num_hidden_layers=talker_llm_config.num_hidden_layers, + num_attention_heads=talker_llm_config.num_attention_heads, + num_key_value_heads=talker_llm_config.num_key_value_heads, + hidden_act=talker_llm_config.hidden_act, + max_position_embeddings=talker_llm_config.max_position_embeddings, + rms_norm_eps=talker_llm_config.rms_norm_eps, + rope_theta=talker_llm_config.rope_theta, + use_sliding_window=talker_llm_config.use_sliding_window, + sliding_window=talker_llm_config.sliding_window, + max_window_layers=talker_llm_config.max_window_layers, + tie_word_embeddings=talker_llm_config.tie_word_embeddings, + attention_dropout=talker_llm_config.attention_dropout, + use_cache=talker_llm_config.use_cache, + bos_token_id=talker_llm_config.bos_token_id, + eos_token_id=talker_llm_config.eos_token_id, + attn_implementation=attn_implementation, + ) + model = Qwen2Model(llm_cfg) + model = model.to(dtype=dtype, device=device) + model.eval() + return model + + +def build_talker_heads( + talker_config, + spk_embed_dim: int = 192, + dtype: torch.dtype = torch.bfloat16, + device: str | torch.device = "cpu", +) -> dict[str, nn.Module]: + """Build the talker's small per-purpose Linear heads. + + Two heads sit alongside the LLM + CFM + Aggregator + AudioVAE: + + * ``stop_head`` — ``Linear(hidden_size, 2, bias=True)``: binary + end-of-audio classifier consumed during the generation loop + to decide when to stop. + * ``spk_head`` — ``Linear(spk_embed_dim=192, hidden_size, + bias=True)``: projects a CAMPPlus speaker embedding into the + LLM hidden space; the projected embedding is prepended to + the prompt as a voice-condition token. + + Returned as a dict so callers can wire them into the talker + forward without depending on a specific module-tree shape. + """ + hidden = talker_config.llm.hidden_size + stop_head = nn.Linear(hidden, 2, bias=True) + spk_head = nn.Linear(spk_embed_dim, hidden, bias=True) + stop_head = stop_head.to(dtype=dtype, device=device) + spk_head = spk_head.to(dtype=dtype, device=device) + stop_head.eval() + spk_head.eval() + return {"stop_head": stop_head, "spk_head": spk_head} + + +__all__ = [ + "DiT", + "CFM", + "Aggregator", + "DiTTimestepEmbedding", + "RotaryEmbedding", + "get_epss_timesteps", + "build_talker_cfm", + "build_aggregator", + "build_talker_llm", + "build_talker_heads", +] diff --git a/mstar/model/ming_omni_flash/components/talker_generator.py b/mstar/model/ming_omni_flash/components/talker_generator.py new file mode 100644 index 00000000..02a1eeb9 --- /dev/null +++ b/mstar/model/ming_omni_flash/components/talker_generator.py @@ -0,0 +1,543 @@ +"""TalkerGenerator: orchestrates Qwen2 + CFM + Aggregator + AudioVAE (step 6e-1). + +Port of vllm-omni's ``MingAudioGenerator`` (``talker_module.py:854-1146``) +plus the streaming-decode utilities (`silence_holder`, +`trim_trailing_silence`). Stateless across requests — one ``__init__`` +binds the model components, then each call to `generate_latents` runs a +fresh per-request AR loop. + +Skipped from upstream: + * `CFMGraphExecutorPool` / `CFMGraphExecutor` — vllm-specific CUDA-graph + batching infrastructure. We always run `cfm_sample_step` through the + manual path; mstar's engine layer handles graph capture separately. + * `build_tts_input` / `_looks_like_music_prompt` — prompt-construction + helpers that go alongside the eventual `process_prompt` audio-out path. + Lives in step 8 (TTS caption template). + +The generator's outputs feed directly into the mstar graph wiring in +step 6e-2: + * `generate_latents()` is what `TalkerSubmodule.forward` will call per + request, returning the list of CFM-generated latent patches. + * `decode_to_waveform()` is what the audio-output submodule will call + to produce the final waveform tensor for `EMIT_TO_CLIENT`. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch +from torch import nn + +from mstar.model.ming_omni_flash.components.talker_dit import ( + CFM, + Aggregator, + get_epss_timesteps, +) + +if TYPE_CHECKING: + from transformers.models.qwen2.modeling_qwen2 import Qwen2Model + + from mstar.model.ming_omni_flash.components.audio_vae import AudioVAE + from mstar.model.ming_omni_flash.config import TalkerConfig + +logger = logging.getLogger(__name__) + + +# =========================================================================== +# Streaming silence / trim utilities +# =========================================================================== + + +def trim_trailing_silence( + waveform: torch.Tensor, + sample_rate: int, + sil_th: float = 1e-3, + tail_silence_s: float = 0.3, +) -> torch.Tensor: + """Drop low-energy frames off the tail; keep a short trailing silence. + + Accepts 2-D ``(C, T)`` or 3-D ``(B, C, T)`` waveforms. Anything else + is passed through unchanged (defensive: rather than raise, leave the + output untouched so a misshaped tensor doesn't crash decode). + """ + if waveform.numel() == 0: + return waveform + + original_dim = waveform.dim() + if original_dim == 3: + speech = waveform[:, 0, :] + elif original_dim == 2: + speech = waveform + else: + return waveform + + frame_size = int(sample_rate * 0.1) + frame_step = int(sample_rate * 0.1) + if speech.shape[-1] < frame_size: + keep = min(speech.shape[-1], int(tail_silence_s * sample_rate)) + trimmed = speech[..., :keep] + else: + num_frame = (speech.shape[-1] - frame_size) // frame_step + 1 + cur_len = (num_frame - 1) * frame_step + frame_size + speech = speech[..., :cur_len] + spe_frames = speech.unfold(-1, frame_size, frame_step) + scores = spe_frames.abs().mean(dim=-1) + scores = scores.mean(dim=list(range(scores.dim() - 1))) + idx = int(scores.shape[0]) - 1 + while idx >= 0 and scores[idx] <= sil_th: + idx -= 1 + if idx < 0: + keep = min(speech.shape[-1], int(tail_silence_s * sample_rate)) + trimmed = speech[..., :keep] + else: + non_sil_len = idx * frame_step + frame_size + int(tail_silence_s * sample_rate) + non_sil_len = min(non_sil_len, speech.shape[-1]) + trimmed = speech[..., :non_sil_len] + + if original_dim == 3: + return trimmed.unsqueeze(1) + return trimmed + + +def silence_holder( + speech: torch.Tensor, + sample_rate: int, + sil_cache: dict | None = None, + last_chunk: bool = True, + sil_th: float = 1e-3, + last_sil: float = 0.3, +) -> tuple[torch.Tensor, dict]: + """Streaming silence holder used during chunked VAE decode. + + Buffers low-energy chunks until a non-silent frame arrives (or the + stream ends), so the client doesn't see long silent runs that would + later get trimmed anyway. ``sil_cache`` carries state across chunks: + ``{"holder": [tensors], "buffer": [tensors]}``. + + Same algorithm as upstream's ``silence_holder``. The leading-silence + holder lets you defer emission of long silent regions; the + short-chunk buffer concatenates chunks smaller than one frame. + """ + if speech.numel() == 0: + return speech, sil_cache or {"holder": [], "buffer": []} + + frame_step = int(sample_rate * 0.1) + frame_size = int(sample_rate * 0.1) + if sil_cache is None: + sil_cache = {"holder": [], "buffer": []} + + if sil_cache["buffer"]: + speech = torch.cat([*sil_cache["buffer"], speech], dim=-1) + sil_cache["buffer"] = [] + + if speech.shape[-1] < frame_size: + sil_cache["buffer"].append(speech) + if last_chunk: + out = torch.cat(sil_cache["holder"] + sil_cache["buffer"], dim=-1) + return out[..., : int(last_sil * sample_rate)], sil_cache + return torch.zeros((*speech.shape[:-1], 0), device=speech.device, dtype=speech.dtype), sil_cache + + num_frame = (speech.shape[-1] - frame_size) // frame_step + 1 + cur_len = (num_frame - 1) * frame_step + frame_size + if speech.shape[-1] > cur_len: + sil_cache["buffer"].append(speech[..., cur_len:]) + speech = speech[..., :cur_len] + + spe_frames = speech.unfold(-1, frame_size, frame_step) + scores = spe_frames.abs().mean(dim=-1) + scores = scores.mean(dim=list(range(scores.dim() - 1))) + idx = int(scores.shape[0]) - 1 + while idx >= 0 and scores[idx] <= sil_th: + idx -= 1 + + if idx < 0: + sil_cache["holder"].append(speech) + if last_chunk: + out = torch.cat(sil_cache["holder"] + sil_cache["buffer"], dim=-1) + return out[..., : int(last_sil * sample_rate)], sil_cache + return torch.zeros((*speech.shape[:-1], 0), device=speech.device, dtype=speech.dtype), sil_cache + + non_sil_len = idx * frame_step + frame_size + if last_chunk: + non_sil_len += int(last_sil * sample_rate) + non_sil_len = min(non_sil_len, speech.shape[-1]) + speech_out = torch.cat([*sil_cache["holder"], speech[..., :non_sil_len]], dim=-1) + sil_cache["holder"] = [] + if non_sil_len < speech.shape[-1]: + sil_cache["holder"].append(speech[..., non_sil_len:]) + return speech_out, sil_cache + + +# =========================================================================== +# TalkerGenerator +# =========================================================================== + + +class TalkerGenerator: + """Drives prefill → AR decode → VAE decode for a single TTS request. + + Stateless across requests: bind the model components once at + construction, then each `generate_latents` / `decode_to_waveform` + call runs a fresh per-request flow. The eventual `TalkerSubmodule` + (step 6e-2) instantiates one per worker and calls into it once per + request. + + Field naming mirrors upstream `MingAudioGenerator.__init__` so the + eventual graph-walk wiring + tests can reference the same surface. + """ + + def __init__( + self, + talker_config: "TalkerConfig", + llm: "Qwen2Model", + cfm: CFM, + aggregator: Aggregator, + stop_head: nn.Module, + audio_vae: "AudioVAE | None" = None, + cfg_strength: float | None = None, + ) -> None: + self.config = talker_config + self.llm = llm + self.cfm = cfm + self.aggregator = aggregator + self.stop_head = stop_head + self.audio_vae = audio_vae + self.patch_size = talker_config.patch_size + self.his_patch_size = talker_config.history_patch_size + self.latent_dim = talker_config.vae.latent_dim + self.cfg_strength = ( + cfg_strength if cfg_strength is not None else talker_config.cfg_strength + ) + # Trailing latent frames prepended on each VAE-decode chunk so the + # Qwen2 backbone sees enough context for FA2 to be happy. + self._vae_decode_pad_frames = 32 + + # ------------------------------------------------------------------ + # Step entry points (mirror upstream MingAudioGenerator) + # ------------------------------------------------------------------ + + def llm_step( + self, + inputs_embeds: torch.Tensor, + *, + step: int, + past_key_values=None, + use_static_cache: bool, + ) -> torch.Tensor: + """Single Qwen2 forward step; returns the last hidden state row. + + On step 0 (or when no static cache is in use), call the LLM + without an explicit `cache_position`. On subsequent decode + steps with a `StaticCache`, supply `cache_position` so the + cache knows where to write the new K/V. + """ + if step == 0 or not use_static_cache: + outputs = self.llm( + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=True, + ) + else: + past_seen = int(past_key_values.get_seq_length()) + cache_position = torch.arange( + past_seen, + past_seen + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + outputs = self.llm( + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=True, + cache_position=cache_position, + ) + return outputs.last_hidden_state[:, -1:, :] + + def cfm_sample_step( + self, + last_hidden_state: torch.Tensor, + his_lat: torch.Tensor, + *, + cfg: float | None = None, + sigma: float = 0.25, + temperature: float = 0.0, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """One CFM sampling step. + + Returns ``(gen_lat, next_inputs_embeds, stop_out)`` where: + * `gen_lat`: ``(B, patch_size, latent_dim)`` — the new + latent patch. + * `next_inputs_embeds`: ``(B, 1, llm_hidden)`` — what to feed + the LLM on the next step (Aggregator output). + * `stop_out`: ``(B, 2)`` — softmaxed stop classifier output. + """ + if cfg is None: + cfg = self.cfg_strength + + bat_size, _, z_dim = his_lat.shape + randn_tensor = torch.randn( + (bat_size, self.patch_size, z_dim), + device=last_hidden_state.device, + dtype=last_hidden_state.dtype, + ) + t = get_epss_timesteps( + self.config.steps, + device=last_hidden_state.device, + dtype=last_hidden_state.dtype, + ) + sde_rnd = torch.randn( + (self.config.steps, *randn_tensor.shape), + device=last_hidden_state.device, + dtype=last_hidden_state.dtype, + ) + sde_args = torch.tensor( + [cfg, sigma, temperature], + device=last_hidden_state.device, + dtype=last_hidden_state.dtype, + ) + + gen_lat = self.cfm.sample(last_hidden_state, his_lat, randn_tensor, t, sde_args, sde_rnd) + inputs_embeds = self.aggregator(gen_lat) + stop_out = self.stop_head(last_hidden_state[:, -1, :]).softmax(dim=-1) + return gen_lat, inputs_embeds, stop_out + + # ------------------------------------------------------------------ + # AR generation loop + # ------------------------------------------------------------------ + + @torch.no_grad() + def generate_latents( + self, + inputs_embeds: torch.Tensor, + *, + prompt_wav_lat: torch.Tensor | None = None, + min_new_token: int = 10, + max_steps: int = 1000, + cfg: float | None = None, + sigma: float = 0.25, + temperature: float = 0.0, + use_static_cache: bool = True, + ) -> list[torch.Tensor]: + """AR loop: prefill → repeated (LLM step → CFM sample → stop check). + + Returns the list of per-step CFM-generated latent patches in + emission order. Each entry is ``(B, patch_size, latent_dim)``; + caller concatenates along dim=1 before feeding to `decode_to_waveform` + for the one-shot decode path, or hands them in one at a time for + the streaming path. + """ + if cfg is None: + cfg = self.cfg_strength + device = next(self.llm.parameters()).device + dtype = next(self.llm.parameters()).dtype + + his_lat = self._init_his_lat(prompt_wav_lat, device, dtype) + past_key_values, max_cache_len = self._init_kv_cache(use_static_cache, device, dtype) + prefill_len = inputs_embeds.shape[1] + all_latents: list[torch.Tensor] = [] + + steps_budget = min(max_steps, max_cache_len - prefill_len) if max_cache_len else max_steps + for step in range(steps_budget): + last_hs = self.llm_step( + inputs_embeds, + step=step, + past_key_values=past_key_values, + use_static_cache=use_static_cache, + ) + gen_lat, inputs_embeds, stop_out = self.cfm_sample_step( + last_hs, his_lat, cfg=cfg, sigma=sigma, temperature=temperature, + ) + his_lat = self._update_his_lat(his_lat, gen_lat) + all_latents.append(gen_lat) + + stop_prob = float(stop_out[0, 1].detach().cpu().item()) + if step > min_new_token and stop_prob > 0.5: + logger.debug("TalkerGenerator: stop at step=%d (prob=%.4f)", step, stop_prob) + break + + return all_latents + + # ------------------------------------------------------------------ + # KV cache + history-latent bookkeeping + # ------------------------------------------------------------------ + + def _init_his_lat( + self, + prompt_wav_lat: torch.Tensor | None, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """Build the initial history-latent buffer (shape (1, his_patch_size, latent_dim)). + + If `prompt_wav_lat` is supplied (e.g. voice-prompt conditioning), + right-align it inside the his_patch_size window; otherwise the + buffer starts as zeros. + """ + his_lat = torch.zeros( + 1, self.his_patch_size, self.latent_dim, device=device, dtype=dtype, + ) + if prompt_wav_lat is not None: + start_index = self.his_patch_size - prompt_wav_lat.size(1) + if start_index < 0: + his_lat[:] = prompt_wav_lat[:, -start_index:, :] + else: + his_lat[:, start_index:, :] = prompt_wav_lat + return his_lat + + def _init_kv_cache( + self, + use_static_cache: bool, + device: torch.device, + dtype: torch.dtype, + ) -> tuple[object | None, int]: + """Allocate a `StaticCache` for the Qwen2 LLM when requested. + + Returns ``(cache_or_None, max_cache_len)``. `StaticCache` is the + upstream choice; matches what the released ckpt's serving path + uses and lets us pass `cache_position` through `llm_step` on + step > 0. + """ + max_cache_len = 2048 + if not use_static_cache: + return None, max_cache_len + from transformers import Qwen2Config, StaticCache + # Build a Qwen2Config from our TalkerLLMConfig dataclass so + # StaticCache can read the layer / head dims it needs. + llm_cfg = Qwen2Config( + hidden_size=self.config.llm.hidden_size, + num_hidden_layers=self.config.llm.num_hidden_layers, + num_attention_heads=self.config.llm.num_attention_heads, + num_key_value_heads=self.config.llm.num_key_value_heads, + vocab_size=self.config.llm.vocab_size, + max_position_embeddings=self.config.llm.max_position_embeddings, + ) + cache = StaticCache( + config=llm_cfg, + max_batch_size=1, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + ) + return cache, max_cache_len + + def _update_his_lat( + self, his_lat: torch.Tensor, gen_lat: torch.Tensor, + ) -> torch.Tensor: + """Slide the his_patch_size window forward by patch_size.""" + if self.his_patch_size == self.patch_size: + return gen_lat + if self.his_patch_size > self.patch_size: + return torch.cat( + [his_lat[:, self.patch_size - self.his_patch_size:], gen_lat], dim=1, + ) + raise NotImplementedError( + f"his_patch_size ({self.his_patch_size}) < patch_size ({self.patch_size})", + ) + + # ------------------------------------------------------------------ + # Duration cap heuristic (port of upstream `duration_capped_steps`) + # ------------------------------------------------------------------ + + def duration_capped_steps( + self, text_len: int, requested_max_steps: int, + ) -> int: + """Cap requested max_steps by a duration heuristic. + + Mirrors upstream: each generation step yields + ``(patch_size * vae_patch_size * vae_hop_length) / sample_rate`` + seconds of audio. The max-duration budget per turn is + ``max(2.0, text_len * 5818/16000)`` seconds (the 5818/16000 + constant is a duration-per-token estimate matched against + the released ckpt's prosody). + """ + if self.audio_vae is None: + return requested_max_steps + sample_rate = float(self.audio_vae.config.sample_rate) + vae_patch_size = float(self.audio_vae.config.patch_size) + hop_size = float(self.audio_vae.decoder.hop_length) + seconds_per_step = (self.patch_size * vae_patch_size * hop_size) / sample_rate + if seconds_per_step <= 0: + return requested_max_steps + max_duration_s = max(2.0, float(text_len) * (5818.0 / 16000.0)) + max_steps_by_duration = max(1, int(max_duration_s / seconds_per_step)) + return min(requested_max_steps, max_steps_by_duration) + + # ------------------------------------------------------------------ + # Audio decode (one-shot + streaming) + # ------------------------------------------------------------------ + + def decode_to_waveform( + self, latents: list[torch.Tensor], stream_decode: bool = True, + ) -> torch.Tensor: + """Decode latents → waveform via `AudioVAE.decode`. + + ``stream_decode=True`` runs the chunked path (matches the live + serving topology where each CFM step's latent is decoded as it + emits); False concatenates everything and runs one decode. + """ + if self.audio_vae is None: + raise RuntimeError("TalkerGenerator: audio_vae is None — cannot decode.") + if not latents: + device = next(self.llm.parameters()).device + dtype = next(self.llm.parameters()).dtype + return torch.zeros((1, 1, 0), device=device, dtype=dtype) + + if stream_decode: + return self._stream_decode(latents) + all_lat = torch.cat(latents, dim=1) + waveform, _, _ = self.audio_vae.decode( + all_lat, use_cache=False, stream_state=(None, None, None), last_chunk=True, + ) + return waveform + + def _stream_decode(self, latents: list[torch.Tensor]) -> torch.Tensor: + """Chunked VAE decode with sliding-window pad + silence holder.""" + sr = int(self.audio_vae.config.sample_rate) + decode_pad: torch.Tensor | None = None + sil_cache: dict | None = None + wav_chunks: list[torch.Tensor] = [] + + for i, lat in enumerate(latents): + last_chunk = (i == len(latents) - 1) + if decode_pad is not None: + vae_input = torch.cat([decode_pad, lat], dim=1) + pad_frames = decode_pad.shape[1] + else: + vae_input = lat + pad_frames = 0 + + speech, _, _ = self.audio_vae.decode( + vae_input, + use_cache=False, + stream_state=(None, None, None), + last_chunk=True, + ) + total_frames = vae_input.shape[1] + dcs = speech.shape[-1] // total_frames + speech_chunk = speech[:, :, pad_frames * dcs:][0].detach().float() + speech_chunk, sil_cache = silence_holder( + speech_chunk, sr, sil_cache=sil_cache, last_chunk=last_chunk, + ) + if speech_chunk.numel() > 0: + wav_chunks.append(speech_chunk) + decode_pad = vae_input[:, -self._vae_decode_pad_frames:, :].detach() + + if not wav_chunks: + device = next(self.llm.parameters()).device + dtype = next(self.llm.parameters()).dtype + return torch.zeros((1, 1, 0), device=device, dtype=dtype) + return torch.cat(wav_chunks, dim=-1).unsqueeze(0) + + def trim_trailing_silence(self, waveform: torch.Tensor) -> torch.Tensor: + """Tail-silence trim using the audio VAE's sample rate.""" + if self.audio_vae is None: + return waveform + return trim_trailing_silence(waveform, int(self.audio_vae.config.sample_rate)) + + +__all__ = [ + "TalkerGenerator", + "silence_holder", + "trim_trailing_silence", +] diff --git a/mstar/model/ming_omni_flash/components/vision_encoder.py b/mstar/model/ming_omni_flash/components/vision_encoder.py new file mode 100644 index 00000000..08b69291 --- /dev/null +++ b/mstar/model/ming_omni_flash/components/vision_encoder.py @@ -0,0 +1,149 @@ +"""Vision encoder factory for Ming-flash-omni-2.0. + +The Ming-flash-omni-2.0 vision encoder is ``Qwen3MoeVisionTransformer`` +from the Ming source repo's ``qwen3_moe_vit.py`` (574 LOC). Rather than +fork the file, we resolve it dynamically from the staged Ming source dir +that ``MingFlashOmniModel.__init__`` already symlinks alongside the +snapshot (see ``_prepare_tokenizer_dir``). + +The vllm-omni port (``vision_encoder.py:MingVisionEncoder``) wraps +vLLM's ``Qwen3Omni_VisionTransformer`` because vLLM ships a TP/quant- +aware re-implementation. mstar doesn't have vLLM as a dep, and the +upstream encoder runs at full quality on a single GPU (~1 GB at bf16), +so we use the reference implementation as-is. The encoder is built once +per process and lives on the rank that owns the ``vision_encoder`` graph +node (typically rank 0; see ``configs/ming_flash_omni.yaml``). + +Returned encoder's ``.forward(hidden_states, grid_thw)`` matches the +upstream signature: returns a single ``(N_tokens, out_hidden_size)`` +tensor when ``use_deepstack=False`` (the default for the released ckpt, +since the LLM-side DeepStack splicing isn't enabled in step 4), or a +``(hidden_states, deepstack_feature_lists)`` tuple when +``use_deepstack=True``. +""" + +from __future__ import annotations + +import importlib +import logging +import sys +from pathlib import Path + +import torch +from torch import nn + +from mstar.model.ming_omni_flash.config import VisionEncoderConfig + +logger = logging.getLogger(__name__) + + +def _import_ming_vit(local_dir: str | None = None) -> type[nn.Module]: + """Resolve ``Qwen3MoeVisionTransformer`` from the staged Ming source. + + ``MingFlashOmniModel.__init__`` pushes the snapshot dir onto + ``sys.path`` and symlinks ``qwen3_moe_vit.py`` into it (see + ``_MING_CODE_FILES`` and ``_prepare_tokenizer_dir``). We import via + that path so all the other dynamic imports the file performs + (e.g. ``from configuration_bailingmm2 import ...``) keep resolving + against the same staged tree. + + Args: + local_dir: Optional snapshot dir to put on ``sys.path`` first. + Callers that bypass ``MingFlashOmniModel.__init__`` (tests, + standalone benchmarks) can pass this to avoid an + ``ImportError`` on a fresh interpreter. + """ + if local_dir is not None: + if str(local_dir) not in sys.path: + sys.path.insert(0, str(local_dir)) + # Also push the Ming source repo (if discoverable) so the dynamic + # imports inside qwen3_moe_vit.py resolve cross-file. The snapshot + # is the symlink staging dir; we discover any "real" source by + # following one of the staged symlinks back to its target. + candidate = Path(local_dir) / "qwen3_moe_vit.py" + if candidate.is_symlink(): + ming_root = Path(candidate).resolve().parent + if str(ming_root) not in sys.path: + sys.path.insert(0, str(ming_root)) + + try: + module = importlib.import_module("qwen3_moe_vit") + except ImportError as e: + raise ImportError( + "Could not import qwen3_moe_vit. Ensure MingFlashOmniModel " + "was constructed (which stages the Ming source files), or " + "pass local_dir= explicitly. See " + "PORTING_NOTES.md 'Ming source dependency' for setup." + ) from e + + return module.Qwen3MoeVisionTransformer + + +def build_vision_encoder( + config: VisionEncoderConfig, + use_deepstack: bool = False, + dtype: torch.dtype = torch.bfloat16, + device: str | torch.device = "cpu", + attn_implementation: str = "flash_attention_2", + local_dir: str | None = None, +) -> nn.Module: + """Construct the Ming vision encoder. + + Args: + config: VisionEncoderConfig from MingFlashOmniModelConfig. + use_deepstack: Whether ``.forward()`` returns the per-checkpoint + deepstack feature lists. Off by default — the + LLM-side DeepStack splice lands with step 5 + (thinker graph walks for vision prefill). + dtype: Cast the encoder to this dtype after construction. + bf16 matches the released ckpt; fp16 also works. + device: Final device for the encoder weights. + attn_implementation: Maps to ``config._attn_implementation`` on the + internal Qwen3VLMoeVisionConfig. ``flash_attention_2`` + is mandatory for video performance — sdpa falls + into the per-segment Python loop (see qwen3_omni + model.py:1508-1519 for the same gotcha). + local_dir: Snapshot directory to add to sys.path if the Ming + source modules aren't already importable. + + Returns: + An ``nn.Module`` ready to consume ``(pixel_values, grid_thw)``. + Weight loading is the caller's job — Ming stores vision encoder + weights under the top-level ``vision.*`` prefix in the released + ckpt. + """ + Qwen3MoeVisionTransformer = _import_ming_vit(local_dir=local_dir) + + # Build the internal config the Ming module expects. + module = sys.modules["qwen3_moe_vit"] + InternalConfig = module.Qwen3VLMoeVisionConfig + internal_config = InternalConfig( + depth=config.depth, + hidden_size=config.hidden_size, + hidden_act=config.hidden_act, + intermediate_size=config.intermediate_size, + num_heads=config.num_heads, + in_channels=config.in_channels, + patch_size=config.patch_size, + spatial_merge_size=config.spatial_merge_size, + temporal_patch_size=config.temporal_patch_size, + out_hidden_size=config.out_hidden_size, + num_position_embeddings=config.num_position_embeddings, + deepstack_visual_indexes=list(config.deepstack_visual_indexes), + ) + # The attention path branches on _attn_implementation. The Ming + # source hard-codes it to "flash_attention_2" inside __init__ of + # Qwen3VLMoeVisionAttention, but we set it on the config too for + # the rare debug path that wants to flip to "sdpa" or "eager". + internal_config._attn_implementation = attn_implementation + + encoder = Qwen3MoeVisionTransformer( + internal_config, + use_deepstack=use_deepstack, + ) + encoder = encoder.to(dtype=dtype, device=device) + encoder.eval() + return encoder + + +__all__ = ["build_vision_encoder"] diff --git a/mstar/model/ming_omni_flash/components/zimage_transformer.py b/mstar/model/ming_omni_flash/components/zimage_transformer.py new file mode 100644 index 00000000..1ada9a22 --- /dev/null +++ b/mstar/model/ming_omni_flash/components/zimage_transformer.py @@ -0,0 +1,654 @@ +"""ZImage DiT transformer for Ming-flash-omni-2.0 image generation (step 9b). + +Native mstar port of vllm-omni's ``z_image/z_image_transformer.py`` + +Ming's ``ming_zimage_transformer.py`` subclass. The upstream module is built +on vllm's tensor-parallel linears (``QKVParallelLinear`` / ``MergedColumn`` / +``RowParallel``), a custom fused ``Attention``, vllm's ``RotaryEmbedding``, +and ``CachedTransformer`` — none of which belong in the pure-torch mstar +modeling tree. This reimplementation: + + * uses plain ``nn.Linear`` with the **unfused** parameter names the released + checkpoint actually ships (``attention.to_q/to_k/to_v``, + ``feed_forward.w1/w3``), so the state dict loads with a direct ``copy_`` — + no stacked-param remap (same approach as the byt5 mapper port); + * reimplements the interleaved (GPT-J / ``is_neox_style=False``) RoPE that + vllm's ``RotaryEmbedding(is_neox_style=False)`` applies, the GLIDE/DiT + ``timestep_embedding``, and FP32 ``RMSNorm`` exactly; + * runs attention through ``F.scaled_dot_product_attention``. + +Architecture (released ckpt): dim=3840, 30 main layers + 2 noise-refiner + 2 +context-refiner blocks, 30 heads (head_dim=128), 16-channel latents, 3D axial +RoPE with axes_dims=(32,48,48) summing to the 128-wide head. Caption features +(byt5 + connector, 2560-dim) are embedded, refined, then concatenated with the +patch-embedded image tokens into one unified sequence for the main blocks. + +NOTE — attention masking divergence: vllm-omni *computes* the pad mask but +leaves it unapplied in attention ("we don't support multi prompts now"). This +port applies it (additive ``-inf`` on padded keys) so padded cap/image tokens +cannot leak into real positions. For the dominant batch-size-1 text-to-image +path with sequences already a multiple of ``SEQ_MULTI_OF`` the two are +numerically identical; they only diverge when caption padding is non-zero, +where applying the mask is the correct behavior. +""" + +from __future__ import annotations + +import math +from collections.abc import Iterable + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.utils.rnn import pad_sequence + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 + + +# ============================================================ +# Primitives (native equivalents of the vllm-omni helpers) +# ============================================================ + + +def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10000.0) -> torch.Tensor: + """GLIDE/DiT sinusoidal timestep embedding (cos-then-sin, log-spaced). + + Mirrors ``vllm_omni...timestep_embedding`` byte-for-byte so the adaLN + conditioning matches the validated serving path. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class RMSNorm(nn.Module): + """FP32 RMSNorm with a learnable scale (matches vllm-omni's forward_native).""" + + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + input_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + out = x * torch.rsqrt(variance + self.variance_epsilon) + out = self.weight.to(torch.float32) * out + return out.to(input_dtype) + + +def _rotate_half_interleaved(x: torch.Tensor) -> torch.Tensor: + """GPT-J style rotate: (-x_odd, x_even) interleaved back together.""" + x1 = x[..., ::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def apply_rotary_emb_interleaved(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + """Apply interleaved (is_neox_style=False) RoPE to ``[B, S, H, D]``. + + ``cos``/``sin`` are ``[B, S, D/2]`` (per-axis concatenated half-frequencies + from :class:`RopeEmbedder`); each entry is duplicated to the adjacent pair + to match the interleaved convention, broadcasting over the head axis. + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + # [B, S, D/2] -> [B, S, 1, D] with each freq duplicated to its pair. + cos_r = cos[..., None, :].repeat_interleave(2, dim=-1) + sin_r = sin[..., None, :].repeat_interleave(2, dim=-1) + x_rot = x[..., :ro_dim] + rotated = x_rot * cos_r + _rotate_half_interleaved(x_rot) * sin_r + if ro_dim < x.shape[-1]: + return torch.cat([rotated, x[..., ro_dim:]], dim=-1) + return rotated + + +class RopeEmbedder: + """Per-axis (3D axial) RoPE frequency table, matching vllm-omni's.""" + + def __init__( + self, + theta: float = 256.0, + axes_dims: tuple[int, ...] = (16, 56, 56), + axes_lens: tuple[int, ...] = (64, 128, 128), + ) -> None: + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must match" + self.cos_cached: list[torch.Tensor] | None = None + self.sin_cached: list[torch.Tensor] | None = None + + @staticmethod + def precompute_freqs(dim, end, theta: float = 256.0): + cos_list, sin_list = [], [] + for d, e in zip(dim, end, strict=True): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64) / d)) + timestep = torch.arange(e, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + cos_list.append(torch.cos(freqs)) + sin_list.append(torch.sin(freqs)) + return cos_list, sin_list + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + if self.cos_cached is None: + self.cos_cached, self.sin_cached = self.precompute_freqs(self.axes_dims, self.axes_lens, theta=self.theta) + self.cos_cached = [c.to(device) for c in self.cos_cached] + self.sin_cached = [s.to(device) for s in self.sin_cached] + elif self.cos_cached[0].device != device: + self.cos_cached = [c.to(device) for c in self.cos_cached] + self.sin_cached = [s.to(device) for s in self.sin_cached] + + cos_result, sin_result = [], [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + cos_result.append(self.cos_cached[i][index]) + sin_result.append(self.sin_cached[i][index]) + return torch.cat(cos_result, dim=-1), torch.cat(sin_result, dim=-1) + + +# ============================================================ +# Modules +# ============================================================ + + +class TimestepEmbedder(nn.Module): + def __init__(self, out_size: int, mid_size: int | None = None, frequency_embedding_size: int = 256) -> None: + super().__init__() + if mid_size is None: + mid_size = out_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, mid_size, bias=True), + nn.SiLU(), + nn.Linear(mid_size, out_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + def forward(self, t: torch.Tensor) -> torch.Tensor: + t_freq = timestep_embedding(t, self.frequency_embedding_size) + weight_dtype = self.mlp[0].bias.dtype + if weight_dtype.is_floating_point: + t_freq = t_freq.to(weight_dtype) + return self.mlp(t_freq) + + +class ZImageAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, eps: float = 1e-6) -> None: + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + # Unfused projections — the checkpoint ships to_q/to_k/to_v separately. + self.to_q = nn.Linear(dim, num_heads * self.head_dim, bias=False) + self.to_k = nn.Linear(dim, num_kv_heads * self.head_dim, bias=False) + self.to_v = nn.Linear(dim, num_kv_heads * self.head_dim, bias=False) + self.norm_q = RMSNorm(self.head_dim, eps=eps) + self.norm_k = RMSNorm(self.head_dim, eps=eps) + self.to_out = nn.ModuleList([nn.Linear(dim, dim, bias=False)]) + self.scale = 1.0 / (self.head_dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + bsz, seqlen, _ = hidden_states.shape + query = self.to_q(hidden_states).unflatten(-1, (self.num_heads, self.head_dim)) + key = self.to_k(hidden_states).unflatten(-1, (self.num_kv_heads, self.head_dim)) + value = self.to_v(hidden_states).unflatten(-1, (self.num_kv_heads, self.head_dim)) + + query = self.norm_q(query) + key = self.norm_k(key) + + query = apply_rotary_emb_interleaved(query, cos, sin) + key = apply_rotary_emb_interleaved(key, cos, sin) + dtype = query.dtype + + # [B, S, H, D] -> [B, H, S, D] for SDPA. + q = query.transpose(1, 2) + k = key.transpose(1, 2).to(dtype) + v = value.transpose(1, 2).to(dtype) + + attn_bias = None + if attention_mask is not None: + # bool [B, S] keep-mask -> additive [B, 1, 1, S]. + if attention_mask.dtype == torch.bool: + attn_bias = torch.zeros(bsz, 1, 1, seqlen, dtype=dtype, device=q.device) + attn_bias = attn_bias.masked_fill(~attention_mask[:, None, None, :], float("-inf")) + else: + attn_bias = attention_mask + + enable_gqa = self.num_kv_heads != self.num_heads + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_bias, scale=self.scale, enable_gqa=enable_gqa + ) + out = out.transpose(1, 2).flatten(2, 3).to(dtype) + return self.to_out[0](out) + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int) -> None: + super().__init__() + # Unfused SwiGLU gate/up (checkpoint ships w1 + w3 separately). + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + modulation: bool = True, + ) -> None: + super().__init__() + self.dim = dim + self.layer_id = layer_id + self.attention = ZImageAttention(dim, n_heads, n_kv_heads, eps=1e-5) + self.feed_forward = FeedForward(dim, hidden_dim=int(dim / 3 * 8)) + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential( + nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True), + ) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor | None, + cos: torch.Tensor, + sin: torch.Tensor, + adaln_input: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = ( + self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) + ) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + attn_out = self.attention(self.attention_norm1(x) * scale_msa, attn_mask, cos, sin) + x = x + gate_msa * self.attention_norm2(attn_out) + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) + else: + attn_out = self.attention(self.attention_norm1(x), attn_mask, cos, sin) + x = x + self.attention_norm2(attn_out) + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + return x + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size: int, out_channels: int) -> None: + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), + ) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + scale = 1.0 + self.adaLN_modulation(c) + x = self.norm_final(x) * scale.unsqueeze(1) + return self.linear(x) + + +class ZImageTransformer2DModel(nn.Module): + """Native Z-Image DiT (pure torch). See module docstring for divergences.""" + + def __init__( + self, + all_patch_size: tuple[int, ...] = (2,), + all_f_patch_size: tuple[int, ...] = (1,), + in_channels: int = 16, + dim: int = 3840, + n_layers: int = 30, + n_refiner_layers: int = 2, + n_heads: int = 30, + n_kv_heads: int = 30, + norm_eps: float = 1e-5, + qk_norm: bool = True, + cap_feat_dim: int = 2560, + rope_theta: float = 256.0, + t_scale: float = 1000.0, + axes_dims: tuple[int, ...] = (32, 48, 48), + axes_lens: tuple[int, ...] = (1024, 512, 512), + ) -> None: + super().__init__() + assert len(all_patch_size) == len(all_f_patch_size) + self.in_channels = in_channels + self.out_channels = in_channels + self.all_patch_size = tuple(all_patch_size) + self.all_f_patch_size = tuple(all_f_patch_size) + self.dim = dim + self.n_heads = n_heads + self.rope_theta = rope_theta + self.t_scale = t_scale + + all_x_embedder = {} + all_final_layer = {} + for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size, strict=True): + all_x_embedder[f"{patch_size}-{f_patch_size}"] = nn.Linear( + f_patch_size * patch_size * patch_size * in_channels, dim, bias=True + ) + all_final_layer[f"{patch_size}-{f_patch_size}"] = FinalLayer( + dim, patch_size * patch_size * f_patch_size * self.out_channels + ) + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.all_final_layer = nn.ModuleDict(all_final_layer) + + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock(1000 + i, dim, n_heads, n_kv_heads, norm_eps, modulation=True) + for i in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, modulation=False) + for i in range(n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.cap_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear(cap_feat_dim, dim, bias=True), + ) + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + self.layers = nn.ModuleList( + [ZImageTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, modulation=True) for i in range(n_layers)] + ) + self.axes_dims = tuple(axes_dims) + self.axes_lens = tuple(axes_lens) + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=self.axes_dims, axes_lens=self.axes_lens) + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def unpatchify(self, x: list[torch.Tensor], size: list[tuple], patch_size: int, f_patch_size: int): + pH = pW = patch_size + pF = f_patch_size + bsz = len(x) + assert len(size) == bsz + for i in range(bsz): + Fr, H, W = size[i] + ori_len = (Fr // pF) * (H // pH) * (W // pW) + x[i] = ( + x[i][:ori_len] + .view(Fr // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, Fr, H, W) + ) + return x + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + axes = [ + torch.arange(x0, x0 + span, dtype=torch.int32, device=device) + for x0, span in zip(start, size, strict=True) + ] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def patchify_and_embed( + self, + all_image: list[torch.Tensor], + all_cap_feats: list[torch.Tensor], + patch_size: int, + f_patch_size: int, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out, all_image_size, all_image_pos_ids, all_image_pad_mask = [], [], [], [] + all_cap_pos_ids, all_cap_pad_mask, all_cap_feats_out = [], [], [] + + for image, cap_feat in zip(all_image, all_cap_feats, strict=True): + # ---- Caption + cap_ori_len = len(cap_feat) + cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF + cap_padded_pos_ids = self.create_coordinate_grid( + size=(cap_ori_len + cap_padding_len, 1, 1), start=(1, 0, 0), device=device + ).flatten(0, 2) + all_cap_pos_ids.append(cap_padded_pos_ids) + all_cap_pad_mask.append( + torch.cat( + [ + torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), + torch.ones((cap_padding_len,), dtype=torch.bool, device=device), + ] + ) + ) + all_cap_feats_out.append(torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)) + + # ---- Image + C, Fr, H, W = image.size() + all_image_size.append((Fr, H, W)) + F_tokens, H_tokens, W_tokens = Fr // pF, H // pH, W // pW + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_ori_len + cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + all_image_pos_ids.append(torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)) + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ] + ) + ) + all_image_out.append(torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)) + + return ( + all_image_out, + all_cap_feats_out, + all_image_size, + all_image_pos_ids, + all_cap_pos_ids, + all_image_pad_mask, + all_cap_pad_mask, + ) + + def _unified_prepare(self, x, x_cos, x_sin, cap_feats, cap_cos, cap_sin, x_item_seqlens, cap_item_seqlens): + bsz = x.shape[0] + device = x.device + unified, unified_cos, unified_sin = [], [], [] + for i in range(bsz): + x_len, cap_len = x_item_seqlens[i], cap_item_seqlens[i] + unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) + unified_cos.append(torch.cat([x_cos[i][:x_len], cap_cos[i][:cap_len]])) + unified_sin.append(torch.cat([x_sin[i][:x_len], cap_sin[i][:cap_len]])) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens, strict=True)] + unified_max = max(unified_item_seqlens) + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_cos = pad_sequence(unified_cos, batch_first=True, padding_value=0.0) + unified_sin = pad_sequence(unified_sin, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + return unified, unified_cos, unified_sin, unified_attn_mask + + def forward( + self, + x: list[torch.Tensor], + t: torch.Tensor, + cap_feats: list[torch.Tensor], + patch_size: int = 2, + f_patch_size: int = 1, + ): + assert patch_size in self.all_patch_size + assert f_patch_size in self.all_f_patch_size + bsz = len(x) + device = x[0].device + t = t * self.t_scale + t = self.t_embedder(t) + + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + + # ---- x embed + noise refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max = max(x_item_seqlens) + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + adaln_input = t.type_as(x) + x_pad_mask = torch.cat(x_inner_pad_mask) + x = torch.where(x_pad_mask.unsqueeze(1).expand_as(x), self.x_pad_token.expand(x.shape[0], -1), x) + x = list(x.split(x_item_seqlens, dim=0)) + x_cos, x_sin = self.rope_embedder(torch.cat(x_pos_ids, dim=0)) + x_cos = list(x_cos.split(x_item_seqlens, dim=0)) + x_sin = list(x_sin.split(x_item_seqlens, dim=0)) + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_cos = pad_sequence(x_cos, batch_first=True, padding_value=0.0) + x_sin = pad_sequence(x_sin, batch_first=True, padding_value=0.0) + x_attn_mask = torch.zeros((bsz, x_max), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + for layer in self.noise_refiner: + x = layer(x, x_attn_mask, x_cos, x_sin, adaln_input) + + # ---- cap embed + context refine + cap_item_seqlens = [len(_) for _ in cap_feats] + assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) + cap_max = max(cap_item_seqlens) + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_pad_mask = torch.cat(cap_inner_pad_mask) + cap_feats = torch.where( + cap_pad_mask.unsqueeze(1).expand_as(cap_feats), + self.cap_pad_token.expand(cap_feats.shape[0], -1), + cap_feats, + ) + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_cos, cap_sin = self.rope_embedder(torch.cat(cap_pos_ids, dim=0)) + cap_cos = list(cap_cos.split(cap_item_seqlens, dim=0)) + cap_sin = list(cap_sin.split(cap_item_seqlens, dim=0)) + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_cos = pad_sequence(cap_cos, batch_first=True, padding_value=0.0) + cap_sin = pad_sequence(cap_sin, batch_first=True, padding_value=0.0) + cap_attn_mask = torch.zeros((bsz, cap_max), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_cos, cap_sin) + + # ---- unify + main blocks + unified, unified_cos, unified_sin, unified_attn_mask = self._unified_prepare( + x, x_cos, x_sin, cap_feats, cap_cos, cap_sin, x_item_seqlens, cap_item_seqlens + ) + for layer in self.layers: + unified = layer(unified, unified_attn_mask, unified_cos, unified_sin, adaln_input) + + unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) + unified = list(unified.unbind(dim=0)) + return self.unpatchify(unified, x_size, patch_size, f_patch_size), {} + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Direct state-dict load — our unfused layout matches the checkpoint. + + Unlike vllm-omni (which fuses to_qkv / w13 and remaps), we keep + to_q/to_k/to_v + w1/w3 separate, so the released DiT weights copy in + verbatim. Returns the set of param names covered so callers can assert + completeness. + """ + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if name not in params_dict: + continue + param = params_dict[name] + if param.shape != loaded_weight.shape: + raise ValueError( + f"Shape mismatch loading ZImage DiT weight {name}: " + f"param {tuple(param.shape)} vs checkpoint {tuple(loaded_weight.shape)}" + ) + with torch.no_grad(): + param.copy_(loaded_weight) + loaded_params.add(name) + return loaded_params + + +class MingZImageTransformer2DModel(ZImageTransformer2DModel): + """ZImage DiT with Ming's reference-latent (img2img) support. + + Ming's img2img path concatenates a VAE-encoded reference latent along the + frame axis before patchification, then drops the reference-frame prediction + from the unpatchified output. ``ref_latent`` is threaded through as an + explicit forward arg (the upstream reads it from a global forward-context; + mstar passes it directly). + """ + + def forward( + self, + x: list[torch.Tensor], + t: torch.Tensor, + cap_feats: list[torch.Tensor], + patch_size: int = 2, + f_patch_size: int = 1, + ref_latent: list[torch.Tensor] | None = None, + ): + self._dropping_ref = ref_latent is not None + if ref_latent is not None: + per_item = ref_latent[0].unsqueeze(1).to(dtype=x[0].dtype, device=x[0].device) # [C, 1, H, W] + x = [torch.cat([img, per_item], dim=1) for img in x] + return super().forward(x, t, cap_feats, patch_size=patch_size, f_patch_size=f_patch_size) + + def unpatchify(self, x, size, patch_size, f_patch_size): + out = super().unpatchify(x, size, patch_size, f_patch_size) + if getattr(self, "_dropping_ref", False): + # Drop the reference frame (F==2 -> keep first frame only). + return [t[:, :1, :, :] for t in out] + return out + + +__all__ = ["ZImageTransformer2DModel", "MingZImageTransformer2DModel"] diff --git a/mstar/model/ming_omni_flash/config.py b/mstar/model/ming_omni_flash/config.py new file mode 100644 index 00000000..c0b1d035 --- /dev/null +++ b/mstar/model/ming_omni_flash/config.py @@ -0,0 +1,873 @@ +"""Configuration dataclass for Ming-flash-omni-2.0. + +Mirrors mstar's qwen3_omni pattern (pure ``@dataclass`` tree, +``from_pretrained(local_dir)``, convenience ``@property``s) so the rest of +the framework can read dims off the loaded config without going through +``transformers.PretrainedConfig`` machinery. + +The released checkpoint (``inclusionAI/Ming-flash-omni-2.0``) does NOT match +upstream vllm-omni's flat ``MingFlashOmniConfig`` nesting. On disk only the +``BailingMM2Config`` shape lives at ``config.json``:: + + config.json # thinker: audio_config + llm_config + vision_config + scalars + talker/config.json # talker top-level (BailingTalker2) + talker/llm/config.json # talker LLM backbone (Qwen2) + talker/vae/config.json # talker AudioVAE + transformer/config.json # image-gen DiT (ZImageTransformer2DModel) + vae/config.json # image-gen VAE + scheduler/scheduler_config.json # image-gen diffusion scheduler + byt5/google__byt5-smal/config.json # image-gen text encoder + connector/config.json # image-gen connector + mlp/config.json # image-gen projector + +This loader follows the on-disk layout: it parses ``config.json`` for the +thinker path and lazy-loads talker / image-gen from sibling subdirs when +those exist. Talker and image-gen are SKELETON dataclasses today — exhaustive +field semantics land with the talker port (step 6 of PORTING_NOTES.md) and +the image-gen port (step 9). +""" + +from __future__ import annotations + +import json +import logging +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Thinker LLM (Ling-2.0 sparse MoE — model_type "bailing_moe_v2") +# --------------------------------------------------------------------------- + +@dataclass +class ThinkerLLMConfig: + """Ling-2.0 sparse-MoE thinker (BailingMoeV2). + + Field set is the union of what upstream + ``vllm_omni/transformers_utils/configs/ming_flash_omni.py:BailingMoeV2Config`` + declares and what the released ``llm_config`` actually populates. + Defaults reflect the released ckpt, not the upstream class defaults + (which were trained for a smaller config). + """ + + # Dims + vocab_size: int = 157184 + hidden_size: int = 4096 + intermediate_size: int = 9216 + num_hidden_layers: int = 32 + num_attention_heads: int = 32 + num_key_value_heads: int = 4 + head_dim: int | None = None # computed in __post_init__ + + # Norm / activation + hidden_act: str = "silu" + rms_norm_eps: float = 1e-6 + use_qk_norm: bool = True + use_qkv_bias: bool = False + use_bias: bool = False + tie_word_embeddings: bool = False + + # Position / RoPE + max_position_embeddings: int = 32768 + rope_theta: float = 2_400_000.0 + rope_scaling: dict[str, Any] | None = None + partial_rotary_factor: float = 0.5 + + # MoE + num_experts: int = 256 + num_shared_experts: int = 1 + num_experts_per_tok: int = 8 + moe_intermediate_size: int = 1024 + first_k_dense_replace: int = 1 + router_type: str = "MultiRouter" + n_group: int = 8 + topk_group: int = 4 + moe_router_topk_scaling_factor: float = 2.5 + norm_topk_prob: bool = True + use_expert_bias: bool = True + output_router_logits: bool = False + + # Misc + pad_token_id: int = 156892 + eos_token_id: int = 156895 + use_interleaved_frame_timestamp: bool = True + + # Multimodal token IDs (used by the prefill processor / chat template). + # Defaults mirror the actual tokenizer (`tokenizer.json` added_tokens at + # the released ckpt; cross-checked against Jonathan1909's patched config + # and vllm-omni's BailingMoeV2Config defaults). Two gotchas the on-disk + # `config.json` of `inclusionAI/Ming-flash-omni-2.0` introduces: + # * `video_start_token` is mislabeled as 157159 (= ) in the + # ckpt config; the real `