diff --git a/batch-main.py b/batch-main.py index 2379c7f..62be461 100644 --- a/batch-main.py +++ b/batch-main.py @@ -35,9 +35,11 @@ random.shuffle(prompts) parser.add_argument("--solution", type=str, default="tiny_llm") +parser.add_argument("--loader", type=str, choices=["week2", "week3"], default="week2") parser.add_argument("--device", type=str, default="gpu") parser.add_argument("--batch-size", type=int, default=5) parser.add_argument("--prefill-step", type=int, default=128) +parser.add_argument("--max-seq-len", type=int, default=512) parser.add_argument("--enable-flash-attn", action="store_true") parser.add_argument("--enable-thinking", action="store_true") args = parser.parse_args() @@ -57,11 +59,20 @@ mlx_model, tokenizer = load(args.model) with mx.stream(mx.gpu if args.device == "gpu" else mx.cpu): + dispatch_kwargs = {} + if args.loader == "week2": + dispatch_kwargs["enable_flash_attn"] = args.enable_flash_attn + elif args.enable_flash_attn: + print("--enable-flash-attn is only used by the week2 loader; ignoring it") + print( - f"Using week2 loader with flash_attn={args.enable_flash_attn} thinking={args.enable_thinking} for {args.model}" + f"Using {args.loader} loader with thinking={args.enable_thinking} for {args.model}" ) tiny_llm_model = models.dispatch_model( - args.model, mlx_model, week=2, enable_flash_attn=args.enable_flash_attn + args.model, + mlx_model, + week=int(args.loader.removeprefix("week")), + **dispatch_kwargs, ) encoded_prompts = [] for idx, prompt in enumerate(prompts): @@ -81,6 +92,7 @@ tiny_llm_model, tokenizer, encoded_prompts, + max_seq_len=args.max_seq_len, batch_size=args.batch_size, prefill_step=args.prefill_step, ) diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index fb0ee4d..590dbc9 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -21,6 +21,7 @@ - [Week 3: Serving]() - [Paged Attention, Part 1](./week3-01-paged-attention-part1.md) - [Paged Attention, Part 2](./week3-02-paged-attention-part2.md) + - [Mixture of Experts](./week3-03-moe.md) --- diff --git a/book/src/week3-03-moe.md b/book/src/week3-03-moe.md new file mode 100644 index 0000000..4727caf --- /dev/null +++ b/book/src/week3-03-moe.md @@ -0,0 +1,313 @@ +# Week 3 Day 3: Mixture of Experts + +In this chapter, we will implement the feed-forward shape of **Mixture of +Experts**, or **MoE**, for the Qwen3 family. + +So far, every transformer block in tiny-llm has used the same dense Qwen3 MLP: + +```plain +x -> gate_proj +x -> up_proj +SiLU(gate_proj(x)) * up_proj(x) -> down_proj +``` + +That is a SwiGLU MLP. Every token visits the same weights. + +MoE changes only the feed-forward half of the transformer block. Instead of one +dense MLP, the model owns many expert MLPs. A small router chooses which experts +each token should use: + +```plain +token hidden state -> router -> top-k experts -> weighted expert outputs +``` + +The attention path does not change. KV cache does not change. The sparse work is +inside the MLP half of the block. + +**Readings** + +- [Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer](https://arxiv.org/abs/1701.06538) +- [GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding](https://arxiv.org/abs/2006.16668) +- [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) + +## Dense MLP vs MoE MLP + +The dense Qwen3 MLP from Week 1 has one set of weights: + +```plain +w_gate: hidden_dim, dim +w_up: hidden_dim, dim +w_down: dim, hidden_dim +``` + +A Qwen3-MoE sparse block has a bank of those weights: + +```plain +expert_gate: num_experts, moe_hidden_dim, dim +expert_up: num_experts, moe_hidden_dim, dim +expert_down: num_experts, dim, moe_hidden_dim +``` + +The router produces one score per expert: + +```plain +router_logits: B, L, num_experts +router_probs: softmax(router_logits) +``` + +Then the model picks `num_experts_per_tok` experts for each token: + +```plain +expert_ids: B, L, num_experts_per_tok +expert_scores: B, L, num_experts_per_tok +``` + +For each token, only those selected experts run. Their outputs are weighted and +summed: + +```plain +output[token] = sum(score_i * expert_i(token)) +``` + +That is the central MoE idea: the model can contain many parameters, but each +token activates only a small subset of them. + +## Qwen3-MoE Shape + +Qwen3-MoE keeps the same attention structure as Qwen3, including QK norm, GQA, +RoPE, and the same KV cache interface. It replaces some dense MLP layers with a +sparse MoE block. + +The useful pieces are: + +- `gate`: a router linear layer from hidden size to `num_experts` +- `switch_mlp`: many SwiGLU experts with `moe_intermediate_size` +- `num_experts_per_tok`: how many experts a token uses +- `norm_topk_prob`: whether selected expert scores are renormalized +- `decoder_sparse_step` and `mlp_only_layers`: which layers are sparse vs dense + +There is no shared expert in the Qwen3-MoE block we are following. The sparse +feed-forward output is just the weighted top-k expert mixture. + +## Grouped Quantized Matmul + +MLX does not give us a single high-level MoE block in `mlx.nn`. It does have a +lower-level primitive, `mx.gather_qmm`, that performs quantized matrix +multiplication while selecting a different matrix for each row. In this chapter, +we will build a narrow teaching version of that idea: +`grouped_quantized_matmul`. + +For MoE, that means: + +```plain +token rows: N, D +expert ids: N +weights: E, O, D packed as 4-bit QuantizedWeights +output: N, O +``` + +The row with `expert_ids[i] = e` should multiply by `weights[e]`. + +Task 1 will assume the rows are already sorted by expert id. The MoE helper will +keep the inverse order from the sort so the result can be restored to the +original token order. + +## Router Step + +The router is just a quantized linear layer: + +```python +router_logits = quantized_linear(x, w_router) +router_probs = softmax(router_logits, axis=-1) +``` + +For a batch of tokens: + +```plain +x: B, L, D +router_logits: B, L, E +router_probs: B, L, E +``` + +where `E = num_experts`. + +Qwen3-MoE then uses top-k selection: + +```python +expert_ids = argpartition(-router_probs, k)[:k] +expert_scores = take_along_axis(router_probs, expert_ids) +``` + +If `norm_topk_prob` is true, renormalize `expert_scores` so the selected scores +sum to 1 for each token. + +## Expert Step + +Each expert is the same kind of SwiGLU MLP we already know: + +```plain +expert(x) = down_proj(SiLU(gate_proj(x)) * up_proj(x)) +``` + +The implementation should build token-expert jobs, group them by expert, and run +the expert projections with `grouped_quantized_matmul`: + +```plain +selected expert ids -> expanded token-expert rows +expanded rows -> sort/group by expert id +grouped expert rows -> grouped gate/up projection +SiLU(gate) * up -> grouped down projection +restore original token/top-k order -> weighted sum +``` + +The reorder is part of the model implementation. It keeps all token rows for the +same expert contiguous so the expert bank can be applied with grouped matrix +multiplication. + +## Task 1: Grouped Quantized Matmul + +``` +src/extensions/src/quantized_matmul.cpp +src/extensions/src/quantized_matmul.metal +src/tiny_llm/quantize.py +src/tiny_llm/moe.py +``` + +Implement `grouped_quantized_matmul`, then use it from `grouped_expert_linear`. +This is the quantized grouped-matmul core of MoE. + +`grouped_quantized_matmul` accepts: + +```plain +a: R, D +w_experts: packed QuantizedWeights for num_experts, output_dim, D +expert_ids: R, sorted by expert id +``` + +It returns: + +```plain +out: R, output_dim +``` + +Each row uses the expert selected by the matching row in `expert_ids`: + +```plain +out[row] = a[row] @ dequantize(w_experts[expert_ids[row]]).T +``` + +The implementation should: + +```plain +1. add a Python wrapper for grouped_quantized_matmul, +2. extend the quantized matmul extension with a grouped entrypoint, +3. read expert_ids[row] inside the kernel, +4. use that expert id to choose the expert weight, scale, and bias row. +``` + +After that, implement `grouped_expert_linear` in `src/tiny_llm/moe.py`: + +```plain +1. flatten token rows and expert ids, +2. sort rows by expert id, +3. call grouped_quantized_matmul, +4. restore the original order. +``` + +The call should look like: + +```python +out = grouped_quantized_matmul( + w_experts.scales, + w_experts.biases, + group_size=w_experts.group_size, + bits=w_experts.bits, + a=grouped_rows, + b=w_experts.weight, + expert_ids=grouped_expert_ids, + transpose_b=True, +) +``` + +This task maps to the same idea as `QuantizedSwitchLinear` in `mlx-lm`: each +token row uses a different packed expert matrix, and the expert ids choose the +right matrix. + +## Task 2: Router Top-k + +``` +src/tiny_llm/moe.py +``` + +Implement `route_topk`. It accepts hidden states and router weights, then +returns: + +- router probabilities +- selected expert ids +- selected expert scores + +Use `quantized_linear` and `softmax`. Use `mx.argpartition` to select the top +`num_experts_per_tok` experts, then `mx.take_along_axis` to gather their scores. + +Keep `norm_topk_prob` as an argument because Qwen3-MoE stores this behavior in +the model config. + +## Task 3: Qwen3 Sparse MoE Block + +``` +src/tiny_llm/moe.py +``` + +Implement `Moe` by composing Task 1 and Task 2: + +```plain +hidden states -> route_topk +hidden states + expert ids -> grouped gate projection +hidden states + expert ids -> grouped up projection +SiLU(gate) * up -> grouped down projection +weighted sum over num_experts_per_tok +``` + +This completes the Qwen3-MoE sparse feed-forward block. There is no shared expert +branch in this block. + +## Task 4: Integrate Qwen3-MoE Layers + +``` +src/tiny_llm/qwen3_week3.py +src/tiny_llm/models.py +``` + +Add a Qwen3-MoE loader path that reuses the Week 3 Qwen3 attention and paged KV +cache behavior, but swaps selected block MLPs for `Moe`. + +The model wrapper should: + +- keep Qwen3 attention unchanged, +- use regular `Qwen3MLP` for `mlp_only_layers`, +- use `Moe` for sparse layers selected by + `decoder_sparse_step`, +- load router and expert weights as `QuantizedWeights` from the Qwen3-MoE MLX + model, +- preserve the same decode call shape: + +```python +logits = model(tokens, offset, cache) +``` + +No scheduler API change in `src/tiny_llm/batch.py` is required for correctness. + +Run this task through the normal generation entrypoints instead of adding a +separate unit test. For example: + +```bash +hf download Qwen/Qwen3-30B-A3B-MLX-4bit + +pdm run main --solution tiny_llm --loader week3 --model qwen3-30b-a3b \ + --prompt "Give me a short introduction to mixture of experts." + +pdm run batch-main --solution tiny_llm --loader week3 --model qwen3-30b-a3b \ + --batch-size 2 --prefill-step 16 +``` + +{{#include copyright.md}} diff --git a/pyproject.toml b/pyproject.toml index f02f317..64079ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ clean-ext-ref.working_dir = "src/extensions_ref" main.cmd = "python main.py" main-week1.cmd = "python main.py --loader week1" main-week2.cmd = "python main.py --loader week2" +main-week3.cmd = "python main.py --loader week3" bench.cmd = "python bench.py" batch-main.cmd = "python batch-main.py" test.cmd = "python scripts/dev-tools.py test" diff --git a/src/tiny_llm/__init__.py b/src/tiny_llm/__init__.py index 17790ff..0fae713 100644 --- a/src/tiny_llm/__init__.py +++ b/src/tiny_llm/__init__.py @@ -15,3 +15,4 @@ from .paged_kv_cache import * from .batch import * from .models import * +from .moe import * diff --git a/src/tiny_llm/models.py b/src/tiny_llm/models.py index a9acc1b..a8568ea 100644 --- a/src/tiny_llm/models.py +++ b/src/tiny_llm/models.py @@ -13,6 +13,8 @@ def shortcut_name_to_full_name(shortcut_name: str): return "Qwen/Qwen3-1.7B-MLX-4bit" elif lower_shortcut_name == "qwen3-4b": return "Qwen/Qwen3-4B-MLX-4bit" + elif lower_shortcut_name in ("qwen3-30b-a3b", "qwen3-moe-30b-a3b"): + return "Qwen/Qwen3-30B-A3B-MLX-4bit" else: return shortcut_name diff --git a/src/tiny_llm/moe.py b/src/tiny_llm/moe.py new file mode 100644 index 0000000..7b5019b --- /dev/null +++ b/src/tiny_llm/moe.py @@ -0,0 +1,36 @@ +import mlx.core as mx + +from .quantize import QuantizedWeights + + +def grouped_expert_linear( + x: mx.array, + w_experts: QuantizedWeights, + expert_ids: mx.array, +) -> mx.array: + pass + + +def route_topk( + x: mx.array, + w_router: QuantizedWeights, + top_k: int, + norm_topk_prob: bool = False, +) -> tuple[mx.array, mx.array, mx.array]: + pass + + +class Moe: + def __init__( + self, + w_router: QuantizedWeights, + w_gate: QuantizedWeights, + w_up: QuantizedWeights, + w_down: QuantizedWeights, + num_experts_per_tok: int, + norm_topk_prob: bool = False, + ): + pass + + def __call__(self, x: mx.array) -> mx.array: + pass diff --git a/src/tiny_llm_ref/__init__.py b/src/tiny_llm_ref/__init__.py index 17790ff..0fae713 100644 --- a/src/tiny_llm_ref/__init__.py +++ b/src/tiny_llm_ref/__init__.py @@ -15,3 +15,4 @@ from .paged_kv_cache import * from .batch import * from .models import * +from .moe import * diff --git a/src/tiny_llm_ref/embedding.py b/src/tiny_llm_ref/embedding.py index 0ecf842..c8be2b6 100644 --- a/src/tiny_llm_ref/embedding.py +++ b/src/tiny_llm_ref/embedding.py @@ -4,7 +4,12 @@ class Embedding: - def __init__(self, vocab_size: int, embedding_dim: int, weight: mx.array): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + weight: mx.array, + ): self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.weight = weight @@ -17,16 +22,22 @@ def as_linear(self, x: mx.array) -> mx.array: class QuantizedEmbedding: - def __init__(self, vocab_size: int, embedding_dim: int, weight: QuantizedWeights): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + weight: QuantizedWeights, + ): self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.weight = weight def __call__(self, x: mx.array) -> mx.array: + biases = self.weight.biases[x] if self.weight.biases is not None else None return mx.dequantize( self.weight.weight[x], self.weight.scales[x], - self.weight.biases[x], + biases, self.weight.group_size, self.weight.bits, ) diff --git a/src/tiny_llm_ref/moe.py b/src/tiny_llm_ref/moe.py new file mode 100644 index 0000000..9e6d357 --- /dev/null +++ b/src/tiny_llm_ref/moe.py @@ -0,0 +1,89 @@ +import mlx.core as mx + +from .basics import silu +from .quantize import QuantizedWeights, quantized_linear + + +def grouped_expert_linear( + x: mx.array, + w_experts: QuantizedWeights, + expert_ids: mx.array, +) -> mx.array: + *leading_dims, D = x.shape + flat_x = x.reshape(-1, D) + flat_expert_ids = expert_ids.reshape(-1) + sort_idx = mx.argsort(flat_expert_ids) + inv_sort_idx = mx.argsort(sort_idx) + + grouped_x = flat_x[sort_idx] + grouped_expert_ids = flat_expert_ids[sort_idx] + out = mx.gather_qmm( + mx.expand_dims(grouped_x, -2), + w_experts.weight, + w_experts.scales, + w_experts.biases, + lhs_indices=mx.arange(grouped_x.shape[0]), + rhs_indices=grouped_expert_ids, + transpose=True, + group_size=w_experts.group_size, + bits=w_experts.bits, + sorted_indices=True, + ).squeeze(-2) + out_dim = w_experts.weight.shape[-2] + return out[inv_sort_idx].reshape(*leading_dims, out_dim) + + +def route_topk( + x: mx.array, + w_router: QuantizedWeights, + top_k: int, + norm_topk_prob: bool = False, +) -> tuple[mx.array, mx.array, mx.array]: + router_logits = quantized_linear(x, w_router) + router_probs = mx.softmax(router_logits, axis=-1, precise=True) + expert_ids = mx.argpartition(-router_probs, kth=top_k - 1, axis=-1)[..., :top_k] + expert_scores = mx.take_along_axis(router_probs, expert_ids, axis=-1) + if norm_topk_prob: + expert_scores = expert_scores / mx.sum(expert_scores, axis=-1, keepdims=True) + return router_probs, expert_ids, expert_scores + + +class Moe: + def __init__( + self, + w_router: QuantizedWeights, + w_gate: QuantizedWeights, + w_up: QuantizedWeights, + w_down: QuantizedWeights, + num_experts_per_tok: int, + norm_topk_prob: bool = False, + ): + self.w_router = w_router + self.w_gate = w_gate + self.w_up = w_up + self.w_down = w_down + self.num_experts_per_tok = num_experts_per_tok + self.norm_topk_prob = norm_topk_prob + + def __call__(self, x: mx.array) -> mx.array: + B, L, D = x.shape + _, expert_ids, expert_scores = route_topk( + x, + self.w_router, + top_k=self.num_experts_per_tok, + norm_topk_prob=self.norm_topk_prob, + ) + expanded_x = mx.broadcast_to( + mx.expand_dims(x, -2), + (B, L, self.num_experts_per_tok, D), + ).reshape(-1, D) + flat_expert_ids = expert_ids.reshape(-1) + + gate = grouped_expert_linear(expanded_x, self.w_gate, flat_expert_ids) + up = grouped_expert_linear(expanded_x, self.w_up, flat_expert_ids) + expert_output = grouped_expert_linear( + silu(gate) * up, + self.w_down, + flat_expert_ids, + ).reshape(B, L, self.num_experts_per_tok, D) + return mx.sum(expert_output * mx.expand_dims(expert_scores, -1), axis=-2) diff --git a/src/tiny_llm_ref/qwen3_week3.py b/src/tiny_llm_ref/qwen3_week3.py index 614c4ae..29fb30e 100644 --- a/src/tiny_llm_ref/qwen3_week3.py +++ b/src/tiny_llm_ref/qwen3_week3.py @@ -5,9 +5,10 @@ from .layer_norm import RMSNorm from .positional_encoding import RoPE from typing import Any -from .embedding import Embedding -from .quantize import dequantize_linear, QuantizedWeights, quantized_linear +from .embedding import QuantizedEmbedding +from .quantize import QuantizedWeights, quantized_linear from .kv_cache import TinyKvCache +from .moe import Moe from .paged_kv_cache import TinyKvPagedCache, TinyKvPagedPool @@ -122,7 +123,6 @@ def __init__( num_kv_heads: int, hidden_size: int, head_dim: int, - intermediate_size: int, rms_norm_eps: float, wq: QuantizedWeights, wk: QuantizedWeights, @@ -130,17 +130,15 @@ def __init__( wo: QuantizedWeights, q_norm: mx.array, k_norm: mx.array, - w_gate: QuantizedWeights, - w_up: QuantizedWeights, - w_down: QuantizedWeights, w_input_layernorm: mx.array, w_post_attention_layernorm: mx.array, + mlp: Qwen3MLP | Moe, max_seq_len: int = 32768, theta: int = 1000000, ): self.num_attention_heads = num_attention_heads self.hidden_size = hidden_size - self.mlp = Qwen3MLP(hidden_size, intermediate_size, w_gate, w_up, w_down) + self.mlp = mlp self.input_layernorm = RMSNorm(hidden_size, w_input_layernorm, eps=rms_norm_eps) self.post_attention_layernorm = RMSNorm( hidden_size, w_post_attention_layernorm, eps=rms_norm_eps @@ -175,6 +173,14 @@ def __call__( return out +def is_qwen3_moe_sparse_layer(args: Any, layer_idx: int) -> bool: + return ( + getattr(args, "num_experts", 0) > 0 + and layer_idx not in getattr(args, "mlp_only_layers", []) + and (layer_idx + 1) % getattr(args, "decoder_sparse_step", 1) == 0 + ) + + class Qwen3ModelWeek3: def __init__( self, @@ -191,10 +197,10 @@ def __init__( precision = mx.bfloat16 self.precision = precision - self.embedding = Embedding( + self.embedding = QuantizedEmbedding( vocab_size=self.vocab_size, embedding_dim=self.hidden_size, - weight=dequantize_linear(mlx_model.model.embed_tokens), + weight=QuantizedWeights.from_mlx_layer(mlx_model.model.embed_tokens), ) self.layers_inner = [] @@ -211,22 +217,43 @@ def __init__( wo = QuantizedWeights.from_mlx_layer( mlx_model.model.layers[i].self_attn.o_proj ) - w_gate = QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.gate_proj - ) - w_up = QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.up_proj - ) - w_down = QuantizedWeights.from_mlx_layer( - mlx_model.model.layers[i].mlp.down_proj - ) + if is_qwen3_moe_sparse_layer(mlx_model.args, i): + mlp = Moe( + w_router=QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.gate + ), + w_gate=QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.switch_mlp.gate_proj + ), + w_up=QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.switch_mlp.up_proj + ), + w_down=QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.switch_mlp.down_proj + ), + num_experts_per_tok=mlx_model.args.num_experts_per_tok, + norm_topk_prob=mlx_model.args.norm_topk_prob, + ) + else: + mlp = Qwen3MLP( + mlx_model.args.hidden_size, + mlx_model.args.intermediate_size, + QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.gate_proj + ), + QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.up_proj + ), + QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.down_proj + ), + ) layer = Qwen3TransformerBlock( num_attention_heads=mlx_model.args.num_attention_heads, num_kv_heads=mlx_model.args.num_key_value_heads, hidden_size=mlx_model.args.hidden_size, head_dim=mlx_model.args.head_dim, - intermediate_size=mlx_model.args.intermediate_size, rms_norm_eps=mlx_model.args.rms_norm_eps, wq=wq, wk=wk, @@ -234,13 +261,11 @@ def __init__( wo=wo, q_norm=mlx_model.model.layers[i].self_attn.q_norm.weight, k_norm=mlx_model.model.layers[i].self_attn.k_norm.weight, - w_gate=w_gate, - w_up=w_up, - w_down=w_down, w_input_layernorm=mlx_model.model.layers[i].input_layernorm.weight, w_post_attention_layernorm=mlx_model.model.layers[ i ].post_attention_layernorm.weight, + mlp=mlp, max_seq_len=mlx_model.args.max_position_embeddings, theta=mlx_model.args.rope_theta, ) diff --git a/tests_refsol/test_week_3_day_2.py b/tests_refsol/test_week_3_day_2.py index d86ef2e..efcd2e7 100644 --- a/tests_refsol/test_week_3_day_2.py +++ b/tests_refsol/test_week_3_day_2.py @@ -203,5 +203,9 @@ def test_task_3_incremental_decode_matches_week2_with_paged_attention(): week2_out = week2_out - mx.logsumexp(week2_out, keepdims=True) week3_out = week3_out - mx.logsumexp(week3_out, keepdims=True) assert_allclose( - week3_out, week2_out, precision=mx.bfloat16, rtol=1e-3, atol=1e-3 + week3_out, + week2_out, + precision=mx.bfloat16, + rtol=1e-3, + atol=1e-3, ) diff --git a/tests_refsol/test_week_3_day_3.py b/tests_refsol/test_week_3_day_3.py new file mode 100644 index 0000000..7bf82ab --- /dev/null +++ b/tests_refsol/test_week_3_day_3.py @@ -0,0 +1,131 @@ +from types import SimpleNamespace + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.qwen3_moe import ( + Qwen3MoeSparseMoeBlock as MlxQwen3MoeSparseMoeBlock, +) +from mlx_lm.models.switch_layers import SwitchLinear + +from .tiny_llm_base import ( + Moe, + QuantizedWeights, + grouped_expert_linear, + route_topk, +) +from .utils import assert_allclose + + +def test_task_1_grouped_expert_linear(): + mx.random.seed(1) + scale = 0.25 + x = mx.random.normal(shape=(2, 3, 128), dtype=mx.float16) * scale + w_experts = mx.random.normal(shape=(3, 64, 128), dtype=mx.float16) * scale + expert_ids = mx.array( + [ + [2, 0, 1], + [1, 2, 0], + ], + dtype=mx.uint32, + ) + + ref = SwitchLinear( + input_dims=w_experts.shape[-1], + output_dims=w_experts.shape[-2], + num_experts=w_experts.shape[0], + bias=False, + ) + ref.weight = w_experts + ref = ref.to_quantized(group_size=128, bits=4) + + out = grouped_expert_linear( + x, + QuantizedWeights.from_mlx_layer(ref), + expert_ids, + ) + expected = ref(mx.expand_dims(x, -2), expert_ids).squeeze(-2) + + assert out.shape == (2, 3, 64) + assert_allclose(out, expected, precision=mx.float16) + + +def test_task_2_router_topk(): + mx.random.seed(2) + scale = 0.25 + x = mx.random.normal(shape=(2, 2, 128), dtype=mx.float16) * scale + ref = nn.Linear(128, 4, bias=False) + ref.weight = mx.random.normal(shape=(4, 128), dtype=mx.float16) * scale + ref = ref.to_quantized(group_size=128, bits=4) + + router_probs, expert_ids, expert_scores = route_topk( + x, + QuantizedWeights.from_mlx_layer(ref), + top_k=2, + ) + _, _, normalized_scores = route_topk( + x, + QuantizedWeights.from_mlx_layer(ref), + top_k=2, + norm_topk_prob=True, + ) + + expected_probs = mx.softmax(ref(x), axis=-1, precise=True) + expected_ids = mx.argpartition(-expected_probs, kth=1, axis=-1)[..., :2] + expected_scores = mx.take_along_axis(expected_probs, expected_ids, axis=-1) + expected_normalized_scores = expected_scores / expected_scores.sum( + axis=-1, + keepdims=True, + ) + + assert router_probs.shape == (2, 2, 4) + assert expert_ids.shape == (2, 2, 2) + assert expert_scores.shape == (2, 2, 2) + assert expert_ids.tolist() == expected_ids.tolist() + assert_allclose(router_probs, expected_probs, precision=mx.float16) + assert_allclose(expert_scores, expected_scores, precision=mx.float16) + assert_allclose( + normalized_scores, + expected_normalized_scores, + precision=mx.float16, + ) + + +def test_task_3_moe(): + mx.random.seed(3) + scale = 0.25 + x = mx.random.normal(shape=(2, 3, 128), dtype=mx.float16) * scale + ref = MlxQwen3MoeSparseMoeBlock( + SimpleNamespace( + hidden_size=128, + moe_intermediate_size=128, + num_experts=3, + num_experts_per_tok=2, + norm_topk_prob=True, + ) + ) + ref.gate.weight = mx.random.normal(shape=(3, 128), dtype=mx.float16) * scale + ref.switch_mlp.gate_proj.weight = ( + mx.random.normal(shape=(3, 128, 128), dtype=mx.float16) * scale + ) + ref.switch_mlp.up_proj.weight = ( + mx.random.normal(shape=(3, 128, 128), dtype=mx.float16) * scale + ) + ref.switch_mlp.down_proj.weight = ( + mx.random.normal(shape=(3, 128, 128), dtype=mx.float16) * scale + ) + nn.quantize(ref, group_size=128, bits=4) + + moe = Moe( + w_router=QuantizedWeights.from_mlx_layer(ref.gate), + w_gate=QuantizedWeights.from_mlx_layer(ref.switch_mlp.gate_proj), + w_up=QuantizedWeights.from_mlx_layer(ref.switch_mlp.up_proj), + w_down=QuantizedWeights.from_mlx_layer(ref.switch_mlp.down_proj), + num_experts_per_tok=2, + norm_topk_prob=True, + ) + + out = moe(x) + expected = ref(x) + + assert out.shape == x.shape + assert_allclose(out, expected, precision=mx.float16)