Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions omlx/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,23 @@ def unwrap_tokenizer(tokenizer):


def resolve_vocab_size(model: Any) -> int | None:
"""Extract vocab_size from a model's config/args, handling nested configs.

Tries ``model.config.vocab_size``, then ``model.args.vocab_size``,
then ``text_config.vocab_size`` for VLM composite models (e.g. Qwen3.5).
"""Extract vocab_size from a model, preferring the authoritative source.

Resolution order:
1. The ``lm_head`` weight's first dimension (authoritative — this is the
exact vocabulary the model emits logits over).
2. ``text_config.vocab_size`` when present (the inner language model's
vocab on VLM composite configs).
3. ``model.config.vocab_size`` / ``model.args.vocab_size`` (top-level).

Why lm_head and text_config come first for VLMs: several mlx-vlm
``ModelConfig`` dataclasses (e.g. glm4v, glm4v_moe, gemma3) hard-code a
top-level ``vocab_size`` default that does not match the inner LM vocab
when ``config.json`` omits the top-level key. For example, GLM-4.6V has
``text_config.vocab_size=151552`` but ``ModelConfig.vocab_size=257152``
as a dataclass default. Code that sizes logits-aligned buffers (e.g.
grammar bitmasks) from the top-level value produces a shape mismatch
against the real (151552) logits.

Args:
model: An MLX model object (LLM, VLM, or any object with config/args).
Expand All @@ -52,18 +65,45 @@ def resolve_vocab_size(model: Any) -> int | None:
"""
if model is None:
return None

# 1. lm_head weight — authoritative for any model that exposes one.
# VLM adapters wrap the language model under ``_language_model``;
# raw mlx-lm/mlx-vlm models expose ``lm_head`` directly or under
# ``language_model``.
for path in (
("_language_model", "lm_head"),
("language_model", "lm_head"),
("lm_head",),
):
obj: Any = model
for name in path:
obj = getattr(obj, name, None)
if obj is None:
break
weight = getattr(obj, "weight", None) if obj is not None else None
shape = getattr(weight, "shape", None)
try:
first_dim = shape[0] if shape is not None else None
except (TypeError, IndexError):
first_dim = None
if isinstance(first_dim, int):
return int(first_dim)

# 2 & 3. Config-based fallbacks.
for attr in ('config', 'args'):
config = getattr(model, attr, None)
if config is None:
continue
vs = getattr(config, 'vocab_size', None)
if isinstance(vs, int):
return vs
text_cfg = getattr(config, 'text_config', None)
if isinstance(text_cfg, dict):
vs = text_cfg.get('vocab_size')
elif text_cfg is not None:
vs = getattr(text_cfg, 'vocab_size', None)
else:
vs = None
if isinstance(vs, int):
return vs
vs = getattr(config, 'vocab_size', None)
if isinstance(vs, int):
return vs
return None
Expand Down
152 changes: 152 additions & 0 deletions tests/test_utils_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
is_gemma4_model,
is_harmony_model,
is_qwen3_model,
resolve_vocab_size,
)


Expand Down Expand Up @@ -193,3 +194,154 @@ def test_apply_fix_preserves_other_keys(self):
assert result["use_fast"] is True
assert result["padding_side"] == "left"
assert result["eos_token"] == "<|im_end|>"


def _stub_lm_head(vocab_size: int):
"""Build a stub lm_head whose weight has shape (vocab_size, hidden)."""
from types import SimpleNamespace

return SimpleNamespace(weight=SimpleNamespace(shape=(vocab_size, 16)))


class TestResolveVocabSize:
"""Test cases for resolve_vocab_size function.

Covers the three-level resolution order:
1. lm_head.weight.shape[0] (authoritative)
2. text_config.vocab_size (inner LM vocab on VLM composite)
3. config.vocab_size / args.vocab_size (top-level fallback)
"""

def test_returns_none_for_none_model(self):
"""resolve_vocab_size(None) returns None without crashing."""
assert resolve_vocab_size(None) is None

def test_lm_head_authoritative_pure_llm(self):
"""Pure LLM with model.lm_head: returns lm_head's first dim."""
from types import SimpleNamespace

model = SimpleNamespace(
lm_head=_stub_lm_head(32000),
config=SimpleNamespace(vocab_size=32000),
)
assert resolve_vocab_size(model) == 32000

def test_lm_head_authoritative_via_language_model(self):
"""Raw mlx-vlm wrapping: model.language_model.lm_head."""
from types import SimpleNamespace

model = SimpleNamespace(
language_model=SimpleNamespace(lm_head=_stub_lm_head(151552)),
config=SimpleNamespace(vocab_size=257152),
)
assert resolve_vocab_size(model) == 151552

def test_lm_head_authoritative_via_underscore_language_model(self):
"""oMLX VLM-adapter wrapping: model._language_model.lm_head."""
from types import SimpleNamespace

model = SimpleNamespace(
_language_model=SimpleNamespace(lm_head=_stub_lm_head(151552)),
config=SimpleNamespace(vocab_size=257152),
)
assert resolve_vocab_size(model) == 151552

def test_lm_head_overrides_wrong_top_level_vocab_size(self):
"""The bug case: GLM-4.6V-style mismatch where ModelConfig
defaults to 257152 but actual logits vocab is 151552. The
lm_head probe wins over the top-level config value.
"""
from types import SimpleNamespace

model = SimpleNamespace(
language_model=SimpleNamespace(lm_head=_stub_lm_head(151552)),
config=SimpleNamespace(
vocab_size=257152,
text_config=SimpleNamespace(vocab_size=151552),
),
)
assert resolve_vocab_size(model) == 151552

def test_text_config_fallback_when_lm_head_missing(self):
"""No lm_head exposed anywhere: fall back to text_config.vocab_size."""
from types import SimpleNamespace

model = SimpleNamespace(
config=SimpleNamespace(
vocab_size=257152,
text_config=SimpleNamespace(vocab_size=151552),
),
)
assert resolve_vocab_size(model) == 151552

def test_text_config_dict_fallback(self):
"""text_config can be a dict (some configs use dataclass-as-dict)."""
from types import SimpleNamespace

model = SimpleNamespace(
config=SimpleNamespace(
vocab_size=None,
text_config={"vocab_size": 151552},
),
)
assert resolve_vocab_size(model) == 151552

def test_top_level_config_fallback(self):
"""No lm_head, no text_config: top-level config.vocab_size."""
from types import SimpleNamespace

model = SimpleNamespace(config=SimpleNamespace(vocab_size=32000))
assert resolve_vocab_size(model) == 32000

def test_args_attr_fallback(self):
"""Older mlx-lm models expose vocab via args, not config."""
from types import SimpleNamespace

model = SimpleNamespace(args=SimpleNamespace(vocab_size=32000))
assert resolve_vocab_size(model) == 32000

def test_malformed_shape_falls_back_gracefully(self):
"""A weight whose .shape is non-subscriptable (e.g. a bare object)
must not crash; the resolver should fall back to config.

Bad shape lives under `_language_model.lm_head` (the first wrapping
path probed) so the test directly exercises the try/except guard
for shape=object() rather than falling through earlier paths
whose shape=None happens to short-circuit safely.
"""
from types import SimpleNamespace

model = SimpleNamespace(
_language_model=SimpleNamespace(
lm_head=SimpleNamespace(weight=SimpleNamespace(shape=object()))
),
config=SimpleNamespace(vocab_size=32000),
)
assert resolve_vocab_size(model) == 32000

def test_none_shape_falls_back_gracefully(self):
"""Weight present but .shape is None: skip lm_head, fall back."""
from types import SimpleNamespace

model = SimpleNamespace(
lm_head=SimpleNamespace(weight=SimpleNamespace(shape=None)),
config=SimpleNamespace(vocab_size=32000),
)
assert resolve_vocab_size(model) == 32000

def test_empty_shape_falls_back_gracefully(self):
"""Weight present but .shape is empty: skip lm_head, fall back."""
from types import SimpleNamespace

model = SimpleNamespace(
lm_head=SimpleNamespace(weight=SimpleNamespace(shape=())),
config=SimpleNamespace(vocab_size=32000),
)
assert resolve_vocab_size(model) == 32000

def test_returns_none_when_all_sources_unavailable(self):
"""No lm_head, no config, no args: returns None."""
from types import SimpleNamespace

model = SimpleNamespace()
assert resolve_vocab_size(model) is None