diff --git a/benchmark/base.py b/benchmark/base.py index 9e35badd..bba02904 100644 --- a/benchmark/base.py +++ b/benchmark/base.py @@ -214,6 +214,78 @@ def get_supported_modalities(self): } +class MingFlashOmni(Model): + """Ming-flash-omni-2.0 (inclusionAI), the Ling-2.0 sparse-MoE omni model + (100B total / 6B active params) released 2026-02-11. + + Reachable today via the vllm-omni server using + ``vllm_omni/deploy/ming_flash_omni.yaml`` (thinker+talker) or + ``ming_flash_omni_thinker_only.yaml`` (text-only). The native ``ours`` / + ``ours_openai`` backends will work once the mminf-side port under + ``mminf/model/ming_omni_flash/`` is finished — until then, point the + benchmark at a vllm-omni instance with ``--inference-system vllm_omni``. + + Wire shape mirrors :class:`Qwen3Omni`: standard OpenAI + ``/v1/chat/completions`` with multimodal content parts. The role remap + from OpenAI's ``user``/``assistant``/``system`` to Ming's internal + ``HUMAN``/``ASSISTANT``/``SYSTEM`` happens inside the jinja chat_template + shipped in ``tokenizer_config.json`` — vllm-omni renders prompts via + ``tokenizer.apply_chat_template`` which uses that jinja, so the benchmark + sends the standard OpenAI shape unchanged. + + Caveat: Ming ALSO ships a Python-side ``BailingMM2Processor.apply_chat_template`` + (in the Ming source repo) that is strict about uppercase roles and would + AssertionError on ``user``/``assistant``. mminf's native port uses that + processor for full multimodal preprocessing (vision/audio feature + extraction) and remaps roles in ``process_prompt`` accordingly — see + ``mminf/model/ming_omni_flash/`` and its tokenizer tests. + """ + + def get_hf_url(self): + return "inclusionAI/Ming-flash-omni-2.0" + + def get_openai_system_message(self) -> Optional[dict]: + # Ming-flash-omni-2.0's cookbook uses ``sys_prompt_exp=None`` and + # ``use_cot_system_prompt=False`` by default — there's no required + # "You are Ming…"-style preamble equivalent to Qwen3-Omni's. The HF + # processor's chat_template fills in any internal system text on its + # own, and vllm-omni's serving layer goes through that template via + # ``trust_remote_code``. Sending an explicit system message here only + # risks overriding the model's own defaults, so default to None. + return None + + def get_model_kwargs(self, request_type: RequestType): + # Cap thinker output at 256 tokens for cross-system fairness — same + # rationale as Qwen3Omni: comparable runs need a fixed decode budget. + # vllm-omni's released stage default is ``max_tokens: 2048`` (see + # ``vllm_omni/deploy/ming_flash_omni.yaml`` stage 0); we lower it for + # benchmark parity. Send both ``max_tokens`` (OpenAI convention) and + # ``max_output_tokens`` (mminf's native kwarg) so the cap survives + # whichever ``--inference-system`` is in use. + # + # Force greedy on the thinker (``temperature=0.0`` at payload top-level + # in VLLMOmni.send_request) for deterministic text. The talker's + # sampling defaults live server-side in the deploy yaml + # (``stage_id: 1`` → ``temperature: 0.0`` per the released config) — + # we don't override them here. + return { + "max_tokens": 256, + "max_output_tokens": 256, + } + + def get_supported_modalities(self): + return { + RequestType.T2T, + RequestType.T2S, + RequestType.I2T, + RequestType.I2S, + RequestType.A2T, + RequestType.A2S, + RequestType.V2T, + RequestType.V2S, + } + + class Pi05(Model): """Physical Intelligence Pi0.5 VLA model. @@ -268,6 +340,7 @@ class ModelType(Enum): BAGEL = "bagel" ORPHEUS = "orpheus" QWEN3OMNI = "qwen3omni" + MING_FLASH_OMNI = "ming_flash_omni" PI05 = "pi05" VJEPA2AC = "vjepa2ac" @@ -278,6 +351,8 @@ def inst(self, **kwargs) -> Model: return Orpheus(**kwargs) if self == ModelType.QWEN3OMNI: return Qwen3Omni(**kwargs) + if self == ModelType.MING_FLASH_OMNI: + return MingFlashOmni(**kwargs) if self == ModelType.PI05: return Pi05(**kwargs) if self == ModelType.VJEPA2AC: diff --git a/benchmark/vllm_omni_instructions.md b/benchmark/vllm_omni_instructions.md index 2934c6c9..3e534544 100644 --- a/benchmark/vllm_omni_instructions.md +++ b/benchmark/vllm_omni_instructions.md @@ -21,4 +21,93 @@ CUDA_VISIBLE_DEVICES=3 vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8000 ### for qwen3-omni: ``` vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml -``` \ No newline at end of file +``` + +### for ming-flash-omni-2.0: + +The released `inclusionAI/Ming-flash-omni-2.0` ckpt (~238 GB / 42 shards) +does NOT load cleanly into vllm-omni's `MingFlashOmniForConditionalGeneration` +class as-is. Two patches are needed (one-time setup): + +1. **Replace metadata files.** vllm-omni's model class uses + `Qwen2VLImageProcessor` + `MingWhisperFeatureExtractor` (its own + registered classes), while the inclusionAI snapshot declares the + `BailingMM2*` processor variants via `auto_map` and `trust_remote_code`. + Use `Jonathan1909/Ming-flash-omni-2.0`'s `preprocessor_config.json`, + `config.json` (auto_map stripped), and `tokenizer*.json` instead. + +2. **Replace the talker weights.** vllm-omni's `MingFlashOmniTalker` expects + weights under `audio_vae.*` but the inclusionAI talker safetensors uses + `audio.*` prefix. Jonathan1909 reshipped the talker with renamed weights + (~1.5 GB). + +Building a hybrid snapshot avoids re-downloading the 200+ GB thinker weights: + +```bash +# 1. Make sure the inclusionAI thinker shards are cached +huggingface-cli download inclusionAI/Ming-flash-omni-2.0 \ + --include="model-*.safetensors" --include="model.safetensors.index.json" + +# 2. Pull only Jonathan1909's metadata + talker (no thinker weights) +huggingface-cli download Jonathan1909/Ming-flash-omni-2.0 \ + --include="*.json" --include="*.py" --include="*.txt" --include="*.mvn" \ + --include="talker/**" \ + --cache-dir /dev/shm/hf-cache # or any path with ~3 GB free + +# 3. Stitch the two together +INCL=$(huggingface-cli scan-cache | grep inclusionAI/Ming-flash-omni-2.0 \ + | awk '{print $NF}')/snapshots/$(ls ~/.cache/huggingface/hub/models--inclusionAI--Ming-flash-omni-2.0/snapshots | head -1) +JONA=/dev/shm/hf-cache/models--Jonathan1909--Ming-flash-omni-2.0/snapshots/* +HYBRID=/dev/shm/ming-hybrid +mkdir -p $HYBRID +for f in $INCL/model-*.safetensors; do ln -s "$f" "$HYBRID/$(basename $f)"; done +for f in $JONA/*; do + base=$(basename "$f") + [ -L "$HYBRID/$base" ] && rm "$HYBRID/$base" + ln -s "$f" "$HYBRID/$base" +done +``` + +Then serve and benchmark: + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve /dev/shm/ming-hybrid \ + --omni --port 8091 --host 0.0.0.0 --trust-remote-code \ + --stage-configs-path /tmp/vllm-omni/vllm_omni/model_executor/stage_configs/ming_flash_omni.yaml + +# Wait for "Application startup complete" then: +MODEL=ming_flash_omni INF_SYS=vllm_omni TASK=text_to_text \ + URL=http://0.0.0.0:8091 ./benchmark/run_benchmark.sh +``` + +NOTE: vllm-omni's `/v1/chat/completions` rejects unknown model ids, so the +client must send `"model": "/dev/shm/ming-hybrid"` (the served path), not +`"inclusionAI/Ming-flash-omni-2.0"`. Easiest is to monkey-patch +`MingFlashOmni.get_hf_url` before calling the benchmark runner: + +```python +from benchmark.base import MingFlashOmni +MingFlashOmni.get_hf_url = lambda self: "/dev/shm/ming-hybrid" +``` + +Or pass `--served-model-name inclusionAI/Ming-flash-omni-2.0` to `vllm serve` +(untested; would also work in principle). + +#### Modalities exercised on a local 4×H100 run (2026-06-06) + +| Task | Status | Notes | +|---|---|---| +| T2T (text → text) | ✅ | offline B=1: 110 tok/s, closed-loop C=32: **1060 tok/s** (full scaling sweep in [`results/ming_t2t_sweep/SUMMARY.md`](../results/ming_t2t_sweep/SUMMARY.md)) | +| I2T (image → text) | ✅ | TTFT 87 ms, ~100 tok/s on Food101 | +| A2T (audio → text) | ✅ | English transcription + Chinese audio QA both work | +| T2S (text → speech) | ✅ | RTF 0.14, 24 kHz mono PCM via harness; 44.1 kHz via direct OpenAI path | +| V2T (video → text) | ✅ | Local Ming demo mp4s; coherent descriptions (`yoga.mp4` → yoga pose narration, `cup_change.mp4` → "shell game") | +| V2S (video → speech) | ✅ | Local Ming demo mp4s; 2-3 MB WAV/clip @ 44.1 kHz | +| I2S (image → speech) | ✅ | Food101 in, ~7 s/req for ~48 s of audio | +| A2S (audio → speech) | ✅ | Ming sample wavs; 0.5-3 MB WAV/clip @ 44.1 kHz | +| T2I / I2I (image gen) | not wired | requires `ming_flash_omni_image.yaml` + a benchmark wrapper similar to BAGEL's `/v1/images/generations` path | + +The V2T/V2S/A2S runs sidestep the bench harness's `UCF101Dataset` and +`LibriSpeechDataset` (both want fresh HF-Hub downloads) by hitting +`/v1/chat/completions` directly with base64-inlined media from local files +(Ming repo's `figures/cases/*.mp4` and `data/wavs/*.wav`). \ No newline at end of file diff --git a/configs/ming_flash_omni.yaml b/configs/ming_flash_omni.yaml new file mode 100644 index 00000000..d3b2fe8c --- /dev/null +++ b/configs/ming_flash_omni.yaml @@ -0,0 +1,31 @@ +# Ming-flash-omni-2.0 — thinker + talker + audio VAE. +# +# WIP: the native mminf model port at mminf/model/ming_omni_flash/ is a +# scaffold (every abstractmethod raises NotImplementedError), so +# `mminf-serve --config configs/ming_flash_omni.yaml` will fail at startup +# until that port lands. Until then, benchmark Ming-flash-omni-2.0 via the +# vllm-omni server (see benchmark/vllm_omni_instructions.md). +# +# Target topology mirrors vllm-omni/deploy/ming_flash_omni.yaml: +# * Thinker (Ling-2.0 sparse MoE LLM, the multimodal understanding core) +# wants TP=4 across GPUs 0-3. +# * Talker (CFM-based audio generator) colocates on GPU 3. +# * Audio VAE (codec -> waveform) and stateless encoders (vision / audio) +# can ride on rank 0. +# +# Node names below are the placeholders the scaffold will reference; rename +# in lockstep with mminf/model/ming_omni_flash/ming_omni_flash_model.py once +# the graph walks are implemented. + +model: "ming_flash_omni" +max_seq_len: 32768 +node_groups: + - node_names: [audio_encoder, vision_encoder, AudioVAE] + ranks: [0] + + - node_names: [Thinker] + ranks: [0, 1, 2, 3] + tp_size: 4 + + - node_names: [Talker] + ranks: [3] diff --git a/configs/ming_flash_omni_thinker_only.yaml b/configs/ming_flash_omni_thinker_only.yaml new file mode 100644 index 00000000..8036af8e --- /dev/null +++ b/configs/ming_flash_omni_thinker_only.yaml @@ -0,0 +1,21 @@ +# Ming-flash-omni-2.0 — thinker-only deploy (text out, no talker). +# +# TP=8 across 8 H100s. Per-rank shard_inter = 1024/8 = 128; +# experts.gate_up_proj is (256, 2*128, 4096) per rank, ~33 GB across +# 31 MoE layers. With embed + lm_head + attention + dense layer 0 + +# KV cache, ~40 GB per rank fits the 80 GB H100s comfortably. +# +# TP=4 OOMs at ~78.5 / 80 GB per rank even with +# PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True (re-verified +# 2026-06-08; loader streaming overhead pushes past the 80 GB limit). +# TP=8 halves the model footprint with plenty of headroom. +# +# Audio / vision / talker / image-gen are step 4+; this config is for +# text-only T2T benchmarking and the first mminf-served Ming forward. + +model: "ming_flash_omni" +max_seq_len: 32768 +node_groups: + - node_names: [Thinker] + ranks: [0, 1, 2, 3, 4, 5, 6, 7] + tp_size: 8 diff --git a/mminf/model/base.py b/mminf/model/base.py index a127f68f..71088183 100644 --- a/mminf/model/base.py +++ b/mminf/model/base.py @@ -253,19 +253,29 @@ def get_worker_graphs(self, config_path: str) -> list[WorkerGraph]: if node_groups is None: raise KeyError("Config must define `node_groups`.") + # Nodes this deploy actually provides. A graph walk referencing a + # node absent from node_groups (e.g. the encoder / talker walks in + # a thinker-only deploy) is skipped rather than KeyError'ing during + # worker-graph division — that deploy simply can't serve the walk. + available_nodes: set[str] = set() + for group in node_groups: + available_nodes.update(group["node_names"]) + # TODO: merge identical worker graphs from different graph walks - return sum( - [ + worker_graphs: list[WorkerGraph] = [] + for graph_walk, graph in self.get_graph_walk_graphs().items(): + required = set(graph.get_nodes().keys()) + if not required <= available_nodes: + continue + worker_graphs.extend( self._get_worker_graphs_for_graph_walk(graph_walk, graph, node_groups) - for graph_walk, graph in self.get_graph_walk_graphs().items() - ], - start=[], - ) - + ) + return worker_graphs + def get_sharding_config(self, config_path: str) -> ShardingConfig: with open(config_path, "r") as f: config = yaml.safe_load(f) - + sharding_config = self.get_default_sharding_config() # Derive sharding groups from node_groups with tp_size > 1. The diff --git a/mminf/model/ming_omni_flash/PORTING_NOTES.md b/mminf/model/ming_omni_flash/PORTING_NOTES.md new file mode 100644 index 00000000..0e6a7d29 --- /dev/null +++ b/mminf/model/ming_omni_flash/PORTING_NOTES.md @@ -0,0 +1,538 @@ +# Ming-flash-omni-2.0 — porting notes + +Native mminf port of `inclusionAI/Ming-flash-omni-2.0`. This directory is a +scaffold today; everything below is the punch list to make it real. + +## Status + +- `benchmark/base.py` has `MingFlashOmni` + `ModelType.MING_FLASH_OMNI`. + Benchmarking against a vllm-omni server **works today** with + `--inference-system vllm_omni` (see `benchmark/vllm_omni_instructions.md`). +- Step 1 (config port) — DONE. `mminf/model/ming_omni_flash/config.py` + loads the released ckpt; 10 tests in `test/modular/test_ming_flash_omni_config.py`. +- Step 2 (tokenizer + processor wiring) — DONE. + `MingFlashOmniModel.__init__` resolves the snapshot, stages Ming source + files (see "Ming source dependency" below), and loads + `BailingTokenizer` + `BailingMM2Processor` with graceful fallback; + 11 tests in `test/modular/test_ming_flash_omni_tokenizer.py`. +- Everything else in `MingFlashOmniModel` still raises `NotImplementedError` + — `mminf-serve --config configs/ming_flash_omni.yaml` will fail at + startup until step 3+ lands. + +## Ming source dependency (loading the tokenizer/processor) + +The released HF checkpoint `inclusionAI/Ming-flash-omni-2.0` ships +**only weights and sub-dir configs**. The tokenizer/processor Python +modules (`configuration_bailingmm2.py`, `tokenization_bailing.py`, +`processing_bailingmm2.py`, etc.) live in the source repo at +https://github.com/inclusionAI/Ming . To load the tokenizer/processor: + +```bash +# 1. Clone the source repo +git clone https://github.com/inclusionAI/Ming.git /path/to/Ming + +# 2. Install extra Python deps Ming's modules depend on +pip install opencv-python-headless openai-whisper + +# 3. Tell mminf where to find the source repo +export MING_CODE_DIR=/path/to/Ming +# (or pass ming_code_dir="/path/to/Ming" to MingFlashOmniModel) +``` + +`MingFlashOmniModel.__init__` (via `_prepare_tokenizer_dir`) symlinks +the required .py and .json files from `$MING_CODE_DIR` alongside the +snapshot's `config.json` so transformers' `trust_remote_code` machinery +can resolve them. The snapshot dir is also pushed onto `sys.path` so +the dynamic-module loader's sibling imports resolve. + +## Role-handling nuance (chat templates) + +Ming-flash-omni-2.0 ships **two** chat-template implementations with +**different role conventions**: + +- `tokenizer.apply_chat_template(messages)` — uses the **jinja template + in `tokenizer_config.json`**. Accepts standard OpenAI roles + (`user` / `assistant` / `system`) and remaps them to Ming's uppercase + `HUMAN` / `ASSISTANT` / `SYSTEM` inside the template. This is the path + vllm-omni's serving layer uses → the benchmark side works unchanged. + +- `processor.apply_chat_template(messages, sys_prompt_exp=..., use_cot_system_prompt=...)` + — uses the **Python implementation in `BailingMM2Processor`** (Ming + source repo). **Strict**: asserts `role in [HUMAN, ASSISTANT]` and + raises `AssertionError` on lowercase OpenAI roles. The native mminf + `process_prompt` (step 7) will need this path for the multimodal + preprocessing (vision feature extraction, audio padding, etc.) and + must explicitly remap roles before calling. + +## Upstream reference + +Treat the vllm-omni port as the source of truth for architecture. Files to +read (totals ~6.5 KLOC): + +| Concern | vllm-omni file | +|---|---| +| Pipeline glue | `vllm_omni/model_executor/models/ming_flash_omni/pipeline.py` (141 LOC) | +| Top-level model | `ming_flash_omni.py` (255 LOC) | +| Thinker (Ling-2.0 MoE + multimodal) | `ming_flash_omni_thinker.py` (1,164 LOC) | +| Talker (CFM + LLM) | `ming_flash_omni_talker.py` (586) + `talker_module.py` (1,145) | +| Audio VAE | `audio_vae.py` (392) | +| Audio encoder | `audio_encoder.py` (246) | +| Vision encoder | `vision_encoder.py` (125) + `projectors.py` (184) | +| Ling MoE backbone | `modeling_bailing_moe_v2.py` (892) | +| Prompt utils | `prompt_utils.py` (134) — `IMAGE_PATCH_TOKEN`, `DEFAULT_NUM_QUERY_TOKENS=256`, TTS caption template | +| Text processing | `text_processing.py` (535) | +| Speaker presets | `spk_embedding.py` (44) + `voice_presets.py` (289) | +| Config | `vllm_omni/transformers_utils/configs/ming_flash_omni.py` (420) | +| Stage input processor | `vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py` | +| ImageGen pipeline | `vllm_omni/diffusion/models/ming_flash_omni/` | +| Deploy yamls | `vllm_omni/deploy/ming_flash_omni{,_image,_thinker_only,_tts}.yaml` | + +## mminf parallels + +Mirror the structure of `mminf/model/qwen3_omni/` end-to-end. That model is +the closest analog (multimodal thinker + speech talker + vocoder), and the +graph-walk / partition / streaming patterns transfer 1:1. + +| mminf surface | Qwen3-Omni reference | Ming-flash-omni equivalent | +|---|---|---| +| Model class | `qwen3_omni_model.py` (1,529) | `ming_omni_flash_model.py` | +| Submodules | `submodules.py` (2,016) | `submodules.py` (TODO) | +| Config | `config.py` (544) | `config.py` | +| Talker | `components/talker.py` (549) + `code2wav.py` (534) | `components/talker.py` + `audio_vae.py` (TODO) | +| Thinker | `components/thinker.py` (259) | `components/thinker.py` (TODO) | +| Attention / RoPE | `components/attention.py` + `rope.py` | likely shareable; check Ling-2.0 attention shape | + +## Punch list (in order) + +1. **Config port — DONE.** `mminf/model/ming_omni_flash/config.py` + loads `config.json` + sibling subdir configs (talker / image-gen) into + a dataclass tree. Verified via 10 tests in + `test/modular/test_ming_flash_omni_config.py`. + +2. **Tokenizer + processor — DONE.** `MingFlashOmniModel.__init__` + resolves the snapshot, stages Ming source files alongside it (see + "Ming source dependency" above), and loads `BailingTokenizer` + + `BailingMM2Processor` with graceful fallback. The chat-template role + handling has two paths (see "Role-handling nuance" above); the native + `process_prompt` (step 7) will use the strict processor path and must + remap roles. Verified via 11 tests in + `test/modular/test_ming_flash_omni_tokenizer.py`. + +3. **Ling-2.0 thinker LLM port — IN PROGRESS.** + - **3a — DONE** (`components/router.py`, `rope.py`, `attention.py`): + architecture-novel pieces (MultiRouter group-limited top-k, partial + 3D `video_rope`, QK-norm attention). 12 tests in + `test/modular/test_ming_flash_omni_components.py`. + - **3b — DONE** (`components/moe.py`, `decoder_layer.py`, `model.py`): + `LingMoeBlock` (3-router text/image/audio with `torch.where` + per-token swap), `LingDecoderLayer` (hybrid dense/MoE per + `first_k_dense_replace`), full `LingMoeModel` (embed + N layers + + RMSNorm + lm_head). 9 tests in `test_ming_flash_omni_model.py`. + - **3c — DONE** (`loader.py`): weight loader that maps the released + ckpt's `model.model.*` namespace to `LingMoeModel`'s state_dict, + with per-expert gate/up/down fusion into the packed + `experts.gate_up_proj` tensor via mminf's existing + `WeightConverter` machinery. Real-ckpt smoke test loads embed + + dense layer 0 + lm_head from the released shards and runs a + forward — output is finite bf16 logits at the expected + `(T, vocab_size)` shape. 6 tests in + `test_ming_flash_omni_loader.py` (4 pure-Python + 2 CUDA+snapshot). + - **3e — DONE** (TP-aware variants): `LingAttention` uses + `QKVParallelLinear` + `RowParallelLinear` (per-rank heads + dense + row-parallel); `LingMoeBlock` shards fused experts by + `shard_inter = moe_intermediate_size / tp_size` and uses mminf's + existing `_gate_up_weight_loader` / `_down_proj_weight_loader` + for per-rank weight slicing; dense layer-0 MLP uses + `ParallelGatedMLP`; `LingMoeModel` threads `comm_group` through + every decoder layer. Weight loader refactored onto mminf's + `load_hf_weights` + 770 `StackedParamRule`s (3 per expert × + num_experts + dense MLP + synthetic QKV). The packed + `attention.query_key_value.weight` from the checkpoint is split + into synthetic `q_proj` / `k_proj` / `v_proj` keys by + `_split_packed_qkv` so `QKVParallelLinear`'s standard weight + loader handles per-rank head slicing. + + **Verified via TP=8 mminf-serve smoke** (8 H100s): server starts, + all 8 workers load 507 thinker params each (one per packed + parameter; per-rank ~40 GB), KVCacheEngine warmup_and_capture + completes, torch.compile applies, dedicated GPU threads spin up, + port 8092 listens. Per-rank model + KV cache is well under 80 GB. + TP=4 was tried first and OOMed at 78.58 GB / 80 GB; TP=8 has + plenty of headroom. + + **Known gap (resolved in 3f)**: see step 3f. + + - **3d — DONE** (cache wiring + submodule + engine integration): + `LingAttention` now uses `cache_handle.run_attention` for paged + KV-cache attention (keeps the custom partial-3D rope inline); + `BailingMoeV2ThinkerSubmodule` in `submodules.py` implements + `prepare_inputs` / `preprocess` / `forward` / `check_stop` for + the prefill + decode walks; `MingFlashOmniModel.__init__` no + longer raises NotImplementedError and all Model ABC methods + (`get_kv_cache_config`, `get_graph_walk_graphs`, `get_partitions`, + `process_prompt`, `postprocess`, `get_submodule`, etc.) are + implemented for the text-only path. 12 tests in + `test_ming_flash_omni_model.py` + the existing 30+ Ming tests + still pass. + + **Verified via `mminf-serve` smoke**: the engine instantiates the + model class, calls `get_submodule("Thinker")`, and reaches + `load_thinker_weights` — failing with OOM on a single GPU + (loaded ~75 GB before exhausting the 80 GB H100). The engine + plumbing itself works; **single-GPU OOM is the expected blocker + until step 3e brings TP-aware variants**. To actually serve the + full 100B model we need TP=4 distributing the experts + attention + across 4 H100s. + + - **3f — DONE** (graph wiring for the text-only generate loop): + two model-side bugs blocked the first end-to-end `/generate` + response on top of step 3e. + + (a) `BailingMoeV2ThinkerSubmodule` had no `postprocess` hook. + The decode loop's output edge is named `text_inputs` so the + loop feeds the previous sampled token back into the next + iteration. `submodule.forward` returns `{"logits": [...]}`; + the KV-cache engine samples into `{"new_token": [...]}`; but + the graph router needs a `text_inputs` key under that name. + Added `postprocess` that rebinds `new_token → text_inputs`, + mirroring :meth:`OrpheusLLMSubmodule.postprocess`. Without + this, every decode iteration hit `IndexError` at + `prepare_inputs` (`text_inputs` list arrived empty), which + is the same symptom the 3e notes called out. + + (b) The prefill / decode output edges used `EMPTY_DESTINATION` + + `conductor_new_token=True` rather than + `EMIT_TO_CLIENT` + `output_modality="text"`. With (a) fixed + the loop produced tokens, but the API server received + `{"outputs": {}}` because no edge routed `new_token` to the + client. Switched to Qwen3-Omni's pattern: prefill emits its + first token to the client and the decode-loop section emits + each subsequent sampled token via a parallel + `EMIT_TO_CLIENT, name="new_token", output_modality="text"` + edge alongside the `text_inputs` loopback. + + **Environment / dependency patches collected along the way** + (not Ming code, but required on this box to reach a working + forward): + + * `BailingTokenizer` doesn't load under transformers >= 5.0: + (i) accessor properties reference `self.verbose`, removed + in 5.x — set a class-level `verbose = False`; (ii) + `__init__` sets `self.add_bos_token` before + `super().__init__()` and the 5.x setter calls + `update_post_processor()` which dereferences the not-yet- + built `self._tokenizer`. Both patches live in + `_patch_bailing_tokenizer_for_transformers5` in + `ming_omni_flash_model.py`, applied once after the first + `AutoTokenizer.from_pretrained` raises an `AttributeError` + matching either signature. + + * `LingMoeBlock._dispatch_tp` always called + `mminf.utils.fused_moe.fused_experts`, which hard-requires + `sgl_kernel`. On boxes where the installed `sgl_kernel.so` + has an ABI mismatch against the running torch (the + importlib-level error doesn't propagate as a normal + `ImportError` until you actually call into the .so), this + crashes mid-forward. Added a naive fallback that calls + `dispatch_experts_fused` on each rank's expert shard then + all-reduces; math is equivalent because sum-over-TP and + sum-over-top-k commute. + + * `flashinfer-python` 0.6.6 ships a Python wrapper that + passes 10 args to the bundled `top_p_sampling_from_probs` + op while `flashinfer-jit-cache` 0.6.2 expects 8. Pin + `flashinfer-python==0.6.2` (via `pip install --no-deps`) + to match the jit-cache; the alternative would be rebuilding + the cache against 0.6.6. + + **Verified via `mminf-serve` smoke (TP=8 on 8 H100s)**: + /generate returns real model text.
+ + Note: expert layout doesn't share with Qwen3-Omni's MoE block — + `MultiRouter` (3 gates + modality masks) is Ling-specific, and + the per-expert fused weight tensor has its own shape constraints. + +4. **Vision + audio encoders.** Stateless graph nodes. Port + `vision_encoder.py` + `projectors.py` and `audio_encoder.py`. Wire into + the prefill graph walks. + + - **4a — DONE** (`components/projectors.py`, + `components/vision_encoder.py`, `components/audio_encoder.py`): + pure-port encoder + projector modules with weight-key parity + against the released ckpt's top-level prefixes + (`vision.*`, `audio.*`, `linear_proj.*`, `linear_proj_audio.*`). + + * `MingVisionProjector` / `MingAudioProjector` mirror the + `nn.Sequential` chains built inline in + `modeling_bailingmm2.py` (Linear→GELU→Linear for vision, + Conv1d→Transpose→GELU→Linear→Transpose for audio). Layer + indices match the on-disk keys (`linear_proj.{0,2}` vision, + `linear_proj_audio.{0,3}` audio). + + * `build_vision_encoder` constructs Ming's + `Qwen3MoeVisionTransformer` via dynamic import from the staged + Ming source dir (same path used by the tokenizer + processor). + Reused as-is rather than forked — no vLLM dep, ~1 GB at bf16, + runs on a single GPU. + + * `MingAudioEncoder` is a self-contained port of vllm-omni's + packed-sequence Whisper encoder (~250 LOC) — no + `openai-whisper` runtime dep, optional flash-attn varlen fast + path with a manual fallback. Param names match upstream + Whisper (`query` / `key` / `value` / `out`, + `mlp.{0,2}.{weight,bias}`) so the released ckpt's + `audio.blocks.N.*` keys load by state-dict equality. + + * 17 tests in `test/modular/test_ming_flash_omni_encoders.py`: + 12 pure-Python (projector shape / layer indices / forward / + audio encoder weight-key parity / packed-attention fallback + shape) + 1 snapshot-gated (vision encoder builds from the + real `VisionEncoderConfig`) + 1 CUDA-gated (forward smoke + under eager attention — currently skipped on this box for + missing libnvrtc-builtins, not a code bug). + + - **4b — DONE** (encoder weight loading): `loader.py` now exposes + `load_vision_encoder_weights`, `load_audio_encoder_weights`, + `load_vision_projector_weights`, `load_audio_projector_weights` + on top of a shared `_load_prefixed_state_dict` helper. None of + these are TP-aware — vision + audio encoders colocate on rank 0 + in the typical topology (see `configs/ming_flash_omni.yaml`) so + a plain prefix-strip + `load_state_dict` path suffices. The + projector loaders also prepend `proj.` to the stripped key so + the on-disk `linear_proj.{0,2}.*` / `linear_proj_audio.{0,3}.*` + keys hit the `nn.Sequential` slot by integer index. + + Verified by 4 snapshot-gated tests in + `test_ming_flash_omni_encoders.py` against the real + `/dev/shm/ming-hybrid` ckpt — all four prefixes load strictly + (no missing / unexpected). The audio encoder's + `positional_embedding` is loaded as a buffer (overrides the + sinusoidal init); the vision encoder loads all 27 blocks + + merger + deepstack_merger_list cleanly. + +5. **Thinker graph walks.** `prefill_text`, `prefill_audio`, `prefill_vision`, + `prefill_video`, `thinker_decode`. Follow Qwen3-Omni's pattern for + conditional walks based on `input_modalities`. + + - **5a — DONE** (`submodules.py`, `ming_omni_flash_model.py`): the two + encoder NodeSubmodules and their construction paths. + + * `VisionEncoderSubmodule` wraps Ming's `Qwen3MoeVisionTransformer` + + `MingVisionProjector`, mirrors + `modeling_bailingmm2.extract_image_feature` (encoder → projector + → L2 norm). `prepare_inputs` raises clearly on missing + `pixel_values` / `image_grid_thw` and promotes 1-D + `[T, H, W]` grid_thw to `(1, 3)`. + + * `AudioEncoderSubmodule` wraps `MingAudioEncoder` + + `MingAudioProjector`. Accepts either a single `(n_mels, T)` clip + or a `(B, n_mels, T)` batched tensor and optionally trims the + padded tail using `audio_seqlens`. Per-clip embeddings are + concatenated along time; L2-norm is applied when + `audio_config.norm_query_embeds` is set (true on the released + ckpt — matches `modeling_bailingmm2.extract_audio_feature`). + + * `get_node_engine_types` now registers + `vision_encoder` / `audio_encoder` as `EngineType.STATELESS` + alongside the KV-cache Thinker. Construction routes through + new `_create_vision_encoder_submodule` / + `_create_audio_encoder_submodule` helpers that build, dtype-cast, + and weight-load via the loaders from step 4b. + + * 12 tests in `test/modular/test_ming_flash_omni_submodules.py`: + 10 pure-Python (input-validation, output shape, L2 norm, + audio batched-vs-single equivalence, audio_seqlens trim, + grid_thw promotion, node-type registration, friendly error on + unknown node) + 2 snapshot-gated (full + `_create_audio_encoder_submodule` on the real ckpt — verifies + Conv1 + projector params are non-zero post-load). + + - **5b — DONE** (Thinker prefill dispatch + position helpers): + `BailingMoeV2ThinkerSubmodule.prepare_inputs` now dispatches on + `graph_walk` and emits either `input_ids` (text-only walks) or + `input_embeds` + `custom_pos_ids` (multimodal walks). `preprocess` + and `forward` route both shapes through to `LingMoeModel`'s + existing dual input_ids/input_embeds + 1D/3D position_ids + handling — no new model.py path needed. + + Three new position-id helpers live in `components/positions.py`, + each producing `(3, T)` long tensors compatible with + `LingPartialMRotaryEmbedding`'s `video_rope` branch: + + * `get_rope_index_text(seq_len, start_pos)` — three identical + sequential rows. Matches `modeling_bailing_moe_v2.get_rope_index`'s + pure-text branch (`:658-675`). + * `get_rope_index_audio` — alias to the text helper (Ming + does not special-case audio in `get_rope_index`). + * `get_rope_index_vision(grid_thw, start_pos, spatial_merge_size, + second_per_grid_t=None, tokens_per_second=2)` — per-image + 3D grid math from `:625-647`. Optional video timestamp + scaling via `second_per_grid_t * tokens_per_second`. + + The Thinker dispatch: + + * `prefill` / `prefill_text` — backward-compat text path + (unchanged from step 3f). + * `prefill_audio` — wraps `audio_embeds` with `audio_start` + / `audio_end` sentinel embeddings, builds text-like positions + for the span. + * `prefill_vision` / `prefill_video` — wraps `vision_embeds` + with `image_start`/`image_end` (or `video_start`/`video_end`), + builds grid-aware 3D positions; `eos` sentinel sits at + `global_max(vision_pos) + 1` so the next walk's text positions + can resume without collision (matches Ming source's + `llm_pos_ids_list[-1].max() + 1` accounting). + * `decode` / `thinker_decode` — single-token AR step (unchanged). + + Sentinel embeds are lazily computed per device on first use. + The model.py construction now passes `config=self.config` to the + submodule so it can read `vision.spatial_merge_size`, + `thinker_llm.tokens_per_second`, and the `*_start_token` / + `*_end_token` ids. + + Step 5b restricts to single-image / single-clip requests + (multi-image splice via `Sequential` graph wiring lands in 5c). + + 21 new tests across `test_ming_flash_omni_positions.py` (11) and + `test_ming_flash_omni_submodules.py` (10): position-id shape / + offset / abs-time math, missing-input error paths, + multi-image rejection, sentinel embed correctness for audio / + image / video walks, start_pos advancement, legacy `prefill` + walk name compat. All green. + + - **5c — DONE** (graph wiring + multimodal scheduling): + `get_graph_walk_graphs` now returns five walks instead of the + step 3f text-only `prefill` / `decode` pair: + + * `prefill_text` — bare `Thinker` node. + * `prefill_audio` — `Sequential([audio_encoder, Thinker])` + where the encoder emits `audio_embeds` into the Thinker. + * `prefill_vision` — `Sequential([vision_encoder, Thinker])`; + `image_grid_thw` routes to BOTH the encoder (for spatial + positions on the patches) AND the Thinker (for 3D MRoPE math + around the vision span). + * `prefill_video` — same shape as `prefill_vision` plus + `video_second_per_grid` routed into the Thinker. + * `thinker_decode` — AR loop, renamed from step 3f's `decode`. + + `get_partitions` lists all five walks under the single `Thinker` + partition with `initial_walk="prefill_text"`. Two new helpers + drive the scheduling: + + * `_build_thinker_prefill_schedule(input_modalities, input_signals)` + — one schedule step per modality, in `input_modalities` order; + each step is `(walk_name, {input_name: TensorPointerInfo})`. + Modalities listed without matching tensors in `input_signals` + are silently skipped (parity with qwen3_omni). + * `_get_thinker_prefill_inputs(metadata, input_signals)` — emits + one `GraphEdge` per input for the current step, routing each + to the right node (encoder vs Thinker), including the dual + `image_grid_thw` edge for vision walks. + + `get_initial_forward_pass_args` builds the schedule, picks the + first walk, and stashes the schedule + step counter on the + metadata. `get_partition_forward_pass_args` is the Thinker state + machine: advance schedule → transition to `thinker_decode` → + return `request_done=True` after the decode loop unwinds. Mirrors + `mminf/model/qwen3_omni/qwen3_omni_model.py:765+` minus the + Talker / Code2Wav partitions (which land in step 6+). + + Empty-schedule edge case (no usable modalities) short-circuits + to `request_done=True` so the conductor doesn't hang. + + 21 tests in `test/modular/test_ming_flash_omni_graph.py`: + graph-walk structure (5 walks, encoder→Thinker chaining, dual + grid_thw edge, loop feedback edge), partition listing, prefill + schedule construction for text-only / text+audio+image / video / + unknown-modality / no-inputs cases, edge routing for each walk + type, full state-machine drive across a text+audio request + (init → audio prefill → decode → done). + +6. **Talker + Audio VAE.** Port `ming_flash_omni_talker.py` + `talker_module.py` + + `audio_vae.py`. The talker is CFM-based (continuous flow matching) rather + than discrete-codec-AR like Qwen3-Omni's — the streaming topology will + differ. Re-read `mminf/streaming/topology.py` before wiring connections. + +7. **Process_prompt — DONE.** `MingFlashOmniModel.process_prompt` now + produces the full `NameToTensorList` consumed by step 5c's prefill + scheduler. Strategy mirrors `qwen3_omni`'s `process_prompt`: apply + the chat template to TEXT-ONLY messages (so the tokenizer doesn't + insert placeholder tokens we'd later have to strip), then run the + image / video / audio sub-processors separately for each modality. + The Ming chat template path uses `tokenizer.apply_chat_template` + (jinja, accepts OpenAI roles `user`/`assistant`/`system`) rather + than `processor.apply_chat_template` (Python implementation in + `BailingMM2Processor`, asserts on lowercase OpenAI roles — see + "Role-handling nuance" above). + + Input convention (`tensors: NameToTensorList`): + * `image_inputs` — list of CHW float [0,1] tensors per image. + Internal `_image_to_processor_input` converts to HWC uint8 to + avoid the upstream's double-rescale near-zero bug + (`qwen3_omni_model.py:1033-1038` documents the same gotcha). + Single-channel inputs auto-broadcast to 3 channels. + * `audio_inputs` — list of either raw 1-D float tensors (sample + rate inferred from processor default 16 kHz) or + `(waveform, sample_rate)` tuples. + * `video_inputs` — list of (T, C, H, W) float tensors. Per-frame + `second_per_grid` defaults to 1.0; override via + `kwargs["input_metadata"]["video"][i]["second_per_grid"]`. + + Output keys consumed by `_build_thinker_prefill_schedule`: + * `text_inputs` — list of 1-D long tensors (one per text turn). + * `pixel_values`, `image_grid_thw` — one entry per image. + * `pixel_values_videos`, `video_grid_thw`, + `video_second_per_grid` — one entry per video clip. + * `audio_features` (n_mels, T) + `audio_seqlens` (length-1 long) + — one entry per audio clip. Note: upstream returns audio_feats + as (B, T, n_mels); we transpose to (n_mels, T) per clip so + `AudioEncoderSubmodule.prepare_inputs` can splice without a + reshape. + + 17 tests in `test/modular/test_ming_flash_omni_process_prompt.py`: + text-only happy path, no-prompt audio-only path, image conversion + correctness (CHW float [0,1] → HWC uint8, grayscale broadcast, + uint8 pass-through), per-modality dispatch, missing-processor + error paths, multi-image / mixed-modality combinations, video + metadata override, snapshot-gated text+image E2E with the real + `BailingMM2Processor`. 16 green + 1 env-skip on this box. + + Image-gen-specific `*256` block (the + query-token expansion for the imagegen DiT path) is deferred to + step 9 (ImageGen partition), since today's prefill schedule only + covers text-out generation. + +8. **TTS caption template (optional, talker-only deploy).** Port + `prompt_utils.BASE_CAPTION_TEMPLATE` + `create_instruction` so the + `ming_flash_omni_tts` deploy variant accepts the same JSON caption shape + that vllm-omni speaks. + +9. **ImageGen partition (deferred).** Separate from the omni pipeline; lives + under vllm-omni's diffusion tree. Wire as a fourth partition with its own + graph walk once #1–8 are landed. Needs `FlowEngine`-style integration. + +10. **Configs.** Update `configs/ming_flash_omni*.yaml` to match the final + node names emerging from #5 and #6. Add an image-gen variant when #9 + lands. + +11. **Benchmark `OursOpenAI` parity.** Once `mminf-serve` boots the model, + extend `benchmark/request.py:OursOpenAI` to route Ming TTS through the + correct endpoint (likely `/v1/chat/completions` with `modalities=["audio"]`, + matching the Qwen3-Omni path — `MingFlashOmni` declares no Orpheus-style + speech-only fallback). + +12. **Tests.** Add `test/modular/test_ming_flash_omni_*.py` covering config + load, submodule weight load on a tiny shard, and a smoke graph walk on + a single GPU. Mirror `test/modular/test_qwen3_omni_*.py` if present. + +## Things to verify against the released checkpoint (not in vllm-omni) + +- Exact `max_position_embeddings` and `rope_theta` for thinker vs talker + (read from `config.json`, not the deploy yaml). +- Whether `default_sampling_params.repetition_penalty=1.05` from the deploy + yaml is a serving default or a hard requirement — affects + `benchmark/base.py:MingFlashOmni.get_model_kwargs`. +- The output sample rate for the talker (Qwen3-Omni is 24 kHz; check + `audio_vae.py` for Ming's). Override + `Model.get_output_sample_rate` if it differs. diff --git a/mminf/model/ming_omni_flash/__init__.py b/mminf/model/ming_omni_flash/__init__.py new file mode 100644 index 00000000..72bb12a7 --- /dev/null +++ b/mminf/model/ming_omni_flash/__init__.py @@ -0,0 +1,21 @@ +from mminf.model.ming_omni_flash.components.model import ( + LingMoeModel as LingMoeModel, +) +from mminf.model.ming_omni_flash.loader import ( + load_audio_encoder_weights as load_audio_encoder_weights, +) +from mminf.model.ming_omni_flash.loader import ( + load_audio_projector_weights as load_audio_projector_weights, +) +from mminf.model.ming_omni_flash.loader import ( + load_thinker_weights as load_thinker_weights, +) +from mminf.model.ming_omni_flash.loader import ( + load_vision_encoder_weights as load_vision_encoder_weights, +) +from mminf.model.ming_omni_flash.loader import ( + load_vision_projector_weights as load_vision_projector_weights, +) +from mminf.model.ming_omni_flash.ming_omni_flash_model import ( + MingFlashOmniModel as MingFlashOmniModel, +) diff --git a/mminf/model/ming_omni_flash/components/__init__.py b/mminf/model/ming_omni_flash/components/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mminf/model/ming_omni_flash/components/attention.py b/mminf/model/ming_omni_flash/components/attention.py new file mode 100644 index 00000000..042d2a1c --- /dev/null +++ b/mminf/model/ming_omni_flash/components/attention.py @@ -0,0 +1,171 @@ +"""Ling-2.0 attention (TP-aware, packed-tokens, cache-handle-aware). + +Uses mminf's :class:`QKVParallelLinear` + :class:`RowParallelLinear` for +TP-sharded projections. Per-rank head counts come from the QKV proj — +when ``tp_size > 1``, attention runs on this rank's slice of heads and +the output `dense` projection all-reduces across ranks. + +The architecture-specific bits (per-head QK-norm, partial 3D +``video_rope`` rotation) stay inline — they only operate on this rank's +heads, no cross-rank comm. + +Reference: mminf's :class:`ParallelAttention` +(`mminf/model/components/distributed/attention.py`) + +Qwen3-Omni's :class:`Qwen3OmniAttention` +(`mminf/model/qwen3_omni/components/attention.py`). +""" + +from __future__ import annotations + +import torch +from torch import nn + +from mminf.distributed.communication import TPCommGroup +from mminf.engine.cache_manager import BatchedCacheManager +from mminf.model.components.distributed.linear import ( + QKVParallelLinear, + RowParallelLinear, +) +from mminf.model.components.norm import RMSNorm +from mminf.model.ming_omni_flash.components.rope import LingPartialMRotaryEmbedding + + +class LingAttention(nn.Module): + """Ling-2.0 attention layer (TP-aware). + + Constructor takes TOTAL head counts; per-rank counts are derived from + ``qkv_proj.num_heads`` / ``qkv_proj.num_kv_heads`` after construction + (computed by :class:`QKVParallelLinear` based on ``comm_group.world_size``). + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + rms_norm_eps: float, + rotary: LingPartialMRotaryEmbedding, + use_qkv_bias: bool = False, + use_bias: bool = False, + comm_group: TPCommGroup | None = None, + ) -> None: + super().__init__() + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads={num_heads} must be divisible by " + f"num_kv_heads={num_kv_heads} for GQA" + ) + if rotary.head_dim != head_dim: + raise ValueError( + f"rotary.head_dim={rotary.head_dim} must equal head_dim={head_dim}" + ) + if comm_group is None: + comm_group = TPCommGroup.trivial() + self.comm_group = comm_group + + self.hidden_size = hidden_size + self.head_dim = head_dim + self.total_num_heads = num_heads + self.total_num_kv_heads = num_kv_heads + + # Packed QKV projection — TP-sharded along the heads axis. + # Q rows: total_num_heads * head_dim; K rows: total_num_kv_heads * + # head_dim; V rows: same. Stored ordered [Q, K, V] along dim 0 — + # same packing the released ckpt uses for ``query_key_value.weight``, + # so the manual q/k/v split in loader.py copies into the right + # slots automatically. + self.qkv_proj = QKVParallelLinear( + comm_group=comm_group, + hidden_size=hidden_size, + head_size=head_dim, + total_num_heads=num_heads, + total_num_kv_heads=num_kv_heads, + bias=use_qkv_bias, + ) + # Per-rank head counts; everything downstream uses these. + self.num_heads = self.qkv_proj.num_heads + self.num_kv_heads = self.qkv_proj.num_kv_heads + self.kv_groups = self.num_heads // self.num_kv_heads + self.q_size = self.num_heads * head_dim + self.kv_size = self.num_kv_heads * head_dim + self.scaling = head_dim ** -0.5 + + # Output projection — input dim is sharded (per-rank q_size), + # output dim is full hidden_size; row-parallel runs all-reduce + # across ranks. + self.dense = RowParallelLinear( + comm_group=comm_group, + input_size=num_heads * head_dim, # full pre-shard input + output_size=hidden_size, + bias=use_bias, + input_is_parallel=True, + reduce_results=True, + ) + + # Per-head normalisation on q and k before rope. Operates on the + # head_dim axis, so identical math at each rank's local heads. + self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps) + + self.rotary = rotary + + def forward( + self, + hidden_states: torch.Tensor, + cache_handle: BatchedCacheManager, + position_ids: torch.Tensor, + ) -> torch.Tensor: + """Engine-facing forward (packed tokens, cache-aware, TP-aware). + + Args: + hidden_states: ``(num_tokens, hidden_size)``. NOT pre-sharded + — QKVParallelLinear takes the full hidden dim as input. + cache_handle: see step 3d. + position_ids: see step 3d. + + Returns: + ``(num_tokens, hidden_size)`` — full hidden dim after the + row-parallel dense all-reduces across ranks. + """ + num_tokens = hidden_states.shape[0] + + # qkv_proj returns this rank's slice along the heads axis: + # (num_tokens, num_heads * head_dim + 2 * num_kv_heads * head_dim). + qkv = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(num_tokens, self.num_heads, self.head_dim) + k = k.view(num_tokens, self.num_kv_heads, self.head_dim) + v = v.view(num_tokens, self.num_kv_heads, self.head_dim) + + # QK-norm: per-head RMSNorm on the head_dim axis. Each rank + # operates on its own slice of heads — no comm. + q = self.q_norm(q.reshape(-1, self.head_dim)).view( + num_tokens, self.num_heads, self.head_dim + ) + k = self.k_norm(k.reshape(-1, self.head_dim)).view( + num_tokens, self.num_kv_heads, self.head_dim + ) + + # Partial 3D rope on this rank's heads (rope cos/sin are + # head_dim-shaped, identical at every rank). + q = q.transpose(0, 1) + k = k.transpose(0, 1) + q, k = self.rotary(q, k, position_ids) + q = q.transpose(0, 1).contiguous() + k = k.transpose(0, 1).contiguous() + + # Cache attention on per-rank heads. mminf's BatchedCacheManager + # is per-worker, so its KV cache config already accounts for the + # per-rank head counts (worker derives this from ShardingConfig). + attn_output = cache_handle.run_attention(q=q, k=k, v=v) + attn_output = attn_output.reshape(num_tokens, self.q_size) + # dense is row-parallel: it consumes the per-rank slice along the + # input dim and all-reduces the (full hidden_size) output. + return self.dense(attn_output) + + @staticmethod + def head_norm_check(q_after_norm: torch.Tensor) -> float: + """Diagnostic: returns max abs deviation of per-head RMS from 1.""" + norms = q_after_norm.float().pow(2).mean(dim=-1).sqrt() + return (norms - 1.0).abs().max().item() diff --git a/mminf/model/ming_omni_flash/components/audio_encoder.py b/mminf/model/ming_omni_flash/components/audio_encoder.py new file mode 100644 index 00000000..37acefd3 --- /dev/null +++ b/mminf/model/ming_omni_flash/components/audio_encoder.py @@ -0,0 +1,343 @@ +"""Whisper-style audio encoder for Ming-flash-omni-2.0. + +Self-contained port of vllm-omni's +``vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py`` (247 +LOC) — itself a re-implementation of the OpenAI Whisper encoder that +supports packed variable-length inputs (the Ming source's +``modeling_whisper_encoder.py`` uses padded batches and depends on +``openai-whisper``; we avoid that runtime dep entirely). + +Weight-key parity with the upstream Whisper encoder: + - ``conv1.{weight,bias}`` (kernel=3, stride=1, pad=1) + - ``conv2.{weight,bias}`` (kernel=3, stride=2, pad=1) + - ``positional_embedding`` buffer (sinusoidal, not loaded) + - ``blocks.{N}.attn.{query,key,value,out}.{weight,bias}`` + - ``blocks.{N}.attn_ln.{weight,bias}`` + - ``blocks.{N}.mlp.{0,2}.{weight,bias}`` (Linear, GELU, Linear) + - ``blocks.{N}.mlp_ln.{weight,bias}`` + - ``ln_post.{weight,bias}`` + +The released Ming checkpoint stores these under the top-level prefix +``audio.*`` (see ``model.safetensors.index.json``); the loader strips +that prefix before applying state_dict here. +""" + +from __future__ import annotations + +import logging +import operator +from itertools import accumulate + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Whisper primitives (auto-dtype-casting layers + sinusoidal embedding) +# --------------------------------------------------------------------------- + + +def _sinusoids(length: int, channels: int, max_timescale: int = 10000) -> torch.Tensor: + """Sinusoidal positional embedding from Whisper. + + Args: + length: positions. + channels: must be even. + max_timescale: matches OpenAI Whisper's default (10_000). + """ + if channels % 2 != 0: + raise ValueError(f"channels must be even, got {channels}") + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + +class _AutoCastConv1d(nn.Conv1d): + """Conv1d that casts its weight/bias to the input dtype on every forward. + + Lets the encoder keep bf16 weights while taking fp32 mel inputs + without an explicit ``.to(bf16)`` at the call site (Whisper does + this too). + """ + + def _conv_forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None) -> torch.Tensor: + return super()._conv_forward( + x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype), + ) + + +class _AutoCastLinear(nn.Linear): + """Linear with the same auto-cast trick.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.linear( + x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype), + ) + + +# --------------------------------------------------------------------------- +# Multi-head attention (packed sequence with optional FA2 fast path) +# --------------------------------------------------------------------------- + + +def _try_import_flash_attn(): + """Return flash_attn_varlen_func if importable, else None. + + Wrapped so test boxes without flash-attn keep green via the manual + PyTorch fallback. Audio encoder forward shape is identical either way. + """ + try: + from flash_attn import flash_attn_varlen_func # type: ignore + return flash_attn_varlen_func + except ImportError: + return None + + +_FLASH_ATTN_VARLEN = _try_import_flash_attn() + + +class _PackedMultiHeadAttention(nn.Module): + """Whisper-style MHA with variable-length packed sequences. + + Param naming matches OpenAI Whisper (``query`` / ``key`` / ``value`` / + ``out`` — not ``q_proj`` / ``k_proj`` / etc.) so the checkpoint keys + load directly. + """ + + def __init__(self, n_state: int, n_head: int, use_flash_attn: bool = True) -> None: + super().__init__() + if n_state % n_head != 0: + raise ValueError(f"n_state={n_state} not divisible by n_head={n_head}") + self.n_head = n_head + self.query = _AutoCastLinear(n_state, n_state) + self.key = _AutoCastLinear(n_state, n_state, bias=False) + self.value = _AutoCastLinear(n_state, n_state) + self.out = _AutoCastLinear(n_state, n_state) + + if use_flash_attn and _FLASH_ATTN_VARLEN is None: + logger.warning("flash-attn not available — falling back to manual attention.") + self.use_flash_attn = use_flash_attn and _FLASH_ATTN_VARLEN is not None + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + """Packed-sequence attention. + + Args: + x: (total_tokens, n_state) packed tensor. + cu_seqlens: (num_seqs + 1,) cumulative seq lengths, + e.g. [0, len1, len1+len2, ...]. int32. + """ + q = self.query(x) + k = self.key(x) + v = self.value(x) + + n_tokens, n_state = q.shape + head_dim = n_state // self.n_head + q = q.view(n_tokens, self.n_head, head_dim) + k = k.view(n_tokens, self.n_head, head_dim) + v = v.view(n_tokens, self.n_head, head_dim) + + if self.use_flash_attn and q.dtype in (torch.float16, torch.bfloat16): + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = _FLASH_ATTN_VARLEN( + q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + ) + else: + attn_output = self._manual_packed_attention(q, k, v, cu_seqlens) + + attn_output = attn_output.contiguous().view(n_tokens, n_state) + return self.out(attn_output) + + @staticmethod + def _manual_packed_attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor, + ) -> torch.Tensor: + """Pad-attention-unpack fallback for the packed format.""" + _, n_head, head_dim = q.shape + scale = head_dim ** -0.5 + + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + batch = len(seqlens) + max_len = max(seqlens) + + # Pad each sequence to max_len so we can run a single batched matmul. + q_pad = torch.zeros(batch, max_len, n_head, head_dim, dtype=q.dtype, device=q.device) + k_pad = torch.zeros_like(q_pad) + v_pad = torch.zeros_like(q_pad) + for i, ln in enumerate(seqlens): + start = int(cu_seqlens[i].item()) + end = int(cu_seqlens[i + 1].item()) + q_pad[i, :ln] = q[start:end] + k_pad[i, :ln] = k[start:end] + v_pad[i, :ln] = v[start:end] + + # (B, H, T, D) + q_pad = q_pad.transpose(1, 2) + k_pad = k_pad.transpose(1, 2) + v_pad = v_pad.transpose(1, 2) + + # Mask padding columns out of softmax. + padding_mask = ( + torch.arange(max_len, device=q.device)[None, :] + >= torch.tensor(seqlens, device=q.device)[:, None] + ) + attn_mask = torch.zeros(batch, 1, 1, max_len, dtype=q.dtype, device=q.device) + attn_mask = attn_mask.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2), -torch.finfo(q.dtype).max, + ) + + scores = torch.matmul(q_pad, k_pad.transpose(-2, -1)) * scale + attn_mask + weights = F.softmax(scores, dim=-1) + context = torch.matmul(weights, v_pad) # (B, H, T, D) + context = context.transpose(1, 2).contiguous() # (B, T, H, D) + + # Unpack back to packed. + return torch.cat([context[i, :ln] for i, ln in enumerate(seqlens)], dim=0) + + +# --------------------------------------------------------------------------- +# Residual block (Whisper attn + FFN) +# --------------------------------------------------------------------------- + + +class _ResidualAttentionBlock(nn.Module): + """Whisper-style attn + FFN residual block (param names match upstream).""" + + def __init__(self, n_state: int, n_head: int, use_flash_attn: bool = True) -> None: + super().__init__() + self.attn = _PackedMultiHeadAttention(n_state, n_head, use_flash_attn=use_flash_attn) + self.attn_ln = nn.LayerNorm(n_state) + + n_mlp = n_state * 4 + # Sequential layout (Linear, GELU, Linear) so checkpoint keys + # blocks.{N}.mlp.0.* / .2.* hit the right module by integer index. + self.mlp = nn.Sequential( + _AutoCastLinear(n_state, n_mlp), + nn.GELU(), + _AutoCastLinear(n_mlp, n_state), + ) + self.mlp_ln = nn.LayerNorm(n_state) + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.attn_ln(x), cu_seqlens=cu_seqlens) + x = x + self.mlp(self.mlp_ln(x)) + return x + + +# --------------------------------------------------------------------------- +# Encoder — public API +# --------------------------------------------------------------------------- + + +class MingAudioEncoder(nn.Module): + """Whisper audio encoder with packed-sequence support. + + Loadable from the released Ming-flash-omni-2.0 checkpoint's + ``audio.*`` weight subtree (caller strips the prefix). Defaults + match the released ckpt's ``audio_config.whisper_encoder_config``. + + Note the deviation from the openai-whisper original: the + ``positional_embedding`` is a *buffer* with a fixed sinusoidal + table sized to ``n_ctx`` (15000 on the released ckpt — enough for + ~150 s of audio at the post-conv frame rate). The Ming source's + ``modeling_whisper_encoder.py`` notes the same change — they drop + the trainable parameter so they can shrink the sequence length + below the original 30 s pad. + """ + + def __init__( + self, + n_mels: int = 128, + n_ctx: int = 15000, + n_state: int = 1280, + n_head: int = 20, + n_layer: int = 32, + use_flash_attn: bool = True, + ) -> None: + super().__init__() + self.n_layer = n_layer + self.n_mels = n_mels + self.use_flash_attn = use_flash_attn + self.audio_emb_dim = n_state + + self.conv1 = _AutoCastConv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = _AutoCastConv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) + # Buffer (not Parameter) — checkpoint doesn't ship this; we + # recompute it. Keeps load_state_dict happy with the snapshot. + self.register_buffer("positional_embedding", _sinusoids(n_ctx, n_state)) + self.blocks = nn.ModuleList( + [_ResidualAttentionBlock(n_state, n_head, use_flash_attn=use_flash_attn) for _ in range(n_layer)] + ) + self.ln_post = nn.LayerNorm(n_state) + + def forward(self, x_list: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + """Run the encoder on a list of variable-length mel spectrograms. + + Args: + x_list: list of (n_mels, T_i) mel features per audio clip. + + Returns: + (packed, cu_seqlens): + - packed: (total_T', n_state) all clips concatenated + along time. + - cu_seqlens: (len(x_list) + 1,) int32 cumulative encoded + lengths suitable for re-segmenting / feeding + into the projector. + """ + target_dtype = self.conv1.weight.dtype + + encoded = [] + encoded_lens: list[int] = [] + for mel in x_list: + mel = mel.to(target_dtype) + x = mel.unsqueeze(0) # (1, n_mels, T) + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.squeeze(0).transpose(0, 1) # (T', n_state) + + seq_len = x.shape[0] + x = (x + self.positional_embedding[:seq_len, :]).to(x.dtype) + encoded.append(x) + encoded_lens.append(seq_len) + + packed = torch.cat(encoded, dim=0) # (sum T', n_state) + cu_seqlens = torch.tensor( + list(accumulate(encoded_lens, func=operator.add, initial=0)), + device=packed.device, dtype=torch.int32, + ) + for block in self.blocks: + packed = block(packed, cu_seqlens=cu_seqlens) + packed = self.ln_post(packed) + return packed, cu_seqlens + + +def build_audio_encoder( + audio_config, + dtype: torch.dtype = torch.bfloat16, + device: str | torch.device = "cpu", + use_flash_attn: bool = True, +) -> MingAudioEncoder: + """Construct :class:`MingAudioEncoder` from an ``AudioEncoderConfig``. + + Matches ``build_vision_encoder``'s factory shape so the model class + treats both modalities symmetrically when wiring submodules. + """ + whisper_cfg = audio_config.whisper_encoder_config + encoder = MingAudioEncoder( + n_mels=int(whisper_cfg["n_mels"]), + n_ctx=int(whisper_cfg["n_ctx"]), + n_state=int(whisper_cfg["n_state"]), + n_head=int(whisper_cfg["n_head"]), + n_layer=int(whisper_cfg["n_layer"]), + use_flash_attn=use_flash_attn, + ) + encoder = encoder.to(dtype=dtype, device=device) + encoder.eval() + return encoder + + +__all__ = ["MingAudioEncoder", "build_audio_encoder"] diff --git a/mminf/model/ming_omni_flash/components/decoder_layer.py b/mminf/model/ming_omni_flash/components/decoder_layer.py new file mode 100644 index 00000000..44871456 --- /dev/null +++ b/mminf/model/ming_omni_flash/components/decoder_layer.py @@ -0,0 +1,111 @@ +"""Ling-2.0 decoder layer (TP-aware, hybrid dense / MoE).""" + +from __future__ import annotations + +import torch +from torch import nn + +from mminf.distributed.communication import TPCommGroup +from mminf.engine.cache_manager import BatchedCacheManager +from mminf.model.components.distributed.mlp import ParallelGatedMLP +from mminf.model.components.norm import RMSNorm +from mminf.model.ming_omni_flash.components.attention import LingAttention +from mminf.model.ming_omni_flash.components.moe import LingMoeBlock +from mminf.model.ming_omni_flash.components.rope import ( + LingPartialMRotaryEmbedding, +) + + +class LingDecoderLayer(nn.Module): + """One Ling-2.0 decoder layer; layer_idx decides dense-vs-MoE FFN. + + All sub-modules receive ``comm_group``; defaults to single-rank + trivial when not set. Dense layer-0 MLP uses :class:`ParallelGatedMLP` + so its `down_proj` all-reduces across ranks. + """ + + def __init__( + self, + layer_idx: int, + first_k_dense_replace: int, + hidden_size: int, + intermediate_size: int, + moe_intermediate_size: int, + num_attention_heads: int, + num_kv_heads: int, + head_dim: int, + rms_norm_eps: float, + num_experts: int, + num_experts_per_tok: int, + num_shared_experts: int, + n_group: int, + topk_group: int, + routed_scaling_factor: float, + rotary: LingPartialMRotaryEmbedding, + use_qkv_bias: bool = False, + use_bias: bool = False, + comm_group: TPCommGroup | None = None, + ) -> None: + super().__init__() + if comm_group is None: + comm_group = TPCommGroup.trivial() + self.layer_idx = layer_idx + self.is_moe = layer_idx >= first_k_dense_replace + + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + + self.self_attn = LingAttention( + hidden_size=hidden_size, + num_heads=num_attention_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + rotary=rotary, + use_qkv_bias=use_qkv_bias, + use_bias=use_bias, + comm_group=comm_group, + ) + + if self.is_moe: + self.mlp: nn.Module = LingMoeBlock( + hidden_size=hidden_size, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + moe_intermediate_size=moe_intermediate_size, + num_shared_experts=num_shared_experts, + n_group=n_group, + topk_group=topk_group, + routed_scaling_factor=routed_scaling_factor, + comm_group=comm_group, + ) + else: + # Dense layer-0 MLP — ParallelGatedMLP so its column-parallel + # gate/up + row-parallel down handle TP sharding internally. + self.mlp = ParallelGatedMLP( + comm_group=comm_group, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + bias=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + cache_handle: BatchedCacheManager, + position_ids: torch.Tensor, + image_mask: torch.Tensor | None = None, + audio_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + residual = hidden_states + h = self.input_layernorm(hidden_states) + h = self.self_attn(h, cache_handle, position_ids) + h = residual + h + + residual = h + h = self.post_attention_layernorm(h) + if self.is_moe: + h = self.mlp(h, image_mask=image_mask, audio_mask=audio_mask) + else: + h = self.mlp(h) + return residual + h diff --git a/mminf/model/ming_omni_flash/components/model.py b/mminf/model/ming_omni_flash/components/model.py new file mode 100644 index 00000000..ed6d5466 --- /dev/null +++ b/mminf/model/ming_omni_flash/components/model.py @@ -0,0 +1,202 @@ +"""Ling-2.0 thinker LLM (full forward, no KV cache yet). + +Composes :class:`LingDecoderLayer` × N with a shared rope, vocab +embedding, final RMSNorm, and an untied lm_head. The shape downstream +mminf code will eventually wrap is one of these :class:`LingMoeModel` +instances behind a :class:`NodeSubmodule` (step 3c). + +Reference structure: vllm-omni's :class:`BailingMoeV2Model` + +:class:`BailingMoeV2ForCausalLM` +``/tmp/vllm-omni/.../modeling_bailing_moe_v2.py:662-895``. +""" + +from __future__ import annotations + +import torch +from torch import nn + +from mminf.distributed.communication import TPCommGroup +from mminf.model.components.norm import RMSNorm +from mminf.model.ming_omni_flash.components.decoder_layer import ( + LingDecoderLayer, +) +from mminf.model.ming_omni_flash.components.rope import ( + LingPartialMRotaryEmbedding, +) + + +class LingMoeModel(nn.Module): + """Full Ling-2.0 thinker forward (embed + layers + lm_head). + + All shape-relevant config flattens into the constructor so callers + don't need a :class:`MingFlashOmniModelConfig` instance — useful for + small-dim unit tests. The eventual mminf submodule (step 3c) builds + one of these from the real config. + + Args (all required, but small-dim test configs only need plausible + values; nothing here is hard-coded to Ming-specific dims): + vocab_size: e.g. 157184 on released ckpt. + hidden_size: e.g. 4096. + intermediate_size: dense layer-0 MLP intermediate; e.g. 9216. + moe_intermediate_size: per-expert intermediate; e.g. 1024. + num_hidden_layers: e.g. 32. + num_attention_heads, num_kv_heads, head_dim: e.g. 32 / 4 / 128. + rms_norm_eps: 1e-6. + rope_theta: 2_400_000. + max_position_embeddings: 32768. + partial_rotary_factor: 0.5. + mrope_section: [8, 12, 12]. + num_experts: 256. + num_experts_per_tok: 8. + num_shared_experts: 1. + n_group: 8. + topk_group: 4. + routed_scaling_factor: 2.5. + first_k_dense_replace: 1. + tie_word_embeddings: False on released ckpt — lm_head is a + separate matrix from embed_tokens. + """ + + def __init__( + self, + *, + vocab_size: int, + hidden_size: int, + intermediate_size: int, + moe_intermediate_size: int, + num_hidden_layers: int, + num_attention_heads: int, + num_kv_heads: int, + head_dim: int, + rms_norm_eps: float, + rope_theta: float, + max_position_embeddings: int, + partial_rotary_factor: float, + mrope_section: list[int], + num_experts: int, + num_experts_per_tok: int, + num_shared_experts: int, + n_group: int, + topk_group: int, + routed_scaling_factor: float, + first_k_dense_replace: int, + tie_word_embeddings: bool = False, + use_qkv_bias: bool = False, + use_bias: bool = False, + comm_group: TPCommGroup | None = None, + ) -> None: + super().__init__() + if comm_group is None: + comm_group = TPCommGroup.trivial() + self.comm_group = comm_group + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + + # embed_tokens + lm_head stay replicated. At hidden_size=4096 + # they're 1.3 GB each — cheap compared to the layers. + self.embed_tokens = nn.Embedding(vocab_size, hidden_size) + + # Single rotary instance shared across every layer — inv_freq is + # config-only, no per-layer state. + rotary = LingPartialMRotaryEmbedding( + head_dim=head_dim, + partial_rotary_factor=partial_rotary_factor, + mrope_section=mrope_section, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + ) + + self.layers = nn.ModuleList([ + LingDecoderLayer( + layer_idx=i, + first_k_dense_replace=first_k_dense_replace, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + moe_intermediate_size=moe_intermediate_size, + num_attention_heads=num_attention_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + num_shared_experts=num_shared_experts, + n_group=n_group, + topk_group=topk_group, + routed_scaling_factor=routed_scaling_factor, + rotary=rotary, + use_qkv_bias=use_qkv_bias, + use_bias=use_bias, + comm_group=comm_group, + ) + for i in range(num_hidden_layers) + ]) + + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) + self.tie_word_embeddings = tie_word_embeddings + if tie_word_embeddings: + self.lm_head.weight = self.embed_tokens.weight + + def forward( + self, + cache_handle, + input_ids: torch.Tensor | None = None, + input_embeds: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + image_mask: torch.Tensor | None = None, + audio_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Run the full thinker forward. + + Args: + cache_handle: :class:`BatchedCacheManager` from the engine + (or a unit-test mock with ``set_layer_idx`` + + ``run_attention``). Required — the attention layer + writes K/V to its paged cache and runs FlashInfer + attention against it. + input_ids: ``(T,)`` token ids — if provided, ``embed_tokens`` + turns them into embeddings. + input_embeds: ``(T, hidden_size)`` precomputed embeddings — + used directly (multimodal callers pass this with vision / + audio embeddings already spliced in). + position_ids: ``(T,)`` for 1D rope, or ``(3, T)`` for 3D + video_rope. Defaults to ``torch.arange(T)`` if None. + image_mask, audio_mask: per-token modality masks for + :class:`LingMoeBlock`. ``None`` ⇒ all text routing. + + Returns: + ``(T, vocab_size)`` logits. The caller (the submodule) + slices the last position for next-token sampling. + """ + if (input_ids is None) == (input_embeds is None): + raise ValueError( + "Exactly one of input_ids / input_embeds must be provided" + ) + + if input_embeds is None: + assert input_ids is not None + h = self.embed_tokens(input_ids) + else: + h = input_embeds + + if h.dim() != 2: + raise ValueError( + f"LingMoeModel expects packed (T, hidden) input; got " + f"shape {tuple(h.shape)}." + ) + + T = h.shape[0] + if position_ids is None: + position_ids = torch.arange(T, device=h.device) + + for layer_idx, layer in enumerate(self.layers): + cache_handle.set_layer_idx(layer_idx) + h = layer( + h, cache_handle, position_ids, + image_mask=image_mask, + audio_mask=audio_mask, + ) + + h = self.norm(h) + return self.lm_head(h) diff --git a/mminf/model/ming_omni_flash/components/moe.py b/mminf/model/ming_omni_flash/components/moe.py new file mode 100644 index 00000000..23e1b8f6 --- /dev/null +++ b/mminf/model/ming_omni_flash/components/moe.py @@ -0,0 +1,303 @@ +"""Ling-2.0 MoE block (TP-aware ``MultiRouter`` flavour). + +Same 3-router text/image/audio gate selection as step 3b, now with +per-rank expert sharding when ``comm_group.world_size > 1``: + + * Fused expert tensors hold ``(E, 2*shard_inter, hidden)`` and + ``(E, hidden, shard_inter)`` per rank, where + ``shard_inter = moe_intermediate_size // tp_size``. + * Mminf's ``_gate_up_weight_loader`` / ``_down_proj_weight_loader`` + handle per-rank slicing during checkpoint load — these get + attached to the params via the ``_attach_weight_loaders`` dance + that survives ``.to_empty`` / ``.to(...)``. + * Shared expert is a ``ParallelGatedMLP`` so its ``down_proj`` + all-reduces internally. + * Forward TP path mirrors :class:`ParallelSparseMoeBlock._dispatch_tp`: + `fused_experts(..., reduce_results=False)` → ``all_reduce`` → + ``moe_sum_reduce_triton``. + +Routers (``LingMoeRouter``) stay replicated across ranks — gates must +make identical decisions so every rank dispatches tokens to the same +experts. + +Reference: vllm-omni's ``BailingMoeV2SparseMoeBlock`` (lines 304-433) ++ mminf's :class:`ParallelSparseMoeBlock` +(`mminf/model/components/moe.py:318-414`). +""" + +from __future__ import annotations + +from functools import partial + +import torch +from torch import nn + +from mminf.distributed.communication import TPCommGroup +from mminf.distributed.utils import divide +from mminf.model.components.distributed.mlp import ParallelGatedMLP +from mminf.model.components.mlp import GatedMLP +from mminf.model.components.moe import ( + _dispatch, + _down_proj_weight_loader, + _gate_up_weight_loader, + dispatch_experts_fused, +) +from mminf.model.ming_omni_flash.components.router import LingMoeRouter + + +def _normalize_modality_mask( + mask: torch.Tensor | None, num_tokens: int, name: str, +) -> torch.Tensor | None: + """Reshape a modality mask to ``(num_tokens, 1)`` bool, or pass through None.""" + if mask is None: + return None + if mask.dim() == 1: + if mask.shape[0] != num_tokens: + raise ValueError( + f"{name} length {mask.shape[0]} != num_tokens={num_tokens}" + ) + return mask.reshape(num_tokens, 1).bool() + if mask.dim() == 2: + if mask.numel() != num_tokens: + raise ValueError( + f"{name} shape {tuple(mask.shape)} has {mask.numel()} elements; " + f"expected num_tokens={num_tokens}" + ) + return mask.reshape(num_tokens, 1).bool() + if mask.dim() == 3: + if mask.shape[-1] != 1 or mask.numel() != num_tokens: + raise ValueError( + f"{name} shape {tuple(mask.shape)} not compatible with " + f"num_tokens={num_tokens}" + ) + return mask.reshape(num_tokens, 1).bool() + raise ValueError( + f"{name} must be 1D, 2D, or 3D; got shape {tuple(mask.shape)}" + ) + + +class LingMoeBlock(nn.Module): + """Ling-2.0 MoE FFN with text/image/audio gate selection per token. + + Constructor takes the FULL ``moe_intermediate_size``; the per-rank + ``shard_inter`` is computed from ``comm_group.world_size``. + + Args: + hidden_size: model hidden dim. + num_experts: total routed experts. + num_experts_per_tok: top-k experts per token. + moe_intermediate_size: per-expert intermediate dim (FULL — + sharding handled internally). + num_shared_experts: number of shared experts (1 on the released + ckpt). The shared expert is a ``ParallelGatedMLP`` of width + ``moe_intermediate_size * num_shared_experts``. + n_group, topk_group, routed_scaling_factor: passed to the + :class:`LingMoeRouter`s. + comm_group: TP comm group; defaults to single-rank trivial. + """ + + def __init__( + self, + hidden_size: int, + num_experts: int, + num_experts_per_tok: int, + moe_intermediate_size: int, + num_shared_experts: int, + n_group: int, + topk_group: int, + routed_scaling_factor: float = 1.0, + comm_group: TPCommGroup | None = None, + ) -> None: + super().__init__() + if comm_group is None: + comm_group = TPCommGroup.trivial() + self.comm_group = comm_group + tp_size = comm_group.world_size + tp_rank = comm_group.rank + + self.hidden_size = hidden_size + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.moe_intermediate_size = moe_intermediate_size + + router_kwargs = dict( + hidden_size=hidden_size, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + n_group=n_group, + topk_group=topk_group, + routed_scaling_factor=routed_scaling_factor, + ) + # Routers — replicated. All ranks must agree on which experts a + # given token routes to, so gate weights are loaded identically + # per rank (default weight_loader, no shard_id). + self.gate = LingMoeRouter(**router_kwargs) + self.image_gate = LingMoeRouter(**router_kwargs) + self.audio_gate = LingMoeRouter(**router_kwargs) + + # Fused expert tensors with per-rank intermediate shard. + shard_inter = divide(moe_intermediate_size, tp_size) + self.experts = nn.Module() + self.experts.gate_up_proj = nn.Parameter( + torch.empty(num_experts, 2 * shard_inter, hidden_size) + ) + self.experts.down_proj = nn.Parameter( + torch.empty(num_experts, hidden_size, shard_inter) + ) + + # Shared expert: ParallelGatedMLP. Its down_proj all-reduces, so + # the shared output already lives on the full hidden state at + # every rank. + if num_shared_experts <= 0: + raise ValueError( + "LingMoeBlock requires num_shared_experts >= 1; released " + "Ming-flash-omni-2.0 has 1." + ) + self.shared_expert = ParallelGatedMLP( + comm_group=comm_group, + hidden_size=hidden_size, + intermediate_size=moe_intermediate_size * num_shared_experts, + bias=False, + ) + + self._attach_weight_loaders(tp_rank, tp_size, moe_intermediate_size) + + # ------------------------------------------------------------------ + # Weight loader plumbing — mirrors ParallelSparseMoeBlock + # ------------------------------------------------------------------ + + def _attach_weight_loaders( + self, tp_rank: int, tp_size: int, full_inter: int, + ) -> None: + """Attach mminf's per-rank fused-expert weight loaders. + + The loaders accept shard ids ``"gate:N"``, ``"up:N"``, ``"down:N"`` + and slice along the intermediate dim per rank, then write into + the right expert slot. ``load_hf_weights`` dispatches based on + the ``StackedParamRule.shard_id`` we configure in the loader. + """ + self.experts.gate_up_proj.weight_loader = partial( + _gate_up_weight_loader, tp_rank, tp_size, full_inter, + ) + self.experts.down_proj.weight_loader = partial( + _down_proj_weight_loader, tp_rank, tp_size, full_inter, + ) + + def _apply(self, fn, recurse=True): + """Re-attach loaders after any ``to_empty`` / ``.to(...)`` since + those operations re-allocate Parameters and drop attached + attributes on the old objects.""" + result = super()._apply(fn, recurse=recurse) + self._attach_weight_loaders( + self.comm_group.rank, + self.comm_group.world_size, + self.moe_intermediate_size, + ) + return result + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward( + self, + hidden_states: torch.Tensor, + image_mask: torch.Tensor | None = None, + audio_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Route + dispatch + add shared expert output. + + TP=1 path uses the direct ``_dispatch`` helper (mminf's + triton-fused or naive loop depending on availability). TP>1 + path uses the unreduced fused_experts call + manual all-reduce + + sum-reduce — mirrors :class:`ParallelSparseMoeBlock._dispatch_tp`. + """ + input_shape = hidden_states.shape + flat = hidden_states.view(-1, hidden_states.shape[-1]).contiguous() + num_tokens = flat.shape[0] + + # Text-gate baseline routing (always computed). + _, topk_weight, topk_idx = self.gate(flat) + + image_mask = _normalize_modality_mask(image_mask, num_tokens, "image_mask") + audio_mask = _normalize_modality_mask(audio_mask, num_tokens, "audio_mask") + + if image_mask is not None: + _, img_w, img_idx = self.image_gate(flat) + topk_idx = torch.where(image_mask, img_idx, topk_idx) + topk_weight = torch.where(image_mask, img_w, topk_weight) + if audio_mask is not None: + _, aud_w, aud_idx = self.audio_gate(flat) + topk_idx = torch.where(audio_mask, aud_idx, topk_idx) + topk_weight = torch.where(audio_mask, aud_w, topk_weight) + + if self.comm_group.world_size == 1: + routed = _dispatch( + flat, + self.experts.gate_up_proj, + self.experts.down_proj, + self.num_experts, + topk_idx, + topk_weight, + ) + else: + routed = self._dispatch_tp(flat, topk_weight, topk_idx) + + shared = self.shared_expert(flat) + # Upstream sums routed + shared without an additional gate + # (BailingMoeV2SparseMoeBlock.forward:429). The + # routed_scaling_factor is baked into topk_weight via the router. + return (routed + shared).view(input_shape) + + def _dispatch_tp( + self, + flat: torch.Tensor, + routing_weights: torch.Tensor, + selected_experts: torch.Tensor, + ) -> torch.Tensor: + """TP>1 expert dispatch. + + Identical to :func:`ParallelSparseMoeBlock._dispatch_tp` — runs + fused_experts WITHOUT the final per-token reduce, all-reduces + the per-rank partial results across TP ranks, then sum-reduces + across top-k. Result is the full-precision routed output at + every rank. + + Falls back to the naive per-expert loop in + :func:`dispatch_experts_fused` when ``sgl_kernel`` isn't loadable + (e.g. ABI-mismatched against the installed torch). The naive path + already returns ``(tokens, hidden)`` summed across top-k, so we + all-reduce that directly — math is equivalent because sum-over-TP + and sum-over-top-k commute. + """ + from mminf.utils.fused_moe.align import has_sgl_kernel + + if has_sgl_kernel(): + from mminf.utils.fused_moe import fused_experts, moe_sum_reduce_triton + + cache3 = fused_experts( + flat, + self.experts.gate_up_proj, + self.experts.down_proj, + routing_weights, + selected_experts, + reduce_results=False, + ) + self.comm_group.all_reduce(cache3) + output = torch.empty_like(flat) + moe_sum_reduce_triton(cache3, output, routed_scaling_factor=1.0) + return output + + partial = dispatch_experts_fused( + flat, + self.experts.gate_up_proj, + self.experts.down_proj, + self.experts.gate_up_proj.shape[0], + selected_experts, + routing_weights, + ) + self.comm_group.all_reduce(partial) + return partial + + +__all__ = ["LingMoeBlock", "GatedMLP"] # GatedMLP re-export for back-compat diff --git a/mminf/model/ming_omni_flash/components/positions.py b/mminf/model/ming_omni_flash/components/positions.py new file mode 100644 index 00000000..b5652413 --- /dev/null +++ b/mminf/model/ming_omni_flash/components/positions.py @@ -0,0 +1,209 @@ +"""3D MRoPE position-id helpers for Ming-flash-omni-2.0. + +Ming-flash-omni-2.0 uses partial 3D MRoPE +(`mrope_section=[8, 12, 12]`, `partial_rotary_factor=0.5`) in the +``video_rope`` layout. The cos/sin remap lives in +:class:`mminf.model.ming_omni_flash.components.rope.LingPartialMRotaryEmbedding`; +this module produces the *position-id* tensors that feed into it. + +Three helpers cover the modality-specific position layouts used by the +Thinker prefill walks: + + * :func:`get_rope_index_text` — pure-text span (sentinels included). + * :func:`get_rope_index_audio` — audio embeddings (treated as text + positions per ``modeling_bailing_moe_v2.get_rope_index``, which + only special-cases ``image_*`` / ``video_*`` tokens). + * :func:`get_rope_index_vision` — image (or video) embeddings with + grid-aware T/H/W position ids per + ``modeling_bailing_moe_v2.get_rope_index:592-647``. + +All three return ``(3, seq_len)`` tensors with rows ``[t, h, w]``; +the rope module's ``video_rope`` remap will pick out H/W on even/odd +spatial slots and T on the temporal tail (see +``LingPartialMRotaryEmbedding._cos_sin_3d_video_rope`` for the layout). +""" + +from __future__ import annotations + +import torch + + +def get_rope_index_text( + seq_len: int, + start_pos: int | float, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.long, +) -> torch.Tensor: + """3D MRoPE positions for a pure-text span. + + All three (T, H, W) components share the same sequential positions + ``[start_pos, start_pos+1, ..., start_pos+seq_len-1]``. This matches + the pure-text branch of ``modeling_bailing_moe_v2.get_rope_index`` + (`./modeling_bailing_moe_v2.py:658-675`). + + Args: + seq_len: number of tokens in this span. + start_pos: position offset for the first token. + device: target device. + dtype: integer dtype for the position ids (rope module + casts to float internally; long matches the upstream). + + Returns: + ``(3, seq_len)`` tensor. + """ + positions = torch.arange(seq_len, dtype=dtype, device=device) + int(start_pos) + return positions.unsqueeze(0).expand(3, -1).contiguous() + + +def get_rope_index_audio( + num_audio_tokens: int, + start_pos: int | float, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.long, +) -> torch.Tensor: + """3D MRoPE positions for an audio span. + + Ming's `get_rope_index` does NOT special-case audio: audio tokens + advance through the same per-token position counter as text. Each + audio token contributes one position; T/H/W all match. Audio + semantics live in the audio encoder + projector (which already + down-sample to one embedding per LLM-time-step). + + Args: + num_audio_tokens: number of audio embeddings (after the + projector's conv1d down-sample). + start_pos: position offset for the first audio embedding. + device: target device. + dtype: integer dtype for position ids. + + Returns: + ``(3, num_audio_tokens)`` tensor, identical rows. + """ + return get_rope_index_text(num_audio_tokens, start_pos, device=device, dtype=dtype) + + +def get_rope_index_vision( + grid_thw: torch.Tensor, + start_pos: int | float, + spatial_merge_size: int, + device: torch.device | str | None = None, + second_per_grid_t: float | None = None, + tokens_per_second: int = 2, + dtype: torch.dtype = torch.long, +) -> torch.Tensor: + """3D MRoPE positions for a vision span (single image or video). + + Mirrors `modeling_bailing_moe_v2.get_rope_index:625-647` for one + image: + + * Temporal: ``arange(grid_t)`` expanded across ``H*W``, optionally + scaled by ``second_per_grid_t * tokens_per_second`` + for absolute video timestamps. + * Height: ``arange(llm_grid_h)`` expanded across ``T * W``. + * Width: ``arange(llm_grid_w)`` expanded across ``T * H``. + + ``llm_grid_h = grid_h // spatial_merge_size`` (same for W). All + three components are offset by ``start_pos`` so the span fits into + the global position-id counter the caller is tracking. + + Multi-image / video frames concatenate across images by calling + this helper per image and stitching the results — see + :func:`stitch_vision_positions` (or the dispatch in + `BailingMoeV2ThinkerSubmodule.prepare_inputs`). + + Args: + grid_thw: ``(3,)`` long tensor of (T, H, W) grid sizes. + start_pos: position offset for this image's first token. + spatial_merge_size: from `VisionEncoderConfig.spatial_merge_size` + (= 2 on the released ckpt). + device: target device. + second_per_grid_t: when set, multiply the temporal component by + ``second_per_grid_t * tokens_per_second`` for absolute video + timestamps. None ⇒ raw frame index. Image inputs always pass + None; video inputs pass the per-clip frame interval. + tokens_per_second: temporal-resolution multiplier + (= 2 on the released ckpt; mirrors ``config.tokens_per_second``). + dtype: integer dtype for position ids. + + Returns: + ``(3, grid_t * (H/m) * (W/m))`` tensor of T/H/W positions + offset by ``start_pos``. + """ + if grid_thw.dim() != 1 or grid_thw.numel() != 3: + raise ValueError( + f"grid_thw must be a 1-D tensor of length 3 (T, H, W); " + f"got shape {tuple(grid_thw.shape)}" + ) + grid_t = int(grid_thw[0].item()) + grid_h = int(grid_thw[1].item()) + grid_w = int(grid_thw[2].item()) + if grid_h % spatial_merge_size != 0 or grid_w % spatial_merge_size != 0: + raise ValueError( + f"grid_h={grid_h} / grid_w={grid_w} not divisible by " + f"spatial_merge_size={spatial_merge_size}." + ) + llm_grid_h = grid_h // spatial_merge_size + llm_grid_w = grid_w // spatial_merge_size + + # Temporal: arange(grid_t), expanded across H*W, optionally absolute time. + range_t = torch.arange(grid_t, dtype=dtype, device=device).view(-1, 1) + expanded_t = range_t.expand(-1, llm_grid_h * llm_grid_w) + if second_per_grid_t is not None: + # Float math then back to int (matches modeling_bailing_moe_v2 path). + t_index = ( + expanded_t.float() * float(second_per_grid_t) * float(tokens_per_second) + ).to(dtype).flatten() + else: + t_index = expanded_t.flatten() + + h_index = ( + torch.arange(llm_grid_h, dtype=dtype, device=device) + .view(1, -1, 1) + .expand(grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w, dtype=dtype, device=device) + .view(1, 1, -1) + .expand(grid_t, llm_grid_h, -1) + .flatten() + ) + return torch.stack([t_index, h_index, w_index], dim=0) + int(start_pos) + + +def vision_span_max_position( + grid_thw: torch.Tensor, + start_pos: int | float, + spatial_merge_size: int, + second_per_grid_t: float | None = None, + tokens_per_second: int = 2, +) -> int: + """Compute one past the largest position id this vision span produces. + + Useful for advancing the global ``start_pos`` counter past a vision + span when the next walk needs to know where text positions resume + (mirrors ``modeling_bailing_moe_v2.get_rope_index``'s + ``llm_pos_ids_list[-1].max() + 1`` accounting at the end of an + image span). + """ + grid_t = int(grid_thw[0].item()) + grid_h = int(grid_thw[1].item()) + grid_w = int(grid_thw[2].item()) + llm_grid_h = grid_h // spatial_merge_size + llm_grid_w = grid_w // spatial_merge_size + + if second_per_grid_t is not None: + max_t = int((grid_t - 1) * float(second_per_grid_t) * float(tokens_per_second)) + else: + max_t = grid_t - 1 + max_h = llm_grid_h - 1 + max_w = llm_grid_w - 1 + return int(start_pos) + max(max_t, max_h, max_w) + 1 + + +__all__ = [ + "get_rope_index_text", + "get_rope_index_audio", + "get_rope_index_vision", + "vision_span_max_position", +] diff --git a/mminf/model/ming_omni_flash/components/projectors.py b/mminf/model/ming_omni_flash/components/projectors.py new file mode 100644 index 00000000..337e3e30 --- /dev/null +++ b/mminf/model/ming_omni_flash/components/projectors.py @@ -0,0 +1,165 @@ +"""Vision + audio projectors for Ming-flash-omni-2.0. + +Ports the two ``nn.Sequential`` blocks built inline in +``modeling_bailingmm2.py:BailingMM2NativeForConditionalGeneration.__init__`` +(lines 66-88 of the Ming source repo) into standalone modules that mminf +can load weights into directly. The released checkpoint stores the +weights under the top-level prefixes ``linear_proj.*`` (vision) and +``linear_proj_audio.*`` (audio): + + * Vision (mlp_depth=2): + linear_proj.0.{weight,bias} -> Linear(vision_out_hidden, llm_hidden) + [GELU at index 1, no params] + linear_proj.2.{weight,bias} -> Linear(llm_hidden, llm_hidden) + + * Audio (mlp_depth=2): + linear_proj_audio.0.{weight,bias} -> Conv1d(audio_d_model, llm_hidden, ds_kernel_size, ds_stride) + [Transpose at index 1, GELU at index 2, no params] + linear_proj_audio.3.{weight,bias} -> Linear(llm_hidden, llm_hidden) + [Transpose at index 4, no params] + +We mirror the upstream layer ordering exactly so the +``linear_proj.*`` / ``linear_proj_audio.*`` keys from the checkpoint land +on the right ``nn.Module`` slot via plain index-based lookup. +""" + +from __future__ import annotations + +import torch +from torch import nn + + +class _Transpose(nn.Module): + """Used inside ``nn.Sequential`` chains (modeling_utils.py:Transpose).""" + + def __init__(self, dim0: int, dim1: int) -> None: + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.transpose(self.dim0, self.dim1) + + +class MingVisionProjector(nn.Module): + """MLP projector: vision encoder output -> LLM hidden space. + + Args: + vision_dim: ``VisionEncoderConfig.out_hidden_size`` (4096 on the + released ckpt — the vision encoder already projects internally + via its ``PatchMerger``). + llm_dim: ``ThinkerLLMConfig.hidden_size`` (4096). + mlp_depth: ``MingFlashOmniModelConfig.mlp_depth`` (2 on the + released ckpt). depth=1 yields a single Linear; depth=N adds + (N-1) GELU+Linear pairs after it. + """ + + def __init__(self, vision_dim: int, llm_dim: int, mlp_depth: int = 2) -> None: + super().__init__() + if mlp_depth < 1: + raise ValueError(f"mlp_depth must be >= 1, got {mlp_depth}") + layers: list[nn.Module] = [nn.Linear(vision_dim, llm_dim)] + for _ in range(1, mlp_depth): + layers.append(nn.GELU()) + layers.append(nn.Linear(llm_dim, llm_dim)) + # Expose as ``proj`` (not raw ``nn.Sequential``) so subclassing / + # surgery has a stable name. Weight loading walks ``proj..*``. + self.proj = nn.Sequential(*layers) + + def forward(self, vision_embeds: torch.Tensor) -> torch.Tensor: + """Project vision tokens. + + Args: + vision_embeds: (N_tokens, vision_dim) or (B, N_tokens, vision_dim). + + Returns: + Same shape with the last dim replaced by ``llm_dim``. + """ + return self.proj(vision_embeds) + + +class MingAudioProjector(nn.Module): + """Conv1d-downsample + MLP projector: Whisper encoder -> LLM hidden space. + + Layer ordering matches ``modeling_bailingmm2.py`` exactly so the + released ckpt's ``linear_proj_audio.0`` / ``.3`` keys hit the Conv1d + and Linear by integer index. + + Args: + audio_dim: ``AudioEncoderConfig.d_model`` (= whisper n_state, + 1280 on the released ckpt). + llm_dim: ``ThinkerLLMConfig.hidden_size``. + ds_kernel_size: temporal kernel for the down-sample conv (3 on + the released ckpt). + ds_stride: temporal stride (2 on the released ckpt). + mlp_depth: ``MingFlashOmniModelConfig.mlp_depth`` (2 on the + released ckpt; depth=N adds (N-1) GELU+Linear pairs + after the conv). + """ + + def __init__( + self, + audio_dim: int, + llm_dim: int, + ds_kernel_size: int = 3, + ds_stride: int = 2, + mlp_depth: int = 2, + ) -> None: + super().__init__() + if mlp_depth < 1: + raise ValueError(f"mlp_depth must be >= 1, got {mlp_depth}") + self.ds_kernel_size = ds_kernel_size + self.ds_stride = ds_stride + self.audio_dim = audio_dim + self.llm_dim = llm_dim + + layers: list[nn.Module] = [ + nn.Conv1d( + audio_dim, + llm_dim, + kernel_size=ds_kernel_size, + stride=ds_stride, + padding=ds_kernel_size // 2, + ), + # Conv1d output is (B, llm_dim, T'); MLP wants (B, T', llm_dim). + _Transpose(-1, -2), + ] + for _ in range(1, mlp_depth): + layers.append(nn.GELU()) + layers.append(nn.Linear(llm_dim, llm_dim)) + # Trailing transpose flips back to (B, llm_dim, T') — that's the + # shape upstream callers expect after the projector. + layers.append(_Transpose(-1, -2)) + self.proj = nn.Sequential(*layers) + + def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: + """Project a packed (B, T, audio_dim) tensor. + + Args: + audio_embeds: (B, T, audio_dim) Whisper encoder output, channels-last. + + Returns: + (B, llm_dim, T') tensor, where + ``T' = (T - ds_kernel_size + 2*(ds_kernel_size//2)) // ds_stride + 1``. + """ + # Conv1d expects (B, C, T) — flip first. + x = audio_embeds.transpose(-1, -2) + return self.proj(x) + + def compute_output_length(self, input_length: torch.Tensor) -> torch.Tensor: + """Output sequence length after Whisper conv stems + this projector. + + Mirrors :func:`projectors.AudioProjector.compute_output_length` from + vllm-omni: the Whisper encoder has two fixed Conv1d stems (kernel=3, + stride=2 then stride=1 -> see ``whisper_encoder``); we then apply + ``Conv1d(ds_kernel_size, ds_stride)``. The Whisper stem formula + ``(L - 3 + 2) // 2 + 1`` applies once, then the projector conv. + """ + # Whisper encoder stem (conv1: kernel=3, pad=1, stride=2) + length = (input_length - 3 + 2 * 1) // 2 + 1 + # Projector conv (kernel=ds_kernel_size, pad=ds_kernel_size//2, stride=ds_stride) + length = (length - self.ds_kernel_size + 2 * (self.ds_kernel_size // 2)) // self.ds_stride + 1 + return length + + +__all__ = ["MingVisionProjector", "MingAudioProjector"] diff --git a/mminf/model/ming_omni_flash/components/rope.py b/mminf/model/ming_omni_flash/components/rope.py new file mode 100644 index 00000000..64d9c11e --- /dev/null +++ b/mminf/model/ming_omni_flash/components/rope.py @@ -0,0 +1,265 @@ +"""Ling-2.0 partial 3D rotary embeddings (``video_rope`` flavor). + +Ling-2.0's attention uses **partial rotary** (only the first +``head_dim * partial_rotary_factor`` dims of each head are rotated; the rest +pass through unchanged) with **3D MRoPE positions** (time / height / width +each get their own position id) in the ``video_rope`` cos/sin layout. + +The cos/sin layout is the unusual bit. Standard MRoPE places contiguous +frequency sections per axis: + + [ T T ... T H H ... H W W ... W ] (sizes mrope_section = [Nt, Nh, Nw]) + +Ling's ``video_rope`` interleaves H and W element-wise in the spatial +section and puts T at the end: + + [ H W H W ... H W T T ... T ] (sizes hw_size = Nh + Nw, Nt at tail) + +For pure-text positions (1D position_ids, no T/H/W split) the rotation +degenerates to the standard 1D rotary on the first ``rotary_dim`` dims. + +References +---------- +* Ming upstream ``apply_3d_rotary_pos_emb`` + ``/tmp/ming_repo/modeling_bailing_moe_v2.py:226-313`` (video_rope branch + is the ``elif rope_type == "video_rope"`` block). +* vllm-omni ``MingVideoRopeMRotaryEmbedding._remap_video_rope`` + ``/tmp/vllm-omni/vllm_omni/model_executor/models/ming_flash_omni/modeling_bailing_moe_v2.py:79-110`` + — same remap as ours; we port the math without depending on vllm. +""" + +from __future__ import annotations + +import torch +from torch import nn + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Standard neox-style rotary half-rotation: ``[-x2, x1]``.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _build_inv_freq(rotary_dim: int, theta: float) -> torch.Tensor: + """Standard rotary inverse-frequency table: ``theta ** (-2i / rotary_dim)`` for i in [0, rotary_dim/2).""" + return 1.0 / ( + theta ** (torch.arange(0, rotary_dim, 2, dtype=torch.float32) / rotary_dim) + ) + + +class LingPartialMRotaryEmbedding(nn.Module): + """Partial rotary + ``video_rope`` 3D MRoPE. + + Args: + head_dim: full head dim of the attention layer. + partial_rotary_factor: fraction of head_dim that's actually rotated + (the rest is concatenated pass-through). The model uses 0.5; + head_dim=128 → rotary_dim=64. + mrope_section: per-axis cos/sin section sizes. Released ckpt: + ``[8, 12, 12]``. The first is Nt (time), the rest are Nh + (height) and Nw (width); Nh+Nw must equal rotary_dim/2 − Nt + (i.e. the section sums to rotary_dim/2 — see config invariant). + rope_theta: rotary base frequency. Released ckpt: ``2_400_000``. + max_position_embeddings: max sequence length; precomputed cache size. + + The forward expects ``position_ids`` of shape ``(3, num_tokens)`` for + 3D positions or ``(num_tokens,)`` for plain 1D rope (degenerates to + standard rotary). + """ + + def __init__( + self, + head_dim: int, + partial_rotary_factor: float, + mrope_section: list[int], + rope_theta: float, + max_position_embeddings: int, + ) -> None: + super().__init__() + self.head_dim = head_dim + self.rotary_dim = int(head_dim * partial_rotary_factor) + if self.rotary_dim % 2 != 0: + raise ValueError( + f"rotary_dim must be even (got {self.rotary_dim}); check " + f"partial_rotary_factor." + ) + self.mrope_section = list(mrope_section) + if sum(self.mrope_section) != self.rotary_dim // 2: + raise ValueError( + f"sum(mrope_section)={sum(self.mrope_section)} must equal " + f"rotary_dim//2={self.rotary_dim // 2}" + ) + if len(self.mrope_section) != 3: + raise ValueError( + f"mrope_section must be length-3 [Nt, Nh, Nw]; got {self.mrope_section}" + ) + self.hw_size = self.mrope_section[1] + self.mrope_section[2] + + self.rope_theta = float(rope_theta) + self.max_position_embeddings = int(max_position_embeddings) + + # Cache inv_freq once; cos/sin tables are computed on first forward + # (lazy so we don't pay for max_position_embeddings * rotary_dim + # storage on CPU for tests). + self.register_buffer( + "inv_freq", + _build_inv_freq(self.rotary_dim, self.rope_theta), + persistent=False, + ) + + # ------------------------------------------------------------------ + # cos / sin cache + # ------------------------------------------------------------------ + + def _compute_cos_sin( + self, position_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute cos/sin for ``position_ids``. + + ``position_ids`` is ``(num_tokens,)`` or ``(3, num_tokens)``. + Returns ``cos, sin`` of shape ``(num_tokens, rotary_dim)`` in the + video_rope layout (H/W interleaved spatial + T tail). + """ + if position_ids.dim() == 1: + return self._cos_sin_1d(position_ids) + if position_ids.dim() != 2 or position_ids.shape[0] != 3: + raise ValueError( + f"position_ids must be (num_tokens,) or (3, num_tokens); " + f"got shape {tuple(position_ids.shape)}" + ) + return self._cos_sin_3d_video_rope(position_ids) + + def _cos_sin_1d( + self, position_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Standard 1D rotary cos/sin — used for pure-text positions.""" + # (num_tokens, rotary_dim/2) + freqs = position_ids.float().unsqueeze(-1) * self.inv_freq.unsqueeze(0) + # (num_tokens, rotary_dim) — neox style: cat freqs with themselves + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos(), emb.sin() + + def _cos_sin_3d_video_rope( + self, position_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """3D positions → video_rope layout. + + position_ids: ``(3, num_tokens)`` — row 0 = time, row 1 = height, + row 2 = width. + + Steps: + 1. Compute per-axis freqs: ``(3, num_tokens, rotary_dim/2)``. + 2. Form (cos, sin) of shape ``(3, num_tokens, rotary_dim)`` neox-style. + 3. Remap each rotary_dim/2 frequency-pair index ``i`` into: + - i < hw_size → H if i even, W if i odd + - i ≥ hw_size → T + Pairs ``(cos[i], cos[i + rotary_dim/2])`` correspond to the + same frequency, so the same row assignment applies to both + halves. + """ + # (3, num_tokens, rotary_dim/2) + freqs = position_ids.float().unsqueeze(-1) * self.inv_freq.view(1, 1, -1) + # (3, num_tokens, rotary_dim) — neox cat + cos_3d = torch.cat((freqs, freqs), dim=-1).cos() + sin_3d = torch.cat((freqs, freqs), dim=-1).sin() + return self._remap_video_rope(cos_3d, sin_3d) + + def _remap_video_rope( + self, cos_3d: torch.Tensor, sin_3d: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Remap per-axis cos/sin into the video_rope 2D layout. + + cos_3d, sin_3d: ``(3, num_tokens, rotary_dim)``. + Returns: ``(num_tokens, rotary_dim)``. + + Mirror of vllm-omni's ``_remap_video_rope`` with one difference: + we operate on the *full* rotary_dim tables (not the half-tables + chunked from the cos_sin cache), because we never built a cache — + we computed freqs in 1:1 correspondence with positions in the + forward path. The H/W alternation rule still picks the correct + index because each half of the neox-cat repeats the same + frequency. + """ + # Both halves of the rotary_dim (the first and second halves + # contain the same frequencies after the neox cat) get the same + # axis-assignment. So a single index i in [0, rotary_dim/2) picks + # a frequency-pair that should come from one axis. + half = self.rotary_dim // 2 + + result_cos = torch.empty_like(cos_3d[0]) + result_sin = torch.empty_like(sin_3d[0]) + + # Spatial half: H on even indices, W on odd indices, capped at hw_size. + # Then mirror to the second half (which holds the same freqs). + for offset in (0, half): + # H rows go on even positions [0, 2, 4, ...] up to hw_size + result_cos[:, offset : offset + self.hw_size : 2] = cos_3d[ + 1, :, offset : offset + self.hw_size : 2 + ] + result_cos[:, offset + 1 : offset + self.hw_size : 2] = cos_3d[ + 2, :, offset + 1 : offset + self.hw_size : 2 + ] + result_sin[:, offset : offset + self.hw_size : 2] = sin_3d[ + 1, :, offset : offset + self.hw_size : 2 + ] + result_sin[:, offset + 1 : offset + self.hw_size : 2] = sin_3d[ + 2, :, offset + 1 : offset + self.hw_size : 2 + ] + # Temporal tail + result_cos[:, offset + self.hw_size : offset + half] = cos_3d[ + 0, :, offset + self.hw_size : offset + half + ] + result_sin[:, offset + self.hw_size : offset + half] = sin_3d[ + 0, :, offset + self.hw_size : offset + half + ] + return result_cos, result_sin + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward( + self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Rotate the first ``rotary_dim`` dims of q and k in-place. + + Args: + q, k: ``(..., num_tokens, head_dim)`` (typical layout from + ParallelAttention is ``(num_tokens, num_heads, head_dim)``). + Only the last dim and the per-token axis matter. + position_ids: ``(num_tokens,)`` for 1D rope or + ``(3, num_tokens)`` for video_rope. + + Returns: + ``(q, k)`` with rotation applied to the rotary half. + """ + if q.shape[-1] != self.head_dim or k.shape[-1] != self.head_dim: + raise ValueError( + f"q/k last dim {q.shape[-1]}/{k.shape[-1]} != " + f"head_dim {self.head_dim}" + ) + + cos, sin = self._compute_cos_sin(position_ids) + # Broadcast cos/sin across the leading axes of q (typically a + # heads axis comes BEFORE the token axis: q is (..., heads, T, + # head_dim)). cos starts as (T, rotary_dim); we need to insert + # ones at every leading dim of q so the broadcast aligns + # (T at the second-to-last position, rotary_dim at the last). + while cos.dim() < q.dim(): + cos = cos.unsqueeze(0) + sin = sin.unsqueeze(0) + + q_rot, q_pass = q[..., : self.rotary_dim], q[..., self.rotary_dim :] + k_rot, k_pass = k[..., : self.rotary_dim], k[..., self.rotary_dim :] + cos_q = cos.to(q.dtype) + sin_q = sin.to(q.dtype) + cos_k = cos.to(k.dtype) + sin_k = sin.to(k.dtype) + + q_rot = (q_rot * cos_q) + (_rotate_half(q_rot) * sin_q) + k_rot = (k_rot * cos_k) + (_rotate_half(k_rot) * sin_k) + return ( + torch.cat([q_rot, q_pass], dim=-1), + torch.cat([k_rot, k_pass], dim=-1), + ) diff --git a/mminf/model/ming_omni_flash/components/router.py b/mminf/model/ming_omni_flash/components/router.py new file mode 100644 index 00000000..858d464a --- /dev/null +++ b/mminf/model/ming_omni_flash/components/router.py @@ -0,0 +1,159 @@ +"""Ling-2.0 MoE router with grouped expert selection. + +Ling-2.0 (BailingMoeV2) uses ``router_type: "MultiRouter"``, which differs from +mminf's standard :class:`mminf.model.components.moe.TopKRouter` in four ways: + + * **Sigmoid** activation on the gate logits, not softmax. + * A learned per-expert bias added to the routing scores before top-k — + not gradient-trained on this checkpoint (stored as ``requires_grad=False``). + * **Group-limited top-k**: the ``num_experts`` are partitioned into + ``n_group`` groups; tokens may only route to experts within the + ``topk_group`` highest-scoring groups (group score = sum of top-2 + expert scores in that group). This caps cross-group all-to-all + bandwidth at the cost of expressiveness. + * Weights are renormalised to sum to 1 across the chosen top-k and then + multiplied by ``routed_scaling_factor``. + +Returns the same 3-tuple as :class:`TopKRouter` (``logits, weights, indices``) +so it can drop into mminf's existing :class:`SparseMoeBlockWithSharedExpert` +and the fused-Triton dispatch path. + +Reference: vllm-omni's ``BailingMoeV2Gate`` +``/tmp/vllm-omni/vllm_omni/model_executor/models/ming_flash_omni/modeling_bailing_moe_v2.py:211-279`` +and Ming upstream ``modeling_bailing_moe_v2.py:696-765``. +""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch import nn + + +class LingMoeRouter(nn.Module): + """Ling-2.0 ``MultiRouter`` (group-limited top-k with sigmoid + bias). + + Args: + hidden_size: input hidden dimension. + num_experts: total routed experts. Must divide evenly by ``n_group``. + num_experts_per_tok: top-k experts selected per token. + n_group: expert groups; the experts are split contiguously by + ``num_experts // n_group``. + topk_group: how many groups a single token may route into. + routed_scaling_factor: post-renormalisation scale applied to the + top-k weights (matches upstream ``routed_scaling_factor``). + + The gate ``nn.Linear`` weight is **replicated** across TP ranks in the + parallel build (router decisions must be identical across ranks); for + this step-3a unit-test scope we just expose a plain ``nn.Linear``. + """ + + def __init__( + self, + hidden_size: int, + num_experts: int, + num_experts_per_tok: int, + n_group: int, + topk_group: int, + routed_scaling_factor: float = 1.0, + ) -> None: + super().__init__() + if num_experts % n_group != 0: + raise ValueError( + f"num_experts={num_experts} must be divisible by n_group={n_group}" + ) + if topk_group > n_group: + raise ValueError( + f"topk_group={topk_group} cannot exceed n_group={n_group}" + ) + self.hidden_size = hidden_size + self.num_experts = num_experts + self.top_k = num_experts_per_tok + self.n_group = n_group + self.topk_group = topk_group + self.experts_per_group = num_experts // n_group + self.routed_scaling_factor = routed_scaling_factor + + # Gate projection — replicated (no bias). + self.gate = nn.Linear(hidden_size, num_experts, bias=False) + + # Expert bias — not gradient-trained, but stored as a parameter so + # state_dict loaders see it. + self.expert_bias = nn.Parameter( + torch.zeros(num_experts), requires_grad=False, + ) + + def _group_limited_topk(self, scores: torch.Tensor) -> torch.Tensor: + """Pick the top-k experts under the ``topk_group``-best-groups constraint. + + Args: + scores: ``(num_tokens, num_experts)``. Already sigmoid + bias. + + Returns: + ``(num_tokens, top_k)`` int64 expert indices. + + Per-group score = sum of that group's top-2 expert scores. The + ``topk_group`` groups with the highest per-group scores are kept; + the rest are masked out before the final top-k. + """ + num_tokens = scores.size(0) + # (N, n_group, experts_per_group) + grouped = scores.view(num_tokens, self.n_group, self.experts_per_group) + # Per-group score: sum of top-2 expert scores in that group. + # Matches upstream exactly (``.topk(2, dim=-1)[0].sum(dim=-1)``). + group_scores = grouped.topk(2, dim=-1)[0].sum(dim=-1) + # Pick the topk_group best groups. + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1.0) + # Broadcast group mask back across experts_per_group. + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_tokens, self.n_group, self.experts_per_group) + .reshape(num_tokens, -1) + ) + # Mask un-selected groups' experts to -inf so they can't be picked. + masked = scores.masked_fill(~score_mask.bool(), float("-inf")) + return torch.topk(masked, k=self.top_k, dim=-1, sorted=False)[1] + + def forward( + self, hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Route tokens to experts. + + Args: + hidden_states: ``(..., hidden_size)``. Flattened internally. + + Returns: + Three tensors matching :class:`TopKRouter`'s shape: + - ``router_logits``: ``(N, num_experts)`` raw gate logits + (pre-sigmoid). Kept as float32 for stability and parity + with ``TopKRouter``. + - ``routing_weights``: ``(N, top_k)`` normalised + scaled + weights for the chosen experts. + - ``selected_experts``: ``(N, top_k)`` int64 expert indices. + """ + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) + # Linear is rank-replicated; the float() cast matches upstream's + # ``logits = logits.float()`` for numeric stability. + logits = F.linear(hidden_states, self.gate.weight).float() + # Per-expert sigmoid (NOT softmax). Bias is added AFTER sigmoid + # in the routing path; the gathered weights below pull from the + # un-biased sigmoid scores. + sigmoid_scores = torch.sigmoid(logits) + scored_for_routing = sigmoid_scores + self.expert_bias + + selected_experts = self._group_limited_topk(scored_for_routing) + # Gather the un-biased sigmoid score for the chosen experts. + chosen_scores = torch.gather( + sigmoid_scores, dim=1, index=selected_experts, + ).to(logits.dtype) + if self.top_k > 1: + chosen_scores = chosen_scores / ( + chosen_scores.sum(dim=-1, keepdim=True) + 1e-20 + ) + routing_weights = chosen_scores * self.routed_scaling_factor + + return logits, routing_weights, selected_experts diff --git a/mminf/model/ming_omni_flash/components/vision_encoder.py b/mminf/model/ming_omni_flash/components/vision_encoder.py new file mode 100644 index 00000000..7b64bda9 --- /dev/null +++ b/mminf/model/ming_omni_flash/components/vision_encoder.py @@ -0,0 +1,149 @@ +"""Vision encoder factory for Ming-flash-omni-2.0. + +The Ming-flash-omni-2.0 vision encoder is ``Qwen3MoeVisionTransformer`` +from the Ming source repo's ``qwen3_moe_vit.py`` (574 LOC). Rather than +fork the file, we resolve it dynamically from the staged Ming source dir +that ``MingFlashOmniModel.__init__`` already symlinks alongside the +snapshot (see ``_prepare_tokenizer_dir``). + +The vllm-omni port (``vision_encoder.py:MingVisionEncoder``) wraps +vLLM's ``Qwen3Omni_VisionTransformer`` because vLLM ships a TP/quant- +aware re-implementation. mminf doesn't have vLLM as a dep, and the +upstream encoder runs at full quality on a single GPU (~1 GB at bf16), +so we use the reference implementation as-is. The encoder is built once +per process and lives on the rank that owns the ``vision_encoder`` graph +node (typically rank 0; see ``configs/ming_flash_omni.yaml``). + +Returned encoder's ``.forward(hidden_states, grid_thw)`` matches the +upstream signature: returns a single ``(N_tokens, out_hidden_size)`` +tensor when ``use_deepstack=False`` (the default for the released ckpt, +since the LLM-side DeepStack splicing isn't enabled in step 4), or a +``(hidden_states, deepstack_feature_lists)`` tuple when +``use_deepstack=True``. +""" + +from __future__ import annotations + +import importlib +import logging +import sys +from pathlib import Path + +import torch +from torch import nn + +from mminf.model.ming_omni_flash.config import VisionEncoderConfig + +logger = logging.getLogger(__name__) + + +def _import_ming_vit(local_dir: str | None = None) -> type[nn.Module]: + """Resolve ``Qwen3MoeVisionTransformer`` from the staged Ming source. + + ``MingFlashOmniModel.__init__`` pushes the snapshot dir onto + ``sys.path`` and symlinks ``qwen3_moe_vit.py`` into it (see + ``_MING_CODE_FILES`` and ``_prepare_tokenizer_dir``). We import via + that path so all the other dynamic imports the file performs + (e.g. ``from configuration_bailingmm2 import ...``) keep resolving + against the same staged tree. + + Args: + local_dir: Optional snapshot dir to put on ``sys.path`` first. + Callers that bypass ``MingFlashOmniModel.__init__`` (tests, + standalone benchmarks) can pass this to avoid an + ``ImportError`` on a fresh interpreter. + """ + if local_dir is not None: + if str(local_dir) not in sys.path: + sys.path.insert(0, str(local_dir)) + # Also push the Ming source repo (if discoverable) so the dynamic + # imports inside qwen3_moe_vit.py resolve cross-file. The snapshot + # is the symlink staging dir; we discover any "real" source by + # following one of the staged symlinks back to its target. + candidate = Path(local_dir) / "qwen3_moe_vit.py" + if candidate.is_symlink(): + ming_root = Path(candidate).resolve().parent + if str(ming_root) not in sys.path: + sys.path.insert(0, str(ming_root)) + + try: + module = importlib.import_module("qwen3_moe_vit") + except ImportError as e: + raise ImportError( + "Could not import qwen3_moe_vit. Ensure MingFlashOmniModel " + "was constructed (which stages the Ming source files), or " + "pass local_dir= explicitly. See " + "PORTING_NOTES.md 'Ming source dependency' for setup." + ) from e + + return module.Qwen3MoeVisionTransformer + + +def build_vision_encoder( + config: VisionEncoderConfig, + use_deepstack: bool = False, + dtype: torch.dtype = torch.bfloat16, + device: str | torch.device = "cpu", + attn_implementation: str = "flash_attention_2", + local_dir: str | None = None, +) -> nn.Module: + """Construct the Ming vision encoder. + + Args: + config: VisionEncoderConfig from MingFlashOmniModelConfig. + use_deepstack: Whether ``.forward()`` returns the per-checkpoint + deepstack feature lists. Off by default — the + LLM-side DeepStack splice lands with step 5 + (thinker graph walks for vision prefill). + dtype: Cast the encoder to this dtype after construction. + bf16 matches the released ckpt; fp16 also works. + device: Final device for the encoder weights. + attn_implementation: Maps to ``config._attn_implementation`` on the + internal Qwen3VLMoeVisionConfig. ``flash_attention_2`` + is mandatory for video performance — sdpa falls + into the per-segment Python loop (see qwen3_omni + model.py:1508-1519 for the same gotcha). + local_dir: Snapshot directory to add to sys.path if the Ming + source modules aren't already importable. + + Returns: + An ``nn.Module`` ready to consume ``(pixel_values, grid_thw)``. + Weight loading is the caller's job — Ming stores vision encoder + weights under the top-level ``vision.*`` prefix in the released + ckpt. + """ + Qwen3MoeVisionTransformer = _import_ming_vit(local_dir=local_dir) + + # Build the internal config the Ming module expects. + module = sys.modules["qwen3_moe_vit"] + InternalConfig = module.Qwen3VLMoeVisionConfig + internal_config = InternalConfig( + depth=config.depth, + hidden_size=config.hidden_size, + hidden_act=config.hidden_act, + intermediate_size=config.intermediate_size, + num_heads=config.num_heads, + in_channels=config.in_channels, + patch_size=config.patch_size, + spatial_merge_size=config.spatial_merge_size, + temporal_patch_size=config.temporal_patch_size, + out_hidden_size=config.out_hidden_size, + num_position_embeddings=config.num_position_embeddings, + deepstack_visual_indexes=list(config.deepstack_visual_indexes), + ) + # The attention path branches on _attn_implementation. The Ming + # source hard-codes it to "flash_attention_2" inside __init__ of + # Qwen3VLMoeVisionAttention, but we set it on the config too for + # the rare debug path that wants to flip to "sdpa" or "eager". + internal_config._attn_implementation = attn_implementation + + encoder = Qwen3MoeVisionTransformer( + internal_config, + use_deepstack=use_deepstack, + ) + encoder = encoder.to(dtype=dtype, device=device) + encoder.eval() + return encoder + + +__all__ = ["build_vision_encoder"] diff --git a/mminf/model/ming_omni_flash/config.py b/mminf/model/ming_omni_flash/config.py new file mode 100644 index 00000000..e356da24 --- /dev/null +++ b/mminf/model/ming_omni_flash/config.py @@ -0,0 +1,526 @@ +"""Configuration dataclass for Ming-flash-omni-2.0. + +Mirrors mminf's qwen3_omni pattern (pure ``@dataclass`` tree, +``from_pretrained(local_dir)``, convenience ``@property``s) so the rest of +the framework can read dims off the loaded config without going through +``transformers.PretrainedConfig`` machinery. + +The released checkpoint (``inclusionAI/Ming-flash-omni-2.0``) does NOT match +upstream vllm-omni's flat ``MingFlashOmniConfig`` nesting. On disk only the +``BailingMM2Config`` shape lives at ``config.json``:: + + config.json # thinker: audio_config + llm_config + vision_config + scalars + talker/config.json # talker top-level (BailingTalker2) + talker/llm/config.json # talker LLM backbone (Qwen2) + talker/vae/config.json # talker AudioVAE + transformer/config.json # image-gen DiT (ZImageTransformer2DModel) + vae/config.json # image-gen VAE + scheduler/scheduler_config.json # image-gen diffusion scheduler + byt5/google__byt5-smal/config.json # image-gen text encoder + connector/config.json # image-gen connector + mlp/config.json # image-gen projector + +This loader follows the on-disk layout: it parses ``config.json`` for the +thinker path and lazy-loads talker / image-gen from sibling subdirs when +those exist. Talker and image-gen are SKELETON dataclasses today — exhaustive +field semantics land with the talker port (step 6 of PORTING_NOTES.md) and +the image-gen port (step 9). +""" + +from __future__ import annotations + +import json +import logging +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Thinker LLM (Ling-2.0 sparse MoE — model_type "bailing_moe_v2") +# --------------------------------------------------------------------------- + +@dataclass +class ThinkerLLMConfig: + """Ling-2.0 sparse-MoE thinker (BailingMoeV2). + + Field set is the union of what upstream + ``vllm_omni/transformers_utils/configs/ming_flash_omni.py:BailingMoeV2Config`` + declares and what the released ``llm_config`` actually populates. + Defaults reflect the released ckpt, not the upstream class defaults + (which were trained for a smaller config). + """ + + # Dims + vocab_size: int = 157184 + hidden_size: int = 4096 + intermediate_size: int = 9216 + num_hidden_layers: int = 32 + num_attention_heads: int = 32 + num_key_value_heads: int = 4 + head_dim: int | None = None # computed in __post_init__ + + # Norm / activation + hidden_act: str = "silu" + rms_norm_eps: float = 1e-6 + use_qk_norm: bool = True + use_qkv_bias: bool = False + use_bias: bool = False + tie_word_embeddings: bool = False + + # Position / RoPE + max_position_embeddings: int = 32768 + rope_theta: float = 2_400_000.0 + rope_scaling: dict[str, Any] | None = None + partial_rotary_factor: float = 0.5 + + # MoE + num_experts: int = 256 + num_shared_experts: int = 1 + num_experts_per_tok: int = 8 + moe_intermediate_size: int = 1024 + first_k_dense_replace: int = 1 + router_type: str = "MultiRouter" + n_group: int = 8 + topk_group: int = 4 + moe_router_topk_scaling_factor: float = 2.5 + norm_topk_prob: bool = True + use_expert_bias: bool = True + output_router_logits: bool = False + + # Misc + pad_token_id: int = 156892 + eos_token_id: int = 156895 + use_interleaved_frame_timestamp: bool = True + + # Multimodal token IDs (used by the prefill processor / chat template). + # Defaults mirror the actual tokenizer (`tokenizer.json` added_tokens at + # the released ckpt; cross-checked against Jonathan1909's patched config + # and vllm-omni's BailingMoeV2Config defaults). Two gotchas the on-disk + # `config.json` of `inclusionAI/Ming-flash-omni-2.0` introduces: + # * `video_start_token` is mislabeled as 157159 (= ) in the + # ckpt config; the real `