Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions examples/models/gemma4/chat_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Shared Gemma 4 chat-template special-token IDs and the image+text input builder.

Single source of truth for the multimodal chat-template token layout. The
Python eager runner (``gemma4_31b/inference.py``) imports from here, and the C++
runner mirrors these exact values in ``gemma4/runner/chat_template.h`` so the two
implementations can never drift.

Image+text turn layout (matches the Gemma 4 HF chat template):

<bos><start_of_turn>user\n<boi><image>*N<eoi>{prompt}<end_of_turn>\n
<start_of_turn>model\n
"""

# Gemma 4 special token IDs (match the tokenizer + the C++ runner constants).
BOS_ID = 2
TURN_START_ID = 105 # <start_of_turn>
TURN_END_ID = 106 # <end_of_turn>
BOI_TOKEN_ID = 255999 # <start_of_image>
IMAGE_TOKEN_ID = 258880 # <image> soft-token placeholder
EOI_TOKEN_ID = 258882 # <end_of_image>


def build_vision_input_ids(
tokenizer,
prompt: str,
num_vision_tokens: int,
bos_id: int = BOS_ID,
) -> list[int]:
"""Build the chat-template token sequence for an image+text turn.

Produces the same layout the C++ runner builds in
``gemma4/runner/chat_template.h::build_vision_input_ids``:

<bos><start_of_turn>user\\n<boi><image>*N<eoi>{prompt}<end_of_turn>\\n
<start_of_turn>model\\n

Args:
tokenizer: a ``tokenizers.Tokenizer``-like object exposing
``encode(str).ids``.
prompt: the user text prompt.
num_vision_tokens: number of ``<image>`` soft-token placeholders to
insert (one per valid vision soft token).
bos_id: beginning-of-sequence id (defaults to the Gemma 4 BOS).

Returns:
The flat list of token IDs for the turn.
"""
user_tokens = tokenizer.encode("user\n").ids
prompt_tokens = tokenizer.encode(prompt).ids
newline_tokens = tokenizer.encode("\n").ids
model_tokens = tokenizer.encode("model\n").ids

ids: list[int] = []
ids.append(bos_id)
ids.append(TURN_START_ID)
ids.extend(user_tokens)
ids.append(BOI_TOKEN_ID)
ids.extend([IMAGE_TOKEN_ID] * num_vision_tokens)
ids.append(EOI_TOKEN_ID)
ids.extend(prompt_tokens)
ids.append(TURN_END_ID)
ids.extend(newline_tokens)
ids.append(TURN_START_ID)
ids.extend(model_tokens)
return ids


__all__ = [
"BOS_ID",
"TURN_START_ID",
"TURN_END_ID",
"BOI_TOKEN_ID",
"IMAGE_TOKEN_ID",
"EOI_TOKEN_ID",
"build_vision_input_ids",
]
154 changes: 133 additions & 21 deletions examples/models/gemma4_31b/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,41 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Export Gemma 4 31B-IT to ExecuTorch (.pte + .ptd).
"""Export Gemma 4 31B-IT to ExecuTorch (.pte + .ptd) — text-only contract.

Two methods are exported and lowered together so they share KV-cache buffers:
- "decode": T=1, static shape, returns the next sampled token.
- "prefill": T>=2, dynamic shape, returns the next sampled token.
Two methods are exported and lowered together so they share KV-cache buffers,
EXACTLY matching the pre-vision ``main`` contract (so the C++ runner and
CMake build need no changes):

- "decode": (tokens [1,1] i64, input_pos [1] i64, temperature [1] f32)
-> sampled [1,1] f32. T=1, static shape.
- "prefill": (tokens [1,T] i64, input_pos [T] i64, temperature [1] f32)
-> sampled [1,1] f32. T>=2, dynamic shape.

Vision is loaded but NOT exported here
=====================================

Gemma 4 31B is multimodal-by-default: ``model.py`` always constructs the
``vision_tower`` / ``embed_vision`` submodules and the quantized checkpoint
carries vision weights, so loading + quantization + packing are all
vision-aware (INT8 vision position table, vision patch-embedder packer).

The exported ``.pte`` however is **text-only**: the vision head is dropped just
before lowering (``_drop_vision_head``) and the two exported methods use
TOKEN inputs, identical in name and signature to ``main``. The token-input
``prefill`` is realized with a temporary bound-method that fuses
``embed_text`` + the embeddings-``forward`` into one method — mirroring the MLX
``_mlx_model_forward(tokens) = prefill(embed_text(tokens))`` pattern. ``decode``
maps to the existing token-input ``decode_forward``.

The embeddings-based 4-method vision contract (``embed_text``,
``vision_encoder``, embeds-``prefill``, ``decode``) is introduced on top of this
in the ``g4-vision-cuda`` branch.

Three input paths:
--prequantized <dir> Load a quantized checkpoint (from quantize_and_save.py)
and pack for the target backend. No re-quantization.
This is the primary path (checkpoint includes vision).
--gguf <file> Load a GGUF file (e.g., Q4_K_M from the community).
--model-dir <hf> Load bf16 checkpoint, quantize, pack, and export
in one shot.
Expand All @@ -34,6 +60,11 @@
Gemma4_31BConfig,
materialize_runtime_buffers,
)
from executorch.examples.models.gemma4_31b.pack_vision import (
pack_vision_patch_embedder,
quantize_vision_position_table,
)
from executorch.examples.models.gemma4_31b.vision_tower import Gemma4VisionPatchEmbedder


# ---------------------------------------------------------------------------
Expand All @@ -45,7 +76,12 @@ def load_prequantized_model(
max_seq_len: int = 4096,
backend: str = "cuda",
) -> tuple[Gemma4_31B, Gemma4_31BConfig]:
"""Load a quantized checkpoint and pack for the target backend."""
"""Load a quantized checkpoint and pack for the target backend.

The checkpoint contains vision keys (Gemma 4 31B is multimodal-by-default)
so packing installs the vision patch-embedder packer too. The vision head
is dropped at export time — see ``_drop_vision_head``.
"""
config = Gemma4_31BConfig.from_hf_config(
os.path.join(prequantized_dir, "config.json")
)
Expand Down Expand Up @@ -84,6 +120,9 @@ def load_and_quantize(
model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone())

print(f"Quantizing with recipe '{recipe_name}'...")
# Pre-quantize the vision position-embedding table to INT8 before the INT4
# text recipe runs (the PE table is a large [2, N, D] buffer, not a Linear).
quantize_vision_position_table(model.vision_tower)
state_dict = quantize_model(model, recipe, verbose=True)

print(f"Packing for {backend}...")
Expand All @@ -107,31 +146,84 @@ def _get_packers(backend: str) -> dict:
if backend == "cuda":
from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS

return DEFAULT_CUDA_PACKERS
return {
**DEFAULT_CUDA_PACKERS,
Gemma4VisionPatchEmbedder: pack_vision_patch_embedder,
}
if backend == "mlx":
from executorch.examples.models.gemma4_31b.quant import DEFAULT_MLX_PACKERS

return DEFAULT_MLX_PACKERS
return {
**DEFAULT_MLX_PACKERS,
Gemma4VisionPatchEmbedder: pack_vision_patch_embedder,
}
raise ValueError(
f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}."
)


def _pack_for_backend(model: nn.Module, path: str, backend: str) -> None:
packers = _get_packers(backend)
if backend == "cuda":
from executorch.examples.models.gemma4_31b.quant import load_and_pack_for_cuda

load_and_pack_for_cuda(path, model)
load_and_pack_for_cuda(path, model, packers=packers)
elif backend == "mlx":
from executorch.examples.models.gemma4_31b.quant import load_and_pack_for_mlx

load_and_pack_for_mlx(path, model)
load_and_pack_for_mlx(path, model, packers=packers)
else:
raise ValueError(
f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}."
)


# ---------------------------------------------------------------------------
# Text-only contract helpers
#
# The 2 exported methods live on the SAME Gemma4_31B instance. torch.export
# only takes an nn.Module (not a bound method), so we temporarily shadow
# ``model.forward`` with the per-method bound method for the duration of the
# export call. Critically the model's CLASS identity does NOT change, so all
# ExportedPrograms share identical mutable-buffer FQNs
# (``layers.X.self_attn.kv_cache.k_cache`` ...). This is what lets
# ``to_executorch(share_mutable_buffers=True)`` unify the prefill / decode
# KV-cache buffers under ONE runtime tensor.


class _BoundMethodForward:
"""Context manager: temporarily set ``model.forward`` to a bound method."""

def __init__(self, model: nn.Module, bound_method) -> None:
self._model = model
self._bound = bound_method

def __enter__(self):
self._model.forward = self._bound # instance-attr; shadows the class method
return self._model

def __exit__(self, *exc):
# ``del`` restores the class method via the descriptor protocol.
try:
del self._model.forward
except AttributeError:
pass
return False


def _drop_vision_head(model: nn.Module) -> None:
"""Detach the multimodal head before text-only lowering.

Gemma 4 31B is multimodal-by-default, so ``vision_tower`` / ``embed_vision``
are always constructed and loaded. The text-only ``.pte`` does not ship
them, so we delete them here BEFORE ``materialize_runtime_buffers`` (which
only rebuilds the vision RoPE table when ``vision_tower`` is still attached).
"""
for name in ("vision_tower", "embed_vision"):
if hasattr(model, name):
delattr(model, name)


# ---------------------------------------------------------------------------
# Export + lower

Expand Down Expand Up @@ -162,6 +254,7 @@ def export_and_lower(

def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None:
import gc
import types

import torch._inductor.config as inductor_config

Expand All @@ -182,17 +275,30 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -
# Register Int4Tensor dispatch → executorch_cuda::int4_plain_mm shim
import executorch.backends.cuda.int4_dispatch # noqa: F401

# Text-only contract: drop the multimodal head, then materialize buffers
# (materialize skips the vision RoPE table once vision_tower is gone).
_drop_vision_head(model)
materialize_runtime_buffers(model, dtype=torch.bfloat16)

# Int4Tensor weights are used directly — no format conversion.
# F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim).
# Both decode and prefill share the same nibble-packed weights.

# Prefill (T>=2): shim does dequant+cuBLAS (optimal for large M).
max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2)

# ---------- prefill (tokens, T>=2, dynamic) ----------
# ``main``'s contract is a TOKEN-input prefill, but the model's ``forward``
# takes inputs_embeds. Realize the token-input prefill with a temporary
# bound-method that fuses embed_text + the embeds-forward into one method
# (mirroring MLX ``_mlx_model_forward(tokens) = prefill(embed_text(tokens))``).
def _text_prefill(self, tokens, input_pos, temperature):
return self._run_blocks(self.embed_text(tokens), input_pos, temperature)

seq_dim = Dim("seq_len", min=5, max=max_prefill)
print(f"Exporting prefill (T in [2, {max_prefill}])...")
with torch.no_grad():
print(f"Exporting prefill (T in [2, {max_prefill}], tokens input)...")
with _BoundMethodForward(
model, types.MethodType(_text_prefill, model)
), torch.no_grad():
prefill_ep = export(
model,
(
Expand All @@ -204,9 +310,13 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -
strict=True,
)

# Decode (T=1): same Int4Tensor weights, same format. No transform needed.
print("Exporting decode (T=1)...")
with torch.no_grad():
# ---------- decode (tokens, T=1, static) ----------
# Bound-method swap (NOT __class__ swap) so the model instance's class —
# and therefore its mutable-buffer FQN identity — is preserved across both
# ExportedPrograms. This is what lets share_mutable_buffers unify the
# prefill / decode KV-cache buffers under one runtime tensor.
print("Exporting decode (T=1, tokens input)...")
with _BoundMethodForward(model, model.decode_forward), torch.no_grad():
decode_ep = export(
model,
(
Expand Down Expand Up @@ -292,13 +402,12 @@ def _export_mlx(
output_dir: str,
use_turboquant: bool = False,
) -> None:
"""Export to .pte via torch.export + MLX backend.
"""Export to .pte via torch.export + MLX backend (text-only contract).

Unlike CUDA (which exports separate decode/prefill methods with an
Int4Tensor dispatch override), MLX uses a single method with dynamic
sequence length. No int4_dispatch import — IntxUnpackedToInt8Tensor's
default dispatch produces the ``dequantize_affine → linear`` pattern
that MLX's QuantizedLinearHandler matches.
Single token-input method with dynamic sequence length; MLX samples on the
host so there is no temperature input. ``mlx_source_transformations``
installs the token-input ``forward`` used here. The vision head is dropped
first (text-only contract).

When ``use_turboquant=True``, full-attention layers swap to
``MLXTurboQuantKVCache`` for ~3.8× KV cache memory savings. Sliding
Expand All @@ -320,6 +429,9 @@ def _export_mlx(
from executorch.exir.passes import MemoryPlanningPass
from torch.export import Dim, export

# Text-only contract: drop the vision head before transforming / exporting.
_drop_vision_head(model)

mlx_source_transformations(
model, dtype=torch.bfloat16, use_turboquant=use_turboquant
)
Expand Down
Loading
Loading