Skip to content

Add DiffusionGemma (text): block-diffusion model + diffusion_generate#1402

Draft
gilbert-barajas wants to merge 6 commits into
ml-explore:mainfrom
gilbert-barajas:diffusion-gemma-port
Draft

Add DiffusionGemma (text): block-diffusion model + diffusion_generate#1402
gilbert-barajas wants to merge 6 commits into
ml-explore:mainfrom
gilbert-barajas:diffusion-gemma-port

Conversation

@gilbert-barajas

Copy link
Copy Markdown

DiffusionGemma (text) for mlx-lm

A text-only port of DiffusionGemma (google/diffusiongemma-26B-A4B-it) — Google's
encoder-decoder block-diffusion MoE built on the Gemma 4 stack. Addresses the request
in #1391 (text serving in mlx-lm; mlx-vlm already covers the multimodal path).

Instead of emitting tokens autoregressively, the decoder denoises a fixed 256-token
canvas over reverse-diffusion steps (entropy-bound acceptance + renoise, temperature
schedule, adaptive stop), with the prompt prefilled into an encoder KV cache.

Credit

This mirrors @Blaizzy's working mlx-vlm DiffusionGemma4Backbone 1:1 (so the existing
mlx-community/diffusiongemma-* conversions load directly, no reconvert) and uses the HF
transformers reference for the sampler semantics. Grateful for both as parity targets.

What works

  • Full encoder-decoder model + a top-level diffusion_generate (a sibling to the AR loop, since
    the canvas-denoising contract doesn't fit generate_step).
  • Loads the existing mlx-community/diffusiongemma-26B-A4B-it conversions (4-bit and bf16)
    with zero key/shape mismatchessanitize drops the vision tower + handles the MoE
    weight names; from_dict flattens the nested text_config.
  • CLI: mlx_lm.generate dispatches to diffusion_generate when the model exposes it
    (generic hasattr check; no model-specific import in generate.py).

Verification

  • MoE numerically parity-checked vs transformers: Router top-k indices exact, Experts
    max abs diff ~1e-9 (machine precision).
  • Adversarial code review against the HF + mlx-vlm reference caught and fixed two sampler
    bugs (committing the argmax canvas vs the renoised one; carrying temperature-scaled vs raw
    logits for self-conditioning).

Receipts (Apple Silicon, M5 Max)

Running the real 26B-A4B through this port (256-token canvas, cold):

build size tok/s
4-bit 16.5 GB ~250
bf16 (full precision) 51.6 GB ~121
mlx_lm.generate --model mlx-community/diffusiongemma-26B-A4B-it-4bit \
    --prompt "Why is the sky blue? Answer in one sentence."
# -> "The sky is blue because Earth's atmosphere scatters shorter wavelengths
#     of light (blue and violet) more effectively than other colors."

Coherent on prose and code (an is_prime with the 6k±1 rule).

Status / remaining (hence draft)

  • Single-canvas generation (≤ canvas_length); the >256-token block-autoregressive outer
    loop
    is a follow-up.
  • mlx-lm-format unit tests + a logits-parity test to add.
  • Open shape questions for maintainers (from my Model type diffusion_gemma not supported. #1391 comment, still unanswered): is a separate
    diffusion_generate the seam you'd prefer, or dispatch inside generate? Appetite for
    mlx_lm.server support here vs a follow-up? Any preferred encoder-decoder cache convention?

Happy to iterate on the shape — opening this as a draft to move the conversation from the
issue to actual code. Closes #1391 when complete.

gilbert-barajas and others added 6 commits June 13, 2026 18:06
First Apple-repo contribution (ml-explore/mlx-lm ml-explore#1391, lane confirmed open
2026-06-13). Text-only port of DiffusionGemma, mirroring mlx-vlm's working
DiffusionGemma4Backbone (credit @Blaizzy) 1:1 so wire weights load directly.

- ModelArgs (mirrors HF DiffusionGemmaTextConfig) + the 5-step build map.
- Slice 1 building blocks: MLP, Router (scale + per_expert_scale), Experts
  (SwitchLinear), Attention (v_norm, k==v on full layers, global wide heads,
  per-layer-type RoPE, decoder cache-concat), SelfConditioning, DecoderLayer
  (4-norm summed MLP+MoE FF × layer_scalar).
- Smoke-green (sliding+full forward). Next: encoder/decoder models, the
  encoder-decoder cache, Model+sanitize, diffusion_generate sampler. Logits
  parity vs transformers pending (torch env).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Slices 2-3: DecoderModel (canvas denoiser + self-conditioning + per-layer-type
decoder masks), text-only EncoderModel (prompt prefill via the decoder's tied
layers in encoder mode + encoder layer-scalars), DiffusionGemma4Backbone, and the
top-level Model (mlx-lm conventions: __call__→softcapped logits, make_cache,
layers, sanitize, quant_predicate).

KEY RESULT: the encoder→decoder cache-concat works with STANDARD mlx-lm
KVCache/RotatingKVCache (StaticPrefixKVCache only needed for the static fast-path)
— Fable's 'encoder-decoder breaks mlx-lm make_prompt_cache' risk resolved.

Smoke-green: prompt -> prefill -> canvas denoise -> logits, softcap respected,
all finite. Remaining: diffusion_generate sampler (slice 5), logits parity vs
transformers (torch env), and the encoder-bidirectional-'all' parity item (flagged
in code). Mirrors mlx-vlm (credit @Blaizzy).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…to-end

Slice 5: the block-diffusion sampler. Prefill prompt -> denoise a random canvas
over <=max_denoising_steps: decode (w/ self-conditioning) -> reverse-schedule
temperature (0.8->0.4) -> EntropyBound accept (the k lowest-entropy tokens whose
cumulative entropy minus the max is <= bound) -> renoise the rest -> adaptive stop
(argmax-canvas stable AND mean entropy < 0.005). Mirrors the HF generation
semantics (credit @Blaizzy / the transformers reference).

Caught + fixed a schedule-direction bug: HF's cur_step counts DOWN (reverse
diffusion), ours counts up, so the naive port ran temperature 0.4->0.8 backwards;
inverted to the correct 0.8->0.4.

The full DiffusionGemma text model + sampler now runs end-to-end in mlx-lm
(smoke-green: forward + generate + sanitize). 785 lines. Remaining (slice 6):
logits parity vs transformers (torch), encoder-bidi-'all', real-weights load,
mlx-lm registry/CLI wiring, and the block-autoregressive outer loop (5b).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…cation

A 12-agent adversarial review (5 slices reviewed vs HF + mlx-vlm, each finding
verified) surfaced 6 confirmed divergences; fixed the real ones:

- HIGH: diffusion_generate committed the RENOISED canvas (random tokens at
  unconverged positions) instead of HF argmax(logits). Now returns the
  per-position argmax canvas.
- HIGH: self-conditioning carried RAW logits; HF carries the temperature-SCALED
  (processed) logits (softmax differs under temperature), perturbing every step.
  Now carries the scaled logits.
- MEDIUM: the encoder silently dropped the padding attention_mask (PAD tokens
  attended + cached). Threaded attention_mask through Model->backbone->encoder,
  with an explicit causal+sliding mask AND key-mask when one is supplied (mirrors
  mlx-vlm); fast no-mask path kept.
- LOW: per-step renoise reused a fixed key (degenerate on the seeded path) — now
  splits the key each step.

Deferred LOW (documented in-file): sliding-window decode for >512 prompts
(inherited from mlx-vlm), no per-row B>1 freeze, the all sliding_window halve
(verify vs real config). MoE parity still exact; full forward+pad-mask+generate
smoke-green.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Adds from_dict that flattens the conversion config (text fields are nested under
text_config; canvas_length is top-level) so mlx-lm's loader builds the right
text-only args. With this + the existing sanitize, the real mlx-community
diffusiongemma-26B-A4B-it conversions load with ZERO key/shape mismatches.

PROVEN ON THE REAL MODEL: mlx-community/diffusiongemma-26B-A4B-it-4bit loads
(hidden=2816, 30 layers, 128 experts) and generates coherent text via
diffusion_generate — 'Why is the sky blue?' -> a correct one-sentence answer,
256-token canvas in ~1.0s = ~250 tok/s on an M5 Max. First DiffusionGemma text
port running the real 26B on Apple Silicon.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…_generate

Model exposes a diffusion_generate(prompt_ids, **kw) method; mlx_lm.generate's
main() dispatches to it when present (generic hasattr check — no model-specific
import in generate.py), instead of the autoregressive token loop. A model that
denoises a canvas opts in by exposing the method.

Verified end-to-end via the real conversion through mlx-lm's standard load()
(from_dict + sanitize + quantization all integrate):
  mlx_lm.generate --model mlx-community/diffusiongemma-26B-A4B-it-4bit \
      --prompt 'Why is the sky blue?' --verbose true
  -> the model's thinking channel + a correct one-sentence answer, ~87 tok/s.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Model type diffusion_gemma not supported.

1 participant