From 5dea98c15a6b1bc2d2dee3ca68fb9e31c408fbb2 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 1 Jun 2026 22:59:19 -0700 Subject: [PATCH] [gemma4_31b] Branch 1/4: vision tower + vision quant + Python eager vision (g4-vision-quant) First branch of the g4-vision decomposition. Adds the vision tower, vision quantization, and the Python eager (CUDA) vision inference path, plus tests. The ExecuTorch export + C++ runtime stay TEXT-ONLY with the exact same exported contract as `main` (token-input `prefill` + `decode`), so `main.cpp` and `CMakeLists.txt` are byte-identical to `main` and the existing runner keeps working unchanged. CUDA/MLX/GGUF vision are added in the following branches. Scope (built on top of `main`) ------------------------------ Vision model + quant + Python eager: - vision_tower.py: full Gemma 4 vision tower (patch embedder, 2D-RoPE encoder, pooler, multimodal embedder, Gemma4_31BVisionTower wrapper, HF key maps). - model.py: multimodal-by-default model (vision_tower / embed_vision attached), embeds-based `forward(inputs_embeds, ...)`, `embed_text`, `decode_forward`, `_run_blocks`, vision RoPE buffer materialization, vision-aware HF loading. - pack_vision.py: backend-agnostic INT8 vision position-table quant + patch embedder packer; quant/pack.py buffer-registration tweak for the PE int8 buffers; quantize_and_save.py made vision-aware. - inference.py: Python eager vision path (`generate_with_image`, `_build_vision_encoder`, `--image-path`, `--max-vision-soft-tokens`); text loop uses `decode_forward` (forward now takes embeds). Text-only ET export (main-compatible contract via a fake-prefill wrapper): - export.py: exports the SAME two methods as `main` -- token-input `prefill` (T>=2, dynamic) and `decode` (T=1). The token-input `prefill` is realized with a temporary bound-method that fuses `embed_text` + the embeds-`forward` into one method (mirroring the MLX `_mlx_model_forward(tokens) = prefill(embed_text(tokens))` pattern); `decode` maps to `decode_forward`. The vision head is dropped (`_drop_vision_head`) before lowering, so the exported .pte method names + signatures are identical to `main`. Loading stays vision-aware (the prequant checkpoint carries vision weights). No `embed_text` / `vision_encoder` method, no `--max-vision-soft-tokens`. - main.cpp / CMakeLists.txt: unchanged from `main` (verified byte-identical). Refactors landed at the model/quant layer: - R1: vision_tower.py reuses the shared `Gemma4MLP` and `rotate_half` from examples/models/gemma4/text_decoder (drops the duplicate `Gemma4VisionMLP` and local `_rotate_half`; submodule names gate/up/down_proj unchanged so HF key maps stay valid). - R3: drops the `RMSNorm` / `RMSNormNoWeight` stopgap from text_decoder/__init__.py; vision_tower.py uses `nn.RMSNorm(...)` and `nn.RMSNorm(..., elementwise_affine=False)` inline. text_decoder/__init__.py is now byte-identical to `main`. - R2 (Python half): the 5 chat-template special-token IDs + `build_vision_input_ids` are extracted into a shared module examples/models/gemma4/chat_template.py; inference.py imports from it. The C++ header half lands in branch 2. Docs (review finding #3): model.py `forward` and vision_tower.py state the scaling convention explicitly -- `embed_text` scales text rows by sqrt(hidden_size); the vision-tower output is NOT pre-scaled (matches HF). Test fixes: - test_cuda_pipeline.py: updated for the new model contract (review finding #2) -- T=1 calls use `decode_forward`; multi-token prefill uses `forward(embed_text(tokens), ...)`. Fixes the previously-failing TestInt4Inference::test_inference_produces_valid_output. - test_pipeline.py: added vision-aware quant/save/load coverage + TestVisionConfigRequired. Also fixed a pre-existing fixture bug in build_hf_checkpoint -- vision encoder-layer attn/mlp projections are now written with the HF `.linear.` segment (Gemma4ClippableLinear) so they match hf_vision_per_layer_key_map(); without this the --model-dir export test silently skipped 7 vision projection weights. - New: tests/test_vision_tower.py, tests/test_vision_quant_roundtrip.py. Verification ------------ - CPU: pytest test_vision_tower.py test_vision_quant_roundtrip.py test_pipeline.py -> 20 passed. - CUDA (2x A100): pytest test_cuda_pipeline.py -> 8 passed, incl. both export tests (text-only prefill+decode .pte/.ptd produced and serialized) and the chunked-prefill / int4 inference contract tests. - flake8 + ufmt clean on all changed/new files. - MLX export+runtime e2e and the Python eager --image-path smoke are to be run by the user (no MLX hardware here; image checkpoint user-provided). --- examples/models/gemma4/chat_template.py | 82 ++ examples/models/gemma4_31b/export.py | 154 +++- examples/models/gemma4_31b/inference.py | 323 +++++++- examples/models/gemma4_31b/model.py | 205 ++++- examples/models/gemma4_31b/pack_vision.py | 318 ++++++++ examples/models/gemma4_31b/quant/pack.py | 17 +- .../models/gemma4_31b/quantize_and_save.py | 45 +- .../gemma4_31b/tests/test_cuda_pipeline.py | 31 +- .../models/gemma4_31b/tests/test_pipeline.py | 87 +- .../tests/test_vision_quant_roundtrip.py | 298 +++++++ .../gemma4_31b/tests/test_vision_tower.py | 233 ++++++ examples/models/gemma4_31b/vision_tower.py | 759 ++++++++++++++++++ 12 files changed, 2448 insertions(+), 104 deletions(-) create mode 100644 examples/models/gemma4/chat_template.py create mode 100644 examples/models/gemma4_31b/pack_vision.py create mode 100644 examples/models/gemma4_31b/tests/test_vision_quant_roundtrip.py create mode 100644 examples/models/gemma4_31b/tests/test_vision_tower.py create mode 100644 examples/models/gemma4_31b/vision_tower.py diff --git a/examples/models/gemma4/chat_template.py b/examples/models/gemma4/chat_template.py new file mode 100644 index 00000000000..f8d944c4937 --- /dev/null +++ b/examples/models/gemma4/chat_template.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Shared Gemma 4 chat-template special-token IDs and the image+text input builder. + +Single source of truth for the multimodal chat-template token layout. The +Python eager runner (``gemma4_31b/inference.py``) imports from here, and the C++ +runner mirrors these exact values in ``gemma4/runner/chat_template.h`` so the two +implementations can never drift. + +Image+text turn layout (matches the Gemma 4 HF chat template): + + user\n*N{prompt}\n + model\n +""" + +# Gemma 4 special token IDs (match the tokenizer + the C++ runner constants). +BOS_ID = 2 +TURN_START_ID = 105 # +TURN_END_ID = 106 # +BOI_TOKEN_ID = 255999 # +IMAGE_TOKEN_ID = 258880 # soft-token placeholder +EOI_TOKEN_ID = 258882 # + + +def build_vision_input_ids( + tokenizer, + prompt: str, + num_vision_tokens: int, + bos_id: int = BOS_ID, +) -> list[int]: + """Build the chat-template token sequence for an image+text turn. + + Produces the same layout the C++ runner builds in + ``gemma4/runner/chat_template.h::build_vision_input_ids``: + + user\\n*N{prompt}\\n + model\\n + + Args: + tokenizer: a ``tokenizers.Tokenizer``-like object exposing + ``encode(str).ids``. + prompt: the user text prompt. + num_vision_tokens: number of ```` soft-token placeholders to + insert (one per valid vision soft token). + bos_id: beginning-of-sequence id (defaults to the Gemma 4 BOS). + + Returns: + The flat list of token IDs for the turn. + """ + user_tokens = tokenizer.encode("user\n").ids + prompt_tokens = tokenizer.encode(prompt).ids + newline_tokens = tokenizer.encode("\n").ids + model_tokens = tokenizer.encode("model\n").ids + + ids: list[int] = [] + ids.append(bos_id) + ids.append(TURN_START_ID) + ids.extend(user_tokens) + ids.append(BOI_TOKEN_ID) + ids.extend([IMAGE_TOKEN_ID] * num_vision_tokens) + ids.append(EOI_TOKEN_ID) + ids.extend(prompt_tokens) + ids.append(TURN_END_ID) + ids.extend(newline_tokens) + ids.append(TURN_START_ID) + ids.extend(model_tokens) + return ids + + +__all__ = [ + "BOS_ID", + "TURN_START_ID", + "TURN_END_ID", + "BOI_TOKEN_ID", + "IMAGE_TOKEN_ID", + "EOI_TOKEN_ID", + "build_vision_input_ids", +] diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index ed3dcdba9c3..0b7912a233b 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -4,15 +4,41 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Export Gemma 4 31B-IT to ExecuTorch (.pte + .ptd). +"""Export Gemma 4 31B-IT to ExecuTorch (.pte + .ptd) — text-only contract. -Two methods are exported and lowered together so they share KV-cache buffers: - - "decode": T=1, static shape, returns the next sampled token. - - "prefill": T>=2, dynamic shape, returns the next sampled token. +Two methods are exported and lowered together so they share KV-cache buffers, +EXACTLY matching the pre-vision ``main`` contract (so the C++ runner and +CMake build need no changes): + + - "decode": (tokens [1,1] i64, input_pos [1] i64, temperature [1] f32) + -> sampled [1,1] f32. T=1, static shape. + - "prefill": (tokens [1,T] i64, input_pos [T] i64, temperature [1] f32) + -> sampled [1,1] f32. T>=2, dynamic shape. + +Vision is loaded but NOT exported here +===================================== + +Gemma 4 31B is multimodal-by-default: ``model.py`` always constructs the +``vision_tower`` / ``embed_vision`` submodules and the quantized checkpoint +carries vision weights, so loading + quantization + packing are all +vision-aware (INT8 vision position table, vision patch-embedder packer). + +The exported ``.pte`` however is **text-only**: the vision head is dropped just +before lowering (``_drop_vision_head``) and the two exported methods use +TOKEN inputs, identical in name and signature to ``main``. The token-input +``prefill`` is realized with a temporary bound-method that fuses +``embed_text`` + the embeddings-``forward`` into one method — mirroring the MLX +``_mlx_model_forward(tokens) = prefill(embed_text(tokens))`` pattern. ``decode`` +maps to the existing token-input ``decode_forward``. + +The embeddings-based 4-method vision contract (``embed_text``, +``vision_encoder``, embeds-``prefill``, ``decode``) is introduced on top of this +in the ``g4-vision-cuda`` branch. Three input paths: --prequantized Load a quantized checkpoint (from quantize_and_save.py) and pack for the target backend. No re-quantization. + This is the primary path (checkpoint includes vision). --gguf Load a GGUF file (e.g., Q4_K_M from the community). --model-dir Load bf16 checkpoint, quantize, pack, and export in one shot. @@ -34,6 +60,11 @@ Gemma4_31BConfig, materialize_runtime_buffers, ) +from executorch.examples.models.gemma4_31b.pack_vision import ( + pack_vision_patch_embedder, + quantize_vision_position_table, +) +from executorch.examples.models.gemma4_31b.vision_tower import Gemma4VisionPatchEmbedder # --------------------------------------------------------------------------- @@ -45,7 +76,12 @@ def load_prequantized_model( max_seq_len: int = 4096, backend: str = "cuda", ) -> tuple[Gemma4_31B, Gemma4_31BConfig]: - """Load a quantized checkpoint and pack for the target backend.""" + """Load a quantized checkpoint and pack for the target backend. + + The checkpoint contains vision keys (Gemma 4 31B is multimodal-by-default) + so packing installs the vision patch-embedder packer too. The vision head + is dropped at export time — see ``_drop_vision_head``. + """ config = Gemma4_31BConfig.from_hf_config( os.path.join(prequantized_dir, "config.json") ) @@ -84,6 +120,9 @@ def load_and_quantize( model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) print(f"Quantizing with recipe '{recipe_name}'...") + # Pre-quantize the vision position-embedding table to INT8 before the INT4 + # text recipe runs (the PE table is a large [2, N, D] buffer, not a Linear). + quantize_vision_position_table(model.vision_tower) state_dict = quantize_model(model, recipe, verbose=True) print(f"Packing for {backend}...") @@ -107,31 +146,84 @@ def _get_packers(backend: str) -> dict: if backend == "cuda": from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS - return DEFAULT_CUDA_PACKERS + return { + **DEFAULT_CUDA_PACKERS, + Gemma4VisionPatchEmbedder: pack_vision_patch_embedder, + } if backend == "mlx": from executorch.examples.models.gemma4_31b.quant import DEFAULT_MLX_PACKERS - return DEFAULT_MLX_PACKERS + return { + **DEFAULT_MLX_PACKERS, + Gemma4VisionPatchEmbedder: pack_vision_patch_embedder, + } raise ValueError( f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}." ) def _pack_for_backend(model: nn.Module, path: str, backend: str) -> None: + packers = _get_packers(backend) if backend == "cuda": from executorch.examples.models.gemma4_31b.quant import load_and_pack_for_cuda - load_and_pack_for_cuda(path, model) + load_and_pack_for_cuda(path, model, packers=packers) elif backend == "mlx": from executorch.examples.models.gemma4_31b.quant import load_and_pack_for_mlx - load_and_pack_for_mlx(path, model) + load_and_pack_for_mlx(path, model, packers=packers) else: raise ValueError( f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}." ) +# --------------------------------------------------------------------------- +# Text-only contract helpers +# +# The 2 exported methods live on the SAME Gemma4_31B instance. torch.export +# only takes an nn.Module (not a bound method), so we temporarily shadow +# ``model.forward`` with the per-method bound method for the duration of the +# export call. Critically the model's CLASS identity does NOT change, so all +# ExportedPrograms share identical mutable-buffer FQNs +# (``layers.X.self_attn.kv_cache.k_cache`` ...). This is what lets +# ``to_executorch(share_mutable_buffers=True)`` unify the prefill / decode +# KV-cache buffers under ONE runtime tensor. + + +class _BoundMethodForward: + """Context manager: temporarily set ``model.forward`` to a bound method.""" + + def __init__(self, model: nn.Module, bound_method) -> None: + self._model = model + self._bound = bound_method + + def __enter__(self): + self._model.forward = self._bound # instance-attr; shadows the class method + return self._model + + def __exit__(self, *exc): + # ``del`` restores the class method via the descriptor protocol. + try: + del self._model.forward + except AttributeError: + pass + return False + + +def _drop_vision_head(model: nn.Module) -> None: + """Detach the multimodal head before text-only lowering. + + Gemma 4 31B is multimodal-by-default, so ``vision_tower`` / ``embed_vision`` + are always constructed and loaded. The text-only ``.pte`` does not ship + them, so we delete them here BEFORE ``materialize_runtime_buffers`` (which + only rebuilds the vision RoPE table when ``vision_tower`` is still attached). + """ + for name in ("vision_tower", "embed_vision"): + if hasattr(model, name): + delattr(model, name) + + # --------------------------------------------------------------------------- # Export + lower @@ -162,6 +254,7 @@ def export_and_lower( def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None: import gc + import types import torch._inductor.config as inductor_config @@ -182,17 +275,30 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - # Register Int4Tensor dispatch → executorch_cuda::int4_plain_mm shim import executorch.backends.cuda.int4_dispatch # noqa: F401 + # Text-only contract: drop the multimodal head, then materialize buffers + # (materialize skips the vision RoPE table once vision_tower is gone). + _drop_vision_head(model) materialize_runtime_buffers(model, dtype=torch.bfloat16) # Int4Tensor weights are used directly — no format conversion. # F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim). # Both decode and prefill share the same nibble-packed weights. - # Prefill (T>=2): shim does dequant+cuBLAS (optimal for large M). max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2) + + # ---------- prefill (tokens, T>=2, dynamic) ---------- + # ``main``'s contract is a TOKEN-input prefill, but the model's ``forward`` + # takes inputs_embeds. Realize the token-input prefill with a temporary + # bound-method that fuses embed_text + the embeds-forward into one method + # (mirroring MLX ``_mlx_model_forward(tokens) = prefill(embed_text(tokens))``). + def _text_prefill(self, tokens, input_pos, temperature): + return self._run_blocks(self.embed_text(tokens), input_pos, temperature) + seq_dim = Dim("seq_len", min=5, max=max_prefill) - print(f"Exporting prefill (T in [2, {max_prefill}])...") - with torch.no_grad(): + print(f"Exporting prefill (T in [2, {max_prefill}], tokens input)...") + with _BoundMethodForward( + model, types.MethodType(_text_prefill, model) + ), torch.no_grad(): prefill_ep = export( model, ( @@ -204,9 +310,13 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - strict=True, ) - # Decode (T=1): same Int4Tensor weights, same format. No transform needed. - print("Exporting decode (T=1)...") - with torch.no_grad(): + # ---------- decode (tokens, T=1, static) ---------- + # Bound-method swap (NOT __class__ swap) so the model instance's class — + # and therefore its mutable-buffer FQN identity — is preserved across both + # ExportedPrograms. This is what lets share_mutable_buffers unify the + # prefill / decode KV-cache buffers under one runtime tensor. + print("Exporting decode (T=1, tokens input)...") + with _BoundMethodForward(model, model.decode_forward), torch.no_grad(): decode_ep = export( model, ( @@ -292,13 +402,12 @@ def _export_mlx( output_dir: str, use_turboquant: bool = False, ) -> None: - """Export to .pte via torch.export + MLX backend. + """Export to .pte via torch.export + MLX backend (text-only contract). - Unlike CUDA (which exports separate decode/prefill methods with an - Int4Tensor dispatch override), MLX uses a single method with dynamic - sequence length. No int4_dispatch import — IntxUnpackedToInt8Tensor's - default dispatch produces the ``dequantize_affine → linear`` pattern - that MLX's QuantizedLinearHandler matches. + Single token-input method with dynamic sequence length; MLX samples on the + host so there is no temperature input. ``mlx_source_transformations`` + installs the token-input ``forward`` used here. The vision head is dropped + first (text-only contract). When ``use_turboquant=True``, full-attention layers swap to ``MLXTurboQuantKVCache`` for ~3.8× KV cache memory savings. Sliding @@ -320,6 +429,9 @@ def _export_mlx( from executorch.exir.passes import MemoryPlanningPass from torch.export import Dim, export + # Text-only contract: drop the vision head before transforming / exporting. + _drop_vision_head(model) + mlx_source_transformations( model, dtype=torch.bfloat16, use_turboquant=use_turboquant ) diff --git a/examples/models/gemma4_31b/inference.py b/examples/models/gemma4_31b/inference.py index 92654fca5f2..d91c93ad425 100644 --- a/examples/models/gemma4_31b/inference.py +++ b/examples/models/gemma4_31b/inference.py @@ -6,9 +6,13 @@ """Eager inference on Gemma 4 31B-IT (CUDA + torch.compile). -Three input paths: +Input paths: --prequantized Load a quantized checkpoint (from quantize_and_save.py). + This is the path that supports the image (vision) flow. --gguf Load a GGUF file (e.g., Q4_K_M from the community). + Text-only (community GGUFs carry no vision tensors); + in-band GGUF vision is added in the g4-vision-llama-cpp + branch. --bf16 Load the bf16 HF safetensors checkpoint via from_hf_checkpoint. Gemma 4 31B-IT is instruction-tuned and requires chat-template formatting. @@ -16,6 +20,15 @@ (``<|turn>user\\n{prompt}\\n<|turn>model\\n<|channel>thought\\n``; BOS is prepended separately). Pass ``--raw-prompt`` to skip template wrapping (e.g., for pre-formatted input). +When ``--image-path`` is supplied the runner mirrors the C++ runner in +``main.cpp``: it patchifies the image, runs the vision tower (built as a +``Gemma4_31BVisionTower`` wrapper around ``model.vision_tower`` / +``model.embed_vision``), runs ``embed_text`` on the chat-template token +sequence, splices the vision embeddings into the rows where the +```` placeholder lives, and then prefills on the spliced embeds via +``model.forward(inputs_embeds, input_pos, temperature)``. Decode then +proceeds one token at a time through ``model.decode_forward``. + Usage: python inference.py \\ --prequantized ./gemma4_31b_int4 \\ @@ -23,6 +36,11 @@ --max-new-tokens 128 \\ --temperature 0.8 + python inference.py \\ + --prequantized ./gemma4_31b_int4 \\ + --image-path ./some_image.png \\ + --prompt "Describe this image." + python inference.py \\ --gguf ./gemma-4-31B-it-Q4_K_M.gguf \\ --tokenizer-path ./tokenizer.json \\ @@ -35,6 +53,11 @@ import torch +from executorch.examples.models.gemma4.chat_template import ( + BOS_ID, + build_vision_input_ids, + IMAGE_TOKEN_ID, +) from executorch.examples.models.gemma4_31b.export import load_prequantized_model from executorch.examples.models.gemma4_31b.model import ( Gemma4_31B, @@ -52,6 +75,14 @@ def _move_to_cuda(model, config) -> None: ``materialize_runtime_buffers``. """ for name, p in model.named_parameters(): + if p.device.type == "meta": + # All checkpoints (prequant / GGUF / bf16) produce a fully + # materialized text + vision model, so this branch is unreachable + # in normal use. Surface it loudly if it ever trips. + raise AssertionError( + f"_move_to_cuda: parameter {name!r} is still on meta — " + "checkpoint loader did not populate it." + ) parts = name.rsplit(".", 1) parent = model.get_submodule(parts[0]) if len(parts) > 1 else model setattr( @@ -110,13 +141,22 @@ def generate( temp_val = max(temperature, 1e-6) # avoid div-by-zero in the on-device sampler temp_tensor = torch.tensor([temp_val], dtype=torch.float32, device="cuda") + # The 4-method export contract changed `model.forward` to take pre-computed + # embeddings (used by the unified prefill). For the per-token text-only + # eager loop we use `decode_forward(tokens, pos, temperature)` instead, + # which takes token inputs and internally runs `embed_text` -> `_run_blocks`. + # `decode_forward` is NOT one of the methods that `torch.compile` wraps + # (only `forward` is); we go through `_orig_mod` so attribute access lands + # on the original module rather than the compile wrapper's proxy path. + underlying = getattr(model, "_orig_mod", model) + sampled = None with torch.no_grad(): # Prefill, one token at a time. for i, tok_id in enumerate(input_ids): tok = torch.tensor([[tok_id]], dtype=torch.long, device="cuda") pos = torch.tensor([i], dtype=torch.long, device="cuda") - sampled = model(tok, pos, temp_tensor) + sampled = underlying.decode_forward(tok, pos, temp_tensor) # First generated token from the last prefill step. next_id = int(sampled.item()) @@ -127,7 +167,7 @@ def generate( for i in range(max_new_tokens - 1): tok = torch.tensor([[next_id]], dtype=torch.long, device="cuda") pos = torch.tensor([seq_len + i], dtype=torch.long, device="cuda") - sampled = model(tok, pos, temp_tensor) + sampled = underlying.decode_forward(tok, pos, temp_tensor) next_id = int(sampled.item()) generated.append(next_id) if next_id in eos_token_ids: @@ -136,7 +176,149 @@ def generate( return tokenizer.decode(generated) -def main() -> None: +# --------------------------------------------------------------------------- +# Vision helpers +# --------------------------------------------------------------------------- + + +def _build_vision_encoder(model, config): + """Build a ``Gemma4_31BVisionTower`` wrapper that reuses the model's already- + loaded ``vision_tower`` and ``embed_vision`` submodules. + + Mirrors ``export.py::_build_vision_encoder_wrapper``: construct the wrapper + on the meta device (so its freshly-built children take no real allocation), + then swap in the loaded modules so parameter identity is preserved. + """ + from executorch.examples.models.gemma4_31b.vision_tower import Gemma4_31BVisionTower + + # When ``model`` has been wrapped by ``torch.compile`` we still want the + # raw underlying modules — torch.compile proxies attribute access to + # ``_orig_mod``, so ``model.vision_tower`` already gives us the originals. + underlying = getattr(model, "_orig_mod", model) + + with torch.device("meta"): + wrapper = Gemma4_31BVisionTower(config.vision_config, config.hidden_size) + wrapper.vision_tower = underlying.vision_tower + wrapper.embed_vision = underlying.embed_vision + wrapper.eval() + return wrapper + + +def generate_with_image( + model, + vision_encoder, + tokenizer, + prompt: str, + image_path: str, + max_vision_soft_tokens: int = 280, + max_new_tokens: int = 128, + temperature: float = 0.0, + eos_token_ids=None, + bos_token_id: int = BOS_ID, +) -> str: + """Image+text generation. Mirrors the C++ runner flow in main.cpp: + + 1. Patchify image -> (pixel_values, pixel_position_ids). + 2. vision_encoder(pixels, position_ids) -> (image_embeds, pooler_mask). + 3. Build chat-template input_ids with ``num_soft_tokens`` image + placeholders. + 4. embed_text(input_ids) -> text_embeds. + 5. Splice image_embeds into text_embeds at ```` rows. + 6. Single-shot prefill via model.forward(spliced, input_pos, temp). + 7. Decode loop via model.decode_forward(token, input_pos, temp). + """ + from executorch.examples.models.gemma4.image_utils import preprocess_image + + if eos_token_ids is None: + eos_token_ids = set() + + # 1. Patchify. + pixel_values, pixel_position_ids, num_soft_tokens = preprocess_image( + image_path, max_soft_tokens=max_vision_soft_tokens + ) + pixel_values = pixel_values.to("cuda") + pixel_position_ids = pixel_position_ids.to("cuda") + print( + f"Image: patchified to {pixel_values.shape[1]} patches; " + f"{num_soft_tokens} soft tokens (max={max_vision_soft_tokens})." + ) + + underlying = getattr(model, "_orig_mod", model) + + temp_val = max(temperature, 1e-6) + temp_tensor = torch.tensor([temp_val], dtype=torch.float32, device="cuda") + + with torch.no_grad(): + # 2. Vision tower. + image_embeds, pooler_mask = vision_encoder(pixel_values, pixel_position_ids) + # image_embeds: [1, output_length, hidden_size] bf16 + # pooler_mask: [1, output_length] bool, True = valid soft token + + # 3. Token sequence. + input_ids = build_vision_input_ids( + tokenizer, prompt, num_soft_tokens, bos_id=bos_token_id + ) + T = len(input_ids) + tokens = torch.tensor([input_ids], dtype=torch.long, device="cuda") + print(f"Prompt tokens (image+text): {T}") + + # 4. embed_text. + text_embeds = underlying.embed_text(tokens) # [1, T, hidden] bf16 + + # 5. Splice image rows into text_embeds at IMAGE_TOKEN_ID positions, + # skipping any image-embed rows whose pooler_mask is False (padded + # soft tokens). + inputs_embeds = text_embeds.clone() + valid_mask_row = pooler_mask[0] # [output_length] + n_image_rows = int(image_embeds.shape[1]) + image_idx = 0 + spliced = 0 + for i, tok_id in enumerate(input_ids): + if tok_id != IMAGE_TOKEN_ID: + continue + # Advance to next valid soft-token row. + while image_idx < n_image_rows and not bool(valid_mask_row[image_idx]): + image_idx += 1 + if image_idx >= n_image_rows: + raise RuntimeError( + f"Ran out of valid vision soft tokens at text position {i} " + f"(used {spliced}, needed {num_soft_tokens})." + ) + inputs_embeds[0, i] = image_embeds[0, image_idx] + image_idx += 1 + spliced += 1 + if spliced != num_soft_tokens: + raise RuntimeError( + f"Spliced {spliced} image rows but expected {num_soft_tokens}." + ) + + # 6. Single-shot prefill on spliced embeddings. We bypass the + # torch.compile wrapper here: calling ``forward`` on the underlying + # module executes uncompiled, but for prefill (one call) the cost + # is negligible and avoids re-compiling for the variable T. + input_pos = torch.arange(T, dtype=torch.long, device="cuda") + sampled = underlying.forward(inputs_embeds, input_pos, temp_tensor) + next_id = int(sampled.item()) + generated = [next_id] + if next_id in eos_token_ids: + return tokenizer.decode(generated) + + # 7. Decode loop. We use ``decode_forward`` (token-input single-step) + # which the runner also uses for decode steps. Like prefill, this + # bypasses the compile wrapper — fine for label generation. + for i in range(max_new_tokens - 1): + tok = torch.tensor([[next_id]], dtype=torch.long, device="cuda") + pos = torch.tensor([T + i], dtype=torch.long, device="cuda") + sampled = underlying.decode_forward(tok, pos, temp_tensor) + next_id = int(sampled.item()) + generated.append(next_id) + if next_id in eos_token_ids: + break + + return tokenizer.decode(generated) + + +def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Eager inference on Gemma 4 31B-IT.") src = parser.add_mutually_exclusive_group(required=True) src.add_argument( @@ -194,58 +376,101 @@ def main() -> None: choices=["cuda"], help="Target backend.", ) - args = parser.parse_args() + parser.add_argument( + "--image-path", + default="", + help=( + "Optional: path to an image file (JPEG/PNG). When set, the runner " + "uses the multimodal flow: vision_tower + embed_text + spliced " + "prefill (mirrors examples/models/gemma4_31b/main.cpp)." + ), + ) + parser.add_argument( + "--max-vision-soft-tokens", + type=int, + default=280, + help=( + "Maximum number of vision soft tokens emitted by the vision " + "tower. Must be one of {70,140,280,560,1120}. Default 280 matches " + "the Gemma 4 vision tower default." + ), + ) + return parser - if args.backend == "cuda" and not torch.cuda.is_available(): - parser.error("CUDA is required for the cuda backend.") - # ---- Tokenizer ---- +def _resolve_tokenizer_path(args, parser: argparse.ArgumentParser) -> str: if args.tokenizer_path: - tokenizer_path = args.tokenizer_path - elif args.prequantized: - tokenizer_path = os.path.join(args.prequantized, "tokenizer.json") - elif args.bf16: - tokenizer_path = os.path.join(args.bf16, "tokenizer.json") - else: - parser.error("--tokenizer-path is required with --gguf.") - from tokenizers import Tokenizer - - tokenizer = Tokenizer.from_file(tokenizer_path) - - prompt_str = args.prompt if args.raw_prompt else apply_chat_template(args.prompt) + return args.tokenizer_path + if args.prequantized: + return os.path.join(args.prequantized, "tokenizer.json") + if args.bf16: + return os.path.join(args.bf16, "tokenizer.json") + parser.error("--tokenizer-path is required with --gguf.") - # Gemma 4 EOS tokens (from generation_config.json: ids 1, 50, 106). - eos_token_ids = {1, 50, 106} +def _load_model_from_args(args, parser: argparse.ArgumentParser): if args.gguf: from executorch.examples.models.gemma4_31b.gguf_loader import load_gguf_model - model, config = load_gguf_model( - args.gguf, args.max_seq_len, backend=args.backend - ) - elif args.bf16: - model, config = Gemma4_31B.from_hf_checkpoint( - args.bf16, max_seq_len=args.max_seq_len - ) - else: - print(f"Loading prequantized model from {args.prequantized}...") - model, config = load_prequantized_model( - args.prequantized, max_seq_len=args.max_seq_len, backend=args.backend + return load_gguf_model( + args.gguf, + max_seq_len=args.max_seq_len, + backend=args.backend, ) + if args.bf16: + return Gemma4_31B.from_hf_checkpoint(args.bf16, max_seq_len=args.max_seq_len) + + print(f"Loading prequantized model from {args.prequantized}...") + return load_prequantized_model( + args.prequantized, max_seq_len=args.max_seq_len, backend=args.backend + ) + + +def _prepare_model_for_inference(model, config, args, parser: argparse.ArgumentParser): _move_to_cuda(model, config) model.eval() import executorch.backends.cuda.int4_dispatch # noqa: F401 + # Build the vision encoder BEFORE wrapping the model with torch.compile — + # the wrapper steals references to model.vision_tower / model.embed_vision, + # and we want those references to stay valid no matter what we do with + # ``model`` afterwards. (Building it after compile also works because + # torch.compile proxies attribute access to _orig_mod, but doing it first + # is clearer.) + vision_encoder = None + if args.image_path: + if config.vision_config is None: + parser.error( + "Loaded model has no vision_config; cannot run with --image-path." + ) + vision_encoder = _build_vision_encoder(model, config) + if not args.no_compile: print("Compiling model with torch.compile...") model = torch.compile(model, mode="default") - print(f"\nPrompt: {args.prompt}") - print("-" * 40) + return model, vision_encoder - t0 = time.perf_counter() - output = generate( + +def _run_generation(model, vision_encoder, tokenizer, args, prompt_str: str) -> str: + # Gemma 4 EOS tokens (from generation_config.json: ids 1, 50, 106). + eos_token_ids = {1, 50, 106} + + if args.image_path: + return generate_with_image( + model, + vision_encoder, + tokenizer, + args.prompt, + args.image_path, + max_vision_soft_tokens=args.max_vision_soft_tokens, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + eos_token_ids=eos_token_ids, + ) + + return generate( model, tokenizer, prompt_str, @@ -253,6 +478,30 @@ def main() -> None: temperature=args.temperature, eos_token_ids=eos_token_ids, ) + + +def main() -> None: + parser = _build_parser() + args = parser.parse_args() + + if args.backend == "cuda" and not torch.cuda.is_available(): + parser.error("CUDA is required for the cuda backend.") + + from tokenizers import Tokenizer + + tokenizer = Tokenizer.from_file(_resolve_tokenizer_path(args, parser)) + prompt_str = args.prompt if args.raw_prompt else apply_chat_template(args.prompt) + + model, config = _load_model_from_args(args, parser) + model, vision_encoder = _prepare_model_for_inference(model, config, args, parser) + + print(f"\nPrompt: {args.prompt}") + if args.image_path: + print(f"Image: {args.image_path}") + print("-" * 40) + + t0 = time.perf_counter() + output = _run_generation(model, vision_encoder, tokenizer, args, prompt_str) elapsed = time.perf_counter() - t0 print(output) diff --git a/examples/models/gemma4_31b/model.py b/examples/models/gemma4_31b/model.py index 657c79e0c4c..ddcaffa8753 100644 --- a/examples/models/gemma4_31b/model.py +++ b/examples/models/gemma4_31b/model.py @@ -51,6 +51,13 @@ Gemma4MLP, ) from executorch.examples.models.gemma4_31b.sampler import sample +from executorch.examples.models.gemma4_31b.vision_tower import ( + Gemma4MultimodalEmbedder, + Gemma4VisionConfig, + Gemma4VisionTower, + hf_vision_key_map, + hf_vision_per_layer_key_map, +) from torch.nn import functional as F @@ -141,6 +148,12 @@ class Gemma4_31BConfig: # Hybrid attention pattern layer_types: list = field(default_factory=list) + # Vision tower config. Gemma 4 31B is multimodal-by-default — the + # vision_tower / embed_vision submodules are ALWAYS instantiated. Must be + # supplied at construction (parsed from the HF config.json's + # "vision_config" block by ``from_hf_config``). + vision_config: Gemma4VisionConfig = None + # Runtime max_seq_len: int = 4096 @@ -168,6 +181,17 @@ def from_hf_config(config_path: str) -> "Gemma4_31BConfig": sliding_rope = rope_params.get("sliding_attention", {}) full_rope = rope_params.get("full_attention", {}) + # Parse vision_config from the original (non-text) section of the file. + # Gemma 4 31B is multimodal-by-default — missing vision_config in the + # checkpoint is a corrupt-checkpoint condition. + vision_config = Gemma4VisionConfig.from_hf_config(config_path) + if not isinstance(vision_config, Gemma4VisionConfig): + raise ValueError( + f"{config_path} has no 'vision_config' block. Gemma 4 31B is " + "multimodal-by-default; a checkpoint without vision_config is " + "considered corrupt." + ) + return Gemma4_31BConfig( vocab_size=cfg.get("vocab_size", 262144), hidden_size=cfg.get("hidden_size", 5376), @@ -188,6 +212,7 @@ def from_hf_config(config_path: str) -> "Gemma4_31BConfig": tie_word_embeddings=cfg.get("tie_word_embeddings", True), sliding_window=cfg.get("sliding_window", 1024), layer_types=cfg.get("layer_types", []), + vision_config=vision_config, ) @@ -420,6 +445,16 @@ def __init__(self, config: Gemma4_31BConfig): # Held separately so it can be untied + quantized at export time. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Vision tower / multimodal embedder. Attached as submodules of the LM + # (so they load with the same checkpoint). The text-only forward never + # touches them; they are exported separately as their own + # ExportedProgram by export.py. Vision is mandatory — Gemma 4 31B is + # multimodal-by-default and config.vision_config must be set. + self.vision_tower = Gemma4VisionTower(config.vision_config) + self.embed_vision = Gemma4MultimodalEmbedder( + config.vision_config, config.hidden_size + ) + # Constants (registered as buffers so they move with .to(device)). self.register_buffer( "embed_normalizer", @@ -466,24 +501,19 @@ def _build_masks( return sliding_mask, full_mask - def forward( + def _run_blocks( self, - tokens: torch.LongTensor, + x: torch.Tensor, input_pos: torch.LongTensor, - temperature: torch.Tensor, + temperature: Optional[torch.Tensor], ) -> torch.Tensor: - """Run the model. - - Args: - tokens: (B, T) token IDs. - input_pos: (T,) absolute positions for RoPE / KV cache. - temperature: 1-D float tensor for Gumbel-max sampling. + """Shared inner stack: 60 decoder layers + final RMSNorm + lm_head + softcap. - Returns: - (B, 1) sampled token IDs as float. + ``x`` is the layer-0 input (already embedding-scaled). When + ``temperature is None`` (eager), returns full (B, T, V) softcapped + logits; otherwise returns the (B, 1) sampled token from the last query + position. """ - x = self.embed_tokens(tokens) * self.embed_normalizer - sliding_mask, full_mask = self._build_masks(input_pos) for layer in self.layers: x = layer(x, input_pos, sliding_mask, full_mask) @@ -494,23 +524,107 @@ def forward( last = torch.tanh(last / cap) * cap return sample(last, temperature) + def forward( + self, + inputs_embeds: torch.Tensor, + input_pos: torch.LongTensor, + temperature: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Run the model on PRE-COMPUTED EMBEDDINGS (unified prefill / decode-on-embeds). + + This is the canonical forward for the new 4-method export contract: + + prefill = forward(inputs_embeds [1,T,5376] bf16, input_pos [T] i64, + temperature [1] f32) -> sampled [1,1] f32 + + Used for BOTH text-only and image+text. The runner is responsible for + producing ``inputs_embeds``: call ``embed_text(tokens)`` for the + text-token rows and (for image input) overwrite the rows at + ``image_token_id`` placeholders with the corresponding rows of + ``Gemma4_31BVisionTower(pixel_values, pixel_position_ids)[0]`` — + which are NOT pre-scaled by ``sqrt(hidden_size)`` (matches HF). + + Args: + inputs_embeds: (B, T, hidden_size) tensor — already embedding-scaled. + input_pos: (T,) absolute positions for RoPE / KV cache. + temperature: optional 1-D float tensor controlling on-device sampling. + When None (eager), returns full (B, T, V) softcapped logits; + when set, returns the (B, 1) sampled token from the last + query position via Gumbel-max. + + Returns: + (B, 1) token IDs when sampling, else (B, T, V) float32 logits. + """ + return self._run_blocks(inputs_embeds, input_pos, temperature) + + # ---------------- multimodal entry points ---------------- + + def embed_text(self, tokens: torch.LongTensor) -> torch.Tensor: + """Pure text-embedding lookup + ``sqrt(hidden_size)`` scale. + + Returns ``embed_tokens(tokens) * sqrt(hidden_size)`` cast to bfloat16. + Used by the runner to build the ``inputs_embeds`` tensor passed into + ``forward`` (the unified prefill): compute this for every text token, + then OVERWRITE the rows at ``image_token_id`` placeholders with + vision-tower output rows. Exported as its own ExecuTorch method. + + Returns: + (B, T, hidden_size) bf16. + """ + return (self.embed_tokens(tokens) * self.embed_normalizer).to(torch.bfloat16) + + def decode_forward( + self, + tokens: torch.LongTensor, + input_pos: torch.LongTensor, + temperature: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Token-input single-step decode (the exported `decode` method). + + Equivalent to ``forward(embed_text(tokens), input_pos, temperature)`` + but in a single fused method to keep the decode export self-contained + (no need to chain ``embed_text`` -> ``forward`` from the C++ runner on + the per-token decode hot path). Used for BOTH text-only and image+text + decoding — by the time decode runs, all positions are sampled tokens. + + Args: + tokens: (B, T=1) token IDs. + input_pos: (T=1,) absolute position. + temperature: required (decode is the sampling path). + + Returns: + (B, 1) sampled token IDs. + """ + x = self.embed_tokens(tokens) * self.embed_normalizer + return self._run_blocks(x, input_pos, temperature) + # ---------------- checkpoint loading ---------------- @staticmethod def from_hf_checkpoint( - model_dir: str, max_seq_len: int = 4096 + model_dir: str, + max_seq_len: int = 4096, ) -> tuple["Gemma4_31B", Gemma4_31BConfig]: """Build the model on `meta` and load weights from the HF safetensors checkpoint. Uses lazy shard-by-shard loading + assign=True so peak memory stays at roughly one shard's worth of weights. + + Vision is always loaded — Gemma 4 31B is multimodal-by-default. The + checkpoint's config.json MUST contain a ``vision_config`` block and the + safetensors shards MUST contain vision keys, or this raises. + + Args: + model_dir: directory containing config.json + safetensors shards. + max_seq_len: max sequence length for KV cache sizing. """ config = Gemma4_31BConfig.from_hf_config(os.path.join(model_dir, "config.json")) config.max_seq_len = max_seq_len print( f"Building Gemma4_31B on meta (layers={config.num_hidden_layers}, " - f"hidden={config.hidden_size}, max_seq_len={max_seq_len})..." + f"hidden={config.hidden_size}, max_seq_len={max_seq_len}, " + f"vision=on)..." ) with torch.device("meta"): model = Gemma4_31B(config) @@ -522,6 +636,15 @@ def from_hf_checkpoint( if "lm_head.weight" not in state_dict and "embed_tokens.weight" in state_dict: state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"] + # Vision keys are mandatory — the model declares vision_tower / + # embed_vision submodules and the checkpoint must populate them. + if not any(k.startswith("vision_tower.") for k in state_dict): + raise ValueError( + f"Checkpoint at {model_dir} is missing vision_tower.* keys. " + "Gemma 4 31B is multimodal-by-default; a text-only checkpoint " + "is not supported." + ) + missing, unexpected = model.load_state_dict( state_dict, strict=False, assign=True ) @@ -577,20 +700,39 @@ def from_hf_checkpoint( "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.down_proj.weight", } -# Multimodal keys we deliberately ignore for the text-only export. -_IGNORED_PREFIXES = ( +# Multimodal HF prefixes the vision-aware key map handles. The model always +# has vision_tower / embed_vision submodules (multimodal-by-default), so these +# keys are always remapped and always consumed by load_state_dict. +_VISION_PREFIXES = ( "model.vision_tower.", "model.embed_vision.", ) def _hf_to_model_key(hf_key: str) -> Optional[str]: + """Map an HF state-dict key to our model's flat key, or None if it should be skipped. + + Vision keys (``model.vision_tower.*``, ``model.embed_vision.*``) are + always remapped through the vision key map and always consumed by the + receiving model (multimodal-by-default). + """ # Gemma4ForConditionalGeneration stores the LM under model.language_model.* norm = hf_key if norm.startswith("model.language_model."): norm = norm.replace("model.language_model.", "model.", 1) - if norm.startswith(_IGNORED_PREFIXES): + if norm.startswith(_VISION_PREFIXES): + # Vision-aware remap (fixed and per-layer patterns). + fixed = hf_vision_key_map() + if norm in fixed: + return fixed[norm] + for hf_pat, model_pat in hf_vision_per_layer_key_map().items(): + regex = re.escape(hf_pat).replace(r"\{\}", r"(\d+)") + m = re.fullmatch(regex, norm) + if m: + return model_pat.replace("{}", m.group(1), 1) + # An unknown vision key — silently skip (keeps loader robust to + # checkpoint additions like audio). return None for hf_pat, model_pat in _HF_KEY_MAP.items(): @@ -606,7 +748,11 @@ def _hf_to_model_key(hf_key: str) -> Optional[str]: def _load_and_remap_checkpoint(model_dir: str, config: Gemma4_31BConfig) -> dict: - """Stream-load safetensors shards and remap keys to model state_dict keys.""" + """Stream-load safetensors shards and remap keys to model state_dict keys. + + Vision keys are always remapped — the model always has the matching + submodules, so every vision tensor in the checkpoint is consumed. + """ from safetensors import safe_open index_path = os.path.join(model_dir, "model.safetensors.index.json") @@ -633,7 +779,7 @@ def _load_and_remap_checkpoint(model_dir: str, config: Gemma4_31BConfig) -> dict # layer_scalar in checkpoint is shape (1,) bf16 — keep as-is. state_dict[model_key] = tensor if skipped > 0: - print(f" Skipped {skipped} non-text keys (vision tower, etc.)") + print(f" Skipped {skipped} unknown / vision-extra keys") return state_dict @@ -675,6 +821,25 @@ def materialize_runtime_buffers( "inv_freq", attn._compute_inv_freq(device=device), persistent=False ) + # Vision tower RoPE: recompute inv_freq for the vision encoder. This + # buffer is non-persistent (not in the HF checkpoint) and is built on the + # meta device during model construction, so the meta-buffer zeroing loop + # above leaves it at all-zeros. We must recompute it with real values. + vision_rotary = getattr(model, "vision_tower", None) + if vision_rotary is not None: + rotary_emb = model.vision_tower.encoder.rotary_emb + head_dim = rotary_emb.head_dim + rope_theta = rotary_emb.rope_theta + spatial_dim = head_dim // 2 + vision_inv_freq = 1.0 / ( + rope_theta + ** ( + torch.arange(0, spatial_dim, 2, device=device, dtype=torch.float32) + / spatial_dim + ) + ) + rotary_emb.register_buffer("inv_freq", vision_inv_freq, persistent=False) + model.register_buffer( "embed_normalizer", torch.tensor(config.hidden_size**0.5, device=device), diff --git a/examples/models/gemma4_31b/pack_vision.py b/examples/models/gemma4_31b/pack_vision.py new file mode 100644 index 00000000000..1690bc24370 --- /dev/null +++ b/examples/models/gemma4_31b/pack_vision.py @@ -0,0 +1,318 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Vision-tower quantization + packing helpers. + +This module is functionally ported from +``examples/models/gemma4/export_gemma4.py::_quantize_position_embedding_table`` +(the E2B/E4B vision PE-int8 packer); the math is identical, only the +hidden-size constants differ. See the per-function ``Ported from`` notes. + +Public API: + + * ``quantize_vision_position_table(vision_tower)`` -- in-place swap of + the patch embedder's bf16 ``position_embedding_table`` Parameter with + two buffers ``_pet_int8`` (per-channel int8) and ``_pet_scale`` + (fp32). The Gemma 4 31B vision PE table is (2, 10240, 1152) ≈ 47 MB + bf16 → ~12 MB int8 + scale. Cosine sim vs bf16 reference > 0.999999 + in upstream Gemma 4 (E2B/E4B) experiments. Mirrors + ``examples/models/gemma4/export_gemma4.py::_quantize_position_embedding_table`` + but operates on our own ported PatchEmbedder (which has hidden_size + 1152 instead of E2B's 768 and exposes ``_position_embeddings`` as the + same instance method we monkey-patch here). + + * ``pack_vision_patch_embedder(patch_embedder, weights)`` — Gemma4-specific + module packer that handles ``_pet_int8`` / ``_pet_scale`` plain tensors. + This lets the generic ``load_and_pack_for_*`` APIs stream safetensors as + usual while using the existing ``packers`` input for model-specific state + adaptation. + + * ``install_int8_pe_dispatch(vision_tower)`` — same monkey-patch / + buffer-shape installation but without quantizing existing data. + + * ``collect_vision_state_dict(vision_tower, embed_vision)`` — return a + flat dict of all vision-side tensors (linears, norms, multimodal + projector, plus the int8 PE buffers). All linears + norms are bf16; + the PE table is in its int8/scale form. + + * ``has_vision_keys(safetensors_path)`` — peek at a saved checkpoint + to detect whether it carries vision tensors. Used by + load_prequantized_model so the new load path is purely additive. + +These keys ride alongside the quantized LM in the same safetensors file +because torchao's ``flatten_tensor_state_dict`` accepts a mixed dict of +quantized subclass tensors + plain tensors and lists every name in +``metadata['tensor_names']``. The existing +the generic loaders iterate that list and route plain tensors through +``pack_one``; the Gemma4 patch-embedder packer handles ``_pet_int8`` / +``_pet_scale`` before the default register-buffer fallback. +""" + +from __future__ import annotations + +import types + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# Anything in the model whose flat key starts with one of these is a +# vision-side tensor. Used by quantize_and_save.py and the loader to +# branch additively without touching the text-decoder code path. +VISION_PREFIXES: tuple[str, ...] = ("vision_tower.", "embed_vision.") + + +# --------------------------------------------------------------------------- +# Position-embedding-table quantization (the only "real" quantization on the +# vision side — every other vision weight stays bf16). +# --------------------------------------------------------------------------- + + +def _patch_position_embeddings_int8(patch_embedder: nn.Module) -> None: + """Monkey-patch ``_position_embeddings`` to dequantize + index the int8 PE table. + + Ported from + ``examples/models/gemma4/export_gemma4.py::_position_embeddings_int8`` + (Gemma 4 E2B/E4B int8 PE table lookup). + + Uses the same one-hot-matmul shape that HF's + ``Gemma4VisionPatchEmbedder._position_embeddings`` produces, except we + dequantize ``_pet_int8 * _pet_scale`` first. We also stick to + ``F.embedding`` for the per-axis lookup to keep the graph tiny — that's + also what the text-decoder vision_tower port uses (so the numerics are + bit-for-bit those of the bf16 reference, modulo the int8 round-trip + on the table itself). + """ + + def _position_embeddings( + self, + pixel_position_ids: torch.Tensor, # [B, P, 2] + padding_positions: torch.Tensor, # [B, P] (True = padding) + ) -> torch.Tensor: + # Dequantize lazily so the bf16 graph stays bf16. (2, 10240, 1152) + table = self._pet_int8.to(self._pet_scale.dtype) * self._pet_scale + clamped = pixel_position_ids.clamp(min=0).long() + emb_x = F.embedding(clamped[..., 0], table[0]) + emb_y = F.embedding(clamped[..., 1], table[1]) + pos_emb = (emb_x + emb_y).to(self.input_proj.weight.dtype) + zero = torch.zeros_like(pos_emb) + return torch.where(padding_positions.unsqueeze(-1), zero, pos_emb) + + patch_embedder._position_embeddings = types.MethodType( + _position_embeddings, patch_embedder + ) + + +def quantize_vision_position_table( + vision_tower: nn.Module, + *, + verbose: bool = False, +) -> None: + """Replace ``vision_tower.patch_embedder.position_embedding_table`` with + int8 per-channel data + fp32 scale buffers, and patch the lookup method. + + Ported from + ``examples/models/gemma4/export_gemma4.py::_quantize_position_embedding_table`` + (the E2B/E4B vision PE-int8 packer). Same per-channel quant math; only the + hidden-size constant differs (1152 here vs 768 in the E2B port). + + Idempotent: a second call is a no-op. + """ + pe = vision_tower.patch_embedder + pet = getattr(pe, "position_embedding_table", None) + if pet is None: + return # already quantized + + if pet.device.type == "meta": + raise RuntimeError( + "quantize_vision_position_table requires a real (non-meta) " + "position_embedding_table tensor. Load the HF weights first." + ) + + pet_fp = pet.data.to(torch.float32) + # Per-channel along the last (hidden) dim — same axis as the gemma4 + # E2B/E4B reference. + scale = pet_fp.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) / 127.0 + qdata = torch.round(pet_fp / scale).clamp(-128, 127).to(torch.int8) + scale = scale.to(torch.float32) + + # Drop the parameter. After this, named_parameters() no longer yields it. + del pe.position_embedding_table + + pe.register_buffer("_pet_int8", qdata, persistent=True) + pe.register_buffer("_pet_scale", scale, persistent=True) + + _patch_position_embeddings_int8(pe) + + if verbose: + bf16_mb = pet.numel() * 2 / (1024 * 1024) + new_mb = (qdata.numel() + scale.numel() * 4) / (1024 * 1024) + print( + f" vision PE table: bf16 -> int8 per-channel " + f"({bf16_mb:.1f} MB -> {new_mb:.1f} MB)" + ) + + +def install_int8_pe_dispatch( + vision_tower: nn.Module, + *, + verbose: bool = False, +) -> None: + """Build-on-meta companion to ``quantize_vision_position_table``. + + Swaps the freshly-constructed ``position_embedding_table`` Parameter + for zero placeholder buffers ``_pet_int8`` / ``_pet_scale`` and + monkey-patches the lookup method. + """ + pe = vision_tower.patch_embedder + if not hasattr(pe, "position_embedding_table"): + # Already swapped (idempotent). + if not hasattr(pe, "_pet_int8"): + raise RuntimeError( + "install_int8_pe_dispatch: patch_embedder has neither the " + "original position_embedding_table nor _pet_int8/_pet_scale." + ) + return + + # Inspect shape from the existing parameter (works on meta). + pet = pe.position_embedding_table + shape = tuple(pet.shape) # (2, position_embedding_size, hidden_size) + del pe.position_embedding_table + pe.register_buffer( + "_pet_int8", + torch.zeros(shape, dtype=torch.int8, device="meta"), + persistent=True, + ) + pe.register_buffer( + "_pet_scale", + torch.zeros((shape[0], shape[1], 1), dtype=torch.float32, device="meta"), + persistent=True, + ) + _patch_position_embeddings_int8(pe) + + if verbose: + print(" vision PE table: int8 dispatch installed (placeholder buffers)") + + +def pack_vision_patch_embedder( + patch_embedder: nn.Module, + weights: dict[str, torch.Tensor], +) -> bool: + """Install/load Gemma4 vision PE int8 buffers via generic packers.""" + if not any(k in weights for k in ("_pet_int8", "_pet_scale")): + return False + + if hasattr(patch_embedder, "position_embedding_table"): + dummy_tower = types.SimpleNamespace(patch_embedder=patch_embedder) + install_int8_pe_dispatch(dummy_tower) + + for name, value in weights.items(): + if name not in ("_pet_int8", "_pet_scale"): + return False + patch_embedder.register_buffer(name, value, persistent=True) + return True + + +# --------------------------------------------------------------------------- +# Vision state-dict collection (bf16 passthrough + int8 PE buffers). +# --------------------------------------------------------------------------- + + +def collect_vision_state_dict( + vision_tower: nn.Module, + embed_vision: nn.Module, + *, + dtype: torch.dtype = torch.bfloat16, +) -> dict[str, torch.Tensor]: + """Return a flat dict of all vision-side tensors ready to mix into the + saved safetensors. + + Caller is responsible for invoking ``quantize_vision_position_table`` + on ``vision_tower`` first if the int8 PE recipe is desired (this is + what quantize_and_save.py does). + + Output keys: + * ``vision_tower.*`` — every Parameter and persistent buffer of + ``vision_tower`` (linears + norms in ``dtype``; the PE buffers + are kept in their native int8 / fp32 dtypes). + * ``embed_vision.*`` — the multimodal projector linear and norm. + + Norms / linear weights are cast to ``dtype`` (bf16 by default) so the + file-level dtype mix matches the LM's quantized + bf16 plain + tensors. Integer PE buffers and fp32 PE scale are passed through + unchanged. + """ + state: dict[str, torch.Tensor] = {} + + def _maybe_cast(name: str, t: torch.Tensor) -> torch.Tensor: + # PE buffers + scale: keep native dtype. + if name.endswith("._pet_int8") or name.endswith("._pet_scale"): + return t.detach().contiguous() + if t.dtype.is_floating_point: + return t.detach().to(dtype).contiguous() + return t.detach().contiguous() + + # vision_tower parameters + for sub_fqn, param in vision_tower.named_parameters(): + key = f"vision_tower.{sub_fqn}" + state[key] = _maybe_cast(key, param.data) + # vision_tower buffers (only persistent ones — std_bias/std_scale, _pet_*) + persistent = set(vision_tower.state_dict().keys()) + for sub_fqn, buf in vision_tower.named_buffers(): + if sub_fqn not in persistent: + continue + key = f"vision_tower.{sub_fqn}" + if key in state: + continue + state[key] = _maybe_cast(key, buf.data) + + # embed_vision parameters + for sub_fqn, param in embed_vision.named_parameters(): + key = f"embed_vision.{sub_fqn}" + state[key] = _maybe_cast(key, param.data) + # embed_vision buffers (RMSNormNoWeight has none; defensive walk anyway) + persistent_ev = set(embed_vision.state_dict().keys()) + for sub_fqn, buf in embed_vision.named_buffers(): + if sub_fqn not in persistent_ev: + continue + key = f"embed_vision.{sub_fqn}" + if key in state: + continue + state[key] = _maybe_cast(key, buf.data) + + return state + + +# --------------------------------------------------------------------------- +# Load-side helpers +# --------------------------------------------------------------------------- + + +def has_vision_keys(safetensors_path: str) -> bool: + """Return True iff the file contains any ``vision_tower.*`` / + ``embed_vision.*`` key. + + Used by ``load_prequantized_model`` so the existing text-only load path + keeps working byte-for-byte when the checkpoint was saved with + ``--no-vision`` (or by an old quantize_and_save.py). + """ + from safetensors import safe_open + + with safe_open(safetensors_path, framework="pt", device="cpu") as f: + for k in f.keys(): + if k.startswith(VISION_PREFIXES): + return True + return False + + +__all__ = [ + "VISION_PREFIXES", + "quantize_vision_position_table", + "install_int8_pe_dispatch", + "collect_vision_state_dict", + "has_vision_keys", +] diff --git a/examples/models/gemma4_31b/quant/pack.py b/examples/models/gemma4_31b/quant/pack.py index 95abc43546a..23258ab86e4 100644 --- a/examples/models/gemma4_31b/quant/pack.py +++ b/examples/models/gemma4_31b/quant/pack.py @@ -17,9 +17,10 @@ import torch import torch.nn as nn -# Packer signature: receives the module + a dict of its quantized weights -# (keyed by attribute name), modifies module in-place. -ModulePackerFn = Callable[[nn.Module, dict[str, torch.Tensor]], None] +# Packer signature: receives the module + a dict of its weights keyed by +# attribute name, modifies module in-place, and may return True when it handled +# the assignment. Returning None preserves the legacy behavior. +ModulePackerFn = Callable[[nn.Module, dict[str, torch.Tensor]], bool | None] def _is_quantized(value: torch.Tensor) -> bool: @@ -84,7 +85,8 @@ def pack_one( """Pack a single weight into ``model``. Quantized subclass tensors are dispatched to the packer for the parent - module's type. Plain tensors are assigned directly. + module's type. Plain tensors are first offered to a parent-module packer; + if it does not return True, they are assigned directly. """ parts = fqn.rsplit(".", 1) parent_fqn = parts[0] if len(parts) > 1 else "" @@ -100,7 +102,12 @@ def pack_one( ) packer(parent, {attr: value}) else: - if isinstance(getattr(parent, attr, None), nn.Parameter): + existing = getattr(parent, attr, None) + if not isinstance(existing, nn.Parameter): + packer = packers.get(type(parent)) + if packer is not None and packer(parent, {attr: value}) is True: + return + if isinstance(existing, nn.Parameter): setattr(parent, attr, nn.Parameter(value, requires_grad=False)) else: parent.register_buffer(attr, value) diff --git a/examples/models/gemma4_31b/quantize_and_save.py b/examples/models/gemma4_31b/quantize_and_save.py index e654e12f637..5e69b104273 100644 --- a/examples/models/gemma4_31b/quantize_and_save.py +++ b/examples/models/gemma4_31b/quantize_and_save.py @@ -8,7 +8,8 @@ Produces a safetensors file containing torchao tensor subclasses (``Int4Tensor``, ``IntxUnpackedToInt8Tensor``) that can be loaded and -packed for any backend via ``load_and_pack_for_cuda`` or ``pack_model``. +packed for any backend via the generic ``load_and_pack_for_*`` APIs with +Gemma4-specific custom packers. The default recipe runs on CPU. The sensitive recipe requires CUDA for HQQ asymmetric quantization. @@ -27,6 +28,9 @@ import torch.nn as nn from executorch.examples.models.gemma4_31b.model import Gemma4_31B +from executorch.examples.models.gemma4_31b.pack_vision import ( + quantize_vision_position_table, +) from executorch.examples.models.gemma4_31b.quant import ( QuantConfig, quantize_model, @@ -35,14 +39,21 @@ ) # --------------------------------------------------------------------------- -# Production recipes for Gemma 4 31B. +# Production recipes for Gemma 4 31B (vision + text in one rule set). # -# Layer sensitivity: +# Layer sensitivity (text decoder): # - v_proj and down_proj are the most sensitive to quantization error # (first/last quarter of layers especially so). # - q_proj, k_proj, o_proj, gate_proj, up_proj tolerate 4-bit well. # - embed_tokens is an index lookup — INT8 per-axis is nearly lossless. # - Norms and layer_scalar are tiny and must stay unquantized. +# +# Vision modality: +# - Vision tower linears are small + accuracy-sensitive; they stay bf16. +# - The vision multimodal projector (``embed_vision.*``) also stays bf16. +# - The patch_embedder's position_embedding_table is the one "real" quant +# on the vision side: bf16 → INT8 per-channel, applied explicitly before +# the generic quantize_model parameter walk. _INT4 = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") _INT4_HQQ = QuantConfig(bits=4, group_size=32, symmetric=False, method="hqq") @@ -52,21 +63,39 @@ ) _EDGE_LAYERS = set(range(15)) | set(range(45, 60)) +# Shared vision rules: every vision-side weight stays bf16. The PE table is +# absent from the parameter walk after quantize_gemma4_vision_position_table +# replaces it with int8 buffers. + + +def quantize_gemma4_vision_position_table(model: nn.Module) -> None: + vision_tower = getattr(model, "vision_tower", None) + if vision_tower is not None: + quantize_vision_position_table(vision_tower) + + +_VISION_RULES = [ + QuantRule(r"vision_tower\..*", None), + QuantRule(r"embed_vision\..*", None), +] + GEMMA4_31B_DEFAULT_RECIPE = QuantRecipe( rules=[ QuantRule(r"embed_tokens\.weight", _INT8_PER_AXIS), + *_VISION_RULES, QuantRule(r".*norm\.weight", None), QuantRule(r".*\.weight", _INT4), - ] + ], ) GEMMA4_31B_SENSITIVE_RECIPE = QuantRecipe( rules=[ QuantRule(r"embed_tokens\.weight", _INT8_PER_AXIS), + *_VISION_RULES, QuantRule(r".*norm\.weight", None), QuantRule(r".*\.(v_proj|down_proj)\.weight", _INT8, layers=_EDGE_LAYERS), QuantRule(r".*\.weight", _INT4_HQQ), - ] + ], ) _RECIPES = { @@ -114,7 +143,13 @@ def main() -> None: print("Untying embed_tokens / lm_head...") model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + # Single quantization entry point. ``quantize_model`` handles both + # modalities in one pass: + # - text decoder linears -> INT4 / INT8 per the recipe; + # - vision tower + embed_vision linears -> stay bf16 (recipe rule); + # - vision PE table -> INT8 per-channel (explicit pre-quantization call). print(f"Quantizing with recipe '{args.quant_recipe}'...") + quantize_gemma4_vision_position_table(model) state_dict = quantize_model(model, recipe, verbose=True) os.makedirs(args.output, exist_ok=True) diff --git a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py index 505d6f7bdc1..d68c6c051ad 100644 --- a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py @@ -25,16 +25,13 @@ import torch import torch.nn as nn from executorch.examples.models.gemma4_31b.export import ( + _get_packers, export_and_lower, load_prequantized_model, ) from executorch.examples.models.gemma4_31b.inference import _move_to_cuda, generate from executorch.examples.models.gemma4_31b.model import Gemma4_31B -from executorch.examples.models.gemma4_31b.quant import ( - DEFAULT_CUDA_PACKERS, - pack_model, - quantize_model, -) +from executorch.examples.models.gemma4_31b.quant import pack_model, quantize_model from executorch.examples.models.gemma4_31b.tests.test_pipeline import ( build_hf_checkpoint, DEFAULT_RECIPE, @@ -114,16 +111,21 @@ def test_chunked_prefill_matches_sequential(self): for i in range(prompt_len): tok = prompt[:, i : i + 1] pos = torch.tensor([i], dtype=torch.long, device="cuda") - token_seq = model_seq(tok, pos, temp) + # T=1 token-input single-step (model.forward now takes embeds). + token_seq = model_seq.decode_forward(tok, pos, temp) with torch.no_grad(): chunk1 = prompt[:, :buf_size] pos1 = torch.arange(buf_size, dtype=torch.long, device="cuda") - model_chunk(chunk1, pos1, temp) + # Multi-token prefill: embed_text -> embeds-forward (the same fusion + # the exported token-input prefill performs). + model_chunk.forward(model_chunk.embed_text(chunk1), pos1, temp) chunk2 = prompt[:, buf_size:] pos2 = torch.arange(buf_size, prompt_len, dtype=torch.long, device="cuda") - token_chunk = model_chunk(chunk2, pos2, temp) + token_chunk = model_chunk.forward( + model_chunk.embed_text(chunk2), pos2, temp + ) self.assertEqual( int(token_seq.item()), @@ -156,11 +158,16 @@ def test_export_from_hf_checkpoint(self): ckpt_dir, max_seq_len=TINY_CONFIG.max_seq_len ) model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + from executorch.examples.models.gemma4_31b.pack_vision import ( + quantize_vision_position_table, + ) + + quantize_vision_position_table(model.vision_tower) state_dict = quantize_model(model, DEFAULT_RECIPE) with torch.device("meta"): model = Gemma4_31B(config) - pack_model(model, state_dict, DEFAULT_CUDA_PACKERS) + pack_model(model, state_dict, _get_packers("cuda")) model.eval() export_and_lower(model, config, out_dir) @@ -185,7 +192,9 @@ def _forward(self): tok = torch.tensor([[1]], dtype=torch.long, device="cuda") pos = torch.tensor([0], dtype=torch.long, device="cuda") temp = torch.tensor([1.0], dtype=torch.float32, device="cuda") - return self.model(tok, pos, temp) + # model.forward takes inputs_embeds now; use the token-input + # single-step decode entry point. + return self.model.decode_forward(tok, pos, temp) def test_int4_weights_preserved(self): """Packing passes Int4Tensor through without conversion.""" @@ -216,7 +225,7 @@ def test_deterministic(self): pos = torch.tensor([0], dtype=torch.long, device="cuda") temp = torch.tensor([1.0], dtype=torch.float32, device="cuda") torch.manual_seed(99) - out2 = model2(tok, pos, temp) + out2 = model2.decode_forward(tok, pos, temp) self.assertEqual(int(out1.item()), int(out2.item())) def test_embedding_works(self): diff --git a/examples/models/gemma4_31b/tests/test_pipeline.py b/examples/models/gemma4_31b/tests/test_pipeline.py index a8d9d9cbe34..35ccd7c61cd 100644 --- a/examples/models/gemma4_31b/tests/test_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_pipeline.py @@ -16,23 +16,27 @@ import json import os +import re import tempfile import unittest import torch import torch.nn as nn - from executorch.examples.models.gemma4_31b.model import ( Gemma4_31B, Gemma4_31BConfig, RingKVCache, ) +from executorch.examples.models.gemma4_31b.pack_vision import ( + quantize_vision_position_table, +) from executorch.examples.models.gemma4_31b.quant import ( QuantConfig, quantize_model, QuantRecipe, QuantRule, ) +from executorch.examples.models.gemma4_31b.vision_tower import Gemma4VisionConfig from safetensors import safe_open from safetensors.torch import save_file from torchao.prototype.safetensors.safetensors_support import ( @@ -63,6 +67,18 @@ final_logit_softcapping=30.0, tie_word_embeddings=True, sliding_window=16, + vision_config=Gemma4VisionConfig( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + head_dim=16, + patch_size=4, + pooling_kernel_size=2, + position_embedding_size=16, + standardize=True, + ), max_seq_len=64, ) @@ -77,9 +93,13 @@ DEFAULT_RECIPE = QuantRecipe( rules=[ QuantRule(r"embed_tokens\.weight", QUANT_8W_PER_AXIS), + # Vision side stays bf16; the PE table is quantized explicitly before + # calling quantize_model. + QuantRule(r"vision_tower\..*", None), + QuantRule(r"embed_vision\..*", None), QuantRule(r".*norm\.weight", None), QuantRule(r".*\.weight", QUANT_4W), - ] + ], ) @@ -99,6 +119,7 @@ def decode(self, ids): def config_dict() -> dict: cfg = TINY_CONFIG + vc = cfg.vision_config return { "vocab_size": cfg.vocab_size, "hidden_size": cfg.hidden_size, @@ -123,13 +144,30 @@ def config_dict() -> dict: "tie_word_embeddings": cfg.tie_word_embeddings, "sliding_window": cfg.sliding_window, "layer_types": cfg.layer_types, + "vision_config": { + "hidden_size": vc.hidden_size, + "intermediate_size": vc.intermediate_size, + "num_hidden_layers": vc.num_hidden_layers, + "num_attention_heads": vc.num_attention_heads, + "num_key_value_heads": vc.num_key_value_heads, + "head_dim": vc.head_dim, + "hidden_activation": vc.hidden_activation, + "rms_norm_eps": vc.rms_norm_eps, + "patch_size": vc.patch_size, + "pooling_kernel_size": vc.pooling_kernel_size, + "position_embedding_size": vc.position_embedding_size, + "max_position_embeddings": vc.max_position_embeddings, + "rope_parameters": {"rope_theta": vc.rope_theta}, + "standardize": vc.standardize, + "use_clipped_linears": vc.use_clipped_linears, + "default_output_length": vc.default_output_length, + }, } def build_random_tiny_model() -> Gemma4_31B: torch.manual_seed(42) - model = Gemma4_31B(TINY_CONFIG) - model.to(dtype=torch.bfloat16) + model = Gemma4_31B(TINY_CONFIG).to(dtype=torch.bfloat16) for p in model.parameters(): if p.device.type != "meta": p.data.normal_(0, 0.02) @@ -140,6 +178,7 @@ def build_random_tiny_model() -> Gemma4_31B: def save_checkpoint(output_dir: str): model = build_random_tiny_model() model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + quantize_vision_position_table(model.vision_tower) state_dict = quantize_model(model, DEFAULT_RECIPE) os.makedirs(output_dir, exist_ok=True) td, md = flatten_tensor_state_dict(state_dict) @@ -152,7 +191,24 @@ def build_hf_checkpoint(output_dir: str) -> None: model = build_random_tiny_model() sd = model.state_dict() sd.pop("lm_head.weight", None) - hf_sd = {f"model.language_model.{k}": v.contiguous() for k, v in sd.items()} + + # HF wraps each vision encoder-layer attn/mlp projection in + # Gemma4ClippableLinear, which exposes the weight under a ".linear." + # segment. Mirror that here so the synthetic checkpoint's keys match + # hf_vision_per_layer_key_map() (which drops ".linear." when remapping). + _clippable = re.compile( + r"vision_tower\.encoder\.layers\.\d+\." + r"(self_attn\.[qkvo]_proj|mlp\.(gate|up|down)_proj)\.weight$" + ) + + def _to_hf_key(key: str) -> str: + if key.startswith(("vision_tower.", "embed_vision.")): + if _clippable.fullmatch(key): + key = key[: -len(".weight")] + ".linear.weight" + return f"model.{key}" + return f"model.language_model.{key}" + + hf_sd = {_to_hf_key(key): value.contiguous() for key, value in sd.items()} save_file(hf_sd, os.path.join(output_dir, "model.safetensors")) with open(os.path.join(output_dir, "config.json"), "w") as f: json.dump(config_dict(), f) @@ -170,6 +226,7 @@ def test_roundtrip_preserves_weights(self): model = build_random_tiny_model() model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + quantize_vision_position_table(model.vision_tower) state_dict = quantize_model(model, DEFAULT_RECIPE) with tempfile.TemporaryDirectory() as tmpdir: @@ -182,6 +239,9 @@ def test_roundtrip_preserves_weights(self): loaded, _ = unflatten_tensor_state_dict(loaded_tensors, loaded_meta) self.assertEqual(set(state_dict.keys()), set(loaded.keys())) + self.assertIn("vision_tower.patch_embedder._pet_int8", loaded) + self.assertIn("vision_tower.patch_embedder._pet_scale", loaded) + self.assertIn("embed_vision.embedding_projection.weight", loaded) for fqn in state_dict: orig = state_dict[fqn] got = loaded[fqn] @@ -201,6 +261,7 @@ def test_embedding_quantized_as_int8(self): model = build_random_tiny_model() model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + quantize_vision_position_table(model.vision_tower) state_dict = quantize_model(model, DEFAULT_RECIPE) self.assertIn("embed_tokens.weight", state_dict) @@ -306,5 +367,21 @@ def test_ignored_key_returns_none(self): self.assertIsNone(gguf_to_model_key("rope_freqs.weight")) +class TestVisionConfigRequired(unittest.TestCase): + def test_from_hf_config_raises_when_vision_config_missing(self): + """Gemma 4 31B is multimodal-by-default \u2014 a config.json missing the + ``vision_config`` block must be treated as a corrupt checkpoint and + ``Gemma4_31BConfig.from_hf_config`` must raise.""" + # Reuse the text-only tiny config but strip the vision_config block. + cfg_payload = config_dict() + cfg_payload.pop("vision_config", None) + with tempfile.TemporaryDirectory() as d: + cfg_path = os.path.join(d, "config.json") + with open(cfg_path, "w") as f: + json.dump(cfg_payload, f) + with self.assertRaises(ValueError): + Gemma4_31BConfig.from_hf_config(cfg_path) + + if __name__ == "__main__": unittest.main() diff --git a/examples/models/gemma4_31b/tests/test_vision_quant_roundtrip.py b/examples/models/gemma4_31b/tests/test_vision_quant_roundtrip.py new file mode 100644 index 00000000000..ee267d4cd80 --- /dev/null +++ b/examples/models/gemma4_31b/tests/test_vision_quant_roundtrip.py @@ -0,0 +1,298 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Validation gates for the gemma4_31b vision-tower quantization recipe. + +All tests build a random-init bf16 baseline of our own model/tower as +the reference. No external HF checkpoint is required. + +Tests: + +* ``test_pe_int8_quantize_and_install_roundtrip`` -- snapshot a random- + init bf16 tower's output, apply ``quantize_vision_position_table``, + collect the state dict, reinstall it on a freshly-built tower via + ``install_int8_pe_dispatch``, and verify the output round-trips with + cosine_sim > 0.999. + +* ``test_unified_recipe_preserves_vision_bf16_and_quantizes_pe`` -- + build a tiny Gemma4_31B with vision attached, run ``quantize_model`` + with the unified recipe, and verify the vision linears stay bf16 + while the PE table is swapped to int8 buffers. + +* ``test_has_vision_keys_*`` -- ``has_vision_keys`` sniff test on plain + safetensors files. +""" + +from __future__ import annotations + +import os +import sys +import tempfile + +import pytest +import torch +import torch.nn as nn + +from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig +from executorch.examples.models.gemma4_31b.pack_vision import ( + collect_vision_state_dict, + has_vision_keys, + install_int8_pe_dispatch, + quantize_vision_position_table, +) +from executorch.examples.models.gemma4_31b.vision_tower import ( + Gemma4_31BVisionTower, + Gemma4VisionConfig, +) +from safetensors.torch import save_file +from torchao.prototype.safetensors.safetensors_support import flatten_tensor_state_dict + + +# --------------------------------------------------------------------------- +# Test 1 -- unified recipe leaves vision linears bf16 and quantizes the PE table. +# --------------------------------------------------------------------------- + + +def _tiny_recipe(hidden_size: int): + """Tiny analogue of GEMMA4_31B_DEFAULT_RECIPE for unit-test models. + + Production recipe uses group_size=hidden_size (5376) for the per-axis + embedding INT8 quant; that doesn't fit a 64-d test model. Functional + shape is identical: INT8 per-axis embed, skip vision side + norms, + INT4 elsewhere. + """ + from executorch.examples.models.gemma4_31b.quant import ( + QuantConfig, + QuantRecipe, + QuantRule, + ) + + int4 = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + int8_per_axis = QuantConfig( + bits=8, group_size=hidden_size, symmetric=True, method="min_max" + ) + return QuantRecipe( + rules=[ + QuantRule(r"embed_tokens\.weight", int8_per_axis), + # Vision modality stays bf16; PE table is quantized explicitly before + # calling quantize_model. + QuantRule(r"vision_tower\..*", None), + QuantRule(r"embed_vision\..*", None), + QuantRule(r".*norm\.weight", None), + QuantRule(r".*\.weight", int4), + ], + ) + + +def test_unified_recipe_preserves_vision_bf16_and_quantizes_pe(): + """Build a model WITH vision attached, run the unified recipe, and check: + + * vision_tower.* and embed_vision.* linear weights are saved (no + detach hack required); + * NO vision linear is quantized to Int4Tensor / + IntxUnpackedToInt8Tensor (they all stay bf16); + * the PE table has been swapped to int8 buffers (_pet_int8 + + _pet_scale) by quantize_model itself. + + The unified ``quantize_model`` API handles vision + text in a single + pass, replacing the old ``del model.vision_tower`` pattern. + """ + from executorch.examples.models.gemma4_31b.quant import quantize_model + from torchao.quantization import IntxUnpackedToInt8Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + cfg = Gemma4_31BConfig( + vocab_size=128, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + head_dim=32, + num_global_key_value_heads=1, + global_head_dim=64, + attention_k_eq_v=True, + sliding_window=8, + max_seq_len=32, + vision_config=Gemma4VisionConfig( + hidden_size=64, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + head_dim=32, + patch_size=4, + pooling_kernel_size=2, + position_embedding_size=32, + standardize=True, + ), + ) + torch.manual_seed(0) + model = Gemma4_31B(cfg).to(dtype=torch.bfloat16) + for p in model.parameters(): + if p.device.type != "meta": + p.data.normal_(0, 0.02) + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + + # Sanity: vision is attached. + assert hasattr(model, "vision_tower") + assert hasattr(model, "embed_vision") + + quantize_vision_position_table(model.vision_tower) + state_dict = quantize_model(model, _tiny_recipe(cfg.hidden_size)) + + # Vision-side linears must be present AND must stay bf16 (plain Tensor, + # not a quantized subclass). + vision_param_keys = [ + k + for k in state_dict + if (k.startswith("vision_tower.") or k.startswith("embed_vision.")) + and k.endswith(".weight") + ] + assert vision_param_keys, "expected vision_tower / embed_vision weights to be saved" + for k in vision_param_keys: + v = state_dict[k] + assert not isinstance(v, (Int4Tensor, IntxUnpackedToInt8Tensor)), ( + f"vision weight {k} was quantized ({type(v).__name__}); " + "the vision_tower / embed_vision recipe rules should keep it bf16" + ) + assert ( + v.dtype == torch.bfloat16 + ), f"vision weight {k} dtype is {v.dtype}, expected bfloat16" + + # PE table is swapped to int8 buffers by quantize_model. + assert "vision_tower.patch_embedder._pet_int8" in state_dict + assert state_dict["vision_tower.patch_embedder._pet_int8"].dtype == torch.int8 + assert "vision_tower.patch_embedder._pet_scale" in state_dict + assert state_dict["vision_tower.patch_embedder._pet_scale"].dtype == torch.float32 + # Sanity: the bf16 PE Parameter is gone from the saved keys. + assert "vision_tower.patch_embedder.position_embedding_table" not in state_dict + + +# --------------------------------------------------------------------------- +# Test 2 -- has_vision_keys() detects both kinds of saves. +# --------------------------------------------------------------------------- + + +def test_has_vision_keys_text_only(): + """A safetensors with no vision keys -> has_vision_keys returns False.""" + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + plain = { + "layers.0.input_layernorm.weight": torch.randn(8, dtype=torch.bfloat16) + } + td, md = flatten_tensor_state_dict(plain) + save_file(td, path, metadata=md) + assert has_vision_keys(path) is False + + +def test_has_vision_keys_with_vision(): + """A safetensors with vision_tower.* -> has_vision_keys returns True.""" + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + plain = { + "embed_tokens.weight": torch.randn(16, 8, dtype=torch.bfloat16), + "vision_tower.encoder.layers.0.input_layernorm.weight": torch.randn( + 4, dtype=torch.bfloat16 + ), + } + td, md = flatten_tensor_state_dict(plain) + save_file(td, path, metadata=md) + assert has_vision_keys(path) is True + + +# --------------------------------------------------------------------------- +# Test 3 -- quantize PE table + reinstall round-trips against the bf16 ref. +# --------------------------------------------------------------------------- + + +def test_pe_int8_quantize_and_install_roundtrip(): + """End-to-end PE-int8 round-trip on a random-init bf16 tower. + + Snapshot the bf16 reference output, apply + ``quantize_vision_position_table`` in place, then collect the state + dict and reinstall it on a freshly-built tower via + ``install_int8_pe_dispatch`` + ``load_state_dict``. Cosine sim vs the + bf16 reference must exceed 0.999 after the int8 swap, and the + reload-then-forward path must match the in-place quantized forward to + > 0.99999. + """ + cfg = Gemma4VisionConfig( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + head_dim=16, + patch_size=4, + pooling_kernel_size=2, + position_embedding_size=16, + standardize=True, + ) + torch.manual_seed(0) + tower = Gemma4_31BVisionTower(cfg, text_hidden_size=64).to(dtype=torch.bfloat16) + for p in tower.parameters(): + p.data.normal_(0, 0.02) + + # Snapshot reference output PRE-quant. + g = torch.Generator().manual_seed(0) + pv = torch.rand(1, 16, cfg.patch_dim, generator=g, dtype=torch.bfloat16) + coords = torch.arange(4) + yy, xx = torch.meshgrid(coords, coords, indexing="ij") + pp = torch.stack([xx.flatten(), yy.flatten()], -1).unsqueeze(0).long() + with torch.no_grad(): + ref_emb, ref_mask = tower(pv, pp) + + # Quantize PE table in-place. + quantize_vision_position_table(tower.vision_tower, verbose=False) + assert hasattr(tower.vision_tower.patch_embedder, "_pet_int8") + assert hasattr(tower.vision_tower.patch_embedder, "_pet_scale") + assert not hasattr(tower.vision_tower.patch_embedder, "position_embedding_table") + + with torch.no_grad(): + post_emb, post_mask = tower(pv, pp) + assert torch.equal(ref_mask, post_mask) + cos = torch.nn.functional.cosine_similarity( + ref_emb.flatten().float(), post_emb.flatten().float(), dim=0 + ).item() + assert cos > 0.999, f"PE-int8 round-trip cosine {cos} too low" + + # Collect and reinstall on a fresh tower. + state = collect_vision_state_dict(tower.vision_tower, tower.embed_vision) + assert "vision_tower.patch_embedder._pet_int8" in state + assert state["vision_tower.patch_embedder._pet_int8"].dtype == torch.int8 + assert "vision_tower.patch_embedder._pet_scale" in state + assert state["vision_tower.patch_embedder._pet_scale"].dtype == torch.float32 + + fresh = Gemma4_31BVisionTower(cfg, text_hidden_size=64).to(dtype=torch.bfloat16) + install_int8_pe_dispatch(fresh.vision_tower, verbose=False) + # Replace meta-buffers with real loaded ones. + pe = fresh.vision_tower.patch_embedder + pe._pet_int8 = state["vision_tower.patch_embedder._pet_int8"].clone() + pe._pet_scale = state["vision_tower.patch_embedder._pet_scale"].clone() + # Load the rest via load_state_dict (skip _pet_* -- we already set them). + nested = { + k: v + for k, v in state.items() + if not k.endswith("._pet_int8") and not k.endswith("._pet_scale") + } + missing, unexpected = fresh.load_state_dict(nested, strict=False) + blocking_missing = [m for m in missing if "patch_embedder._pet_" not in m] + # encoder.rotary_emb.inv_freq is non-persistent -> may be in `missing` + blocking_missing = [m for m in blocking_missing if not m.endswith(".inv_freq")] + assert not blocking_missing, f"Reinstall missing keys: {blocking_missing}" + + with torch.no_grad(): + reloaded_emb, reloaded_mask = fresh(pv, pp) + assert torch.equal(reloaded_mask, post_mask) + cos2 = torch.nn.functional.cosine_similarity( + reloaded_emb.flatten().float(), post_emb.flatten().float(), dim=0 + ).item() + assert cos2 > 0.99999, f"Reinstall round-trip cosine {cos2} too low" + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-v", "-s"])) diff --git a/examples/models/gemma4_31b/tests/test_vision_tower.py b/examples/models/gemma4_31b/tests/test_vision_tower.py new file mode 100644 index 00000000000..ee4fe2f4269 --- /dev/null +++ b/examples/models/gemma4_31b/tests/test_vision_tower.py @@ -0,0 +1,233 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Validation gate for the Gemma 4 31B vision tower port. + +Vision is always on; there is exactly one model shape (text + vision). + +Tests: + 1. ``test_forward_signature_takes_inputs_embeds`` -- forward signature is the + unified-prefill shape. + 2. ``test_multimodal_methods_present`` -- embed_text + decode_forward exist; + legacy prefill_image is GONE. + 3. ``test_decode_forward_equivalent_to_embed_then_forward`` -- the fused + decode_forward matches forward(embed_text(tokens)) within bf16 tolerance + on a tiny, random-init model (no external checkpoint required). + 4. ``test_vision_tower_random_init_forward_smoke`` -- random-init our ported + vision tower (bf16), run the forward, and verify shapes / finite outputs. + +No external HF checkpoint is required. Earlier revisions of this file +compared our port against ``Gemma4ForConditionalGeneration`` loaded from +disk, but that gated the entire suite on a hardcoded path. The +quantization-roundtrip tests now use a random-init bf16 baseline of OUR +tower as the reference, which is sufficient for catching regressions in +the port without needing the upstream weights. +""" + +from __future__ import annotations + +import sys + +import pytest +import torch + +from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig +from executorch.examples.models.gemma4_31b.vision_tower import ( + Gemma4_31BVisionTower, + Gemma4VisionConfig, +) + + +# --------------------------------------------------------------------------- +# Shared tiny config (no external checkpoint required). +# --------------------------------------------------------------------------- + + +def _tiny_vision_config() -> Gemma4VisionConfig: + return Gemma4VisionConfig( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + head_dim=16, + patch_size=4, + pooling_kernel_size=2, + position_embedding_size=16, + standardize=True, + ) + + +def _tiny_model_config() -> Gemma4_31BConfig: + """A minimal Gemma4_31B config sized so the test runs in seconds on CPU.""" + return Gemma4_31BConfig( + vocab_size=256, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + num_global_key_value_heads=1, + global_head_dim=16, + attention_k_eq_v=True, + sliding_window=16, + max_seq_len=64, + layer_types=["sliding_attention", "full_attention"], + vision_config=_tiny_vision_config(), + ) + + +def _make_vision_inputs( + batch: int, + grid: int, + patch_dim: int, + seed: int = 0, + dtype: torch.dtype = torch.bfloat16, +): + """Build a deterministic (pixel_values, pixel_position_ids) pair. + + ``num_patches = grid*grid`` must divide cleanly by + ``pooling_kernel_size**2`` (the pooler reshape constraint). + """ + g = torch.Generator().manual_seed(seed) + num_patches = grid * grid + pixel_values = torch.rand(batch, num_patches, patch_dim, generator=g, dtype=dtype) + coords = torch.arange(grid, dtype=torch.long) + yy, xx = torch.meshgrid(coords, coords, indexing="ij") + pos = torch.stack([xx.flatten(), yy.flatten()], dim=-1) # [G*G, 2] + pixel_position_ids = pos.unsqueeze(0).expand(batch, -1, -1).contiguous() + return pixel_values, pixel_position_ids + + +# --------------------------------------------------------------------------- +# Method shape contract: 4-method unified design (orchestrator pin #4). +# --------------------------------------------------------------------------- + + +def test_forward_signature_takes_inputs_embeds(): + """`Gemma4_31B.forward` is the unified prefill: takes inputs_embeds, NOT tokens.""" + import inspect + + sig = inspect.signature(Gemma4_31B.forward) + params = list(sig.parameters.values()) + assert [p.name for p in params] == [ + "self", + "inputs_embeds", + "input_pos", + "temperature", + ], f"forward signature: {[p.name for p in params]}" + assert params[3].default is None, "temperature must default to None" + + +def test_multimodal_methods_present(): + """Confirm the 4-method contract: forward + decode_forward + embed_text exist; + the legacy prefill_image method is removed.""" + import inspect + + for name in ("embed_text", "decode_forward", "forward"): + assert hasattr(Gemma4_31B, name), f"Gemma4_31B missing required method: {name}" + assert not hasattr( + Gemma4_31B, "prefill_image" + ), "prefill_image should be removed in the 4-method contract" + embed_sig = inspect.signature(Gemma4_31B.embed_text) + assert [p.name for p in embed_sig.parameters.values()] == [ + "self", + "tokens", + ], f"embed_text signature: {embed_sig}" + decode_sig = inspect.signature(Gemma4_31B.decode_forward) + names = [p.name for p in decode_sig.parameters.values()] + assert names == [ + "self", + "tokens", + "input_pos", + "temperature", + ], f"decode_forward signature: {decode_sig}" + + +def test_decode_forward_equivalent_to_embed_then_forward(): + """Eager equivalence: ``decode_forward(tokens, ...)`` must produce the same + output as ``forward(embed_text(tokens), ...)`` on the same input. Both + paths internally call ``sample()`` (Gumbel-max), so the RNG is seeded + identically before each call to keep the noise term equal across runs. + """ + cfg = _tiny_model_config() + torch.manual_seed(0) + model = Gemma4_31B(cfg).eval() + tokens = torch.randint(0, cfg.vocab_size, (1, 8), dtype=torch.long) + input_pos = torch.arange(8, dtype=torch.long) + # Small temperature -> near-greedy Gumbel-max. With matching RNG seeds + # on both branches the sampled token IDs must agree exactly. + temperature = torch.tensor([1.0], dtype=torch.float32) + + with torch.no_grad(): + torch.manual_seed(123) + out_decode = model.decode_forward(tokens, input_pos, temperature).clone() + # Re-run via embed_text + forward. KV slots are overwritten with the + # same values (input_pos starts at 0 again). + torch.manual_seed(123) + inputs_embeds_fp32 = model.embed_text(tokens).to(torch.float32) + out_via_embeds = model.forward(inputs_embeds_fp32, input_pos, temperature) + + assert torch.equal( + out_decode, out_via_embeds + ), f"decode_forward vs embed+forward diverged: {out_decode} vs {out_via_embeds}" + + +# --------------------------------------------------------------------------- +# Random-init forward smoke test for the vision tower. +# --------------------------------------------------------------------------- + + +def test_vision_tower_random_init_forward_smoke(): + """Random-init our ported vision tower (bf16) and verify the forward + produces a finite tensor of the expected shape on a small fixed input. + + This replaces the old HF-parity gate; we no longer require the + upstream Gemma4 checkpoint to validate that the tower wiring is + structurally correct. Numerical parity against HF still belongs in a + separate offline gate that has access to the bf16 reference weights. + """ + torch.manual_seed(0) + dtype = torch.bfloat16 + + cfg = _tiny_vision_config() + text_hidden_size = 64 + grid = 4 # 4x4 = 16 patches, divides cleanly by pooling_kernel_size**2 = 4 + + tower = ( + Gemma4_31BVisionTower(cfg, text_hidden_size=text_hidden_size).to(dtype).eval() + ) + # Spread the random weights a bit so the smoke test exercises non-zero + # activations through every sub-module. + for p in tower.parameters(): + if p.device.type != "meta": + p.data.normal_(0, 0.02) + + pixel_values, pixel_position_ids = _make_vision_inputs( + batch=1, grid=grid, patch_dim=cfg.patch_dim, dtype=dtype + ) + + with torch.no_grad(): + emb, mask = tower(pixel_values, pixel_position_ids) + + # The pooler collapses pks*pks patches into one soft token. + expected_soft_tokens = (grid * grid) // (cfg.pooling_kernel_size**2) + assert emb.shape == ( + 1, + expected_soft_tokens, + text_hidden_size, + ), f"unexpected vision_tower output shape: {emb.shape}" + assert mask.shape == ( + 1, + expected_soft_tokens, + ), f"unexpected pooler mask shape: {mask.shape}" + assert emb.dtype == dtype + assert torch.isfinite(emb).all(), "vision_tower produced non-finite values" + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-v", "-s"])) diff --git a/examples/models/gemma4_31b/vision_tower.py b/examples/models/gemma4_31b/vision_tower.py new file mode 100644 index 00000000000..91bff2be6c1 --- /dev/null +++ b/examples/models/gemma4_31b/vision_tower.py @@ -0,0 +1,759 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Gemma 4 31B vision tower — self-contained PyTorch port. + +This module mirrors the vision tower of HuggingFace's +``transformers.models.gemma4.modeling_gemma4`` (Gemma4VisionModel + +Gemma4MultimodalEmbedder), plus the standardize step and the post-pool soft-token +projection. It contains no transformers imports at runtime, so it is safe to +``torch.export(strict=True)`` and ship in the ExecuTorch binary. + +Final 4-method export contract (locked-in by orchestrator, pin #4) +================================================================== + +The exported .pte ships with **4 methods**. ``forward`` (used by the exported +``prefill`` method) takes pre-computed embeddings, so a single code path covers +both text-only and image+text. The runner stitches the inputs together using +``embed_text`` (and, for images, ``vision_encoder``) before calling ``prefill``. + + 1. ``embed_text(tokens) -> embeds [B,T,5376] bf16`` + Pure embed_tokens lookup + ``sqrt(hidden_size)`` scale, returned as bf16. + 2. ``vision_encoder(pixel_values [B,P,768] f32, pixel_position_ids [B,P,2] i64) + -> (image_embeds [B,N,5376] bf16, mask [B,N] bool)`` + This module's ``Gemma4_31BVisionTower``. + 3. ``prefill(inputs_embeds [B,T,5376] bf16, input_pos [T] i64, temperature [1] f32) + -> sampled [B,1] f32`` + UNIFIED. Used for BOTH text-only and image+text. Maps to + ``Gemma4_31B.forward``. + 4. ``decode(tokens [B,1] i64, input_pos [1] i64, temperature [1] f32) + -> sampled [B,1] f32`` + Single-token decode, token-input. Maps to ``Gemma4_31B.decode_forward``. + +Multimodal prefill flow (runner-side): + + text_embeds = embed_text(tokens) # [1, T, 5376] bf16 + image_embeds, mask = vision_encoder(pixel_values, pixel_position_ids) + # In-place splice: for every i where tokens[i] == image_token_id (258880), + # overwrite text_embeds[:, i, :] with the next valid row of image_embeds + # (skipping rows where mask is False — those are padding soft tokens). + sampled = prefill(text_embeds, input_pos, temperature) + # then per-token decode loop using `decode(tokens, input_pos, temperature)`. + +Vision tower output is NOT pre-scaled by ``sqrt(hidden_size)`` (matches HF). Only +``embed_text`` applies that scale, so text rows of ``inputs_embeds`` are scaled +and image rows are not — same convention HF uses. + +Numerical contract +================== + +For a fp32 reference run: + + cosine_sim( hf_wrapper(pixel_values, pixel_position_ids), + Gemma4_31BVisionTower(pixel_values, pixel_position_ids) ) > 0.99999 + +(See ``tests/test_vision_tower.py``.) +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass + +import torch +import torch.nn as nn + +# Reuse the gemma4 text-decoder primitives that are numerically identical +# between the LM and the vision tower: Gemma4MLP (same SwiGLU GELU-tanh block, +# same gate/up/down submodule names) and rotate_half (same HF-style rotary +# helper). RMSNorm uses torch's nn.RMSNorm directly -- numerically identical to +# HF's Gemma4RMSNorm (float32 upcast + pow(mean_squared, -0.5)); the weightless +# V-norm / pre-projection norm use nn.RMSNorm(..., elementwise_affine=False). +from executorch.examples.models.gemma4.text_decoder import Gemma4MLP, rotate_half +from torch.nn import functional as F + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +@dataclass +class Gemma4VisionConfig: + """Mirror of HF ``Gemma4VisionConfig`` for the bits we actually use.""" + + hidden_size: int = 1152 + intermediate_size: int = 4304 + num_hidden_layers: int = 27 + num_attention_heads: int = 16 + num_key_value_heads: int = 16 + head_dim: int = 72 + hidden_activation: str = "gelu_pytorch_tanh" + rms_norm_eps: float = 1e-6 + patch_size: int = 16 + pooling_kernel_size: int = 3 + position_embedding_size: int = 10240 + max_position_embeddings: int = 131072 + rope_theta: float = 100.0 + standardize: bool = True + use_clipped_linears: bool = ( + False # 31B doesn't clip — checkpoint has no clamp params. + ) + default_output_length: int = 280 # vision_soft_tokens_per_image at top level + + # Channels per spatial axis in the patchified input (RGB * patch_size^2). + in_channels: int = 3 + + @staticmethod + def from_hf_config(hf_cfg_path_or_dict) -> "Gemma4VisionConfig": + """Build from the top-level HF config dict OR a path to ``config.json``. + + Reads the ``vision_config`` block (and ``vision_soft_tokens_per_image``). + Returns ``None`` if there is no vision_config in the file. + """ + if isinstance(hf_cfg_path_or_dict, str): + with open(hf_cfg_path_or_dict, "r") as f: + top = json.load(f) + else: + top = hf_cfg_path_or_dict + + vc = top.get("vision_config", None) + if vc is None: + return None + + rope_params = vc.get("rope_parameters", {}) or {} + rope_theta = rope_params.get("rope_theta", 100.0) + + default_output_length = vc.get( + "default_output_length", + top.get("vision_soft_tokens_per_image", 280), + ) + + return Gemma4VisionConfig( + hidden_size=vc.get("hidden_size", 1152), + intermediate_size=vc.get("intermediate_size", 4304), + num_hidden_layers=vc.get("num_hidden_layers", 27), + num_attention_heads=vc.get("num_attention_heads", 16), + num_key_value_heads=vc.get("num_key_value_heads", 16), + head_dim=vc.get("head_dim", 72), + hidden_activation=vc.get("hidden_activation", "gelu_pytorch_tanh"), + rms_norm_eps=vc.get("rms_norm_eps", 1e-6), + patch_size=vc.get("patch_size", 16), + pooling_kernel_size=vc.get("pooling_kernel_size", 3), + position_embedding_size=vc.get("position_embedding_size", 10240), + max_position_embeddings=vc.get("max_position_embeddings", 131072), + rope_theta=rope_theta, + standardize=vc.get("standardize", True), + use_clipped_linears=vc.get("use_clipped_linears", False), + default_output_length=default_output_length, + ) + + @property + def patch_dim(self) -> int: + return self.in_channels * self.patch_size * self.patch_size + + +# --------------------------------------------------------------------------- +# Patch embedder +# --------------------------------------------------------------------------- + + +class Gemma4VisionPatchEmbedder(nn.Module): + """HF ``Gemma4VisionPatchEmbedder``: rescale → linear → 2D position lookup.""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.position_embedding_size = config.position_embedding_size + + # Linear from patch_dim (3*16*16=768) to hidden_size, no bias. + self.input_proj = nn.Linear(config.patch_dim, self.hidden_size, bias=False) + # 2 axes × position_embedding_size positions × hidden_size. + self.position_embedding_table = nn.Parameter( + torch.zeros(2, self.position_embedding_size, self.hidden_size) + ) + + def _position_embeddings( + self, + pixel_position_ids: torch.Tensor, # [B, P, 2] + padding_positions: torch.Tensor, # [B, P] (True = padding) + ) -> torch.Tensor: + """2D positional lookup. Numerically identical to HF's one_hot @ table form, + but uses ``F.embedding`` for clarity / speed.""" + clamped = pixel_position_ids.clamp(min=0).long() # [B, P, 2] + # axis 0 is x (column), axis 1 is y (row). + emb_x = F.embedding(clamped[..., 0], self.position_embedding_table[0]) + emb_y = F.embedding(clamped[..., 1], self.position_embedding_table[1]) + pos_emb = emb_x + emb_y + # Zero-out padding patches. + pos_emb = torch.where( + padding_positions.unsqueeze(-1), torch.zeros_like(pos_emb), pos_emb + ) + return pos_emb + + def forward( + self, + pixel_values: torch.Tensor, # [B, P, patch_dim] + pixel_position_ids: torch.Tensor, # [B, P, 2] + padding_positions: torch.Tensor, # [B, P] + ) -> torch.Tensor: + # Rescale [0,1] → [-1,1] (HF does ``2*(x-0.5)``). + pixel_values = 2 * (pixel_values - 0.5) + hidden_states = self.input_proj(pixel_values.to(self.input_proj.weight.dtype)) + position_embeddings = self._position_embeddings( + pixel_position_ids, padding_positions + ) + return hidden_states + position_embeddings + + +# --------------------------------------------------------------------------- +# Pooler +# --------------------------------------------------------------------------- + + +class Gemma4VisionPooler(nn.Module): + """HF ``Gemma4VisionPooler``: zero out padding, optional 2D avg-pool, * sqrt(d).""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.root_hidden_size = self.hidden_size**0.5 + + def _avg_pool_by_positions( + self, + hidden_states: torch.Tensor, # [B, P, D] + pixel_position_ids: torch.Tensor, # [B, P, 2] + length: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + input_seq_len = hidden_states.shape[1] + k = int((input_seq_len // length) ** 0.5) + k_squared = k * k + if k_squared * length != input_seq_len: + raise ValueError( + f"Cannot pool {hidden_states.shape} to {length}: k={k}^2 * length={length} " + f"must equal {input_seq_len}." + ) + + # Padding patches contribute zero (their hidden states are masked to zero + # before this is called). Clamp -1's so one_hot doesn't explode. + clamped_positions = pixel_position_ids.clamp(min=0) + max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1 + kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor") + kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1] + weights = F.one_hot(kernel_idxs.long(), length).float() / k_squared + output = weights.transpose(1, 2) @ hidden_states.float() + mask = torch.logical_not((weights == 0).all(dim=1)) + return output.to(hidden_states.dtype), mask + + def forward( + self, + hidden_states: torch.Tensor, # [B, P, D] + pixel_position_ids: torch.Tensor, # [B, P, 2] + padding_positions: torch.Tensor, # [B, P] + output_length: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + if output_length > hidden_states.shape[1]: + raise ValueError( + f"Cannot output more soft tokens (requested {output_length}) than there are " + f"patches ({hidden_states.shape[1]})." + ) + + hidden_states = hidden_states.masked_fill(padding_positions.unsqueeze(-1), 0.0) + + if hidden_states.shape[1] != output_length: + hidden_states, padding_positions = self._avg_pool_by_positions( + hidden_states, pixel_position_ids, output_length + ) + # If no pooling is needed, padding_positions is already True=padding; + # the wrapper expects pooler_mask = True=valid, so flip below. + else: + padding_positions = ~padding_positions # now True = valid + + hidden_states = hidden_states * self.root_hidden_size + return hidden_states, padding_positions + + +# --------------------------------------------------------------------------- +# Attention (with 2D RoPE) +# +# The SwiGLU GELU-tanh MLP is shared with the text decoder (Gemma4MLP) and the +# rotary helper is the shared ``rotate_half`` -- both imported above. The vision +# config always uses ``gelu_pytorch_tanh``, which is exactly what Gemma4MLP +# implements. +# --------------------------------------------------------------------------- + + +def _apply_rotary_pos_emb( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 2 +) -> torch.Tensor: + """HF ``apply_rotary_pos_emb`` for a single spatial axis. + + Input: + x: [B, P, H, head_dim/ndim] + cos: [B, P, head_dim/ndim] + sin: [B, P, head_dim/ndim] + Returns: same shape as x. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + return (x * cos) + (rotate_half(x) * sin) + + +def _apply_multidimensional_rope( + x: torch.Tensor, # [B, P, H, head_dim] + cos: torch.Tensor, # [B, P, head_dim] + sin: torch.Tensor, # [B, P, head_dim] + position_ids: torch.Tensor, # [B, P, ndim] (unused except to read ndim) +) -> torch.Tensor: + """Mirror of HF ``apply_multidimensional_rope`` for ndim=2 (image x,y). + + Splits ``x`` (and cos/sin) into ``ndim`` equal chunks along last dim, applies + RoPE per chunk with the corresponding (cos, sin) chunk, concatenates. + """ + ndim = position_ids.shape[-1] + num_input_channels = x.shape[-1] + num_rotated_channels_per_dim = 2 * (num_input_channels // (2 * ndim)) + if num_rotated_channels_per_dim <= 0: + raise ValueError( + f"num_rotated_channels_per_dim must be > 0; got " + f"num_input_channels={num_input_channels} ndim={ndim}" + ) + split_sizes = [num_rotated_channels_per_dim] * ndim + x_parts = torch.split(x, split_sizes, dim=-1) + cos_parts = torch.split(cos, split_sizes, dim=-1) + sin_parts = torch.split(sin, split_sizes, dim=-1) + y_parts = [ + _apply_rotary_pos_emb(x_parts[k], cos_parts[k], sin_parts[k], unsqueeze_dim=2) + for k in range(ndim) + ] + return torch.cat(y_parts, dim=-1) + + +class Gemma4VisionRotaryEmbedding(nn.Module): + """HF ``Gemma4VisionRotaryEmbedding`` (default RoPE only — vision uses theta=100). + + Computes (cos, sin) per spatial axis and concatenates them so each half of + the head_dim gets rotated by its own axis. + """ + + inv_freq: torch.Tensor # for type checkers + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.head_dim = config.head_dim + self.rope_theta = config.rope_theta + # head_dim is split across 2 spatial axes → spatial_dim = head_dim // 2. + spatial_dim = self.head_dim // 2 + # range(0, spatial_dim, 2) gives spatial_dim // 2 frequencies. + # NOTE: HF divides by ``spatial_dim`` (not head_dim) for the exponent. + inv_freq = 1.0 / ( + self.rope_theta + ** ( + torch.arange(0, spatial_dim, 2, dtype=torch.int64).float() / spatial_dim + ) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward( + self, + x: torch.Tensor, + position_ids: torch.Tensor, # [B, P, 2] + ) -> tuple[torch.Tensor, torch.Tensor]: + all_cos: list[torch.Tensor] = [] + all_sin: list[torch.Tensor] = [] + # [n_freqs] -> [B, n_freqs, 1] + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + for i in range(2): + # [B, P] + dim_position_ids = position_ids[:, :, i] + # [B, 1, P] + dim_position_ids_expanded = dim_position_ids[:, None, :].float() + # [B, n_freqs, P] -> [B, P, n_freqs] + freqs = (inv_freq_expanded @ dim_position_ids_expanded).transpose(1, 2) + emb = torch.cat( + (freqs, freqs), dim=-1 + ) # [B, P, 2*n_freqs] = [B, P, head_dim/2] + all_cos.append(emb.cos()) + all_sin.append(emb.sin()) + cos = torch.cat(all_cos, dim=-1).to(dtype=x.dtype) # [B, P, head_dim] + sin = torch.cat(all_sin, dim=-1).to(dtype=x.dtype) + return cos, sin + + +class Gemma4VisionAttention(nn.Module): + """Multi-head bidirectional attention with QK-norm and 2D RoPE. ``scaling=1``.""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.head_dim = config.head_dim + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.hidden_size = config.hidden_size + self.scaling = 1.0 # QK-norm absorbs 1/sqrt(d) — identical to HF. + + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_kv_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_kv_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + + # Q/K-norm have a learnable scale, V-norm does not. + self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.v_norm = nn.RMSNorm( + self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False + ) + + def forward( + self, + hidden_states: torch.Tensor, # [B, P, D] + cos: torch.Tensor, # [B, P, head_dim] + sin: torch.Tensor, # [B, P, head_dim] + position_ids: torch.Tensor, # [B, P, 2] + attention_mask: torch.Tensor, # [B, 1, T_q, T_kv] (additive, fp32-able) + ) -> torch.Tensor: + B, P, _ = hidden_states.shape + + # Project & reshape to [B, P, H, head_dim] + q = self.q_proj(hidden_states).view(B, P, self.num_heads, self.head_dim) + q = self.q_norm(q) + q = _apply_multidimensional_rope(q, cos, sin, position_ids) + q = q.transpose(1, 2) # [B, H, P, head_dim] + + k = self.k_proj(hidden_states).view(B, P, self.num_kv_heads, self.head_dim) + k = self.k_norm(k) + k = _apply_multidimensional_rope(k, cos, sin, position_ids) + k = k.transpose(1, 2) + + v = self.v_proj(hidden_states).view(B, P, self.num_kv_heads, self.head_dim) + v = self.v_norm(v) + v = v.transpose(1, 2) + + attn_out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + scale=self.scaling, + enable_gqa=(self.num_heads != self.num_kv_heads), + ) + # [B, H, P, head_dim] -> [B, P, H*head_dim] + attn_out = ( + attn_out.transpose(1, 2) + .contiguous() + .view(B, P, self.num_heads * self.head_dim) + ) + return self.o_proj(attn_out) + + +# --------------------------------------------------------------------------- +# Encoder layer / encoder +# --------------------------------------------------------------------------- + + +class Gemma4VisionEncoderLayer(nn.Module): + """Norm-sandwich encoder block (same skeleton as the LM): pre/post norms + around both self-attn and the SwiGLU MLP.""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.self_attn = Gemma4VisionAttention(config) + self.mlp = Gemma4MLP(config.hidden_size, config.intermediate_size) + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = nn.RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = nn.RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + h = self.input_layernorm(hidden_states) + h = self.self_attn(h, cos, sin, position_ids, attention_mask) + h = self.post_attention_layernorm(h) + hidden_states = residual + h + + residual = hidden_states + h = self.pre_feedforward_layernorm(hidden_states) + h = self.mlp(h) + h = self.post_feedforward_layernorm(h) + hidden_states = residual + h + + return hidden_states + + +class Gemma4VisionEncoder(nn.Module): + """Stack of N encoder layers. Builds (cos, sin) once and a bidirectional + additive attention mask from ``padding_positions`` (True = padding).""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.rotary_emb = Gemma4VisionRotaryEmbedding(config) + self.layers = nn.ModuleList( + [Gemma4VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + + def _build_attention_mask( + self, + valid_mask: torch.Tensor, # [B, P], True = valid + dtype: torch.dtype, + ) -> torch.Tensor: + """Bidirectional bool mask: each query attends to all valid key positions. + + Returns a BOOL mask of shape ``[B, 1, P, P]`` where ``True == attend``. + + Shape note: the executorch CUDA-backend triton SDPA kernel + (``executorch/backends/cuda/triton/kernels/sdpa.py:_prepare_mask_params``) + rejects masks that are NOT bool or NOT exactly ``[B, 1, L_q, L_kv]`` + (no broadcast over the L_q dim is allowed). We therefore materialize + the L_q dim with ``expand`` here even though, mathematically, every + query has the same per-key attendance (bidirectional encoder). + + ``dtype`` is accepted for backwards-API symmetry but is unused — the + mask is always bool. + """ + del dtype # mask is always bool — see docstring + B, P = valid_mask.shape + # [B, P] -> [B, 1, 1, P] -> [B, 1, P, P] (materialized, no broadcast). + kv_valid = valid_mask[:, None, None, :].expand(B, 1, P, P).contiguous() + return kv_valid.to(torch.bool) + + def forward( + self, + inputs_embeds: torch.Tensor, # [B, P, D] + valid_mask: torch.Tensor, # [B, P], True = valid + pixel_position_ids: torch.Tensor, # [B, P, 2] + ) -> torch.Tensor: + attention_mask = self._build_attention_mask(valid_mask, inputs_embeds.dtype) + cos, sin = self.rotary_emb(inputs_embeds, pixel_position_ids) + hidden_states = inputs_embeds + for layer in self.layers: + hidden_states = layer( + hidden_states, cos, sin, pixel_position_ids, attention_mask + ) + return hidden_states + + +# --------------------------------------------------------------------------- +# Multimodal embedder (vision-side projection into LM space) +# --------------------------------------------------------------------------- + + +class Gemma4MultimodalEmbedder(nn.Module): + """HF ``Gemma4MultimodalEmbedder`` — pre-projection RMSNorm (no scale) + followed by a single linear projection from vision hidden_size to text + hidden_size.""" + + def __init__(self, vision_config: Gemma4VisionConfig, text_hidden_size: int): + super().__init__() + self.embedding_pre_projection_norm = nn.RMSNorm( + vision_config.hidden_size, + eps=vision_config.rms_norm_eps, + elementwise_affine=False, + ) + self.embedding_projection = nn.Linear( + vision_config.hidden_size, text_hidden_size, bias=False + ) + + def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: + return self.embedding_projection( + self.embedding_pre_projection_norm(inputs_embeds) + ) + + +# --------------------------------------------------------------------------- +# Vision tower (sub-components container, no top-level forward needed) +# --------------------------------------------------------------------------- + + +class Gemma4VisionTower(nn.Module): + """Container matching the HF ``Gemma4VisionModel`` structure: + patch_embedder + encoder + pooler + (std_bias, std_scale) when standardize. + + HF state-dict keys map to this module under the prefix + ``model.vision_tower.*`` → ``vision_tower.*``. + """ + + std_bias: torch.Tensor + std_scale: torch.Tensor + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.config = config + self.patch_embedder = Gemma4VisionPatchEmbedder(config) + self.encoder = Gemma4VisionEncoder(config) + self.pooler = Gemma4VisionPooler(config) + + if config.standardize: + # HF stores these as buffers, not parameters. + self.register_buffer("std_bias", torch.zeros(config.hidden_size)) + self.register_buffer("std_scale", torch.ones(config.hidden_size)) + + +# --------------------------------------------------------------------------- +# Top-level: vision tower + embed_vision wrapper +# --------------------------------------------------------------------------- + + +class Gemma4_31BVisionTower(nn.Module): + """End-to-end vision tower: pixels (pre-patchified) → text-space embeddings. + + Combines ``Gemma4VisionTower`` and ``Gemma4MultimodalEmbedder`` and replicates + the forward path of HF's ``Gemma4VisionModel`` followed by the LM-side + ``embed_vision``. Returns ``(embeddings, pooler_mask)`` with padding rows + zeroed (rather than stripped) so the output shape is fixed and exportable. + + Args (forward): + pixel_values: ``[B, P, patch_dim]`` (pre-patchified, range [0,1]). + pixel_position_ids: ``[B, P, 2]`` (x, y); ``-1`` marks padding. + + Returns: + embeddings: ``[B, output_length, text_hidden_size]`` + pooler_mask: ``[B, output_length]`` (True = valid soft token) + """ + + def __init__( + self, + vision_config: Gemma4VisionConfig, + text_hidden_size: int, + ): + super().__init__() + self.vision_config = vision_config + self.text_hidden_size = text_hidden_size + # Flat names match the HF state-dict prefixes "model.vision_tower.*" and + # "model.embed_vision.*" after the LM-side rename done in model.py. + self.vision_tower = Gemma4VisionTower(vision_config) + self.embed_vision = Gemma4MultimodalEmbedder(vision_config, text_hidden_size) + + def forward( + self, + pixel_values: torch.Tensor, # [B, P, patch_dim], dtype matches weights + pixel_position_ids: torch.Tensor, # [B, P, 2] + ) -> tuple[torch.Tensor, torch.Tensor]: + cfg = self.vision_config + pks = cfg.pooling_kernel_size + output_length = pixel_values.shape[1] // (pks * pks) + + padding_positions = (pixel_position_ids == -1).all(dim=-1) # [B, P] + valid_mask = ~padding_positions + + inputs_embeds = self.vision_tower.patch_embedder( + pixel_values, pixel_position_ids, padding_positions + ) + encoder_out = self.vision_tower.encoder( + inputs_embeds=inputs_embeds, + valid_mask=valid_mask, + pixel_position_ids=pixel_position_ids, + ) + hidden_states, pooler_mask = self.vision_tower.pooler( + hidden_states=encoder_out, + pixel_position_ids=pixel_position_ids, + padding_positions=padding_positions, + output_length=output_length, + ) + + if cfg.standardize: + hidden_states = ( + hidden_states - self.vision_tower.std_bias + ) * self.vision_tower.std_scale + # Re-zero padding rows — HF strips them; we keep the shape but mask + # so embed_vision produces zero rows there (RMSNorm of 0 → 0). + hidden_states = hidden_states.masked_fill(~pooler_mask.unsqueeze(-1), 0.0) + + embeddings = self.embed_vision(hidden_states) + return embeddings, pooler_mask + + +# --------------------------------------------------------------------------- +# HF→our key map for vision_tower + embed_vision +# --------------------------------------------------------------------------- + +# Mapping of HF state-dict keys to ours. The 31B checkpoint stores its vision +# linears as ``....linear.weight`` because HF wraps each linear in +# ``Gemma4ClippableLinear`` (which exposes ``.linear``). Since the 31B vision +# config has ``use_clipped_linears=False``, the wrapper has no extra params and +# we drop the ``.linear`` segment when remapping into our flat layout. +_HF_VISION_KEY_MAP_FIXED = { + # vision_tower (top-level constants) + "model.vision_tower.std_bias": "vision_tower.std_bias", + "model.vision_tower.std_scale": "vision_tower.std_scale", + # patch_embedder + "model.vision_tower.patch_embedder.input_proj.weight": "vision_tower.patch_embedder.input_proj.weight", + "model.vision_tower.patch_embedder.position_embedding_table": "vision_tower.patch_embedder.position_embedding_table", + # embed_vision projector + "model.embed_vision.embedding_projection.weight": "embed_vision.embedding_projection.weight", +} + +_HF_VISION_KEY_MAP_PER_LAYER = { + # norms + "model.vision_tower.encoder.layers.{}.input_layernorm.weight": "vision_tower.encoder.layers.{}.input_layernorm.weight", + "model.vision_tower.encoder.layers.{}.post_attention_layernorm.weight": "vision_tower.encoder.layers.{}.post_attention_layernorm.weight", + "model.vision_tower.encoder.layers.{}.pre_feedforward_layernorm.weight": "vision_tower.encoder.layers.{}.pre_feedforward_layernorm.weight", + "model.vision_tower.encoder.layers.{}.post_feedforward_layernorm.weight": "vision_tower.encoder.layers.{}.post_feedforward_layernorm.weight", + # attention projections (.linear segment dropped) + "model.vision_tower.encoder.layers.{}.self_attn.q_proj.linear.weight": "vision_tower.encoder.layers.{}.self_attn.q_proj.weight", + "model.vision_tower.encoder.layers.{}.self_attn.k_proj.linear.weight": "vision_tower.encoder.layers.{}.self_attn.k_proj.weight", + "model.vision_tower.encoder.layers.{}.self_attn.v_proj.linear.weight": "vision_tower.encoder.layers.{}.self_attn.v_proj.weight", + "model.vision_tower.encoder.layers.{}.self_attn.o_proj.linear.weight": "vision_tower.encoder.layers.{}.self_attn.o_proj.weight", + # qk-norm + "model.vision_tower.encoder.layers.{}.self_attn.q_norm.weight": "vision_tower.encoder.layers.{}.self_attn.q_norm.weight", + "model.vision_tower.encoder.layers.{}.self_attn.k_norm.weight": "vision_tower.encoder.layers.{}.self_attn.k_norm.weight", + # mlp (.linear segment dropped) + "model.vision_tower.encoder.layers.{}.mlp.gate_proj.linear.weight": "vision_tower.encoder.layers.{}.mlp.gate_proj.weight", + "model.vision_tower.encoder.layers.{}.mlp.up_proj.linear.weight": "vision_tower.encoder.layers.{}.mlp.up_proj.weight", + "model.vision_tower.encoder.layers.{}.mlp.down_proj.linear.weight": "vision_tower.encoder.layers.{}.mlp.down_proj.weight", +} + + +def hf_vision_key_map() -> dict[str, str]: + """Return the fixed (non-per-layer) part of the HF→our key map. + + Per-layer patterns are returned via ``hf_vision_per_layer_key_map()``; + callers expand the ``{}`` placeholder with the layer index. + """ + return dict(_HF_VISION_KEY_MAP_FIXED) + + +def hf_vision_per_layer_key_map() -> dict[str, str]: + return dict(_HF_VISION_KEY_MAP_PER_LAYER) + + +__all__ = [ + "Gemma4VisionConfig", + "Gemma4VisionPatchEmbedder", + "Gemma4VisionPooler", + "Gemma4VisionAttention", + "Gemma4VisionRotaryEmbedding", + "Gemma4VisionEncoderLayer", + "Gemma4VisionEncoder", + "Gemma4MultimodalEmbedder", + "Gemma4VisionTower", + "Gemma4_31BVisionTower", + "hf_vision_key_map", + "hf_vision_per_layer_key_map", +]