Add DiffusionGemma (text): block-diffusion model + diffusion_generate#1402
Draft
gilbert-barajas wants to merge 6 commits into
Draft
Add DiffusionGemma (text): block-diffusion model + diffusion_generate#1402gilbert-barajas wants to merge 6 commits into
gilbert-barajas wants to merge 6 commits into
Conversation
First Apple-repo contribution (ml-explore/mlx-lm ml-explore#1391, lane confirmed open 2026-06-13). Text-only port of DiffusionGemma, mirroring mlx-vlm's working DiffusionGemma4Backbone (credit @Blaizzy) 1:1 so wire weights load directly. - ModelArgs (mirrors HF DiffusionGemmaTextConfig) + the 5-step build map. - Slice 1 building blocks: MLP, Router (scale + per_expert_scale), Experts (SwitchLinear), Attention (v_norm, k==v on full layers, global wide heads, per-layer-type RoPE, decoder cache-concat), SelfConditioning, DecoderLayer (4-norm summed MLP+MoE FF × layer_scalar). - Smoke-green (sliding+full forward). Next: encoder/decoder models, the encoder-decoder cache, Model+sanitize, diffusion_generate sampler. Logits parity vs transformers pending (torch env). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Slices 2-3: DecoderModel (canvas denoiser + self-conditioning + per-layer-type decoder masks), text-only EncoderModel (prompt prefill via the decoder's tied layers in encoder mode + encoder layer-scalars), DiffusionGemma4Backbone, and the top-level Model (mlx-lm conventions: __call__→softcapped logits, make_cache, layers, sanitize, quant_predicate). KEY RESULT: the encoder→decoder cache-concat works with STANDARD mlx-lm KVCache/RotatingKVCache (StaticPrefixKVCache only needed for the static fast-path) — Fable's 'encoder-decoder breaks mlx-lm make_prompt_cache' risk resolved. Smoke-green: prompt -> prefill -> canvas denoise -> logits, softcap respected, all finite. Remaining: diffusion_generate sampler (slice 5), logits parity vs transformers (torch env), and the encoder-bidirectional-'all' parity item (flagged in code). Mirrors mlx-vlm (credit @Blaizzy). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…to-end Slice 5: the block-diffusion sampler. Prefill prompt -> denoise a random canvas over <=max_denoising_steps: decode (w/ self-conditioning) -> reverse-schedule temperature (0.8->0.4) -> EntropyBound accept (the k lowest-entropy tokens whose cumulative entropy minus the max is <= bound) -> renoise the rest -> adaptive stop (argmax-canvas stable AND mean entropy < 0.005). Mirrors the HF generation semantics (credit @Blaizzy / the transformers reference). Caught + fixed a schedule-direction bug: HF's cur_step counts DOWN (reverse diffusion), ours counts up, so the naive port ran temperature 0.4->0.8 backwards; inverted to the correct 0.8->0.4. The full DiffusionGemma text model + sampler now runs end-to-end in mlx-lm (smoke-green: forward + generate + sanitize). 785 lines. Remaining (slice 6): logits parity vs transformers (torch), encoder-bidi-'all', real-weights load, mlx-lm registry/CLI wiring, and the block-autoregressive outer loop (5b). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…cation A 12-agent adversarial review (5 slices reviewed vs HF + mlx-vlm, each finding verified) surfaced 6 confirmed divergences; fixed the real ones: - HIGH: diffusion_generate committed the RENOISED canvas (random tokens at unconverged positions) instead of HF argmax(logits). Now returns the per-position argmax canvas. - HIGH: self-conditioning carried RAW logits; HF carries the temperature-SCALED (processed) logits (softmax differs under temperature), perturbing every step. Now carries the scaled logits. - MEDIUM: the encoder silently dropped the padding attention_mask (PAD tokens attended + cached). Threaded attention_mask through Model->backbone->encoder, with an explicit causal+sliding mask AND key-mask when one is supplied (mirrors mlx-vlm); fast no-mask path kept. - LOW: per-step renoise reused a fixed key (degenerate on the seeded path) — now splits the key each step. Deferred LOW (documented in-file): sliding-window decode for >512 prompts (inherited from mlx-vlm), no per-row B>1 freeze, the all sliding_window halve (verify vs real config). MoE parity still exact; full forward+pad-mask+generate smoke-green. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Adds from_dict that flattens the conversion config (text fields are nested under text_config; canvas_length is top-level) so mlx-lm's loader builds the right text-only args. With this + the existing sanitize, the real mlx-community diffusiongemma-26B-A4B-it conversions load with ZERO key/shape mismatches. PROVEN ON THE REAL MODEL: mlx-community/diffusiongemma-26B-A4B-it-4bit loads (hidden=2816, 30 layers, 128 experts) and generates coherent text via diffusion_generate — 'Why is the sky blue?' -> a correct one-sentence answer, 256-token canvas in ~1.0s = ~250 tok/s on an M5 Max. First DiffusionGemma text port running the real 26B on Apple Silicon. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…_generate
Model exposes a diffusion_generate(prompt_ids, **kw) method; mlx_lm.generate's
main() dispatches to it when present (generic hasattr check — no model-specific
import in generate.py), instead of the autoregressive token loop. A model that
denoises a canvas opts in by exposing the method.
Verified end-to-end via the real conversion through mlx-lm's standard load()
(from_dict + sanitize + quantization all integrate):
mlx_lm.generate --model mlx-community/diffusiongemma-26B-A4B-it-4bit \
--prompt 'Why is the sky blue?' --verbose true
-> the model's thinking channel + a correct one-sentence answer, ~87 tok/s.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
DiffusionGemma (text) for mlx-lm
A text-only port of DiffusionGemma (
google/diffusiongemma-26B-A4B-it) — Google'sencoder-decoder block-diffusion MoE built on the Gemma 4 stack. Addresses the request
in #1391 (text serving in mlx-lm; mlx-vlm already covers the multimodal path).
Instead of emitting tokens autoregressively, the decoder denoises a fixed 256-token
canvas over reverse-diffusion steps (entropy-bound acceptance + renoise, temperature
schedule, adaptive stop), with the prompt prefilled into an encoder KV cache.
Credit
This mirrors @Blaizzy's working mlx-vlm
DiffusionGemma4Backbone1:1 (so the existingmlx-community/diffusiongemma-*conversions load directly, no reconvert) and uses the HFtransformersreference for the sampler semantics. Grateful for both as parity targets.What works
diffusion_generate(a sibling to the AR loop, sincethe canvas-denoising contract doesn't fit
generate_step).mlx-community/diffusiongemma-26B-A4B-itconversions (4-bit and bf16)with zero key/shape mismatches —
sanitizedrops the vision tower + handles the MoEweight names;
from_dictflattens the nestedtext_config.mlx_lm.generatedispatches todiffusion_generatewhen the model exposes it(generic
hasattrcheck; no model-specific import ingenerate.py).Verification
transformers: Router top-k indices exact, Expertsmax abs diff ~1e-9 (machine precision).
bugs (committing the argmax canvas vs the renoised one; carrying temperature-scaled vs raw
logits for self-conditioning).
Receipts (Apple Silicon, M5 Max)
Running the real 26B-A4B through this port (256-token canvas, cold):
Coherent on prose and code (an
is_primewith the 6k±1 rule).Status / remaining (hence draft)
canvas_length); the >256-token block-autoregressive outerloop is a follow-up.
diffusion_generatethe seam you'd prefer, or dispatch insidegenerate? Appetite formlx_lm.serversupport here vs a follow-up? Any preferred encoder-decoder cache convention?Happy to iterate on the shape — opening this as a draft to move the conversation from the
issue to actual code. Closes #1391 when complete.