Skip to content

fix(tokenizer): prefer lm_head weight as vocab_size source#962

Open
contrapuntal wants to merge 1 commit intojundot:mainfrom
contrapuntal:fix/tokenizer-prefer-lm-head-vocab-size
Open

fix(tokenizer): prefer lm_head weight as vocab_size source#962
contrapuntal wants to merge 1 commit intojundot:mainfrom
contrapuntal:fix/tokenizer-prefer-lm-head-vocab-size

Conversation

@contrapuntal
Copy link
Copy Markdown

What this fixes

Users running grammar-constrained output ([grammar] extra) on GLM-4.6V — and on several other VLM families — currently hit a shape-mismatch crash on the first constrained-decode step, with no clear error about why. The xgrammar bitmask is allocated for the wrong vocabulary, then doesn't fit the model's actual logits tensor.

The root cause is omlx/utils/tokenizer.py::resolve_vocab_size() returning the wrong number, which then propagates to every consumer that sizes logits-aligned buffers from it. Two such consumers exist today: omlx/api/grammar.py:46 (xgrammar bitmask) and omlx/scheduler.py:1459 (scheduler bookkeeping).

Affected models

Three mlx-vlm ModelConfig dataclasses ship a hard-coded top-level vocab_size default that doesn't match the inner LM's actual vocab when config.json omits the top-level key (as it typically does for these models):

mlx-vlm dataclass ModelConfig.vocab_size default Actual text_config.vocab_size
glm4v (GLM-4.6V) 257152 151552
glm4v_moe 257152 per-model from config.json; e.g. 151552 for GLM-4.6V-Flash
gemma3 257152 262208

resolve_vocab_size() previously read the top-level value first, so it returned the wrong vocab on any of the above.

Concrete public model that reproduces today: mlx-community/GLM-4.6V-Flash-bf16 with [grammar] enabled — the bitmask allocation reads 257152, the logits tensor is 151552 wide, and the constrained-decode step crashes on the shape mismatch.

Fix

Re-order the resolution to prefer authoritative sources:

  1. lm_head.weight.shape[0] — authoritative; this is the exact vocabulary the model emits logits over, regardless of whatever the config dataclass says. Probed under _language_model.lm_head (VLM adapter), language_model.lm_head (raw mlx-vlm), and lm_head (raw mlx-lm) so the lookup works for every wrapping shape we ship.
  2. text_config.vocab_size — inner LM vocab on VLM composite configs. Correct when lm_head is unavailable (rare).
  3. model.config.vocab_size / model.args.vocab_size — original top-level fallback. Still consulted last for compatibility with pure-LLM configs that don't have a text_config.

For pure-LLM models the new order returns the same value the old order did, because lm_head.weight.shape[0] and config.vocab_size agree. The change only affects models where the two disagree, which is exactly the broken case.

Why not fix this in mlx-vlm

The dataclass defaults in glm4v, glm4v_moe, gemma3 ModelConfig are arguably wrong (they should be None or read from text_config), and an upstream fix would obsolete most of this PR. But:

  • The defaults exist on purpose for some mlx-vlm internals that instantiate ModelConfig without a config.json, so changing them is a behavior question, not just a bug fix — out of scope here.
  • Even with mlx-vlm fixed, lm_head.weight.shape[0] is still the more authoritative source — config can drift from weights, weights can't drift from themselves. The new resolution order also defends against runtime config-weight drift on any future model where the two could disagree, not just the specific dataclasses currently broken.

Verification

Added TestResolveVocabSize to tests/test_utils_tokenizer.py covering the new resolution order, the three lm_head wrapping shapes, fallback to text_config (object and dict forms), fallback to top-level config.vocab_size / args.vocab_size, the model = None case, and the malformed-shape defensive path. 13 cases total; all pass.

$ pytest tests/test_utils_tokenizer.py -q
41 passed
$ pytest tests/test_grammar.py::TestGetModelVocabSize -q
4 passed

End-to-end:

  • GLM-4.6V-Flash with [grammar]: structured output now produces 151552-wide bitmasks matching the logits; previously raised a shape-mismatch on first constrained step.
  • Pure-LLM model (Llama 3.1): resolve_vocab_size returns the same int as before; lm_head.weight.shape[0] matches config.vocab_size.
  • VLM with correct top-level config (Qwen2-VL, vocab_size: int = 32000 in its ModelConfig): same int as before.

Risk

The new lm_head probe is gated by getattr chains and a defensive shape extraction (try/except around shape[0]), so it returns None (falling back to the old config path) on any model that doesn't expose lm_head in any of the three wrapping shapes, or whose weight.shape is not subscriptable. No code that previously succeeded should now fail.

Tied embeddings: lm_head.weight is embed_tokens.weight and still maps accurately to the logits vocab dimension.

Padded vocabularies: when a model pads its vocab to a multiple of 64 (or similar) for tensor-core efficiency, lm_head.weight.shape[0] returns the padded size — which is exactly what logits-aligned buffers (e.g. xgrammar bitmasks) need to allocate to avoid shape mismatches against the actual logits tensor. The new resolution order is strictly safer than the old config-first order in this case too.

Several mlx-vlm ModelConfig dataclasses (glm4v, glm4v_moe, gemma3)
hard-code a top-level vocab_size default that mismatches the inner
language model's vocab when config.json omits the top-level key.
Example: GLM-4.6V has text_config.vocab_size=151552 but
ModelConfig.vocab_size=257152 as a dataclass default. Code sizing
logits-aligned buffers (e.g. xgrammar bitmasks) from the top-level
value produced a shape mismatch against the real 151552 logits.

Resolution order becomes:
  1. lm_head.weight.shape[0] — authoritative; the exact vocabulary
     the model emits logits over.
  2. text_config.vocab_size — inner LM vocab on VLM composite configs.
  3. config.vocab_size / args.vocab_size — top-level fallback.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant