From 4b269711033e09ad19226f18a7a403d543c06977 Mon Sep 17 00:00:00 2001 From: Connor Date: Sat, 6 Jun 2026 19:42:47 -0700 Subject: [PATCH 1/5] Add Qwen3 MoE lesson Signed-off-by: Connor1996 --- batch-main.py | 16 +- book/src/SUMMARY.md | 1 + book/src/week3-03-moe.md | 294 ++++++++++++++++++++++++++++++ pyproject.toml | 1 + src/tiny_llm/__init__.py | 1 + src/tiny_llm/models.py | 2 + src/tiny_llm/moe.py | 36 ++++ src/tiny_llm/quantize.py | 3 + src/tiny_llm_ref/__init__.py | 1 + src/tiny_llm_ref/embedding.py | 18 +- src/tiny_llm_ref/moe.py | 90 +++++++++ src/tiny_llm_ref/quantize.py | 3 + src/tiny_llm_ref/qwen3_week3.py | 69 ++++--- tests_refsol/test_week_3_day_2.py | 6 +- tests_refsol/test_week_3_day_3.py | 131 +++++++++++++ 15 files changed, 644 insertions(+), 28 deletions(-) create mode 100644 book/src/week3-03-moe.md create mode 100644 src/tiny_llm/moe.py create mode 100644 src/tiny_llm_ref/moe.py create mode 100644 tests_refsol/test_week_3_day_3.py diff --git a/batch-main.py b/batch-main.py index 2379c7f3..62be4619 100644 --- a/batch-main.py +++ b/batch-main.py @@ -35,9 +35,11 @@ random.shuffle(prompts) parser.add_argument("--solution", type=str, default="tiny_llm") +parser.add_argument("--loader", type=str, choices=["week2", "week3"], default="week2") parser.add_argument("--device", type=str, default="gpu") parser.add_argument("--batch-size", type=int, default=5) parser.add_argument("--prefill-step", type=int, default=128) +parser.add_argument("--max-seq-len", type=int, default=512) parser.add_argument("--enable-flash-attn", action="store_true") parser.add_argument("--enable-thinking", action="store_true") args = parser.parse_args() @@ -57,11 +59,20 @@ mlx_model, tokenizer = load(args.model) with mx.stream(mx.gpu if args.device == "gpu" else mx.cpu): + dispatch_kwargs = {} + if args.loader == "week2": + dispatch_kwargs["enable_flash_attn"] = args.enable_flash_attn + elif args.enable_flash_attn: + print("--enable-flash-attn is only used by the week2 loader; ignoring it") + print( - f"Using week2 loader with flash_attn={args.enable_flash_attn} thinking={args.enable_thinking} for {args.model}" + f"Using {args.loader} loader with thinking={args.enable_thinking} for {args.model}" ) tiny_llm_model = models.dispatch_model( - args.model, mlx_model, week=2, enable_flash_attn=args.enable_flash_attn + args.model, + mlx_model, + week=int(args.loader.removeprefix("week")), + **dispatch_kwargs, ) encoded_prompts = [] for idx, prompt in enumerate(prompts): @@ -81,6 +92,7 @@ tiny_llm_model, tokenizer, encoded_prompts, + max_seq_len=args.max_seq_len, batch_size=args.batch_size, prefill_step=args.prefill_step, ) diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index fb0ee4d7..590dbc9d 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -21,6 +21,7 @@ - [Week 3: Serving]() - [Paged Attention, Part 1](./week3-01-paged-attention-part1.md) - [Paged Attention, Part 2](./week3-02-paged-attention-part2.md) + - [Mixture of Experts](./week3-03-moe.md) --- diff --git a/book/src/week3-03-moe.md b/book/src/week3-03-moe.md new file mode 100644 index 00000000..9938096c --- /dev/null +++ b/book/src/week3-03-moe.md @@ -0,0 +1,294 @@ +# Week 3 Day 3: Mixture of Experts + +In this chapter, we will implement the feed-forward shape of **Mixture of +Experts**, or **MoE**, for the Qwen3 family. + +So far, every transformer block in tiny-llm has used the same dense Qwen3 MLP: + +```plain +x -> gate_proj +x -> up_proj +SiLU(gate_proj(x)) * up_proj(x) -> down_proj +``` + +That is a SwiGLU MLP. Every token visits the same weights. + +MoE changes only the feed-forward half of the transformer block. Instead of one +dense MLP, the model owns many expert MLPs. A small router chooses which experts +each token should use: + +```plain +token hidden state -> router -> top-k experts -> weighted expert outputs +``` + +The attention path does not change. KV cache does not change. The sparse work is +inside the MLP half of the block. + +**Readings** + +- [Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer](https://arxiv.org/abs/1701.06538) +- [GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding](https://arxiv.org/abs/2006.16668) +- [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) + +## Dense MLP vs MoE MLP + +The dense Qwen3 MLP from Week 1 has one set of weights: + +```plain +w_gate: hidden_dim, dim +w_up: hidden_dim, dim +w_down: dim, hidden_dim +``` + +A Qwen3-MoE sparse block has a bank of those weights: + +```plain +expert_gate: num_experts, moe_hidden_dim, dim +expert_up: num_experts, moe_hidden_dim, dim +expert_down: num_experts, dim, moe_hidden_dim +``` + +The router produces one score per expert: + +```plain +router_logits: B, L, num_experts +router_probs: softmax(router_logits) +``` + +Then the model picks `num_experts_per_tok` experts for each token: + +```plain +expert_ids: B, L, num_experts_per_tok +expert_scores: B, L, num_experts_per_tok +``` + +For each token, only those selected experts run. Their outputs are weighted and +summed: + +```plain +output[token] = sum(score_i * expert_i(token)) +``` + +That is the central MoE idea: the model can contain many parameters, but each +token activates only a small subset of them. + +## Qwen3-MoE Shape + +Qwen3-MoE keeps the same attention structure as Qwen3, including QK norm, GQA, +RoPE, and the same KV cache interface. It replaces some dense MLP layers with a +sparse MoE block. + +The useful pieces are: + +- `gate`: a router linear layer from hidden size to `num_experts` +- `switch_mlp`: many SwiGLU experts with `moe_intermediate_size` +- `num_experts_per_tok`: how many experts a token uses +- `norm_topk_prob`: whether selected expert scores are renormalized +- `decoder_sparse_step` and `mlp_only_layers`: which layers are sparse vs dense + +There is no shared expert in the Qwen3-MoE block we are following. The sparse +feed-forward output is just the weighted top-k expert mixture. + +## The MLX Primitive + +MLX does not give us a single high-level MoE block in `mlx.nn`. The relevant +primitive for this chapter is `mx.gather_qmm`: it performs quantized matrix +multiplication while selecting a different matrix for each row. + +For MoE, that means: + +```plain +token rows: N, D +expert ids: N +weights: E, O, D packed as 4-bit QuantizedWeights +output: N, O +``` + +The row with `expert_ids[i] = e` should multiply by `weights[e]`. + +When the expert ids are sorted, pass `sorted_indices=True`. Keep the inverse +order from the sort so the result can be restored to the original token order. + +## Router Step + +The router is just a quantized linear layer: + +```python +router_logits = quantized_linear(x, w_router) +router_probs = softmax(router_logits, axis=-1) +``` + +For a batch of tokens: + +```plain +x: B, L, D +router_logits: B, L, E +router_probs: B, L, E +``` + +where `E = num_experts`. + +Qwen3-MoE then uses top-k selection: + +```python +expert_ids = argpartition(-router_probs, k)[:k] +expert_scores = take_along_axis(router_probs, expert_ids) +``` + +If `norm_topk_prob` is true, renormalize `expert_scores` so the selected scores +sum to 1 for each token. + +## Expert Step + +Each expert is the same kind of SwiGLU MLP we already know: + +```plain +expert(x) = down_proj(SiLU(gate_proj(x)) * up_proj(x)) +``` + +The implementation should build token-expert jobs, group them by expert, and run +the expert projections with `mx.gather_qmm`: + +```plain +selected expert ids -> expanded token-expert rows +expanded rows -> sort/group by expert id +grouped expert rows -> grouped gate/up projection +SiLU(gate) * up -> grouped down projection +restore original token/top-k order -> weighted sum +``` + +The reorder is part of the model implementation. It keeps all token rows for the +same expert contiguous so the expert bank can be applied with grouped matrix +multiplication. + +## Task 1: Grouped Expert Linear + +``` +src/tiny_llm/moe.py +``` + +Implement `grouped_expert_linear`. This is the MLX-shaped core of MoE. + +The function accepts: + +```plain +x: ..., D +w_experts: QuantizedWeights for num_experts, output_dim, D +expert_ids: ... +``` + +It returns: + +```plain +out: ..., output_dim +``` + +The implementation should: + +```plain +1. flatten token rows and expert ids, +2. sort rows by expert id, +3. call mx.gather_qmm with sorted_indices=True, +4. restore the original order. +``` + +For the grouped matmul, the shape should look like: + +```python +out = mx.gather_qmm( + mx.expand_dims(grouped_rows, -2), + w_experts.weight, + w_experts.scales, + w_experts.biases, + lhs_indices=mx.arange(grouped_rows.shape[0]), + rhs_indices=grouped_expert_ids, + transpose=True, + group_size=w_experts.group_size, + bits=w_experts.bits, + mode=w_experts.mode, + sorted_indices=True, +).squeeze(-2) +``` + +This task maps to the same idea as `QuantizedSwitchLinear` in `mlx-lm`: each +token row uses a different packed expert matrix, and the expert ids choose the +right matrix. + +## Task 2: Router Top-k + +``` +src/tiny_llm/moe.py +``` + +Implement `route_topk`. It accepts hidden states and router weights, then +returns: + +- router probabilities +- selected expert ids +- selected expert scores + +Use `quantized_linear` and `softmax`. Use `mx.argpartition` to select the top +`num_experts_per_tok` experts, then `mx.take_along_axis` to gather their scores. + +Keep `norm_topk_prob` as an argument because Qwen3-MoE stores this behavior in +the model config. + +## Task 3: Qwen3 Sparse MoE Block + +``` +src/tiny_llm/moe.py +``` + +Implement `Moe` by composing Task 1 and Task 2: + +```plain +hidden states -> route_topk +hidden states + expert ids -> grouped gate projection +hidden states + expert ids -> grouped up projection +SiLU(gate) * up -> grouped down projection +weighted sum over num_experts_per_tok +``` + +This completes the Qwen3-MoE sparse feed-forward block. There is no shared expert +branch in this block. + +## Task 4: Integrate Qwen3-MoE Layers + +``` +src/tiny_llm/qwen3_week3.py +src/tiny_llm/models.py +``` + +Add a Qwen3-MoE loader path that reuses the Week 3 Qwen3 attention and paged KV +cache behavior, but swaps selected block MLPs for `Moe`. + +The model wrapper should: + +- keep Qwen3 attention unchanged, +- use regular `Qwen3MLP` for `mlp_only_layers`, +- use `Moe` for sparse layers selected by + `decoder_sparse_step`, +- load router and expert weights as `QuantizedWeights` from the Qwen3-MoE MLX + model, +- preserve the same decode call shape: + +```python +logits = model(tokens, offset, cache) +``` + +No scheduler API change in `src/tiny_llm/batch.py` is required for correctness. + +Run this task through the normal generation entrypoints instead of adding a +separate unit test. For example: + +```bash +hf download Qwen/Qwen3-30B-A3B-MLX-4bit + +pdm run main --solution tiny_llm --loader week3 --model qwen3-30b-a3b \ + --prompt "Give me a short introduction to mixture of experts." + +pdm run batch-main --solution tiny_llm --loader week3 --model qwen3-30b-a3b \ + --batch-size 2 --prefill-step 16 +``` + +{{#include copyright.md}} diff --git a/pyproject.toml b/pyproject.toml index f02f3178..64079abf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ clean-ext-ref.working_dir = "src/extensions_ref" main.cmd = "python main.py" main-week1.cmd = "python main.py --loader week1" main-week2.cmd = "python main.py --loader week2" +main-week3.cmd = "python main.py --loader week3" bench.cmd = "python bench.py" batch-main.cmd = "python batch-main.py" test.cmd = "python scripts/dev-tools.py test" diff --git a/src/tiny_llm/__init__.py b/src/tiny_llm/__init__.py index 17790ffb..0fae713d 100644 --- a/src/tiny_llm/__init__.py +++ b/src/tiny_llm/__init__.py @@ -15,3 +15,4 @@ from .paged_kv_cache import * from .batch import * from .models import * +from .moe import * diff --git a/src/tiny_llm/models.py b/src/tiny_llm/models.py index a9acc1b0..a8568ea9 100644 --- a/src/tiny_llm/models.py +++ b/src/tiny_llm/models.py @@ -13,6 +13,8 @@ def shortcut_name_to_full_name(shortcut_name: str): return "Qwen/Qwen3-1.7B-MLX-4bit" elif lower_shortcut_name == "qwen3-4b": return "Qwen/Qwen3-4B-MLX-4bit" + elif lower_shortcut_name in ("qwen3-30b-a3b", "qwen3-moe-30b-a3b"): + return "Qwen/Qwen3-30B-A3B-MLX-4bit" else: return shortcut_name diff --git a/src/tiny_llm/moe.py b/src/tiny_llm/moe.py new file mode 100644 index 00000000..7b5019b5 --- /dev/null +++ b/src/tiny_llm/moe.py @@ -0,0 +1,36 @@ +import mlx.core as mx + +from .quantize import QuantizedWeights + + +def grouped_expert_linear( + x: mx.array, + w_experts: QuantizedWeights, + expert_ids: mx.array, +) -> mx.array: + pass + + +def route_topk( + x: mx.array, + w_router: QuantizedWeights, + top_k: int, + norm_topk_prob: bool = False, +) -> tuple[mx.array, mx.array, mx.array]: + pass + + +class Moe: + def __init__( + self, + w_router: QuantizedWeights, + w_gate: QuantizedWeights, + w_up: QuantizedWeights, + w_down: QuantizedWeights, + num_experts_per_tok: int, + norm_topk_prob: bool = False, + ): + pass + + def __call__(self, x: mx.array) -> mx.array: + pass diff --git a/src/tiny_llm/quantize.py b/src/tiny_llm/quantize.py index 09ef1932..0605e6d5 100644 --- a/src/tiny_llm/quantize.py +++ b/src/tiny_llm/quantize.py @@ -21,12 +21,14 @@ def __init__( group_size: int, bits: int, weight: mx.array, + mode: str = "affine", ): self.scales = scales self.biases = biases self.group_size = group_size self.bits = bits self.weight = weight + self.mode = mode @staticmethod def from_mlx_layer(mlx_layer: Any) -> "QuantizedWeights": @@ -36,6 +38,7 @@ def from_mlx_layer(mlx_layer: Any) -> "QuantizedWeights": group_size=mlx_layer.group_size, bits=mlx_layer.bits, weight=mlx_layer.weight, + mode=getattr(mlx_layer, "mode", "affine"), ) diff --git a/src/tiny_llm_ref/__init__.py b/src/tiny_llm_ref/__init__.py index 17790ffb..0fae713d 100644 --- a/src/tiny_llm_ref/__init__.py +++ b/src/tiny_llm_ref/__init__.py @@ -15,3 +15,4 @@ from .paged_kv_cache import * from .batch import * from .models import * +from .moe import * diff --git a/src/tiny_llm_ref/embedding.py b/src/tiny_llm_ref/embedding.py index 0ecf8424..5a71d027 100644 --- a/src/tiny_llm_ref/embedding.py +++ b/src/tiny_llm_ref/embedding.py @@ -4,7 +4,12 @@ class Embedding: - def __init__(self, vocab_size: int, embedding_dim: int, weight: mx.array): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + weight: mx.array, + ): self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.weight = weight @@ -17,18 +22,25 @@ def as_linear(self, x: mx.array) -> mx.array: class QuantizedEmbedding: - def __init__(self, vocab_size: int, embedding_dim: int, weight: QuantizedWeights): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + weight: QuantizedWeights, + ): self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.weight = weight def __call__(self, x: mx.array) -> mx.array: + biases = self.weight.biases[x] if self.weight.biases is not None else None return mx.dequantize( self.weight.weight[x], self.weight.scales[x], - self.weight.biases[x], + biases, self.weight.group_size, self.weight.bits, + mode=self.weight.mode, ) def as_linear(self, x: mx.array) -> mx.array: diff --git a/src/tiny_llm_ref/moe.py b/src/tiny_llm_ref/moe.py new file mode 100644 index 00000000..c74b8fe0 --- /dev/null +++ b/src/tiny_llm_ref/moe.py @@ -0,0 +1,90 @@ +import mlx.core as mx + +from .basics import silu +from .quantize import QuantizedWeights, quantized_linear + + +def grouped_expert_linear( + x: mx.array, + w_experts: QuantizedWeights, + expert_ids: mx.array, +) -> mx.array: + *leading_dims, D = x.shape + flat_x = x.reshape(-1, D) + flat_expert_ids = expert_ids.reshape(-1) + sort_idx = mx.argsort(flat_expert_ids) + inv_sort_idx = mx.argsort(sort_idx) + + grouped_x = flat_x[sort_idx] + grouped_expert_ids = flat_expert_ids[sort_idx] + out = mx.gather_qmm( + mx.expand_dims(grouped_x, -2), + w_experts.weight, + w_experts.scales, + w_experts.biases, + lhs_indices=mx.arange(grouped_x.shape[0]), + rhs_indices=grouped_expert_ids, + transpose=True, + group_size=w_experts.group_size, + bits=w_experts.bits, + mode=w_experts.mode, + sorted_indices=True, + ).squeeze(-2) + out_dim = w_experts.weight.shape[-2] + return out[inv_sort_idx].reshape(*leading_dims, out_dim) + + +def route_topk( + x: mx.array, + w_router: QuantizedWeights, + top_k: int, + norm_topk_prob: bool = False, +) -> tuple[mx.array, mx.array, mx.array]: + router_logits = quantized_linear(x, w_router) + router_probs = mx.softmax(router_logits, axis=-1, precise=True) + expert_ids = mx.argpartition(-router_probs, kth=top_k - 1, axis=-1)[..., :top_k] + expert_scores = mx.take_along_axis(router_probs, expert_ids, axis=-1) + if norm_topk_prob: + expert_scores = expert_scores / mx.sum(expert_scores, axis=-1, keepdims=True) + return router_probs, expert_ids, expert_scores + + +class Moe: + def __init__( + self, + w_router: QuantizedWeights, + w_gate: QuantizedWeights, + w_up: QuantizedWeights, + w_down: QuantizedWeights, + num_experts_per_tok: int, + norm_topk_prob: bool = False, + ): + self.w_router = w_router + self.w_gate = w_gate + self.w_up = w_up + self.w_down = w_down + self.num_experts_per_tok = num_experts_per_tok + self.norm_topk_prob = norm_topk_prob + + def __call__(self, x: mx.array) -> mx.array: + B, L, D = x.shape + _, expert_ids, expert_scores = route_topk( + x, + self.w_router, + top_k=self.num_experts_per_tok, + norm_topk_prob=self.norm_topk_prob, + ) + expanded_x = mx.broadcast_to( + mx.expand_dims(x, -2), + (B, L, self.num_experts_per_tok, D), + ).reshape(-1, D) + flat_expert_ids = expert_ids.reshape(-1) + + gate = grouped_expert_linear(expanded_x, self.w_gate, flat_expert_ids) + up = grouped_expert_linear(expanded_x, self.w_up, flat_expert_ids) + expert_output = grouped_expert_linear( + silu(gate) * up, + self.w_down, + flat_expert_ids, + ).reshape(B, L, self.num_experts_per_tok, D) + return mx.sum(expert_output * mx.expand_dims(expert_scores, -1), axis=-2) diff --git a/src/tiny_llm_ref/quantize.py b/src/tiny_llm_ref/quantize.py index d98b3a22..8ce5140c 100644 --- a/src/tiny_llm_ref/quantize.py +++ b/src/tiny_llm_ref/quantize.py @@ -11,12 +11,14 @@ def __init__( group_size: int, bits: int, weight: mx.array, + mode: str = "affine", ): self.scales = scales self.biases = biases self.group_size = group_size self.bits = bits self.weight = weight + self.mode = mode @staticmethod def from_mlx_layer(mlx_layer: Any) -> "QuantizedWeights": @@ -26,6 +28,7 @@ def from_mlx_layer(mlx_layer: Any) -> "QuantizedWeights": group_size=mlx_layer.group_size, bits=mlx_layer.bits, weight=mlx_layer.weight, + mode=getattr(mlx_layer, "mode", "affine"), ) diff --git a/src/tiny_llm_ref/qwen3_week3.py b/src/tiny_llm_ref/qwen3_week3.py index 614c4ae0..29fb30e6 100644 --- a/src/tiny_llm_ref/qwen3_week3.py +++ b/src/tiny_llm_ref/qwen3_week3.py @@ -5,9 +5,10 @@ from .layer_norm import RMSNorm from .positional_encoding import RoPE from typing import Any -from .embedding import Embedding -from .quantize import dequantize_linear, QuantizedWeights, quantized_linear +from .embedding import QuantizedEmbedding +from .quantize import QuantizedWeights, quantized_linear from .kv_cache import TinyKvCache +from .moe import Moe from .paged_kv_cache import TinyKvPagedCache, TinyKvPagedPool @@ -122,7 +123,6 @@ def __init__( num_kv_heads: int, hidden_size: int, head_dim: int, - intermediate_size: int, rms_norm_eps: float, wq: QuantizedWeights, wk: QuantizedWeights, @@ -130,17 +130,15 @@ def __init__( wo: QuantizedWeights, q_norm: mx.array, k_norm: mx.array, - w_gate: QuantizedWeights, - w_up: QuantizedWeights, - w_down: QuantizedWeights, w_input_layernorm: mx.array, w_post_attention_layernorm: mx.array, + mlp: Qwen3MLP | Moe, max_seq_len: int = 32768, theta: int = 1000000, ): self.num_attention_heads = num_attention_heads self.hidden_size = hidden_size - self.mlp = Qwen3MLP(hidden_size, intermediate_size, w_gate, w_up, w_down) + self.mlp = mlp self.input_layernorm = RMSNorm(hidden_size, w_input_layernorm, eps=rms_norm_eps) self.post_attention_layernorm = RMSNorm( hidden_size, w_post_attention_layernorm, eps=rms_norm_eps @@ -175,6 +173,14 @@ def __call__( return out +def is_qwen3_moe_sparse_layer(args: Any, layer_idx: int) -> bool: + return ( + getattr(args, "num_experts", 0) > 0 + and layer_idx not in getattr(args, "mlp_only_layers", []) + and (layer_idx + 1) % getattr(args, "decoder_sparse_step", 1) == 0 + ) + + class Qwen3ModelWeek3: def __init__( self, @@ -191,10 +197,10 @@ def __init__( precision = mx.bfloat16 self.precision = precision - self.embedding = Embedding( + self.embedding = QuantizedEmbedding( vocab_size=self.vocab_size, embedding_dim=self.hidden_size, - weight=dequantize_linear(mlx_model.model.embed_tokens), + weight=QuantizedWeights.from_mlx_layer(mlx_model.model.embed_tokens), ) self.layers_inner = [] @@ -211,22 +217,43 @@ def __init__( wo = QuantizedWeights.from_mlx_layer( mlx_model.model.layers[i].self_attn.o_proj ) - w_gate = QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.gate_proj - ) - w_up = QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.up_proj - ) - w_down = QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.down_proj - ) + if is_qwen3_moe_sparse_layer(mlx_model.args, i): + mlp = Moe( + w_router=QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.gate + ), + w_gate=QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.switch_mlp.gate_proj + ), + w_up=QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.switch_mlp.up_proj + ), + w_down=QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.switch_mlp.down_proj + ), + num_experts_per_tok=mlx_model.args.num_experts_per_tok, + norm_topk_prob=mlx_model.args.norm_topk_prob, + ) + else: + mlp = Qwen3MLP( + mlx_model.args.hidden_size, + mlx_model.args.intermediate_size, + QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.gate_proj + ), + QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.up_proj + ), + QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.down_proj + ), + ) layer = Qwen3TransformerBlock( num_attention_heads=mlx_model.args.num_attention_heads, num_kv_heads=mlx_model.args.num_key_value_heads, hidden_size=mlx_model.args.hidden_size, head_dim=mlx_model.args.head_dim, - intermediate_size=mlx_model.args.intermediate_size, rms_norm_eps=mlx_model.args.rms_norm_eps, wq=wq, wk=wk, @@ -234,13 +261,11 @@ def __init__( wo=wo, q_norm=mlx_model.model.layers[i].self_attn.q_norm.weight, k_norm=mlx_model.model.layers[i].self_attn.k_norm.weight, - w_gate=w_gate, - w_up=w_up, - w_down=w_down, w_input_layernorm=mlx_model.model.layers[i].input_layernorm.weight, w_post_attention_layernorm=mlx_model.model.layers[ i ].post_attention_layernorm.weight, + mlp=mlp, max_seq_len=mlx_model.args.max_position_embeddings, theta=mlx_model.args.rope_theta, ) diff --git a/tests_refsol/test_week_3_day_2.py b/tests_refsol/test_week_3_day_2.py index d86ef2e4..efcd2e77 100644 --- a/tests_refsol/test_week_3_day_2.py +++ b/tests_refsol/test_week_3_day_2.py @@ -203,5 +203,9 @@ def test_task_3_incremental_decode_matches_week2_with_paged_attention(): week2_out = week2_out - mx.logsumexp(week2_out, keepdims=True) week3_out = week3_out - mx.logsumexp(week3_out, keepdims=True) assert_allclose( - week3_out, week2_out, precision=mx.bfloat16, rtol=1e-3, atol=1e-3 + week3_out, + week2_out, + precision=mx.bfloat16, + rtol=1e-3, + atol=1e-3, ) diff --git a/tests_refsol/test_week_3_day_3.py b/tests_refsol/test_week_3_day_3.py new file mode 100644 index 00000000..7bf82abd --- /dev/null +++ b/tests_refsol/test_week_3_day_3.py @@ -0,0 +1,131 @@ +from types import SimpleNamespace + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.qwen3_moe import ( + Qwen3MoeSparseMoeBlock as MlxQwen3MoeSparseMoeBlock, +) +from mlx_lm.models.switch_layers import SwitchLinear + +from .tiny_llm_base import ( + Moe, + QuantizedWeights, + grouped_expert_linear, + route_topk, +) +from .utils import assert_allclose + + +def test_task_1_grouped_expert_linear(): + mx.random.seed(1) + scale = 0.25 + x = mx.random.normal(shape=(2, 3, 128), dtype=mx.float16) * scale + w_experts = mx.random.normal(shape=(3, 64, 128), dtype=mx.float16) * scale + expert_ids = mx.array( + [ + [2, 0, 1], + [1, 2, 0], + ], + dtype=mx.uint32, + ) + + ref = SwitchLinear( + input_dims=w_experts.shape[-1], + output_dims=w_experts.shape[-2], + num_experts=w_experts.shape[0], + bias=False, + ) + ref.weight = w_experts + ref = ref.to_quantized(group_size=128, bits=4) + + out = grouped_expert_linear( + x, + QuantizedWeights.from_mlx_layer(ref), + expert_ids, + ) + expected = ref(mx.expand_dims(x, -2), expert_ids).squeeze(-2) + + assert out.shape == (2, 3, 64) + assert_allclose(out, expected, precision=mx.float16) + + +def test_task_2_router_topk(): + mx.random.seed(2) + scale = 0.25 + x = mx.random.normal(shape=(2, 2, 128), dtype=mx.float16) * scale + ref = nn.Linear(128, 4, bias=False) + ref.weight = mx.random.normal(shape=(4, 128), dtype=mx.float16) * scale + ref = ref.to_quantized(group_size=128, bits=4) + + router_probs, expert_ids, expert_scores = route_topk( + x, + QuantizedWeights.from_mlx_layer(ref), + top_k=2, + ) + _, _, normalized_scores = route_topk( + x, + QuantizedWeights.from_mlx_layer(ref), + top_k=2, + norm_topk_prob=True, + ) + + expected_probs = mx.softmax(ref(x), axis=-1, precise=True) + expected_ids = mx.argpartition(-expected_probs, kth=1, axis=-1)[..., :2] + expected_scores = mx.take_along_axis(expected_probs, expected_ids, axis=-1) + expected_normalized_scores = expected_scores / expected_scores.sum( + axis=-1, + keepdims=True, + ) + + assert router_probs.shape == (2, 2, 4) + assert expert_ids.shape == (2, 2, 2) + assert expert_scores.shape == (2, 2, 2) + assert expert_ids.tolist() == expected_ids.tolist() + assert_allclose(router_probs, expected_probs, precision=mx.float16) + assert_allclose(expert_scores, expected_scores, precision=mx.float16) + assert_allclose( + normalized_scores, + expected_normalized_scores, + precision=mx.float16, + ) + + +def test_task_3_moe(): + mx.random.seed(3) + scale = 0.25 + x = mx.random.normal(shape=(2, 3, 128), dtype=mx.float16) * scale + ref = MlxQwen3MoeSparseMoeBlock( + SimpleNamespace( + hidden_size=128, + moe_intermediate_size=128, + num_experts=3, + num_experts_per_tok=2, + norm_topk_prob=True, + ) + ) + ref.gate.weight = mx.random.normal(shape=(3, 128), dtype=mx.float16) * scale + ref.switch_mlp.gate_proj.weight = ( + mx.random.normal(shape=(3, 128, 128), dtype=mx.float16) * scale + ) + ref.switch_mlp.up_proj.weight = ( + mx.random.normal(shape=(3, 128, 128), dtype=mx.float16) * scale + ) + ref.switch_mlp.down_proj.weight = ( + mx.random.normal(shape=(3, 128, 128), dtype=mx.float16) * scale + ) + nn.quantize(ref, group_size=128, bits=4) + + moe = Moe( + w_router=QuantizedWeights.from_mlx_layer(ref.gate), + w_gate=QuantizedWeights.from_mlx_layer(ref.switch_mlp.gate_proj), + w_up=QuantizedWeights.from_mlx_layer(ref.switch_mlp.up_proj), + w_down=QuantizedWeights.from_mlx_layer(ref.switch_mlp.down_proj), + num_experts_per_tok=2, + norm_topk_prob=True, + ) + + out = moe(x) + expected = ref(x) + + assert out.shape == x.shape + assert_allclose(out, expected, precision=mx.float16) From be422a97dc8919fbcb018c43100e198a27e0c5a7 Mon Sep 17 00:00:00 2001 From: Connor1996 Date: Mon, 8 Jun 2026 23:55:00 -0700 Subject: [PATCH 2/5] Simplify MoE quantized weight metadata Signed-off-by: Connor1996 --- book/src/week2-02-quantized-matmul.md | 3 +- book/src/week3-03-moe.md | 69 +++++++++++++++++---------- src/tiny_llm/quantize.py | 18 ++++--- src/tiny_llm_ref/embedding.py | 12 +---- src/tiny_llm_ref/moe.py | 1 - src/tiny_llm_ref/quantize.py | 18 ++++--- 6 files changed, 70 insertions(+), 51 deletions(-) diff --git a/book/src/week2-02-quantized-matmul.md b/book/src/week2-02-quantized-matmul.md index 0bcc22aa..7c0c16fc 100644 --- a/book/src/week2-02-quantized-matmul.md +++ b/book/src/week2-02-quantized-matmul.md @@ -243,7 +243,8 @@ The token embedding table should also stay quantized in Week 2. Add a `QuantizedEmbedding` wrapper with two call patterns: - `embedding(input_ids)` is a row lookup. Gather the matching packed rows, - scales, and biases, then call `mx.dequantize` on only those selected rows. + scales, and biases with the provided `dequantize_linear` helper so only those + selected rows are dequantized. - `embedding.as_linear(h)` is the tied output projection. Implement this with `quantized_linear(h, embedding_weight)` so it uses your quantized matmul path instead of materializing the full `vocab_size x hidden_size` table. This path diff --git a/book/src/week3-03-moe.md b/book/src/week3-03-moe.md index 9938096c..4727cafc 100644 --- a/book/src/week3-03-moe.md +++ b/book/src/week3-03-moe.md @@ -89,11 +89,13 @@ The useful pieces are: There is no shared expert in the Qwen3-MoE block we are following. The sparse feed-forward output is just the weighted top-k expert mixture. -## The MLX Primitive +## Grouped Quantized Matmul -MLX does not give us a single high-level MoE block in `mlx.nn`. The relevant -primitive for this chapter is `mx.gather_qmm`: it performs quantized matrix -multiplication while selecting a different matrix for each row. +MLX does not give us a single high-level MoE block in `mlx.nn`. It does have a +lower-level primitive, `mx.gather_qmm`, that performs quantized matrix +multiplication while selecting a different matrix for each row. In this chapter, +we will build a narrow teaching version of that idea: +`grouped_quantized_matmul`. For MoE, that means: @@ -106,8 +108,9 @@ output: N, O The row with `expert_ids[i] = e` should multiply by `weights[e]`. -When the expert ids are sorted, pass `sorted_indices=True`. Keep the inverse -order from the sort so the result can be restored to the original token order. +Task 1 will assume the rows are already sorted by expert id. The MoE helper will +keep the inverse order from the sort so the result can be restored to the +original token order. ## Router Step @@ -147,7 +150,7 @@ expert(x) = down_proj(SiLU(gate_proj(x)) * up_proj(x)) ``` The implementation should build token-expert jobs, group them by expert, and run -the expert projections with `mx.gather_qmm`: +the expert projections with `grouped_quantized_matmul`: ```plain selected expert ids -> expanded token-expert rows @@ -161,53 +164,69 @@ The reorder is part of the model implementation. It keeps all token rows for the same expert contiguous so the expert bank can be applied with grouped matrix multiplication. -## Task 1: Grouped Expert Linear +## Task 1: Grouped Quantized Matmul ``` +src/extensions/src/quantized_matmul.cpp +src/extensions/src/quantized_matmul.metal +src/tiny_llm/quantize.py src/tiny_llm/moe.py ``` -Implement `grouped_expert_linear`. This is the MLX-shaped core of MoE. +Implement `grouped_quantized_matmul`, then use it from `grouped_expert_linear`. +This is the quantized grouped-matmul core of MoE. -The function accepts: +`grouped_quantized_matmul` accepts: ```plain -x: ..., D -w_experts: QuantizedWeights for num_experts, output_dim, D -expert_ids: ... +a: R, D +w_experts: packed QuantizedWeights for num_experts, output_dim, D +expert_ids: R, sorted by expert id ``` It returns: ```plain -out: ..., output_dim +out: R, output_dim +``` + +Each row uses the expert selected by the matching row in `expert_ids`: + +```plain +out[row] = a[row] @ dequantize(w_experts[expert_ids[row]]).T ``` The implementation should: +```plain +1. add a Python wrapper for grouped_quantized_matmul, +2. extend the quantized matmul extension with a grouped entrypoint, +3. read expert_ids[row] inside the kernel, +4. use that expert id to choose the expert weight, scale, and bias row. +``` + +After that, implement `grouped_expert_linear` in `src/tiny_llm/moe.py`: + ```plain 1. flatten token rows and expert ids, 2. sort rows by expert id, -3. call mx.gather_qmm with sorted_indices=True, +3. call grouped_quantized_matmul, 4. restore the original order. ``` -For the grouped matmul, the shape should look like: +The call should look like: ```python -out = mx.gather_qmm( - mx.expand_dims(grouped_rows, -2), - w_experts.weight, +out = grouped_quantized_matmul( w_experts.scales, w_experts.biases, - lhs_indices=mx.arange(grouped_rows.shape[0]), - rhs_indices=grouped_expert_ids, - transpose=True, group_size=w_experts.group_size, bits=w_experts.bits, - mode=w_experts.mode, - sorted_indices=True, -).squeeze(-2) + a=grouped_rows, + b=w_experts.weight, + expert_ids=grouped_expert_ids, + transpose_b=True, +) ``` This task maps to the same idea as `QuantizedSwitchLinear` in `mlx-lm`: each diff --git a/src/tiny_llm/quantize.py b/src/tiny_llm/quantize.py index 0605e6d5..b24ecb5c 100644 --- a/src/tiny_llm/quantize.py +++ b/src/tiny_llm/quantize.py @@ -2,11 +2,18 @@ from typing import Any -def dequantize_linear(mx_layer: Any) -> mx.array: +def dequantize_linear(mx_layer: Any, indices: mx.array | None = None) -> mx.array: + weight = mx_layer.weight + scales = mx_layer.scales + biases = mx_layer.biases + if indices is not None: + weight = weight[indices] + scales = scales[indices] + biases = biases[indices] if biases is not None else None w = mx.dequantize( - mx_layer.weight, - mx_layer.scales, - mx_layer.biases, + weight, + scales, + biases, mx_layer.group_size, mx_layer.bits, ) @@ -21,14 +28,12 @@ def __init__( group_size: int, bits: int, weight: mx.array, - mode: str = "affine", ): self.scales = scales self.biases = biases self.group_size = group_size self.bits = bits self.weight = weight - self.mode = mode @staticmethod def from_mlx_layer(mlx_layer: Any) -> "QuantizedWeights": @@ -38,7 +43,6 @@ def from_mlx_layer(mlx_layer: Any) -> "QuantizedWeights": group_size=mlx_layer.group_size, bits=mlx_layer.bits, weight=mlx_layer.weight, - mode=getattr(mlx_layer, "mode", "affine"), ) diff --git a/src/tiny_llm_ref/embedding.py b/src/tiny_llm_ref/embedding.py index 5a71d027..305d187e 100644 --- a/src/tiny_llm_ref/embedding.py +++ b/src/tiny_llm_ref/embedding.py @@ -1,6 +1,6 @@ import mlx.core as mx from .basics import linear -from .quantize import QuantizedWeights, quantized_linear +from .quantize import QuantizedWeights, dequantize_linear, quantized_linear class Embedding: @@ -33,15 +33,7 @@ def __init__( self.weight = weight def __call__(self, x: mx.array) -> mx.array: - biases = self.weight.biases[x] if self.weight.biases is not None else None - return mx.dequantize( - self.weight.weight[x], - self.weight.scales[x], - biases, - self.weight.group_size, - self.weight.bits, - mode=self.weight.mode, - ) + return dequantize_linear(self.weight, x) def as_linear(self, x: mx.array) -> mx.array: return quantized_linear(x, self.weight) diff --git a/src/tiny_llm_ref/moe.py b/src/tiny_llm_ref/moe.py index c74b8fe0..9e6d3572 100644 --- a/src/tiny_llm_ref/moe.py +++ b/src/tiny_llm_ref/moe.py @@ -27,7 +27,6 @@ def grouped_expert_linear( transpose=True, group_size=w_experts.group_size, bits=w_experts.bits, - mode=w_experts.mode, sorted_indices=True, ).squeeze(-2) out_dim = w_experts.weight.shape[-2] diff --git a/src/tiny_llm_ref/quantize.py b/src/tiny_llm_ref/quantize.py index 8ce5140c..d1ac7c34 100644 --- a/src/tiny_llm_ref/quantize.py +++ b/src/tiny_llm_ref/quantize.py @@ -11,14 +11,12 @@ def __init__( group_size: int, bits: int, weight: mx.array, - mode: str = "affine", ): self.scales = scales self.biases = biases self.group_size = group_size self.bits = bits self.weight = weight - self.mode = mode @staticmethod def from_mlx_layer(mlx_layer: Any) -> "QuantizedWeights": @@ -28,7 +26,6 @@ def from_mlx_layer(mlx_layer: Any) -> "QuantizedWeights": group_size=mlx_layer.group_size, bits=mlx_layer.bits, weight=mlx_layer.weight, - mode=getattr(mlx_layer, "mode", "affine"), ) @@ -50,11 +47,18 @@ def quantized_linear( ) -def dequantize_linear(mx_layer: Any) -> mx.array: +def dequantize_linear(mx_layer: Any, indices: mx.array | None = None) -> mx.array: + weight = mx_layer.weight + scales = mx_layer.scales + biases = mx_layer.biases + if indices is not None: + weight = weight[indices] + scales = scales[indices] + biases = biases[indices] if biases is not None else None w = mx.dequantize( - mx_layer.weight, - mx_layer.scales, - mx_layer.biases, + weight, + scales, + biases, mx_layer.group_size, mx_layer.bits, ) From e33c18faa619fc3052599fda79845e1d009b9352 Mon Sep 17 00:00:00 2001 From: Connor1996 Date: Tue, 9 Jun 2026 00:19:12 -0700 Subject: [PATCH 3/5] Move Qwen3 MoE MLP selection into block Signed-off-by: Connor1996 --- src/tiny_llm_ref/qwen3_week3.py | 82 +++++++++++++++++++++------------ 1 file changed, 52 insertions(+), 30 deletions(-) diff --git a/src/tiny_llm_ref/qwen3_week3.py b/src/tiny_llm_ref/qwen3_week3.py index 29fb30e6..7a7dbc70 100644 --- a/src/tiny_llm_ref/qwen3_week3.py +++ b/src/tiny_llm_ref/qwen3_week3.py @@ -123,6 +123,7 @@ def __init__( num_kv_heads: int, hidden_size: int, head_dim: int, + intermediate_size: int, rms_norm_eps: float, wq: QuantizedWeights, wk: QuantizedWeights, @@ -130,15 +131,31 @@ def __init__( wo: QuantizedWeights, q_norm: mx.array, k_norm: mx.array, + w_gate: QuantizedWeights, + w_up: QuantizedWeights, + w_down: QuantizedWeights, w_input_layernorm: mx.array, w_post_attention_layernorm: mx.array, - mlp: Qwen3MLP | Moe, + w_router: QuantizedWeights | None = None, + num_experts_per_tok: int | None = None, + norm_topk_prob: bool = False, max_seq_len: int = 32768, theta: int = 1000000, ): self.num_attention_heads = num_attention_heads self.hidden_size = hidden_size - self.mlp = mlp + if w_router is None: + self.mlp = Qwen3MLP(hidden_size, intermediate_size, w_gate, w_up, w_down) + else: + assert num_experts_per_tok is not None + self.mlp = Moe( + w_router=w_router, + w_gate=w_gate, + w_up=w_up, + w_down=w_down, + num_experts_per_tok=num_experts_per_tok, + norm_topk_prob=norm_topk_prob, + ) self.input_layernorm = RMSNorm(hidden_size, w_input_layernorm, eps=rms_norm_eps) self.post_attention_layernorm = RMSNorm( hidden_size, w_post_attention_layernorm, eps=rms_norm_eps @@ -218,42 +235,42 @@ def __init__( mlx_model.model.layers[i].self_attn.o_proj ) if is_qwen3_moe_sparse_layer(mlx_model.args, i): - mlp = Moe( - w_router=QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.gate - ), - w_gate=QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.switch_mlp.gate_proj - ), - w_up=QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.switch_mlp.up_proj - ), - w_down=QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.switch_mlp.down_proj - ), - num_experts_per_tok=mlx_model.args.num_experts_per_tok, - norm_topk_prob=mlx_model.args.norm_topk_prob, + w_router = QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.gate ) + w_gate = QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.switch_mlp.gate_proj + ) + w_up = QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.switch_mlp.up_proj + ) + w_down = QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.switch_mlp.down_proj + ) + intermediate_size = mlx_model.args.moe_intermediate_size + num_experts_per_tok = mlx_model.args.num_experts_per_tok + norm_topk_prob = mlx_model.args.norm_topk_prob else: - mlp = Qwen3MLP( - mlx_model.args.hidden_size, - mlx_model.args.intermediate_size, - QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.gate_proj - ), - QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.up_proj - ), - QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.down_proj - ), + w_router = None + w_gate = QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.gate_proj + ) + w_up = QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.up_proj + ) + w_down = QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.down_proj ) + intermediate_size = mlx_model.args.intermediate_size + num_experts_per_tok = None + norm_topk_prob = False layer = Qwen3TransformerBlock( num_attention_heads=mlx_model.args.num_attention_heads, num_kv_heads=mlx_model.args.num_key_value_heads, hidden_size=mlx_model.args.hidden_size, head_dim=mlx_model.args.head_dim, + intermediate_size=intermediate_size, rms_norm_eps=mlx_model.args.rms_norm_eps, wq=wq, wk=wk, @@ -261,11 +278,16 @@ def __init__( wo=wo, q_norm=mlx_model.model.layers[i].self_attn.q_norm.weight, k_norm=mlx_model.model.layers[i].self_attn.k_norm.weight, + w_gate=w_gate, + w_up=w_up, + w_down=w_down, w_input_layernorm=mlx_model.model.layers[i].input_layernorm.weight, w_post_attention_layernorm=mlx_model.model.layers[ i ].post_attention_layernorm.weight, - mlp=mlp, + w_router=w_router, + num_experts_per_tok=num_experts_per_tok, + norm_topk_prob=norm_topk_prob, max_seq_len=mlx_model.args.max_position_embeddings, theta=mlx_model.args.rope_theta, ) From 9d3e0858393b676382133f1fee395567b77dcd6b Mon Sep 17 00:00:00 2001 From: Connor1996 Date: Tue, 9 Jun 2026 00:21:28 -0700 Subject: [PATCH 4/5] Revert "Move Qwen3 MoE MLP selection into block" This reverts commit e33c18faa619fc3052599fda79845e1d009b9352. Signed-off-by: Connor1996 --- src/tiny_llm_ref/qwen3_week3.py | 82 ++++++++++++--------------------- 1 file changed, 30 insertions(+), 52 deletions(-) diff --git a/src/tiny_llm_ref/qwen3_week3.py b/src/tiny_llm_ref/qwen3_week3.py index 7a7dbc70..29fb30e6 100644 --- a/src/tiny_llm_ref/qwen3_week3.py +++ b/src/tiny_llm_ref/qwen3_week3.py @@ -123,7 +123,6 @@ def __init__( num_kv_heads: int, hidden_size: int, head_dim: int, - intermediate_size: int, rms_norm_eps: float, wq: QuantizedWeights, wk: QuantizedWeights, @@ -131,31 +130,15 @@ def __init__( wo: QuantizedWeights, q_norm: mx.array, k_norm: mx.array, - w_gate: QuantizedWeights, - w_up: QuantizedWeights, - w_down: QuantizedWeights, w_input_layernorm: mx.array, w_post_attention_layernorm: mx.array, - w_router: QuantizedWeights | None = None, - num_experts_per_tok: int | None = None, - norm_topk_prob: bool = False, + mlp: Qwen3MLP | Moe, max_seq_len: int = 32768, theta: int = 1000000, ): self.num_attention_heads = num_attention_heads self.hidden_size = hidden_size - if w_router is None: - self.mlp = Qwen3MLP(hidden_size, intermediate_size, w_gate, w_up, w_down) - else: - assert num_experts_per_tok is not None - self.mlp = Moe( - w_router=w_router, - w_gate=w_gate, - w_up=w_up, - w_down=w_down, - num_experts_per_tok=num_experts_per_tok, - norm_topk_prob=norm_topk_prob, - ) + self.mlp = mlp self.input_layernorm = RMSNorm(hidden_size, w_input_layernorm, eps=rms_norm_eps) self.post_attention_layernorm = RMSNorm( hidden_size, w_post_attention_layernorm, eps=rms_norm_eps @@ -235,42 +218,42 @@ def __init__( mlx_model.model.layers[i].self_attn.o_proj ) if is_qwen3_moe_sparse_layer(mlx_model.args, i): - w_router = QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.gate + mlp = Moe( + w_router=QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.gate + ), + w_gate=QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.switch_mlp.gate_proj + ), + w_up=QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.switch_mlp.up_proj + ), + w_down=QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.switch_mlp.down_proj + ), + num_experts_per_tok=mlx_model.args.num_experts_per_tok, + norm_topk_prob=mlx_model.args.norm_topk_prob, ) - w_gate = QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.switch_mlp.gate_proj - ) - w_up = QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.switch_mlp.up_proj - ) - w_down = QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.switch_mlp.down_proj - ) - intermediate_size = mlx_model.args.moe_intermediate_size - num_experts_per_tok = mlx_model.args.num_experts_per_tok - norm_topk_prob = mlx_model.args.norm_topk_prob else: - w_router = None - w_gate = QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.gate_proj - ) - w_up = QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.up_proj - ) - w_down = QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.down_proj + mlp = Qwen3MLP( + mlx_model.args.hidden_size, + mlx_model.args.intermediate_size, + QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.gate_proj + ), + QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.up_proj + ), + QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.down_proj + ), ) - intermediate_size = mlx_model.args.intermediate_size - num_experts_per_tok = None - norm_topk_prob = False layer = Qwen3TransformerBlock( num_attention_heads=mlx_model.args.num_attention_heads, num_kv_heads=mlx_model.args.num_key_value_heads, hidden_size=mlx_model.args.hidden_size, head_dim=mlx_model.args.head_dim, - intermediate_size=intermediate_size, rms_norm_eps=mlx_model.args.rms_norm_eps, wq=wq, wk=wk, @@ -278,16 +261,11 @@ def __init__( wo=wo, q_norm=mlx_model.model.layers[i].self_attn.q_norm.weight, k_norm=mlx_model.model.layers[i].self_attn.k_norm.weight, - w_gate=w_gate, - w_up=w_up, - w_down=w_down, w_input_layernorm=mlx_model.model.layers[i].input_layernorm.weight, w_post_attention_layernorm=mlx_model.model.layers[ i ].post_attention_layernorm.weight, - w_router=w_router, - num_experts_per_tok=num_experts_per_tok, - norm_topk_prob=norm_topk_prob, + mlp=mlp, max_seq_len=mlx_model.args.max_position_embeddings, theta=mlx_model.args.rope_theta, ) From dd42ca80a99318a3af70cd698e5b8d66d3d4b86d Mon Sep 17 00:00:00 2001 From: Connor1996 Date: Tue, 9 Jun 2026 00:28:36 -0700 Subject: [PATCH 5/5] Restore direct embedding dequantize lookup Signed-off-by: Connor1996 --- book/src/week2-02-quantized-matmul.md | 3 +-- src/tiny_llm/quantize.py | 15 ++++----------- src/tiny_llm_ref/embedding.py | 11 +++++++++-- src/tiny_llm_ref/quantize.py | 15 ++++----------- 4 files changed, 18 insertions(+), 26 deletions(-) diff --git a/book/src/week2-02-quantized-matmul.md b/book/src/week2-02-quantized-matmul.md index 7c0c16fc..0bcc22aa 100644 --- a/book/src/week2-02-quantized-matmul.md +++ b/book/src/week2-02-quantized-matmul.md @@ -243,8 +243,7 @@ The token embedding table should also stay quantized in Week 2. Add a `QuantizedEmbedding` wrapper with two call patterns: - `embedding(input_ids)` is a row lookup. Gather the matching packed rows, - scales, and biases with the provided `dequantize_linear` helper so only those - selected rows are dequantized. + scales, and biases, then call `mx.dequantize` on only those selected rows. - `embedding.as_linear(h)` is the tied output projection. Implement this with `quantized_linear(h, embedding_weight)` so it uses your quantized matmul path instead of materializing the full `vocab_size x hidden_size` table. This path diff --git a/src/tiny_llm/quantize.py b/src/tiny_llm/quantize.py index b24ecb5c..09ef1932 100644 --- a/src/tiny_llm/quantize.py +++ b/src/tiny_llm/quantize.py @@ -2,18 +2,11 @@ from typing import Any -def dequantize_linear(mx_layer: Any, indices: mx.array | None = None) -> mx.array: - weight = mx_layer.weight - scales = mx_layer.scales - biases = mx_layer.biases - if indices is not None: - weight = weight[indices] - scales = scales[indices] - biases = biases[indices] if biases is not None else None +def dequantize_linear(mx_layer: Any) -> mx.array: w = mx.dequantize( - weight, - scales, - biases, + mx_layer.weight, + mx_layer.scales, + mx_layer.biases, mx_layer.group_size, mx_layer.bits, ) diff --git a/src/tiny_llm_ref/embedding.py b/src/tiny_llm_ref/embedding.py index 305d187e..c8be2b65 100644 --- a/src/tiny_llm_ref/embedding.py +++ b/src/tiny_llm_ref/embedding.py @@ -1,6 +1,6 @@ import mlx.core as mx from .basics import linear -from .quantize import QuantizedWeights, dequantize_linear, quantized_linear +from .quantize import QuantizedWeights, quantized_linear class Embedding: @@ -33,7 +33,14 @@ def __init__( self.weight = weight def __call__(self, x: mx.array) -> mx.array: - return dequantize_linear(self.weight, x) + biases = self.weight.biases[x] if self.weight.biases is not None else None + return mx.dequantize( + self.weight.weight[x], + self.weight.scales[x], + biases, + self.weight.group_size, + self.weight.bits, + ) def as_linear(self, x: mx.array) -> mx.array: return quantized_linear(x, self.weight) diff --git a/src/tiny_llm_ref/quantize.py b/src/tiny_llm_ref/quantize.py index d1ac7c34..d98b3a22 100644 --- a/src/tiny_llm_ref/quantize.py +++ b/src/tiny_llm_ref/quantize.py @@ -47,18 +47,11 @@ def quantized_linear( ) -def dequantize_linear(mx_layer: Any, indices: mx.array | None = None) -> mx.array: - weight = mx_layer.weight - scales = mx_layer.scales - biases = mx_layer.biases - if indices is not None: - weight = weight[indices] - scales = scales[indices] - biases = biases[indices] if biases is not None else None +def dequantize_linear(mx_layer: Any) -> mx.array: w = mx.dequantize( - weight, - scales, - biases, + mx_layer.weight, + mx_layer.scales, + mx_layer.biases, mx_layer.group_size, mx_layer.bits, )