Skip to content

[gemma4_31b_vision support] vision tower model defination with prequant update#19932

Draft
Gasoonjia wants to merge 1 commit into
mainfrom
g4-vision-quant
Draft

[gemma4_31b_vision support] vision tower model defination with prequant update#19932
Gasoonjia wants to merge 1 commit into
mainfrom
g4-vision-quant

Conversation

@Gasoonjia

Copy link
Copy Markdown
Contributor

No description provided.

@pytorch-bot

pytorch-bot Bot commented Jun 2, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19932

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 5dea98c with merge base ac3003e (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 2, 2026
@github-actions

github-actions Bot commented Jun 2, 2026

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

…ision (g4-vision-quant)

First branch of the g4-vision decomposition. Adds the vision tower, vision
quantization, and the Python eager (CUDA) vision inference path, plus tests.
The ExecuTorch export + C++ runtime stay TEXT-ONLY with the exact same exported
contract as `main` (token-input `prefill` + `decode`), so `main.cpp` and
`CMakeLists.txt` are byte-identical to `main` and the existing runner keeps
working unchanged. CUDA/MLX/GGUF vision are added in the following branches.

Scope (built on top of `main`)
------------------------------
Vision model + quant + Python eager:
- vision_tower.py: full Gemma 4 vision tower (patch embedder, 2D-RoPE encoder,
  pooler, multimodal embedder, Gemma4_31BVisionTower wrapper, HF key maps).
- model.py: multimodal-by-default model (vision_tower / embed_vision attached),
  embeds-based `forward(inputs_embeds, ...)`, `embed_text`, `decode_forward`,
  `_run_blocks`, vision RoPE buffer materialization, vision-aware HF loading.
- pack_vision.py: backend-agnostic INT8 vision position-table quant + patch
  embedder packer; quant/pack.py buffer-registration tweak for the PE int8
  buffers; quantize_and_save.py made vision-aware.
- inference.py: Python eager vision path (`generate_with_image`,
  `_build_vision_encoder`, `--image-path`, `--max-vision-soft-tokens`); text
  loop uses `decode_forward` (forward now takes embeds).

Text-only ET export (main-compatible contract via a fake-prefill wrapper):
- export.py: exports the SAME two methods as `main` -- token-input `prefill`
  (T>=2, dynamic) and `decode` (T=1). The token-input `prefill` is realized
  with a temporary bound-method that fuses `embed_text` + the embeds-`forward`
  into one method (mirroring the MLX `_mlx_model_forward(tokens) =
  prefill(embed_text(tokens))` pattern); `decode` maps to `decode_forward`.
  The vision head is dropped (`_drop_vision_head`) before lowering, so the
  exported .pte method names + signatures are identical to `main`. Loading
  stays vision-aware (the prequant checkpoint carries vision weights). No
  `embed_text` / `vision_encoder` method, no `--max-vision-soft-tokens`.
- main.cpp / CMakeLists.txt: unchanged from `main` (verified byte-identical).

Refactors landed at the model/quant layer:
- R1: vision_tower.py reuses the shared `Gemma4MLP` and `rotate_half` from
  examples/models/gemma4/text_decoder (drops the duplicate `Gemma4VisionMLP`
  and local `_rotate_half`; submodule names gate/up/down_proj unchanged so HF
  key maps stay valid).
- R3: drops the `RMSNorm` / `RMSNormNoWeight` stopgap from
  text_decoder/__init__.py; vision_tower.py uses `nn.RMSNorm(...)` and
  `nn.RMSNorm(..., elementwise_affine=False)` inline. text_decoder/__init__.py
  is now byte-identical to `main`.
- R2 (Python half): the 5 chat-template special-token IDs +
  `build_vision_input_ids` are extracted into a shared module
  examples/models/gemma4/chat_template.py; inference.py imports from it. The
  C++ header half lands in branch 2.

Docs (review finding #3): model.py `forward` and vision_tower.py state the
scaling convention explicitly -- `embed_text` scales text rows by
sqrt(hidden_size); the vision-tower output is NOT pre-scaled (matches HF).

Test fixes:
- test_cuda_pipeline.py: updated for the new model contract (review finding #2)
  -- T=1 calls use `decode_forward`; multi-token prefill uses
  `forward(embed_text(tokens), ...)`. Fixes the previously-failing
  TestInt4Inference::test_inference_produces_valid_output.
- test_pipeline.py: added vision-aware quant/save/load coverage +
  TestVisionConfigRequired. Also fixed a pre-existing fixture bug in
  build_hf_checkpoint -- vision encoder-layer attn/mlp projections are now
  written with the HF `.linear.` segment (Gemma4ClippableLinear) so they match
  hf_vision_per_layer_key_map(); without this the --model-dir export test
  silently skipped 7 vision projection weights.
- New: tests/test_vision_tower.py, tests/test_vision_quant_roundtrip.py.

Verification
------------
- CPU: pytest test_vision_tower.py test_vision_quant_roundtrip.py
  test_pipeline.py -> 20 passed.
- CUDA (2x A100): pytest test_cuda_pipeline.py -> 8 passed, incl. both export
  tests (text-only prefill+decode .pte/.ptd produced and serialized) and the
  chunked-prefill / int4 inference contract tests.
- flake8 + ufmt clean on all changed/new files.
- MLX export+runtime e2e and the Python eager --image-path smoke are to be run
  by the user (no MLX hardware here; image checkpoint user-provided).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant