diff --git a/.gitignore b/.gitignore index ceee5a944..69496541c 100644 --- a/.gitignore +++ b/.gitignore @@ -71,3 +71,6 @@ fix-plan.md .env.local *.pem *.key + +# Sisyphus working dossier (preserve outside repo) +.sisyphus/ diff --git a/.gitmodules b/.gitmodules index d664da54e..bc501bf80 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,7 @@ [submodule "dflash/deps/llama.cpp"] path = dflash/deps/llama.cpp - url = https://github.com/Luce-Org/llama.cpp-dflash-ggml.git - branch = luce-dflash + url = https://github.com/dusterbloom/llama-cpp-turboquant-cuda.git + branch = feature/tq3-kv-cache-clean [submodule "dflash/deps/Block-Sparse-Attention"] path = dflash/deps/Block-Sparse-Attention url = https://github.com/mit-han-lab/Block-Sparse-Attention.git diff --git a/.sisyphus/plans/20260428-1430-path-b-deltanet-wmma-scope.md b/.sisyphus/plans/20260428-1430-path-b-deltanet-wmma-scope.md deleted file mode 100644 index 00210ce28..000000000 --- a/.sisyphus/plans/20260428-1430-path-b-deltanet-wmma-scope.md +++ /dev/null @@ -1,276 +0,0 @@ -# Scope: Path B — WMMA tensor-core rewrite of `pf_dn_chunk_phase2` (Ampere sm_86) - -**Date**: 2026-04-28 -**Target**: `megakernel/prefill.cu` — `pf_dn_chunk_phase2` only (phase1 unchanged for v1) -**Why**: Profiler shows this kernel = 47.6% of prefill CUDA time (13.27 ms / 27.88 ms total) on RTX 3090 / 512-token prompt. Inner products are scalar FP32 in shared memory; tensor cores are unused. Existing GEMMs (Q/K/V proj, MLP) already go to cuBLAS tensor cores (`ampere_bf16_s16816gemm…`). - ---- - -## Goal & non-goals - -**Goal**: Replace the four matmul-shaped inner loops in phase2 with WMMA bf16 / f32-accum matmuls, with f32 accumulator semantics so numerical drift in the recurrence stays bounded. - -**Non-goals**: -- Phase1 rewrite (it has different math: cumsum, sigmoid, softplus, triangular forward-substitute — only one matmul-shaped piece, marginal win). -- Algorithmic changes to the recurrence (correctness reference is the existing scalar `pf_dn_chunk_phase2`). -- Backporting Blackwell `prefill_megakernel.cu` patterns (different launch model, persistent kernel, `cg::this_grid().sync()`). -- Multi-stream / overlap with FA layers. -- New CLI surface; this is a drop-in kernel replacement gated by a build-time flag for A/B. - -**Out-of-scope but worth noting**: state update is the only operation where reducing precision could plausibly compound across chunks. The other three ops are read-only against state; their bf16 truncation only affects per-chunk outputs. - ---- - -## Math: what's actually matmul-shaped - -Per chunk `n`, phase2 currently runs five distinct compute steps. Four are matmul-shaped with f32 inputs in shared memory: - -| Op | Formula | M × N × K | FMAs/chunk | WMMA fit | -|---|---|---:|---:|---| -| **d compute** | `d[c,j] = u[c,j] - Σ_d w[c,d]·state[j,d]` | 8 × 32 × 128 | 32K | m8n32k16, K-loops=8 | -| **QKt compute** | `QKt[c,s] = Σ_d Q[c,d]·K[s,d]` | 8 × 8 × 128 | 8K | M,N<16; **keep scalar** (only 7% of compute) | -| **o_inter compute** | `tmp[c,j] = Σ_d Q[c,d]·state[j,d]` | 8 × 32 × 128 | 32K | m8n32k16, K-loops=8 | -| **state update** | `state[j,i] = γ·state[j,i] + Σ_c d_scaled[c,j]·K[c,i]` | 32 × 128 × 8 | 32K | m16n16k16, K=8→pad 16 | -| **o_intra** | `o_intra[c,j] = Σ_{s≤c} (QKt[c,s]·exp(cs[c]-cs[s]))·d[s,j]` | 8 × 32 × 8 | 2K | tiny + triangular mask; **keep scalar** | - -WMMA-targeted ops total ~96K FMAs per chunk per block (out of ~106K total matmul-shaped). With N=64 chunks × 72 blocks = ~441M MAC over 13.27 ms today = **~33 GFLOPS realized**. Ampere bf16 tensor-core peak is 142 TFLOPS. Headroom is enormous, but most of the 13.27 ms is launch/sync/smem-load — not pure compute. - -## Constants we're working with - -``` -DN_HEADS = 16 DN_KEY = 128 DN_VAL = 128 -DN_CHUNK_C = 8 DN_PHASE2_J_SPLITS = 4 -DN_PHASE2_J_PER_BLOCK = 32 DN_PHASE2_BLOCK = 128 (4 warps) -Launch grid: (DN_HEADS, J_SPLITS) = (16, 4) = 64 blocks -__launch_bounds__(128, 1) → 1 block / SM = 8% theoretical occupancy -P2 dynamic smem: ~31 KB -``` - ---- - -## WMMA design - -### Fragment shapes - -Ampere sm_86 supports the following bf16 WMMA fragment shapes (``, `nvcuda::wmma`): - -| Shape | A | B | C | -|---|---|---|---| -| **m16n16k16** | matrix_a 16×16 bf16 | matrix_b 16×16 bf16 | accumulator 16×16 f32 | -| m8n32k16 | matrix_a 8×16 bf16 | matrix_b 16×32 bf16 | accumulator 8×32 f32 | -| m32n8k16 | matrix_a 32×16 bf16 | matrix_b 16×8 bf16 | accumulator 32×8 f32 | - -**Choice**: stick to **m16n16k16** for simplicity and uniform fragment lifetime across the three target ops. Tile mapping: - -- **d**, **o_inter** (M=8, N=32, K=128): one warp per op-instance. Fragment M=16 covers M=8 (with padding); N=32 needs 2 N-tiles. K=128 / 16 = 8 K-iters. So per warp: `2 × 8 = 16` mma.sync calls per op. -- **state update** (M=32, N=128, K=8): two warps cooperate. M=32 / 16 = 2 M-tiles; N=128 / 16 = 8 N-tiles. K=8 → pad to 16 = 1 K-iter. 16 mma.sync calls split across 2 warps = 8 per warp. - -Block has 4 warps. Mapping per chunk: -- Warps 0-1: state update (cooperating on M-tiles 0 and 1) -- Warps 2-3: d compute and o_inter (each does half the N-tiles) -- QKt and o_intra stay scalar (handled by all 4 warps before/after the WMMA region) - -### Data types - -| Tensor | Current | Proposed | -|---|---|---| -| `s_state` | f32 | f32 + bf16 staging (load f32 from global, write bf16 mirror for WMMA reads) | -| `s_w` | f32 | bf16 (load from f32 source, downcast on store) | -| `s_Q`, `s_K` | f32 (loaded from bf16 qkv_pre) | bf16 (load directly without f32 round-trip) | -| `s_u`, `s_d` | f32 | f32 (these are accumulator-side; written from WMMA C fragment via `store_matrix_sync`) | -| `state` (global) | f32 | **f32** — unchanged, decode kernel reads it. No format change at the boundary. | - -Accumulator stays f32 inside WMMA fragments. Down-conversion to bf16 only happens when feeding the next op's inputs, not in the persistent state. - -### Numerical risk - -The state update uses `state_new = γ·state_old + Σ d·K` with γ ≤ 1 (decay). Per-chunk error from bf16 truncation of `state_old` reads is ~2⁻⁸ relative. Over N=64 chunks, the recurrence damps errors via γ at each step rather than amplifying. Realistic worst-case: ~2⁻⁵ relative drift at chunk 64. This is well within bf16 inference tolerance for LLM logits (where end-of-prompt position tolerance is typically O(2⁻⁴) vs f32 reference). - -**Verification gate**: bench_pp_tg.py's existing correctness section runs "The capital of France is" and asserts the first generated token. New kernel must produce the same token. Additional guard: compare full pp520 output token IDs against the scalar reference for 100+ prompts. - ---- - -## Layout & smem budget - -New smem layout (sized for 1 block, dynamic smem stays under 100 KB): - -``` -s_state_f32 [J_per × DK_S] 32×129×4 = 16.5 KB -s_state_bf16 [J_per × DK_S] 32×129×2 = 8.3 KB // bf16 mirror, refreshed per chunk -s_u [DN_CHUNK_C × J_per] 8×32×4 = 1.0 KB // f32 accumulator -s_w_bf16 [DN_CHUNK_C × DK_S] 8×129×2 = 2.1 KB -s_Q_bf16 [DN_CHUNK_C × DK_S] 8×129×2 = 2.1 KB -s_K_bf16 [DN_CHUNK_C × DK_S] 8×129×2 = 2.1 KB -s_d [DN_CHUNK_C × J_per] 8×32×4 = 1.0 KB -s_qkt [DN_CHUNK_C × DN_CHUNK_C] 8×8×4 = 0.25 KB -s_cs, s_decay_rem [DN_CHUNK_C] 0.06 KB -───────────────────────────────────────────────── - TOTAL: ~33.4 KB -``` - -Slightly bigger than today's ~31 KB but well under Ampere's 100 KB-per-block ceiling. Refreshing the `s_state_bf16` mirror after each chunk's state update costs 32×129 = 4128 cvts per block, fully parallelizable across 128 threads. - ---- - -## Implementation plan - -### Phase 0 — preconditions (1 hr) - -1. Confirm RTX 3090's actual `cudaDeviceProp.sharedMemPerBlockOptin` (should be 100 KB, but verify). -2. Add a build-time flag in `setup.py`: `MEGAKERNEL_DN_PHASE2_WMMA=on|off` (default off). Plumbs through to a `#define DN_PHASE2_WMMA` in `prefill.cu` so the new kernel sits next to the existing scalar version under `#ifdef`. -3. Wire a Python env-var override (`MEGAKERNEL_DN_PHASE2_WMMA=1`) so we can A/B without rebuilding. - -### Phase 1 — instrumentation (1-2 hr) - -1. Rerun `diag_prefill_kernels.py` with `record_shapes=True` and CUPTI metrics (`profile_memory=True, with_stack=True`). Get per-launch SM occupancy and stall reasons for `pf_dn_chunk_phase2`. -2. Compute the actual mix: how much of the 13.27 ms is compute vs smem-load vs sync. (Ratio determines whether WMMA alone or WMMA+`cp.async` is the right swing.) -3. Save the baseline output token IDs for the correctness corpus (50 fixed prompts × 32-token completions). - -### Phase 2 — WMMA "d compute" only (3-5 hr) - -Smallest demonstrable win. Replace the d-compute loop (lines 580-589) with: - -```cpp -using namespace nvcuda::wmma; -constexpr int WM=16, WN=16, WK=16; - -// Convert s_state f32 → bf16 mirror once per chunk before WMMA region. -// (One-time at chunk start; refresh after state update at chunk end.) - -fragment a_w; -fragment b_state; -fragment c_d; - -int warp_id = tid / 32; -if (warp_id < 2) { // 2 warps cover N-tile 0 and 1 - int n_tile = warp_id; - fill_fragment(c_d, 0.f); - #pragma unroll - for (int kk = 0; kk < DN_KEY; kk += WK) { - load_matrix_sync(a_w, s_w_bf16 + kk, DK_S); // [C(8 padded to 16)][WK] - load_matrix_sync(b_state, s_state_bf16 + n_tile*WN*DK_S + kk, DK_S); // [J(WN)][WK] col-major - mma_sync(c_d, a_w, b_state, c_d); - } - // Subtract u and write to s_d. Layout depends on fragment storage convention; use - // store_matrix_sync into a temp f32 tile then subtract per-thread. - float tmp[16][16]; // (per-warp scratchpad in registers — actually use a smem tile) - store_matrix_sync(/*ptr=*/..., c_d, /*ldc=*/..., mem_row_major); - // ...subtract s_u in-place, write s_d -} -__syncthreads(); -``` - -Compare scalar vs WMMA d-output bit-for-bit difference on a fixed 16-token chunk. Acceptable: relative max diff < 2⁻⁶. Run `bench_pp_tg.py` correctness section; must pass. - -### Phase 3 — extend to o_inter (2 hr) - -Same pattern as d-compute with different operand sources (`s_Q` instead of `s_w`, output multiplied by `expf(cs[c])`). Re-verify. - -### Phase 4 — state update (4-6 hr) - -The trickiest one. Two warps cooperate on M=32 split as two M=16 tiles. Each warp produces 8 N-tiles of f32 accumulator, then writes back to `s_state_f32` after multiplying by `s_decay_total` and adding the existing state value. K=8 → pad to 16 with zero-fill in `s_K_bf16` rows 8-15 (caller of WMMA must zero those rows). Then refresh `s_state_bf16` mirror. - -Pseudocode: -```cpp -// d_scaled is column of M (J_per=32). K is column of K (Dk=128). -// We want: state[j, i] = γ·state[j, i] + Σ_c d_scaled[c, j] * K[c, i] -// -// Reframe as GEMM: state(M=32, N=128) += d_scaled.T(M=32, K=8) @ K(K=8, N=128) -// d_scaled.T means we transpose-on-load: WMMA matrix_a with col_major. - -fragment a_dT; // [J_per_tile=16][K_pad=16] -fragment b_K; // [K_pad=16][N_tile=16] -fragment c_st; - -if (warp_id < 2) { - int j_tile = warp_id; // 0 or 1, covers j_start+[0..16) or +[16..32) - for (int n_tile = 0; n_tile < 8; n_tile++) { - // Load existing state slice into accumulator fragment, scaled by γ - load_matrix_sync(c_st, s_state_f32 + j_tile*16*DK_S + n_tile*16, DK_S, mem_row_major); - scale_fragment(c_st, s_decay_total); // helper - load_matrix_sync(a_dT, s_d_bf16 + j_tile*16, J_per); // C-major as transpose - load_matrix_sync(b_K, s_K_bf16 + n_tile*16, DK_S); - mma_sync(c_st, a_dT, b_K, c_st); - store_matrix_sync(s_state_f32 + j_tile*16*DK_S + n_tile*16, c_st, DK_S, mem_row_major); - } -} -__syncthreads(); -// Refresh s_state_bf16 mirror for next chunk. -``` - -`d_scaled` needs to be available as bf16 too (it's currently f32 from the d-compute store). So phase 2's d-compute should write *both* an f32 copy (for o_intra still scalar) and a bf16 copy. Tradeoff: ~256 extra cvts per chunk per block, negligible. - -### Phase 5 — relax `__launch_bounds__` (1 hr) - -After the kernel rewrite, register pressure should drop (matmul work is in fragments, not in 32×scalar accumulators). Try `__launch_bounds__(128, 2)` then `(128, 4)`. Re-profile occupancy. - -### Phase 6 — `cp.async` pipelining (4-6 hr, optional) - -If profiler shows smem-load stalls dominate after Phases 2-5, overlap the next-chunk loads with the current-chunk WMMA via `__pipeline_memcpy_async` + double-buffered smem. Doubles the staging buffer cost (`s_w_bf16`, `s_Q_bf16`, `s_K_bf16` × 2) but should hide ~80% of load latency. - -This is the highest-risk phase — easy to get pipeline barriers wrong. Skip if Phases 2-5 already get us to ~30% improvement target. - -### Phase 7 — extend to phase1 (deferred) - -Phase1 is 2.0% of CUDA time per the profile. Even a 5x speedup buys 1.6%. Defer. - ---- - -## Verification harness - -1. **Bit-level**: compile both `pf_dn_chunk_phase2_scalar` (existing) and `pf_dn_chunk_phase2_wmma` (new) into the same .so. Add a debug Python entrypoint that runs both on identical inputs and reports `max_abs(diff)` and `max_rel(diff)` for every output element. Tolerance: 2⁻⁶ relative. - -2. **Token-level**: `bench_pp_tg.py` correctness section already runs an end-to-end prompt and asserts the predicted token. Extend with a fixed corpus of 50 prompts × 32 generated tokens each. Build the baseline corpus from current main; assert byte-equal token IDs from the new kernel. (Allow optional `--tolerate=N` mismatches as a release gate if precision drift turns out to be model-dependent.) - -3. **Performance**: rerun `diag_prefill_kernels.py` after each phase. Track `pf_dn_chunk_phase2_*` self-CUDA time in a CSV. Phase 2 alone target: -25% on this kernel. After Phase 4: -50%. After Phase 5+6: -65%. - -4. **Bench matrix**: pp520 across 5 fixed prompt lengths {64, 128, 256, 512} × 3 batches each. Avoid n_gen variability by measuring prefill only. - ---- - -## Risks - -| Risk | Likelihood | Impact | Mitigation | -|---|---|---|---| -| bf16 precision drift breaks logits parity | Med | High — visible token-level differences | Bit-level tolerance harness in Phase 2 catches early; if intolerable, try TF32 m16n16k8 fragments at 0.5x throughput | -| Fragment storage convention mismatches (row vs col) cause silent wrong results | High | High | Phase 2 has explicit bit comparison against scalar; resolve before extending | -| Register spills under WMMA + 4 warps × 2 blocks/SM | Med | Med | `nvcc -Xptxas -v` after each phase; if spills appear, fall back to (128, 1) | -| smem mirror refresh introduces a new sync that hurts more than it saves | Low | Med | Bench Phase 2 in isolation; if regression, restructure to write bf16 directly from the cvt | -| Bank conflicts on bf16 stride 129 (different than f32 stride 129) | Low | Med | bf16 is 2-byte; 129 elements = 258 bytes. 32 banks × 4 bytes = 128 bytes = 64 bf16. Stride 129 ≠ multiple of 64 — already conflict-free. Verify via `__profile_*` smem counters. | -| Hardware variation: maybe RTX 3090 isn't the right test bed (downclocking, thermals) | Low | Low | Lock clocks via `nvidia-smi -lgc 2100` during bench; rerun before/after each phase | -| Real bottleneck turns out to be elsewhere (sync overhead, kernel launch cost) | Med | Med — caps the speedup | Phase 1's CUPTI metrics should reveal this before we commit to Phase 2-6. If launch overhead dominates, the answer is fewer launches (graph capture or megakernel-on-Ampere), not WMMA | - ---- - -## Estimated timeline - -| Phase | Optimistic | Realistic | Pessimistic | Cumulative | -|---|---|---|---|---| -| 0. Preconditions | 1 h | 2 h | 4 h | 4 h | -| 1. Instrumentation + corpus | 2 h | 4 h | 6 h | 10 h | -| 2. d compute WMMA | 3 h | 6 h | 12 h | 22 h | -| 3. o_inter WMMA | 2 h | 3 h | 6 h | 28 h | -| 4. state update WMMA | 4 h | 8 h | 14 h | 42 h | -| 5. launch_bounds tuning | 1 h | 2 h | 4 h | 46 h | -| 6. cp.async (optional) | 4 h | 8 h | 16 h | 62 h | -| Bench/QA/PR | 2 h | 4 h | 8 h | 70 h | - -Realistic: **~7 working days** (1.5 weeks at 5 h/day). Optimistic 19 h / 2.5 days, pessimistic 70 h / 2 weeks. - -## Expected payoff - -Bottoming-up from the 13.27 ms phase2 budget: -- Pure compute fraction (rough): if 60% of phase2 is matmul work and we get 4-8x on it via WMMA, that's `13.27 × 0.6 × (1 - 1/6) = 6.6 ms saved` → phase2 → ~6.7 ms, total prefill 21 ms → **~24,400 tok/s for 512 tokens**. -- After cp.async overlap of remaining smem loads: phase2 → ~4.5 ms, total 19 ms → **~27,000 tok/s**. -- Both well below the README's 37,800 tok/s claim. Confirms that WMMA alone won't close the gap; further work would be on phase1 (2%), kernel launch fusion (graph capture / persistent kernel — back to mega), or revisiting whether the README figure is accurate for stock sm_86. - -**Conservative success bar**: ≥30% prefill speedup (target: 23k tok/s) with token-level parity to ≤1 mismatch per 50-prompt corpus. - ---- - -## Open questions to resolve in Phase 1 - -1. What's the exact CUPTI breakdown of phase2 (compute / memory / sync)? Determines whether WMMA or `cp.async` is the primary lever. -2. Is the global memory load of `state` (32×129 f32 = 16.5 KB / chunk start, only once) actually a hot path, or is most of the 13.27 ms in the inter-chunk smem operations? -3. Does the existing kernel benefit from `__launch_bounds__(128, 2)` even without the WMMA rewrite? Quick experiment in Phase 5; if it does, that's a free 5-10% before any rewrite. -4. Is the 37,800 README figure reproducible at all on this 3090, or was it from a non-stock setup (overclocked, different chip bin, different driver)? Worth one ping to davide221 before committing two weeks of work. diff --git a/dflash/CMakeLists.txt b/dflash/CMakeLists.txt index 0e951857f..965770aa8 100644 --- a/dflash/CMakeLists.txt +++ b/dflash/CMakeLists.txt @@ -205,6 +205,10 @@ add_library(dflash27b STATIC src/qwen3/qwen3_drafter.cpp src/qwen3/qwen3_loader.cpp src/qwen3/qwen3_graph.cpp + src/gemma4_target_loader.cpp + src/gemma4_target_graph.cpp + src/gemma4_mtp_graph.cpp + src/gemma4_dflash_graph.cpp src/flashprefill_q8.cpp src/kv_cache.cpp src/kv_quant.cpp @@ -246,6 +250,11 @@ elseif(DFLASH27B_GPU_BACKEND STREQUAL "hip") target_compile_definitions(dflash27b PRIVATE DFLASH27B_BACKEND_HIP=1 GGML_USE_HIP) endif() +# Backward-compat alias for our gemma4 graph code that uses DFLASH27B_MIN_SM. +# origin/main renamed the variable to _dflash27b_cuda_min_sm; expose both names +# so dflash/src/gemma4_dflash_graph.cpp keeps building unchanged. +target_compile_definitions(dflash27b PRIVATE DFLASH27B_MIN_SM=${_dflash27b_cuda_min_sm}) + # FlashPrefill custom kernels. # CUDA: BF16 WMMA needs sm_80+; on sm_75 we fall back to ggml flash_attn_ext. # HIP Phase 1 (default): ggml q8 fallback, no custom kernels. @@ -283,7 +292,8 @@ elseif(DFLASH27B_GPU_BACKEND STREQUAL "cuda" AND _dflash27b_cuda_min_sm GREATER_ target_sources(dflash27b PRIVATE src/flashprefill_kernels.cu src/flashprefill_select.cpp - src/flashprefill.cpp) + src/flashprefill.cpp + src/pflash_ggml_adapter.cpp) target_compile_definitions(dflash27b PRIVATE DFLASH27B_HAVE_CUDA_WMMA_FLASHPREFILL=1) endif() @@ -525,5 +535,78 @@ if(DFLASH27B_TESTS) target_link_libraries(${_t} PRIVATE CUDA::cudart) endif() endforeach() + + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/gemma4/test_gemma4_dflash.cpp") + add_executable(test_gemma4_dflash test/gemma4/test_gemma4_dflash.cpp) + target_include_directories(test_gemma4_dflash PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) + target_link_libraries(test_gemma4_dflash PRIVATE dflash27b ggml ggml-cuda) + find_package(CUDAToolkit REQUIRED) + target_link_libraries(test_gemma4_dflash PRIVATE CUDA::cudart) + endif() + + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/gemma4/smoke_load_gemma4_target.cpp") + add_executable(smoke_load_gemma4_target test/gemma4/smoke_load_gemma4_target.cpp) + target_include_directories(smoke_load_gemma4_target PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) + target_link_libraries(smoke_load_gemma4_target PRIVATE dflash27b ggml ggml-cuda) + find_package(CUDAToolkit REQUIRED) + target_link_libraries(smoke_load_gemma4_target PRIVATE CUDA::cudart) + endif() + + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/gemma4/smoke_gemma4_target_forward.cpp") + add_executable(smoke_gemma4_target_forward test/gemma4/smoke_gemma4_target_forward.cpp) + target_include_directories(smoke_gemma4_target_forward PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) + target_link_libraries(smoke_gemma4_target_forward PRIVATE dflash27b ggml ggml-cuda) + find_package(CUDAToolkit REQUIRED) + target_link_libraries(smoke_gemma4_target_forward PRIVATE CUDA::cudart) + endif() + + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/gemma4/smoke_load_gemma4_draft.cpp") + add_executable(smoke_load_gemma4_draft test/gemma4/smoke_load_gemma4_draft.cpp) + target_include_directories(smoke_load_gemma4_draft PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) + target_link_libraries(smoke_load_gemma4_draft PRIVATE dflash27b ggml ggml-cuda) + find_package(CUDAToolkit REQUIRED) + target_link_libraries(smoke_load_gemma4_draft PRIVATE CUDA::cudart) + endif() + + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/gemma4/smoke_gemma4_draft_forward.cpp") + add_executable(smoke_gemma4_draft_forward test/gemma4/smoke_gemma4_draft_forward.cpp) + target_include_directories(smoke_gemma4_draft_forward PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) + target_link_libraries(smoke_gemma4_draft_forward PRIVATE dflash27b ggml ggml-cuda) + find_package(CUDAToolkit REQUIRED) + target_link_libraries(smoke_gemma4_draft_forward PRIVATE CUDA::cudart) + endif() + + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/gemma4/test_gemma4_kv_tq3.cpp") + add_executable(test_gemma4_kv_tq3 test/gemma4/test_gemma4_kv_tq3.cpp) + target_include_directories(test_gemma4_kv_tq3 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) + target_link_libraries(test_gemma4_kv_tq3 PRIVATE dflash27b ggml ggml-cuda) + find_package(CUDAToolkit REQUIRED) + target_link_libraries(test_gemma4_kv_tq3 PRIVATE CUDA::cudart) + endif() + + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_flash_attn_sparse.cpp") + add_executable(test_flash_attn_sparse test/test_flash_attn_sparse.cpp) + target_link_libraries(test_flash_attn_sparse PRIVATE dflash27b ggml ggml-cuda ggml-base) + target_include_directories(test_flash_attn_sparse PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/deps/llama.cpp/ggml/include + ${CMAKE_CURRENT_SOURCE_DIR}/deps/llama.cpp/ggml/src + ${CMAKE_CURRENT_SOURCE_DIR}/src) + endif() + + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/gemma4/test_mtp_loader.cpp") + add_executable(test_mtp_loader test/gemma4/test_mtp_loader.cpp) + target_include_directories(test_mtp_loader PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) + target_link_libraries(test_mtp_loader PRIVATE dflash27b ggml ggml-cuda) + find_package(CUDAToolkit REQUIRED) + target_link_libraries(test_mtp_loader PRIVATE CUDA::cudart) + endif() + + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/gemma4/test_mtp_graph_shapes.cpp") + add_executable(test_mtp_graph_shapes test/gemma4/test_mtp_graph_shapes.cpp) + target_include_directories(test_mtp_graph_shapes PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) + target_link_libraries(test_mtp_graph_shapes PRIVATE dflash27b ggml ggml-cuda) + find_package(CUDAToolkit REQUIRED) + target_link_libraries(test_mtp_graph_shapes PRIVATE CUDA::cudart) endif() + endif() # DFLASH27B_GPU_BACKEND STREQUAL "cuda" endif() diff --git a/dflash/README.md b/dflash/README.md index 63c6122f1..67118dcda 100644 --- a/dflash/README.md +++ b/dflash/README.md @@ -328,6 +328,14 @@ DFLASH27B_KV_TQ3=1 DFLASH27B_PREFILL_UBATCH=16 \ **Requirements:** NVIDIA sm_75+ GPU (2080 Ti, 3090, A10, A40, 4090) or Jetson AGX Thor sm_110, CUDA 12+ (CUDA 13+ required for Thor), 22+ GB VRAM, ~80 GB disk. On Turing (SM 7.5), BF16 draft weights are auto-converted to FP16 at load time for tensor core acceleration. +### Small-VRAM cards (<=24 GiB) + +VMM-backed pools waste VRAM on cards under ~24 GiB. The 32 GB VMM pool reservation fragments badly on a 24 GB card and causes prefill+verify cliffs (measured ~50% throughput loss at ctx=64K). Build with: + + cmake -DGGML_CUDA_NO_VMM=ON .. + +`GGML_CUDA_NO_VMM` is a **compile-time** CMake option — it cannot be set at runtime via environment variable. The dflash test binary prints a runtime warning if it detects <=24 GiB VRAM and the binary was built without this flag. + ## How it works **Block-diffusion draft.** Each step, the draft sees `[last_target_token, MASK×15]` plus the last 5 captured target hidden states. It denoises the masks in a single forward, producing 16 candidate tokens conditioned on real target features. Structurally stronger than chain EAGLE: every position conditions on the same captured context, not its own noisy predictions. diff --git a/dflash/deps/llama.cpp b/dflash/deps/llama.cpp index c79573c9b..ecb832bbe 160000 --- a/dflash/deps/llama.cpp +++ b/dflash/deps/llama.cpp @@ -1 +1 @@ -Subproject commit c79573c9b23980181c186b70812799f51e94fb50 +Subproject commit ecb832bbea1a489c08e53800b50910c420bb33b6 diff --git a/dflash/include/gemma4.h b/dflash/include/gemma4.h new file mode 100644 index 000000000..c82687fb0 --- /dev/null +++ b/dflash/include/gemma4.h @@ -0,0 +1,62 @@ +// gemma4 — standalone CUDA library for DFlash speculative decoding of +// Gemma4 models (31B Dense and 26B-A4B MoE) with a DFlash draft model. + +#ifndef GEMMA4_H +#define GEMMA4_H + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// ─── Gemma4-31B Dense config ─────────────────────────────────────── + +#define GEMMA4_31B_HIDDEN 4096 +#define GEMMA4_31B_LAYERS 60 +#define GEMMA4_31B_N_HEADS 32 +#define GEMMA4_31B_N_KV_HEADS 8 +#define GEMMA4_31B_HEAD_DIM 128 +#define GEMMA4_31B_INTERMEDIATE 16384 +#define GEMMA4_31B_VOCAB 262144 +#define GEMMA4_31B_SWA_WINDOW 1024 + +// ─── Gemma4-26B-A4B MoE config ──────────────────────────────────── + +#define GEMMA4_26B_HIDDEN 4096 +#define GEMMA4_26B_LAYERS 30 +#define GEMMA4_26B_N_HEADS 32 +#define GEMMA4_26B_N_KV_HEADS 8 +#define GEMMA4_26B_HEAD_DIM 128 +#define GEMMA4_26B_INTERMEDIATE 16384 +#define GEMMA4_26B_EXPERT_INTERMEDIATE 2048 +#define GEMMA4_26B_N_EXPERTS 128 +#define GEMMA4_26B_N_EXPERTS_USED 8 +#define GEMMA4_26B_VOCAB 262144 +#define GEMMA4_26B_SWA_WINDOW 1024 + +// ─── Shared constants ───────────────────────────────────────────── + +#define GEMMA4_ROPE_THETA 1000000.0f +#define GEMMA4_RMS_EPS 1e-6f +#define GEMMA4_LOGIT_SOFTCAP 30.0f +#define GEMMA4_ATTN_SCALE 1.0f + +// ─── Draft model config ─────────────────────────────────────────── + +#define GEMMA4_DRAFT_LAYERS 5 +#define GEMMA4_DRAFT_BLOCK_SIZE 16 +#define GEMMA4_DRAFT_N_TARGET_LAYERS 6 +#define GEMMA4_31B_DRAFT_MASK_TOKEN_ID 4 +#define GEMMA4_26B_DRAFT_MASK_TOKEN_ID 4 + +// ─── Diagnostics ────────────────────────────────────────────────── + +const char * gemma4_last_error(void); + +#ifdef __cplusplus +} +#endif + +#endif // GEMMA4_H diff --git a/dflash/scripts/quantize_draft_q8.py b/dflash/scripts/quantize_draft_q8.py index 6ad8533d9..98b317c73 100644 --- a/dflash/scripts/quantize_draft_q8.py +++ b/dflash/scripts/quantize_draft_q8.py @@ -1,6 +1,10 @@ #!/usr/bin/env python3 """ -Quantize the z-lab DFlash draft (safetensors, bf16) to a Q8_0 GGUF. +Quantize a z-lab DFlash draft (safetensors, bf16) to a Q8_0 GGUF. + +Supports both Qwen and Gemma4 draft architectures via --arch. +When config.json is present alongside the safetensors file, dimensions are +auto-detected from it; hardcoded defaults are used as fallback. Projection weights (fc, wq, wk, wv, wo, gate, up, down) are quantized to Q8_0 (~50% size reduction vs BF16). Norm weights stay F32 @@ -10,6 +14,17 @@ convert_dflash_to_gguf.py so draft_gguf_loader.cpp can load it. Usage: + # Qwen3.5 draft (auto-detects arch from config.json when present) + python3 scripts/quantize_draft_q8.py --arch qwen \ + models/draft/model.safetensors \ + models/draft/draft-q8_0.gguf + + # Gemma4 draft + python3 scripts/quantize_draft_q8.py --arch gemma4 \ + models/draft-gemma4-31b/model.safetensors \ + models/draft-gemma4-31b/draft-q8_0.gguf + + # Auto-detect arch from config.json (requires model_type field) python3 scripts/quantize_draft_q8.py \ models/draft/model.safetensors \ models/draft/draft-q8_0.gguf @@ -17,6 +32,7 @@ import argparse import json +import re import struct import sys from pathlib import Path @@ -24,31 +40,150 @@ import numpy as np import gguf +Q8_0_BLOCK_SIZE = 32 # elements per Q8_0 block + # ────────────────────────────────────────────────────────────────────── -# DFlash 27B draft architecture constants (must match dflash27b.h) +# Per-arch defaults (used when config.json is absent or incomplete) # ────────────────────────────────────────────────────────────────────── -ARCH = "qwen35-dflash-draft" -HIDDEN = 5120 -N_LAYER = 5 -N_HEAD = 32 -N_HEAD_KV = 8 -HEAD_DIM = 128 -INTERMEDIATE = 17408 -VOCAB = 248320 -N_TARGET_LAYERS = 5 -ROPE_THETA = 1_000_000.0 -RMS_EPS = 1e-6 -MASK_TOKEN_ID = 248070 -BLOCK_SIZE = 16 -CTX_LEN = 32768 +_QWEN_DEFAULTS = dict( + ARCH = "qwen35-dflash-draft", + HIDDEN = 5120, + N_LAYER = 5, + N_HEAD = 32, + N_HEAD_KV = 8, + HEAD_DIM = 128, + INTERMEDIATE = 17408, + VOCAB = 248320, + ROPE_THETA = 1_000_000.0, + RMS_EPS = 1e-6, + MASK_TOKEN_ID = 248070, + BLOCK_SIZE = 16, + CTX_LEN = 32768, + N_TARGET_LAYERS = 5, + MODEL_SIZE_TAG = "27B", + # Qwen-specific (no sliding window or logit softcap) + LOGIT_SOFTCAP = None, + SLIDING_WINDOW = None, + TARGET_LAYER_IDS = None, +) + +_GEMMA4_DEFAULTS = dict( + ARCH = "gemma4-dflash-draft", + HIDDEN = 2816, + N_LAYER = 5, + N_HEAD = 32, + N_HEAD_KV = 8, + HEAD_DIM = 128, + INTERMEDIATE = 5632, + VOCAB = 262144, + ROPE_THETA = 1_000_000.0, + RMS_EPS = 1e-6, + MASK_TOKEN_ID = 4, + BLOCK_SIZE = 16, + CTX_LEN = 262144, + LOGIT_SOFTCAP = 30.0, + SLIDING_WINDOW = 2048, + TARGET_LAYER_IDS = [1, 6, 11, 17, 22, 27], + MODEL_SIZE_TAG = "26B", +) + +_ARCH_DEFAULTS = { + "qwen": _QWEN_DEFAULTS, + "gemma4": _GEMMA4_DEFAULTS, +} + +# config.json model_type -> arch key +_MODEL_TYPE_MAP = { + "qwen3": "qwen", + "gemma4": "gemma4", +} + + +# ────────────────────────────────────────────────────────────────────── +# Config loading +# ────────────────────────────────────────────────────────────────────── + +def detect_arch_from_config(cfg_path: Path) -> str | None: + """Return 'qwen' or 'gemma4' by reading model_type from config.json.""" + if not cfg_path.exists(): + return None + with open(cfg_path) as f: + raw = json.load(f) + model_type = raw.get("model_type", "").lower() + for prefix, arch in _MODEL_TYPE_MAP.items(): + if model_type.startswith(prefix): + return arch + architectures = raw.get("architectures", []) + for a in architectures: + a_lower = a.lower() + for prefix, arch in _MODEL_TYPE_MAP.items(): + if prefix in a_lower: + return arch + return None -Q8_0_BLOCK_SIZE = 32 # elements per Q8_0 block + +def load_config(safetensors_path: Path, arch: str) -> dict: + """ + Load dimensions from config.json next to the safetensors file. + Returns a merged cfg dict, falling back to per-arch defaults for missing keys. + """ + defaults = dict(_ARCH_DEFAULTS[arch]) + cfg_path = safetensors_path.parent / "config.json" + + if not cfg_path.exists(): + print(f"[info] no config.json found at {cfg_path}, using {arch} hardcoded defaults") + return defaults + + print(f"[info] reading config from {cfg_path}") + with open(cfg_path) as f: + raw = json.load(f) + + dflash_cfg = raw.get("dflash_config", {}) + + # Derive model size tag from directory name (e.g. "draft-gemma4-31b" -> "31B") + dir_name = safetensors_path.parent.name + m = re.search(r"(\d+[bBmM])", dir_name) + model_size_tag = m.group(1).upper() if m else defaults["MODEL_SIZE_TAG"] + + cfg = dict(defaults) + cfg.update(dict( + HIDDEN = raw.get("hidden_size", defaults["HIDDEN"]), + N_LAYER = raw.get("num_hidden_layers", defaults["N_LAYER"]), + N_HEAD = raw.get("num_attention_heads", defaults["N_HEAD"]), + N_HEAD_KV = raw.get("num_key_value_heads", defaults["N_HEAD_KV"]), + HEAD_DIM = raw.get("head_dim", defaults["HEAD_DIM"]), + INTERMEDIATE = raw.get("intermediate_size", defaults["INTERMEDIATE"]), + VOCAB = raw.get("vocab_size", defaults["VOCAB"]), + ROPE_THETA = float(raw.get("rope_theta", defaults["ROPE_THETA"])), + RMS_EPS = float(raw.get("rms_norm_eps", defaults["RMS_EPS"])), + MASK_TOKEN_ID = dflash_cfg.get("mask_token_id", defaults["MASK_TOKEN_ID"]), + BLOCK_SIZE = raw.get("block_size", defaults["BLOCK_SIZE"]), + CTX_LEN = raw.get("max_position_embeddings", defaults["CTX_LEN"]), + MODEL_SIZE_TAG = model_size_tag, + )) + + if arch == "gemma4": + target_layer_ids = dflash_cfg.get("target_layer_ids", defaults["TARGET_LAYER_IDS"]) + cfg.update(dict( + LOGIT_SOFTCAP = float(raw.get("final_logit_softcapping", defaults["LOGIT_SOFTCAP"])), + SLIDING_WINDOW = raw.get("sliding_window", defaults["SLIDING_WINDOW"]), + TARGET_LAYER_IDS = target_layer_ids, + )) + + print(f"[info] detected model size tag: {model_size_tag}") + print(f"[info] hidden={cfg['HIDDEN']} n_layers={cfg['N_LAYER']} " + f"n_head={cfg['N_HEAD']} n_head_kv={cfg['N_HEAD_KV']} " + f"head_dim={cfg['HEAD_DIM']}") + print(f"[info] intermediate={cfg['INTERMEDIATE']} vocab={cfg['VOCAB']}") + if arch == "gemma4": + print(f"[info] target_layer_ids={cfg['TARGET_LAYER_IDS']}") + return cfg # ────────────────────────────────────────────────────────────────────── # Tensor name mapping — DFlash safetensors -> llama.cpp GGUF -# (Identical to convert_dflash_to_gguf.py) +# (Identical to convert_dflash_to_gguf.py; shared across both arches) # ────────────────────────────────────────────────────────────────────── def map_name(name: str) -> str | None: @@ -115,26 +250,96 @@ def bf16_bytes_to_f32(raw: bytes, shape: list[int]) -> np.ndarray: def main(): ap = argparse.ArgumentParser( - description="Quantize DFlash draft BF16 safetensors to Q8_0 GGUF") + description="Quantize DFlash draft BF16 safetensors to Q8_0 GGUF (qwen or gemma4)") ap.add_argument("safetensors", type=Path, help="Input BF16 safetensors (e.g. models/draft/model.safetensors)") ap.add_argument("out_gguf", type=Path, help="Output Q8_0 GGUF (e.g. models/draft/draft-q8_0.gguf)") + ap.add_argument("--arch", choices=["qwen", "gemma4"], + help="Draft model architecture. Auto-detected from config.json " + "model_type when omitted.") args = ap.parse_args() if not args.safetensors.exists(): print(f"[error] safetensors not found: {args.safetensors}", file=sys.stderr) sys.exit(1) + # Resolve arch: explicit flag > auto-detect from config.json + arch = args.arch + cfg_path = args.safetensors.parent / "config.json" + if arch is None: + arch = detect_arch_from_config(cfg_path) + if arch is None: + print( + "[error] --arch not specified and could not auto-detect from " + f"config.json (model_type not in {list(_MODEL_TYPE_MAP)}).\n" + " Pass --arch qwen or --arch gemma4 explicitly.", + file=sys.stderr, + ) + sys.exit(1) + print(f"[info] auto-detected arch: {arch}") + else: + print(f"[info] arch: {arch}") + + cfg = load_config(args.safetensors, arch) + ARCH = cfg["ARCH"] + HIDDEN = cfg["HIDDEN"] + N_LAYER = cfg["N_LAYER"] + N_HEAD = cfg["N_HEAD"] + N_HEAD_KV = cfg["N_HEAD_KV"] + HEAD_DIM = cfg["HEAD_DIM"] + INTERMEDIATE = cfg["INTERMEDIATE"] + VOCAB = cfg["VOCAB"] + ROPE_THETA = cfg["ROPE_THETA"] + RMS_EPS = cfg["RMS_EPS"] + MASK_TOKEN_ID = cfg["MASK_TOKEN_ID"] + BLOCK_SIZE = cfg["BLOCK_SIZE"] + CTX_LEN = cfg["CTX_LEN"] + MODEL_SIZE_TAG = cfg["MODEL_SIZE_TAG"] + print(f"[info] reading safetensors header from {args.safetensors}") header_size, header = load_safetensors_header(args.safetensors) n_entries = sum(1 for k in header if k != "__metadata__") print(f"[info] {n_entries} tensor entries") + # Compute N_TARGET_LAYERS / TARGET_HIDDEN from fc.weight shape + fc_info = header.get("fc.weight") + if fc_info is None: + print("[error] fc.weight not found in safetensors", file=sys.stderr) + sys.exit(1) + fc_shape = fc_info["shape"] # [hidden, n_target_layers * target_hidden] + + if arch == "qwen": + N_TARGET_LAYERS = cfg["N_TARGET_LAYERS"] + if fc_shape[1] % N_TARGET_LAYERS != 0: + print(f"[error] fc.weight columns ({fc_shape[1]}) not divisible by " + f"N_TARGET_LAYERS ({N_TARGET_LAYERS})", file=sys.stderr) + sys.exit(1) + else: # gemma4 + TARGET_LAYER_IDS = cfg["TARGET_LAYER_IDS"] + if not TARGET_LAYER_IDS: + print("[error] target_layer_ids is empty; cannot compute N_TARGET_LAYERS " + "(check config.json or _DEFAULTS)", file=sys.stderr) + sys.exit(1) + N_TARGET_LAYERS = len(TARGET_LAYER_IDS) + if fc_shape[1] % N_TARGET_LAYERS != 0: + print(f"[error] fc.weight columns ({fc_shape[1]}) not divisible by " + f"N_TARGET_LAYERS ({N_TARGET_LAYERS})", file=sys.stderr) + sys.exit(1) + + TARGET_HIDDEN = fc_shape[1] // N_TARGET_LAYERS + print(f"[info] fc.weight shape {fc_shape} -> " + f"N_TARGET_LAYERS={N_TARGET_LAYERS} TARGET_HIDDEN={TARGET_HIDDEN}") + writer = gguf.GGUFWriter(args.out_gguf, ARCH) # Architecture metadata (identical to convert_dflash_to_gguf.py) - writer.add_string("general.name", "Qwen3.5-27B-DFlash-Draft-Q8_0") + if arch == "qwen": + model_name = f"Qwen3.5-{MODEL_SIZE_TAG}-DFlash-Draft-Q8_0" + else: + model_name = f"Gemma4-{MODEL_SIZE_TAG}-DFlash-Draft-Q8_0" + writer.add_string("general.name", model_name) + print(f"[info] general.name = {model_name}") writer.add_quantization_version(gguf.GGML_QUANT_VERSION) writer.add_uint32(f"{ARCH}.context_length", CTX_LEN) writer.add_uint32(f"{ARCH}.embedding_length", HIDDEN) @@ -148,11 +353,18 @@ def main(): writer.add_float32(f"{ARCH}.attention.layer_norm_rms_epsilon", RMS_EPS) writer.add_float32(f"{ARCH}.rope.freq_base", ROPE_THETA) - # DFlash-specific hyperparameters + # DFlash-specific hyperparameters (shared) writer.add_uint32(f"{ARCH}.dflash.n_target_layers", N_TARGET_LAYERS) writer.add_uint32(f"{ARCH}.dflash.block_size", BLOCK_SIZE) writer.add_uint32(f"{ARCH}.dflash.mask_token_id", MASK_TOKEN_ID) + # Gemma4-specific hyperparameters + if arch == "gemma4": + writer.add_uint32(f"{ARCH}.dflash.sliding_window", cfg["SLIDING_WINDOW"]) + writer.add_float32(f"{ARCH}.dflash.logit_softcap", cfg["LOGIT_SOFTCAP"]) + writer.add_uint32(f"{ARCH}.dflash.target_hidden", TARGET_HIDDEN) + writer.add_array(f"{ARCH}.dflash.target_layer_ids", cfg["TARGET_LAYER_IDS"]) + # Collect and sort tensors (same order as convert_dflash_to_gguf.py) pending = [] for st_name, info in header.items(): diff --git a/dflash/scripts/server.py b/dflash/scripts/server.py index 0d05f45a0..f431d880f 100644 --- a/dflash/scripts/server.py +++ b/dflash/scripts/server.py @@ -93,6 +93,32 @@ def resolve_draft(root: Path) -> Path: raise FileNotFoundError(f"no model.safetensors under {root}") +def _read_gguf_architecture(gguf_path: Path) -> str: + """Return the 'general.architecture' string from a GGUF file, or '' on error. + + Logs a warning to stderr if the GGUF read fails, since that means the + server will silently pick the non-Gemma4 daemon path and use the wrong + argv shape. Caller should treat empty string as 'detection failed' and + decide accordingly. + """ + try: + from gguf import GGUFReader # type: ignore + import numpy as np + r = GGUFReader(str(gguf_path)) + f = r.fields.get("general.architecture") + if f is None or not f.data: + return "" + p = f.parts[f.data[0]] + if not isinstance(p, np.ndarray): + return "" + return bytes(p).decode("utf-8", errors="replace").strip() + except Exception as e: + import sys + print(f"[server] WARNING: failed to read general.architecture from {gguf_path}: {e}", + file=sys.stderr) + return "" + + _QWEN35_FAMILY_TOKENIZERS = { "Qwen3.5-27B": "Qwen/Qwen3.5-27B", "Qwen3.6-27B": "Qwen/Qwen3.6-27B", @@ -594,6 +620,7 @@ def build_app(target: Path, draft: Path | None, bin_path: Path, budget: int, max prefill_cache_slots: int = 4, prefill_cache_bytes: int = 0, arch: str = "qwen35", + use_pflash: bool = False, extra_daemon_args: list[str] | None = None, lazy_draft: bool = False, verbose_daemon: bool = False) -> FastAPI: @@ -644,7 +671,19 @@ async def _openai_compat_error_handler(_request: Request, exc: OpenAICompatError if sys.platform == "win32": env["PATH"] = dll_dir + os.pathsep + str(Path(bin_abs).parent) + os.pathsep + env.get("PATH", "") - if arch in _LAGUNA_ARCHES: + if arch == "gemma4": + # Gemma4 binary uses named flags (--model, --draft) instead of positional args. + # draft is the safetensors directory, not a resolved file. + cmd = [bin_abs, + "--model", str(target), + "--draft", str(draft), + "--daemon", + "--fast-rollback", "--ddtree", f"--ddtree-budget={budget}", + f"--max-ctx={max_ctx}", + f"--stream-fd={stream_fd_val}"] + if use_pflash: + cmd.append("--pflash") + elif arch in _LAGUNA_ARCHES: # test_dflash detects arch=laguna from the GGUF and dispatches # internally to dflash27b::run_laguna_daemon(). No --draft, no # --fast-rollback, no --ddtree (no Laguna spec-decode draft yet). @@ -2306,6 +2345,9 @@ def main(): help="Disk budget in bytes for persisted full-cache artifacts. " "0 disables budget trimming.") ap.add_argument("--daemon", action="store_true") + ap.add_argument("--pflash", action="store_true", + help="Enable pFlash sparse-attention prefill in the daemon binary " + "(Gemma4 only; no-op for Qwen3).") ap.add_argument("--target-gpu", type=int, default=None, help="Visible CUDA device id for test_dflash (sets DFLASH_TARGET_GPU)") ap.add_argument("--draft-gpu", type=int, default=None, @@ -2352,22 +2394,43 @@ def main(): if not args.target.is_file(): raise SystemExit(f"target GGUF not found at {args.target}") - # Architecture detection. test_dflash itself dispatches by GGUF arch at - # main() entry, so server.py just needs to know enough to omit --draft + - # DFlash/DDTree flags on archs that lack a spec-decode draft. Same - # binary serves every arch. - arch = _arch_from_gguf(args.target) - - if not args.bin.is_file(): - raise SystemExit(f"binary not found at {args.bin} (arch={arch})") - - if arch in _LAGUNA_ARCHES: - # No DFlash draft model exists for laguna yet; test_dflash'́s - # internal arch dispatch reads general.architecture, accepts the - # no-draft argv layout, and routes to run_laguna_daemon(). PFlash - # compression and prefix-cache SNAPSHOT/RESTORE are both wired - # through the laguna daemon now, so --prefill-compression and - # --prefix-cache-slots behave the same as on the qwen35 path. + # Detect architecture and select the right binary. + # test_dflash itself dispatches qwen35/laguna by GGUF arch at main() entry, + # but Gemma4 lives in a separate binary (test_gemma4_dflash) so we route + # explicitly here. + arch = _read_gguf_architecture(args.target) + is_gemma4 = (arch == "gemma4") + + if args.bin != DEFAULT_BIN: + # User explicitly specified a binary — use it as-is. + bin_path = args.bin + elif is_gemma4: + bin_path = ROOT / "build" / ("test_gemma4_dflash" + (".exe" if sys.platform == "win32" else "")) + print(f"[server] detected architecture=gemma4, using binary: {bin_path}") + else: + bin_path = DEFAULT_BIN + + if not bin_path.is_file(): + raise SystemExit(f"binary not found at {bin_path} (arch={arch})") + + if is_gemma4: + # Gemma4 draft is a directory (safetensors dir), not a resolved file. + if args.draft.is_dir(): + draft = args.draft + elif args.draft.is_file(): + # User passed a file path inside the draft directory; use its parent. + draft = args.draft.parent + print(f"[server] note: --draft {args.draft} is a file; using parent {draft}", file=sys.stderr) + else: + raise SystemExit(f"draft path not found or not a directory: {args.draft}") + if not draft.is_dir(): + raise SystemExit(f"draft directory not found: {draft} (from {args.draft})") + elif arch in _LAGUNA_ARCHES: + # No DFlash draft model exists for laguna yet; test_dflash's internal + # arch dispatch reads general.architecture, accepts the no-draft argv + # layout, and routes to run_laguna_daemon(). PFlash compression and + # prefix-cache SNAPSHOT/RESTORE are both wired through the laguna + # daemon now. draft = None else: draft = resolve_draft(args.draft) if args.draft.is_dir() else args.draft @@ -2412,6 +2475,7 @@ def main(): prefill_cache_slots=args.prefill_cache_slots, prefill_cache_bytes=args.prefill_cache_bytes, arch=arch, + use_pflash=getattr(args, "pflash", False), extra_daemon_args=extra_daemon or None, lazy_draft=args.lazy_draft, verbose_daemon=args.verbose_daemon) @@ -2425,8 +2489,9 @@ def main(): print(f"Luce DFlash OpenAI server on http://{args.host}:{args.port}") print(f" arch = {arch}") print(f" target = {args.target}") + print(f" arch = {arch or '(unknown)'}") print(f" draft = {draft}") - print(f" bin = {args.bin}") + print(f" bin = {bin_path}") print(f" budget = {args.budget}") print(f" max_ctx = {args.max_ctx}") print(f" tokenizer = {tokenizer_id}") diff --git a/dflash/scripts/tokenize_prompt.py b/dflash/scripts/tokenize_prompt.py index c2721838c..95dfc42b9 100644 --- a/dflash/scripts/tokenize_prompt.py +++ b/dflash/scripts/tokenize_prompt.py @@ -1,40 +1,92 @@ """ -Tokenize a prompt string using the Qwen3.5 HF tokenizer (via transformers) -and emit the token IDs as a flat int32 binary file. +Tokenize a prompt string using a HuggingFace tokenizer (via transformers). -We depend on Python only for the tokenizer — the C++ library consumes the -int32 file directly. This keeps the standalone lib free of a BPE impl. +Two output modes: + --out FILE Write token IDs as a flat int32 little-endian binary file + (consumed by the C++ library directly). + --csv Print comma-separated token IDs to stdout + (for use with the --tokens flag of test_gemma4_dflash). Usage: - python tokenize_prompt.py --out /tmp/prompt.bin --prompt "The capital of France is" + # Binary output (backward-compatible): + python tokenize_prompt.py --out /tmp/prompt.bin --prompt "Hello, world!" + + # CSV output for --tokens flag: + python tokenize_prompt.py --csv --prompt "Hello, world!" + # -> 9259,236764,1902,236888 + + # Explicit model: + python tokenize_prompt.py --csv --model google/gemma-4-26b-a4b-it --prompt "..." + + # Show token count: + python tokenize_prompt.py --csv --verbose --prompt "Hello, world!" + +Notes: + The Gemma4 tokenizer is cached locally at: + ~/.cache/huggingface/hub/models--google--gemma-4-26b-a4b-it/ + The script tries local_files_only=True first to avoid network calls. + Gemma4 vocab size: 262144, BOS token id: 2, EOS token id: 1. """ import argparse -import os -import sys import struct +import sys + +def build_parser() -> argparse.ArgumentParser: + ap = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--prompt", required=True, help="Text to tokenize") + ap.add_argument("--model", default="google/gemma-4-26b-a4b-it", + help="HF repo id whose tokenizer to use " + "(default: google/gemma-4-26b-a4b-it)") + ap.add_argument("--add-bos", action="store_true", + help="Prepend BOS token (add_special_tokens=True)") + ap.add_argument("--verbose", action="store_true", + help="Print token count and first/last tokens to stderr") + # Output modes (at least one required) + out_group = ap.add_mutually_exclusive_group(required=True) + out_group.add_argument("--out", metavar="FILE", + help="Write int32 binary token ID file") + out_group.add_argument("--csv", action="store_true", + help="Print comma-separated token IDs to stdout") + return ap -def main(): - ap = argparse.ArgumentParser() - ap.add_argument("--out", required=True) - ap.add_argument("--prompt", required=True) - ap.add_argument("--model", default="Qwen/Qwen3.5-27B", - help="HF repo id whose tokenizer to use") - ap.add_argument("--add-bos", action="store_true", help="Prepend BOS token") - args = ap.parse_args() +def load_tokenizer(model: str): + """Load tokenizer, preferring local cache to avoid network calls.""" from transformers import AutoTokenizer - tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + try: + return AutoTokenizer.from_pretrained( + model, trust_remote_code=True, local_files_only=True + ) + except Exception: + # Fall back to network if not cached + return AutoTokenizer.from_pretrained(model, trust_remote_code=True) + + +def tokenize(prompt: str, model: str, add_bos: bool) -> list[int]: + tok = load_tokenizer(model) + return tok.encode(prompt, add_special_tokens=add_bos) + - ids = tok.encode(args.prompt, add_special_tokens=args.add_bos) - print(f"tokenized {len(ids)} tokens: {ids}") +def main() -> None: + args = build_parser().parse_args() + ids = tokenize(args.prompt, args.model, args.add_bos) - with open(args.out, "wb") as f: - for t in ids: - f.write(struct.pack(" 10 else []) + ids[-5:] if len(ids) > 10 else ids + print(f"tokenized {len(ids)} tokens; first/last: {preview}", file=sys.stderr) - print(f"wrote {args.out} ({len(ids) * 4} bytes)") + if args.csv: + print(",".join(str(i) for i in ids)) + else: + with open(args.out, "wb") as f: + for t in ids: + f.write(struct.pack(" @@ -12,6 +13,7 @@ namespace dflash27b { namespace { std::mutex g_err_mu; std::string g_last_error; +thread_local std::string t_err_buf; // per-thread snapshot for safe c_str return } void set_last_error(std::string msg) { @@ -23,5 +25,12 @@ void set_last_error(std::string msg) { extern "C" const char * dflash27b_last_error(void) { std::lock_guard lk(dflash27b::g_err_mu); - return dflash27b::g_last_error.c_str(); + dflash27b::t_err_buf = dflash27b::g_last_error; // copy under lock + return dflash27b::t_err_buf.c_str(); // safe: thread-local +} + +extern "C" const char * gemma4_last_error(void) { + std::lock_guard lk(dflash27b::g_err_mu); + dflash27b::t_err_buf = dflash27b::g_last_error; + return dflash27b::t_err_buf.c_str(); } diff --git a/dflash/src/gemma4_dflash_graph.cpp b/dflash/src/gemma4_dflash_graph.cpp new file mode 100644 index 000000000..5764f7237 --- /dev/null +++ b/dflash/src/gemma4_dflash_graph.cpp @@ -0,0 +1,1112 @@ +// Builds ggml compute graphs for the Gemma4 DFlash draft model +// (5-layer block-diffusion model with KV cache and logit softcapping). +// +// Architecture: +// - 6 captured target layers (Qwen3 used 5) +// - FC input = 6 * target_hidden, where target_hidden = 4096 for all Gemma4 +// variants (31B dense and 26B-A4B MoE), giving FC width = 24576 +// - Logit softcapping: tanh(logits / cap) * cap, cap = 30.0 +// - Tied lm_head: uses tok_embd transposed (or a provided lm_head weight) +// - Vocab = 262144 +// - Draft has its own lm_head + softcap — it does NOT rely on the target's +// lm_head (unlike the Qwen3 draft which shares the target's projection) +// - KV cache (prefix-direct): target features are projected into per-layer +// K/V entries and stored in GemmaTargetCache::draft_k/draft_v. +// build_draft_kv_prefill_graph materializes the context K/V; +// build_gemma4_draft_graph writes block K/V and attends over the full cache. +// - Layer types: 4 SWA (sliding_attention) + 1 full attention +// The attention kernel itself is the same ggml_flash_attn_ext call in both +// cases; the caller controls the mask to implement the sliding window. +// +// Two-step per-decode: +// 1. build_draft_kv_prefill_graph: project new committed context tokens into +// draft KV cache (side-effect only; nullptr returned). +// 2. build_gemma4_draft_graph: attend over context+block K/V and return logits. +// +// build_gemma4_draft_graph takes: +// - draft_embed [draft_hidden, n_tokens] f32 (MASK token embeddings) +// - positions [n_tokens] i32 (absolute token positions) +// - attn_mask [kv_pad, q_pad] f16 (causal over context+block) +// - kv_start = cache.draft_kv_pos (context length before this block) +// and returns: +// - logits [n_vocab, n_tokens] f32 (after softcapping) +// +// Safetensors tensor naming (actual file, no model. prefix): +// fc.weight → fc +// hidden_norm.weight → hidden_norm +// norm.weight → out_norm +// layers.{i}.self_attn.q_proj.weight → wq +// layers.{i}.self_attn.k_proj.weight → wk +// layers.{i}.self_attn.v_proj.weight → wv +// layers.{i}.self_attn.o_proj.weight → wo +// layers.{i}.self_attn.q_norm.weight → q_norm +// layers.{i}.self_attn.k_norm.weight → k_norm +// layers.{i}.input_layernorm.weight → attn_norm +// layers.{i}.post_attention_layernorm.weight → ffn_norm +// layers.{i}.mlp.gate_proj.weight → w_gate +// layers.{i}.mlp.up_proj.weight → w_up +// layers.{i}.mlp.down_proj.weight → w_down +// (no embed_tokens — tok_embd is injected from the target at runtime) + +#include "internal.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_WIN32) +# if !defined(NOMINMAX) +# define NOMINMAX +# endif +# if !defined(WIN32_LEAN_AND_MEAN) +# define WIN32_LEAN_AND_MEAN +# endif +# include +#else +# include +# include +# include +# include +#endif + +namespace dflash27b { + +// ─── Draft SWA truncation toggle ────────────────────────────────────────── +// Set DFLASH_DRAFT_SWA_TRUNC=1 to enable per-layer K/V truncation in the +// draft graph for SWA layers (last n-1 layers — the final layer is full). +// Mirrors PR #129 for the qwen3 drafter, ported to gemma4's cached layout. +static inline bool draft_swa_trunc_enabled() { + static int e = -1; + if (e < 0) { + const char * v = std::getenv("DFLASH_DRAFT_SWA_TRUNC"); + e = (v && std::atoi(v) != 0) ? 1 : 0; + if (e) { + std::fprintf(stderr, "[draft-swa-trunc] enabled\n"); + } + } + return e == 1; +} + +// ─── Draft RoPE wrapper with optional YaRN extrapolation ────────────────── +// Set DFLASH_DRAFT_YARN=1 to enable YaRN scaling for draft RoPE; assumes the +// draft was effectively trained at DFLASH_DRAFT_YARN_NCTX_ORIG (default 32768) +// despite config.json claiming a larger max_position_embeddings. +static inline ggml_tensor * draft_rope(ggml_context * ctx, ggml_tensor * x, + ggml_tensor * positions, int head_dim, + float rope_base) { + static struct { + int nctx; + float ext; + float bf; + float bs; + bool init; + } p = {0, 0.0f, 0.0f, 0.0f, false}; + if (!p.init) { + const char * en = std::getenv("DFLASH_DRAFT_YARN"); + if (en && std::atoi(en) != 0) { + const char * nc = std::getenv("DFLASH_DRAFT_YARN_NCTX_ORIG"); + p.nctx = nc ? std::atoi(nc) : 32768; + p.ext = 1.0f; + p.bf = 32.0f; + p.bs = 1.0f; + std::fprintf(stderr, + "[draft-yarn] enabled: n_ctx_orig=%d ext_factor=%.2f beta_fast=%.1f beta_slow=%.1f\n", + p.nctx, p.ext, p.bf, p.bs); + } + p.init = true; + } + return ggml_rope_ext(ctx, x, positions, /*freq_factors=*/nullptr, + head_dim, GGML_ROPE_TYPE_NEOX, p.nctx, + rope_base, /*freq_scale=*/1.0f, + p.ext, /*attn_factor=*/1.0f, p.bf, p.bs); +} + +// ─── Graph builders ─────────────────────────────────────────────────────── + +// build_draft_kv_prefill_graph — prefix-direct KV materialisation (SGLang style). +// +// Projects n_tokens new context positions through the draft model's Wk / Wv +// (after FC → ctx_hidden) and writes the resulting K, V tensors into +// cache.draft_k[il] / cache.draft_v[il] starting at offset cache.draft_kv_pos. +// +// The function is side-effect only: it expands ggml_cpy ops into gf and +// returns nullptr. The caller must ggml_graph_compute(gf) to materialise +// the cache entries, then increment cache.draft_kv_pos by n_tokens. +// +// target_feat [6*target_hidden, n_tokens] f32 +// positions [n_tokens] i32 (absolute positions for RoPE) +ggml_tensor * build_draft_kv_prefill_graph( + ggml_context * ctx, + ggml_cgraph * gf, + const GemmaDraftWeights & w, + GemmaTargetCache & cache, + ggml_tensor * target_feat, + ggml_tensor * positions, + int n_tokens) +{ + // Guard: writing cache.draft_kv_pos..cache.draft_kv_pos+n_tokens-1 must fit. + if (cache.draft_k.empty() || + cache.draft_kv_pos < 0 || + cache.draft_kv_pos + n_tokens > (int)cache.draft_k[0]->ne[2]) { + const int tensor_cap = cache.draft_k.empty() ? -1 : (int)cache.draft_k[0]->ne[2]; + GGML_ABORT("draft KV prefill out of bounds: draft_kv_pos=%d n_tokens=%d cap=%d tensor_cap=%d", + cache.draft_kv_pos, n_tokens, cache.draft_kv_cap, tensor_cap); + } + + const int n_kv = w.n_head_kv; + const int head_dim = w.head_dim; + const float eps = GEMMA4_RMS_EPS; + const float rope_base = w.rope_theta; + + // ── 1. FC projection: ctx_hidden = fc @ target_feat → [n_embd, n_tokens] + ggml_tensor * ctx_hidden = ggml_mul_mat(ctx, w.fc, target_feat); + // hidden_norm: RMSNorm applied right after the fc projection + // (matches qwen3_dflash_graph.cpp:57-59) + ctx_hidden = ggml_rms_norm(ctx, ctx_hidden, eps); + ctx_hidden = ggml_mul(ctx, ctx_hidden, w.hidden_norm); + ggml_set_name(ctx_hidden, "draft_kv_prefill_ctx_hidden"); + + // ── 2. Per-layer K / V projection, normalisation, RoPE, cache write + for (int il = 0; il < w.n_layer; il++) { + const GemmaDraftLayer & L = w.layers[il]; + + // K = Wk @ ctx_hidden → [kv_dim, n_tokens] → [head_dim, n_kv, n_tokens] + ggml_tensor * Kb = ggml_mul_mat(ctx, L.wk, ctx_hidden); + Kb = ggml_reshape_3d(ctx, Kb, head_dim, n_kv, n_tokens); + Kb = ggml_rms_norm(ctx, Kb, eps); + Kb = ggml_mul(ctx, Kb, L.k_norm); + Kb = draft_rope(ctx, Kb, positions, head_dim, rope_base); + + // V = Wv @ ctx_hidden → [kv_dim, n_tokens] → [head_dim, n_kv, n_tokens] + ggml_tensor * Vb = ggml_mul_mat(ctx, L.wv, ctx_hidden); + Vb = ggml_reshape_3d(ctx, Vb, head_dim, n_kv, n_tokens); + + // Write K into cache.draft_k[il] at offset cache.draft_kv_pos + ggml_tensor * k_dst = ggml_view_3d(ctx, cache.draft_k[il], + head_dim, n_kv, n_tokens, + cache.draft_k[il]->nb[1], cache.draft_k[il]->nb[2], + (size_t)cache.draft_kv_pos * cache.draft_k[il]->nb[2]); + ggml_build_forward_expand(gf, ggml_cpy(ctx, Kb, k_dst)); + + // Write V into cache.draft_v[il] at offset cache.draft_kv_pos + ggml_tensor * v_dst = ggml_view_3d(ctx, cache.draft_v[il], + head_dim, n_kv, n_tokens, + cache.draft_v[il]->nb[1], cache.draft_v[il]->nb[2], + (size_t)cache.draft_kv_pos * cache.draft_v[il]->nb[2]); + ggml_build_forward_expand(gf, ggml_cpy(ctx, Vb, v_dst)); + } + + return nullptr; +} + +// build_gemma4_draft_graph — KV-cached draft forward. +// +// Attends over the full draft KV cache (context K/V already materialised by +// build_draft_kv_prefill_graph, plus newly written block K/V) and returns +// logits for the n_tokens block positions. +// +// draft_embed [n_embd, n_tokens] f32 (MASK token embeddings) +// positions [n_tokens] i32 (absolute token positions) +// attn_mask [kv_pad, q_pad] f16 (causal over context+block) +// kv_start context length before this block (= cache.draft_kv_pos) +// +// Returns logits [n_vocab, n_tokens] f32 (softcapped). +ggml_tensor * build_gemma4_draft_graph( + ggml_context * ctx, + ggml_cgraph * gf, + const GemmaDraftWeights & w, + GemmaTargetCache & cache, + ggml_tensor * draft_embed, + ggml_tensor * positions, + ggml_tensor * attn_mask, + int n_tokens, + int kv_start) +{ + // Validate KV cache write range before any graph nodes touch it. + if (kv_start < 0 || kv_start + n_tokens > cache.draft_kv_cap) { + GGML_ABORT("draft KV write out of bounds: kv_start=%d n_tokens=%d cap=%d", + kv_start, n_tokens, cache.draft_kv_cap); + } + + const int n_head = w.n_head; + const int n_kv = w.n_head_kv; + const int head_dim = w.head_dim; + const float eps = GEMMA4_RMS_EPS; + const float rope_base = w.rope_theta; + const int kv_len = kv_start + n_tokens; + + // Gemma4 scales embeddings by sqrt(hidden_size) — the draft shares the + // target's tok_embd, so it must apply the same scaling. Reference: + // vLLM qwen3_dflash.py embed_normalizer = target_config.hidden_size**0.5 + ggml_tensor * hidden = ggml_scale(ctx, draft_embed, std::sqrt((float)w.n_embd)); + ggml_set_name(hidden, "gemma4_draft_scaled_embed"); + + // ── 2. Transformer layers ───────────────────────────────────────── + for (int il = 0; il < w.n_layer; il++) { + const GemmaDraftLayer & L = w.layers[il]; + + // ── 2a. Attention pre-norm + ggml_tensor * cur = ggml_rms_norm(ctx, hidden, eps); + cur = ggml_mul(ctx, cur, L.attn_norm); + + // ── 2b. Q / K / V projections from block hidden state + ggml_tensor * Q = ggml_mul_mat(ctx, L.wq, cur); // [q_dim, n_tokens] + ggml_tensor * Kb = ggml_mul_mat(ctx, L.wk, cur); // [kv_dim, n_tokens] + ggml_tensor * Vb = ggml_mul_mat(ctx, L.wv, cur); // [kv_dim, n_tokens] + + // ── 2c. Reshape + per-head RMSNorm for Q and block K + Q = ggml_reshape_3d(ctx, Q, head_dim, n_head, n_tokens); + Q = ggml_rms_norm(ctx, Q, eps); + Q = ggml_mul(ctx, Q, L.q_norm); + + Kb = ggml_reshape_3d(ctx, Kb, head_dim, n_kv, n_tokens); + Kb = ggml_rms_norm(ctx, Kb, eps); + Kb = ggml_mul(ctx, Kb, L.k_norm); + + Vb = ggml_reshape_3d(ctx, Vb, head_dim, n_kv, n_tokens); + + // ── 2d. RoPE on Q and block K + Q = draft_rope(ctx, Q, positions, head_dim, rope_base); + Kb = draft_rope(ctx, Kb, positions, head_dim, rope_base); + + // ── 2e. Write block K / V into draft KV cache at [kv_start..kv_start+n_tokens) + ggml_tensor * k_dst = ggml_view_3d(ctx, cache.draft_k[il], + head_dim, n_kv, n_tokens, + cache.draft_k[il]->nb[1], cache.draft_k[il]->nb[2], + (size_t)kv_start * cache.draft_k[il]->nb[2]); + ggml_build_forward_expand(gf, ggml_cpy(ctx, Kb, k_dst)); + + ggml_tensor * v_dst = ggml_view_3d(ctx, cache.draft_v[il], + head_dim, n_kv, n_tokens, + cache.draft_v[il]->nb[1], cache.draft_v[il]->nb[2], + (size_t)kv_start * cache.draft_v[il]->nb[2]); + ggml_build_forward_expand(gf, ggml_cpy(ctx, Vb, v_dst)); + + // ── 2f. Full K / V view (context + block) from draft KV cache + // Optional SWA truncation: when enabled and this is an SWA layer + // with kv_len exceeding sliding_window, restrict K/V (and the mask) + // to the last (sliding_window + n_tokens) slots. Matches the draft + // model's training-time SWA pattern. + const bool layer_is_swa = (il < (int)w.layer_is_swa.size()) + ? w.layer_is_swa[il] : false; + const bool use_swa_trunc = draft_swa_trunc_enabled() + && layer_is_swa + && w.sliding_window > 0 + && kv_len > (w.sliding_window + n_tokens); + const int eff_kv_len = use_swa_trunc + ? (w.sliding_window + n_tokens) + : kv_len; + const int kv_offset = kv_len - eff_kv_len; // 0 if no truncation + + ggml_tensor * K_full = ggml_view_3d(ctx, cache.draft_k[il], + head_dim, n_kv, eff_kv_len, + cache.draft_k[il]->nb[1], cache.draft_k[il]->nb[2], + (size_t)kv_offset * cache.draft_k[il]->nb[2]); + ggml_tensor * V_full = ggml_view_3d(ctx, cache.draft_v[il], + head_dim, n_kv, eff_kv_len, + cache.draft_v[il]->nb[1], cache.draft_v[il]->nb[2], + (size_t)kv_offset * cache.draft_v[il]->nb[2]); + + // ── 2g. Permute into flash_attn_ext layout + // Q: [head_dim, n_tokens, n_head, 1] + // K_full: [head_dim, eff_kv_len, n_head_kv, 1] + // V_full: [head_dim, eff_kv_len, n_head_kv, 1] + Q = ggml_cont(ctx, ggml_permute(ctx, Q, 0, 2, 1, 3)); + K_full = ggml_cont(ctx, ggml_permute(ctx, K_full, 0, 2, 1, 3)); + V_full = ggml_cont(ctx, ggml_permute(ctx, V_full, 0, 2, 1, 3)); + + // SWA-truncated mask view: take the last eff_kv_len rows along the + // kv axis (axis 0). Mask shape is [kv_pad, q_pad] with kv_pad >= kv_len, + // so the slice [kv_offset .. kv_offset+eff_kv_len) gives the same + // causal pattern for the surviving K positions. + ggml_tensor * eff_mask = attn_mask; + if (use_swa_trunc && kv_offset > 0) { + // ggml_view_2d would produce a non-contiguous tensor (row stride is + // unchanged at kv_pad * elt). FA requires contiguous mask, so we + // copy the slice into a fresh tensor. + ggml_tensor * mask_view = ggml_view_2d(ctx, attn_mask, + eff_kv_len, attn_mask->ne[1], + attn_mask->nb[1], + (size_t)kv_offset * ggml_element_size(attn_mask)); + eff_mask = ggml_cont(ctx, mask_view); + } + + // ── 2h. Flash attention over full context+block KV + // scale = 1 / sqrt(head_dim); no logit softcap at attention level + const float scale = 1.0f / std::sqrt((float)head_dim); + ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K_full, V_full, eff_mask, + scale, /*max_bias=*/0.0f, + /*logit_softcap=*/0.0f); + // attn: [head_dim, n_head, n_tokens, 1] + attn = ggml_reshape_2d(ctx, attn, head_dim * n_head, n_tokens); + + // ── 2i. Output projection + residual + ggml_tensor * attn_out = ggml_mul_mat(ctx, L.wo, attn); + hidden = ggml_add(ctx, hidden, attn_out); + + // ── 2j. FFN pre-norm + ggml_tensor * hf = ggml_rms_norm(ctx, hidden, eps); + hf = ggml_mul(ctx, hf, L.ffn_norm); + + // ── 2k. SwiGLU FFN: down(silu(gate(x)) * up(x)) + ggml_tensor * g = ggml_mul_mat(ctx, L.w_gate, hf); + g = ggml_silu(ctx, g); + ggml_tensor * u = ggml_mul_mat(ctx, L.w_up, hf); + ggml_tensor * gu = ggml_mul(ctx, g, u); + ggml_tensor * ffn_out = ggml_mul_mat(ctx, L.w_down, gu); + + hidden = ggml_add(ctx, hidden, ffn_out); + } + + // ── 3. Final output norm + ggml_tensor * out = ggml_rms_norm(ctx, hidden, eps); + out = ggml_mul(ctx, out, w.out_norm); + ggml_set_name(out, "gemma4_draft_hidden_out"); + + // ── 4. LM head (tied: transpose of tok_embd) + // tok_embd: [draft_hidden, n_vocab] ggml ne[0]=draft_hidden, ne[1]=n_vocab + // out: [draft_hidden, n_tokens] + // logits: [n_vocab, n_tokens] + ggml_tensor * logits = ggml_mul_mat(ctx, w.tok_embd, out); + ggml_set_name(logits, "gemma4_draft_logits_pre_cap"); + + // ── 5. Logit softcapping: logits = cap * tanh(logits / cap) + const float cap = w.logit_softcap; + logits = ggml_scale(ctx, logits, 1.0f / cap); + logits = ggml_tanh(ctx, logits); + logits = ggml_scale(ctx, logits, cap); + ggml_set_name(logits, "gemma4_draft_logits"); + + return logits; +} + +// ─── Safetensors loader ─────────────────────────────────────────────────── + +namespace { + +struct GStEntry { + std::string dtype; + std::vector shape; + uint64_t data_start; + uint64_t data_end; +}; + +using GStMap = std::unordered_map; + +// Minimal safetensors JSON header parser (same algorithm as safetensors_draft.cpp). +static bool parse_gst_header(const char * h, size_t hlen, GStMap & out) { + auto skip_ws = [&](size_t & i) { + while (i < hlen && (h[i] == ' ' || h[i] == '\t' || + h[i] == '\n' || h[i] == '\r')) i++; + }; + size_t i = 0; + skip_ws(i); + if (i >= hlen || h[i] != '{') return false; + i++; + while (i < hlen) { + skip_ws(i); + if (i >= hlen) return false; + if (h[i] == '}') { i++; break; } + if (h[i] == ',') { i++; skip_ws(i); } + if (i >= hlen || h[i] != '"') return false; + i++; + size_t name_start = i; + while (i < hlen && h[i] != '"') i++; + if (i >= hlen) return false; + std::string name(h + name_start, i - name_start); + i++; + skip_ws(i); + if (i >= hlen || h[i] != ':') return false; + i++; + skip_ws(i); + if (i >= hlen || h[i] != '{') return false; + size_t obj_start = i; + int depth = 0; + size_t obj_end = i; + for (; obj_end < hlen; obj_end++) { + if (h[obj_end] == '{') depth++; + else if (h[obj_end] == '}') { if (--depth == 0) { obj_end++; break; } } + } + if (depth != 0) return false; + if (name == "__metadata__") { i = obj_end; continue; } + + std::string obj(h + obj_start, obj_end - obj_start); + GStEntry e; + { + auto k = obj.find("\"dtype\":\""); + if (k == std::string::npos) return false; + auto vs = k + 9; + auto ve = obj.find('"', vs); + if (ve == std::string::npos) return false; + e.dtype = obj.substr(vs, ve - vs); + } + { + auto k = obj.find("\"shape\":["); + if (k == std::string::npos) return false; + auto vs = k + 9; + auto ve = obj.find(']', vs); + if (ve == std::string::npos) return false; + const char * p = obj.c_str() + vs; + const char * pe = obj.c_str() + ve; + while (p < pe) { + char * end = nullptr; + long long v = std::strtoll(p, &end, 10); + if (end == p) break; + e.shape.push_back((int64_t)v); + p = end; + while (p < pe && (*p == ',' || *p == ' ')) p++; + } + } + { + auto k = obj.find("\"data_offsets\":["); + if (k == std::string::npos) return false; + auto vs = k + 16; + auto ve = obj.find(']', vs); + if (ve == std::string::npos) return false; + unsigned long long s = 0, ed = 0; + if (std::sscanf(obj.c_str() + vs, "%llu , %llu", &s, &ed) != 2) + if (std::sscanf(obj.c_str() + vs, "%llu,%llu", &s, &ed) != 2) return false; + e.data_start = s; + e.data_end = ed; + } + out.emplace(std::move(name), std::move(e)); + i = obj_end; + } + return true; +} + +static ggml_type gst_dtype_to_ggml(const std::string & dt) { + if (dt == "BF16") return GGML_TYPE_BF16; + if (dt == "F16") return GGML_TYPE_F16; + if (dt == "F32") return GGML_TYPE_F32; + return GGML_TYPE_COUNT; +} + +struct GMmap { + void * addr = nullptr; + size_t len = 0; +#if defined(_WIN32) + HANDLE hFile = INVALID_HANDLE_VALUE; + HANDLE hMap = nullptr; +#else + int fd = -1; +#endif + + bool open_ro(const std::string & path, std::string & err) { +#if defined(_WIN32) + hFile = CreateFileA(path.c_str(), GENERIC_READ, FILE_SHARE_READ, + nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr); + if (hFile == INVALID_HANDLE_VALUE) { + err = "CreateFileA: " + path + ": error " + std::to_string(GetLastError()); + return false; + } + LARGE_INTEGER sz; + if (!GetFileSizeEx(hFile, &sz)) { + err = "GetFileSizeEx failed"; return false; + } + len = (size_t)sz.QuadPart; + hMap = CreateFileMappingA(hFile, nullptr, PAGE_READONLY, 0, 0, nullptr); + if (!hMap) { err = "CreateFileMappingA failed"; return false; } + addr = MapViewOfFile(hMap, FILE_MAP_READ, 0, 0, 0); + if (!addr) { err = "MapViewOfFile failed"; return false; } +#else + fd = ::open(path.c_str(), O_RDONLY); + if (fd < 0) { err = "open: " + path + ": " + std::strerror(errno); return false; } + struct stat st; + if (::fstat(fd, &st) < 0) { err = "fstat: " + std::string(std::strerror(errno)); return false; } + len = (size_t)st.st_size; + addr = ::mmap(nullptr, len, PROT_READ, MAP_PRIVATE, fd, 0); + if (addr == MAP_FAILED) { + err = "mmap: " + std::string(std::strerror(errno)); + addr = nullptr; return false; + } +#endif + return true; + } + + ~GMmap() { +#if defined(_WIN32) + if (addr) UnmapViewOfFile(addr); + if (hMap) CloseHandle(hMap); + if (hFile != INVALID_HANDLE_VALUE) CloseHandle(hFile); +#else + if (addr) ::munmap(addr, len); + if (fd >= 0) ::close(fd); +#endif + } +}; + +// Allocate one ggml tensor for a safetensors entry. +// HF row-major [out, in] → ggml ne[0]=in, ne[1]=out (byte layout identical). +// norm weights are kept as F32 (ggml CUDA elementwise ops require non-BF16 src1). +// Projection weights stay BF16 (Ampere+) or are converted to F16 (Turing). +static ggml_tensor * galloc_tensor( + ggml_context * gctx, + const GStMap & st, + const std::string & name, + const std::vector & expected_shape, + ggml_type gt_override = GGML_TYPE_COUNT) +{ + auto it = st.find(name); + if (it == st.end()) { + set_last_error("gemma4 safetensors: missing tensor '" + name + "'"); + return nullptr; + } + const GStEntry & e = it->second; + if (e.dtype != "BF16") { + set_last_error("gemma4 safetensors: '" + name + "' dtype=" + e.dtype + + " expected BF16"); + return nullptr; + } + if (e.shape.size() != expected_shape.size()) { + set_last_error("gemma4 safetensors: '" + name + "' ndim mismatch"); + return nullptr; + } + for (size_t k = 0; k < expected_shape.size(); k++) { + if (e.shape[k] != expected_shape[k]) { + char buf[256]; + std::snprintf(buf, sizeof(buf), + "gemma4 safetensors: '%s' shape[%zu]=%lld expected %lld", + name.c_str(), k, (long long)e.shape[k], (long long)expected_shape[k]); + set_last_error(buf); + return nullptr; + } + } + ggml_type gt = (gt_override == GGML_TYPE_COUNT) ? GGML_TYPE_BF16 : gt_override; + ggml_tensor * t = nullptr; + if (expected_shape.size() == 1) { + t = ggml_new_tensor_1d(gctx, gt, expected_shape[0]); + } else if (expected_shape.size() == 2) { + // [out, in] → ne[0]=in, ne[1]=out + t = ggml_new_tensor_2d(gctx, gt, expected_shape[1], expected_shape[0]); + } else { + set_last_error("gemma4 safetensors: unexpected ndim > 2 for '" + name + "'"); + return nullptr; + } + ggml_set_name(t, name.c_str()); + return t; +} + +static void g_bf16_to_f32(const uint16_t * src, float * dst, size_t n) { + for (size_t i = 0; i < n; i++) { + uint32_t bits = ((uint32_t)src[i]) << 16; + std::memcpy(&dst[i], &bits, 4); + } +} + +static void g_bf16_to_f16(const uint16_t * src, uint16_t * dst, size_t n) { + for (size_t i = 0; i < n; i++) { + uint32_t bits = ((uint32_t)src[i]) << 16; + float f; + std::memcpy(&f, &bits, 4); + uint32_t u; + std::memcpy(&u, &f, 4); + uint32_t sign = (u >> 16) & 0x8000; + int32_t exp = ((u >> 23) & 0xFF) - 127 + 15; + uint32_t mant = (u >> 13) & 0x03FF; + if (exp <= 0) dst[i] = (uint16_t)sign; + else if (exp >= 31) dst[i] = (uint16_t)(sign | 0x7C00); + else dst[i] = (uint16_t)(sign | (exp << 10) | mant); + } +} + +static bool g_cuda_has_native_bf16() { + const char * env = std::getenv("DFLASH27B_DRAFT_FP16"); + if (env && std::atoi(env) != 0) return false; +#if defined(DFLASH27B_MIN_SM) && DFLASH27B_MIN_SM < 80 + return false; +#else + return true; +#endif +} + +static uint32_t get_u32_or(const gguf_context * g, const char * key, uint32_t fallback) { + int64_t id = gguf_find_key(g, key); + if (id < 0) return fallback; + return gguf_get_val_u32(g, id); +} + +static float get_f32_or(const gguf_context * g, const char * key, float fallback) { + int64_t id = gguf_find_key(g, key); + if (id < 0) return fallback; + return gguf_get_val_f32(g, id); +} + +} // anonymous namespace + +// ─── Public loader ──────────────────────────────────────────────────────── + +// Load Gemma4 DFlash draft weights from a directory containing one or more +// safetensors shards. We look for files named: +// model.safetensors (single-shard) +// model-00001-of-NNNNN.safetensors (multi-shard, first shard only for now) +// +// In practice the z-lab Gemma4 draft is small enough to fit in a single shard. +bool load_gemma4_draft_safetensors(const std::string & dir_path, + ggml_backend_t backend, + GemmaDraftWeights & out) +{ + // ── 1. Find the shard file ──────────────────────────────────────── + // Try the canonical single-shard name first. + std::string path = dir_path + "/model.safetensors"; + { + // Quick existence check without mmap + int fd_check = ::open(path.c_str(), O_RDONLY); + if (fd_check < 0) { + // Fall back to first numbered shard + path = dir_path + "/model-00001-of-00001.safetensors"; + fd_check = ::open(path.c_str(), O_RDONLY); + if (fd_check < 0) { + set_last_error("gemma4 draft: no safetensors file found in " + dir_path); + return false; + } + } + ::close(fd_check); + } + + // ── 2. Open + mmap ─────────────────────────────────────────────── + GMmap mm; + std::string err; + if (!mm.open_ro(path, err)) { set_last_error(err); return false; } + if (mm.len < 8) { set_last_error("gemma4 draft: safetensors file too small"); return false; } + + // ── 3. Parse header ────────────────────────────────────────────── + uint64_t header_len = 0; + std::memcpy(&header_len, mm.addr, 8); + if (header_len == 0 || 8 + header_len > mm.len) { + set_last_error("gemma4 draft: bad safetensors header length"); + return false; + } + const char * header_ptr = (const char *)mm.addr + 8; + GStMap st; + if (!parse_gst_header(header_ptr, (size_t)header_len, st)) { + set_last_error("gemma4 draft: safetensors JSON parse failed"); + return false; + } + const uint8_t * blob = (const uint8_t *)mm.addr + 8 + header_len; + const size_t blob_size = mm.len - 8 - (size_t)header_len; + + // ── 4. Infer draft dimensions from FC weight shape ─────────────── + // fc: [n_vocab_or_target_feat_in, draft_hidden] + // The FC input is 6*target_hidden; FC output is draft_hidden. + // HF shape in safetensors: [draft_hidden, 6*target_hidden] + { + auto it = st.find("fc.weight"); + if (it == st.end()) { + set_last_error("gemma4 draft: fc.weight not found"); + return false; + } + const GStEntry & e = it->second; + if (e.shape.size() != 2) { + set_last_error("gemma4 draft: model.fc.weight expected 2D"); + return false; + } + // HF stores as [out_features, in_features] = [draft_hidden, 6*target_hidden] + out.n_embd = (int)e.shape[0]; + int fc_in = (int)e.shape[1]; + out.target_hidden = fc_in / GEMMA4_DRAFT_N_TARGET_LAYERS; + if (fc_in % GEMMA4_DRAFT_N_TARGET_LAYERS != 0) { + char buf[128]; + std::snprintf(buf, sizeof(buf), + "gemma4 draft: FC input %d not divisible by n_target_layers %d", + fc_in, GEMMA4_DRAFT_N_TARGET_LAYERS); + set_last_error(buf); + return false; + } + } + + // Infer n_head / n_head_kv / n_ff from layer 0 weight shapes + { + auto iq = st.find("layers.0.self_attn.q_proj.weight"); + auto ik = st.find("layers.0.self_attn.k_proj.weight"); + auto ig = st.find("layers.0.mlp.gate_proj.weight"); + if (iq == st.end() || ik == st.end() || ig == st.end()) { + set_last_error("gemma4 draft: missing required layer-0 weight tensors"); + return false; + } + // q_proj HF shape: [q_dim, n_embd] where q_dim = n_head * head_dim + int q_dim = (int)iq->second.shape[0]; + int kv_dim = (int)ik->second.shape[0]; + out.n_head = q_dim / out.head_dim; + out.n_head_kv = kv_dim / out.head_dim; + out.n_ff = (int)ig->second.shape[0]; + // Also set layer_is_swa: layers [0..n_layer-2] are SWA, last is full + out.layer_is_swa.assign((size_t)out.n_layer, true); + out.layer_is_swa[(size_t)(out.n_layer - 1)] = false; + } + + const int64_t HIDDEN = out.n_embd; + const int64_t Q_DIM = (int64_t)out.n_head * out.head_dim; + const int64_t KV_DIM = (int64_t)out.n_head_kv * out.head_dim; + const int64_t INTER = out.n_ff; + const int64_t HD = out.head_dim; + const int64_t FC_IN = (int64_t)GEMMA4_DRAFT_N_TARGET_LAYERS * out.target_hidden; + // VOCAB not used here; tok_embd is injected at runtime from the target model. + + // ── 5. Allocate ggml context ───────────────────────────────────── + // tensors: fc, hidden_norm, out_norm = 3 top-level (tok_embd injected at runtime) + // 11 tensors × 5 layers = 55 + // total = 58 + headroom + const int n_tensors = 3 + 11 * out.n_layer + 8; + ggml_init_params ip{}; + ip.mem_size = (size_t)n_tensors * ggml_tensor_overhead(); + ip.mem_buffer = nullptr; + ip.no_alloc = true; + out.ctx = ggml_init(ip); + if (!out.ctx) { set_last_error("gemma4 draft: ggml_init failed"); return false; } + out.backend = backend; + out.layers.assign((size_t)out.n_layer, GemmaDraftLayer{}); + + const ggml_type NORM_GT = GGML_TYPE_F32; + const bool nbf16 = g_cuda_has_native_bf16(); + const ggml_type PROJ_GT = nbf16 ? GGML_TYPE_COUNT : GGML_TYPE_F16; + + // ── 6. Create named tensors ────────────────────────────────────── + out.fc = galloc_tensor(out.ctx, st, "fc.weight", {HIDDEN, FC_IN}, PROJ_GT); + out.hidden_norm = galloc_tensor(out.ctx, st, "hidden_norm.weight", {HIDDEN}, NORM_GT); + out.out_norm = galloc_tensor(out.ctx, st, "norm.weight", {HIDDEN}, NORM_GT); + // tok_embd is not present in the draft safetensors; the draft shares + // the target model's token embedding which is injected at runtime. + out.tok_embd = nullptr; + if (!out.fc || !out.hidden_norm || !out.out_norm) return false; + + for (int il = 0; il < out.n_layer; il++) { + char pfx[64]; + std::snprintf(pfx, sizeof(pfx), "layers.%d.", il); + std::string p = pfx; + GemmaDraftLayer & L = out.layers[(size_t)il]; + + L.attn_norm = galloc_tensor(out.ctx, st, p + "input_layernorm.weight", {HIDDEN}, NORM_GT); + L.ffn_norm = galloc_tensor(out.ctx, st, p + "post_attention_layernorm.weight", {HIDDEN}, NORM_GT); + L.wq = galloc_tensor(out.ctx, st, p + "self_attn.q_proj.weight", {Q_DIM, HIDDEN}, PROJ_GT); + L.wk = galloc_tensor(out.ctx, st, p + "self_attn.k_proj.weight", {KV_DIM, HIDDEN}, PROJ_GT); + L.wv = galloc_tensor(out.ctx, st, p + "self_attn.v_proj.weight", {KV_DIM, HIDDEN}, PROJ_GT); + L.wo = galloc_tensor(out.ctx, st, p + "self_attn.o_proj.weight", {HIDDEN, Q_DIM}, PROJ_GT); + L.q_norm = galloc_tensor(out.ctx, st, p + "self_attn.q_norm.weight", {HD}, NORM_GT); + L.k_norm = galloc_tensor(out.ctx, st, p + "self_attn.k_norm.weight", {HD}, NORM_GT); + L.w_gate = galloc_tensor(out.ctx, st, p + "mlp.gate_proj.weight", {INTER, HIDDEN}, PROJ_GT); + L.w_up = galloc_tensor(out.ctx, st, p + "mlp.up_proj.weight", {INTER, HIDDEN}, PROJ_GT); + L.w_down = galloc_tensor(out.ctx, st, p + "mlp.down_proj.weight", {HIDDEN, INTER}, PROJ_GT); + + if (!L.attn_norm || !L.ffn_norm || !L.wq || !L.wk || !L.wv || !L.wo || + !L.q_norm || !L.k_norm || !L.w_gate || !L.w_up || !L.w_down) { + return false; + } + } + + // ── 7. Allocate backend buffer and upload bytes ────────────────── + out.buf = ggml_backend_alloc_ctx_tensors(out.ctx, backend); + if (!out.buf) { + set_last_error("gemma4 draft: ggml_backend_alloc_ctx_tensors failed"); + return false; + } + + std::vector scratch_f32; + std::vector scratch_f16; + + for (ggml_tensor * t = ggml_get_first_tensor(out.ctx); t != nullptr; + t = ggml_get_next_tensor(out.ctx, t)) + { + const char * name = ggml_get_name(t); + auto it = st.find(name); + if (it == st.end()) { + set_last_error(std::string("gemma4 draft post-alloc: '") + + name + "' vanished from header"); + return false; + } + const GStEntry & e = it->second; + if (e.data_end > (uint64_t)blob_size) { + set_last_error(std::string("gemma4 draft: data offset out of bounds for '") + + name + "'"); + return false; + } + const size_t src_bytes = (size_t)(e.data_end - e.data_start); + const size_t dst_bytes = ggml_nbytes(t); + const bool same = (t->type == gst_dtype_to_ggml(e.dtype)); + + if (same) { + if (src_bytes != dst_bytes) { + char buf[256]; + std::snprintf(buf, sizeof(buf), + "gemma4 draft: byte mismatch for '%s': blob=%zu ggml=%zu", + name, src_bytes, dst_bytes); + set_last_error(buf); + return false; + } + ggml_backend_tensor_set(t, blob + e.data_start, 0, dst_bytes); + } else if (e.dtype == "BF16" && t->type == GGML_TYPE_F32) { + const size_t n = ggml_nelements(t); + if (src_bytes != n * 2 || dst_bytes != n * 4) { + set_last_error(std::string("gemma4 draft: BF16->F32 size mismatch for '") + name + "'"); + return false; + } + scratch_f32.resize(n); + g_bf16_to_f32((const uint16_t *)(blob + e.data_start), + scratch_f32.data(), n); + ggml_backend_tensor_set(t, scratch_f32.data(), 0, dst_bytes); + } else if (e.dtype == "BF16" && t->type == GGML_TYPE_F16) { + const size_t n = ggml_nelements(t); + if (src_bytes != n * 2 || dst_bytes != n * 2) { + set_last_error(std::string("gemma4 draft: BF16->F16 size mismatch for '") + name + "'"); + return false; + } + scratch_f16.resize(n); + g_bf16_to_f16((const uint16_t *)(blob + e.data_start), + scratch_f16.data(), n); + ggml_backend_tensor_set(t, scratch_f16.data(), 0, dst_bytes); + } else { + set_last_error(std::string("gemma4 draft: unsupported dtype conversion for '") + + name + "': " + e.dtype + " -> " + ggml_type_name(t->type)); + return false; + } + } + + std::fprintf(stderr, + "[gemma4 draft] loaded: n_layer=%d n_head=%d n_kv=%d " + "n_embd=%d n_ff=%d head_dim=%d target_hidden=%d vocab=%d\n", + out.n_layer, out.n_head, out.n_head_kv, + out.n_embd, out.n_ff, out.head_dim, out.target_hidden, out.n_vocab); + std::fflush(stderr); + + return true; +} + +bool load_gemma4_draft_gguf(const std::string & path, + ggml_backend_t backend, + GemmaDraftWeights & out) +{ + // ── 1. Parse metadata + create ggml_context with tensor descriptors ── + ggml_context * meta_ctx = nullptr; + gguf_init_params gip{}; + gip.no_alloc = true; + gip.ctx = &meta_ctx; + gguf_context * gctx = gguf_init_from_file(path.c_str(), gip); + if (!gctx) { + set_last_error("gguf_init_from_file failed: " + path); + return false; + } + + // Validate arch + { + int64_t arch_id = gguf_find_key(gctx, "general.architecture"); + if (arch_id < 0) { + set_last_error("gemma4 draft GGUF: missing general.architecture"); + gguf_free(gctx); + return false; + } + const char * arch = gguf_get_val_str(gctx, arch_id); + if (std::string(arch) != "gemma4-dflash-draft") { + set_last_error(std::string("gemma4 draft GGUF: unexpected arch: ") + arch + + " (expected gemma4-dflash-draft)"); + gguf_free(gctx); + return false; + } + } + + // Read dimensions from GGUF metadata + int64_t arch_id2 = gguf_find_key(gctx, "general.architecture"); + const char * A = gguf_get_val_str(gctx, arch_id2); + char key[256]; + + auto read_u32 = [&](const char * suffix, uint32_t fallback) -> uint32_t { + std::snprintf(key, sizeof(key), "%s.%s", A, suffix); + return get_u32_or(gctx, key, fallback); + }; + auto read_f32 = [&](const char * suffix, float fallback) -> float { + std::snprintf(key, sizeof(key), "%s.%s", A, suffix); + return get_f32_or(gctx, key, fallback); + }; + + const uint32_t n_embd = read_u32("embedding_length", 0); + const uint32_t n_layer = read_u32("block_count", 0); + const uint32_t n_ff = read_u32("feed_forward_length", 0); + const uint32_t n_head = read_u32("attention.head_count", 0); + const uint32_t n_head_kv = read_u32("attention.head_count_kv", 0); + const uint32_t head_dim = read_u32("attention.key_length", 0); + const uint32_t block_sz = read_u32("dflash.block_size", 0); + const uint32_t n_tgt_lay = read_u32("dflash.n_target_layers", 0); + const uint32_t target_hid = read_u32("dflash.target_hidden", 0); + const uint32_t mask_tok_id = read_u32("dflash.mask_token_id", GEMMA4_31B_DRAFT_MASK_TOKEN_ID); + const uint32_t sliding_win = read_u32("dflash.sliding_window", 2048); + const float logit_cap = read_f32("dflash.logit_softcap", GEMMA4_LOGIT_SOFTCAP); + const float rope_theta = read_f32("rope.freq_base", GEMMA4_ROPE_THETA); + + if (n_embd == 0 || n_layer == 0 || n_ff == 0 || n_head == 0 || + n_head_kv == 0 || head_dim == 0) { + char buf[256]; + std::snprintf(buf, sizeof(buf), + "gemma4 draft GGUF: missing hparams: n_embd=%u n_layer=%u n_ff=%u " + "n_head=%u n_head_kv=%u head_dim=%u", + n_embd, n_layer, n_ff, n_head, n_head_kv, head_dim); + set_last_error(buf); + gguf_free(gctx); + return false; + } + + // Validate block_size and n_target_layers match compiled constants + if (block_sz != (uint32_t)GEMMA4_DRAFT_BLOCK_SIZE || + n_tgt_lay != (uint32_t)GEMMA4_DRAFT_N_TARGET_LAYERS) { + char buf[256]; + std::snprintf(buf, sizeof(buf), + "gemma4 draft GGUF: dflash.block_size=%u (expected %d), " + "dflash.n_target_layers=%u (expected %d)", + block_sz, GEMMA4_DRAFT_BLOCK_SIZE, + n_tgt_lay, GEMMA4_DRAFT_N_TARGET_LAYERS); + set_last_error(buf); + gguf_free(gctx); + return false; + } + + // Sanity-check upper bounds + constexpr uint32_t MAX_LAYERS = 1024; + constexpr uint32_t MAX_EMBD = 1u << 17; + constexpr uint32_t MAX_FF = 1u << 19; + constexpr uint32_t MAX_HEADS = 1024; + constexpr uint32_t MAX_HEADDIM = 1024; + if (n_layer > MAX_LAYERS || n_embd > MAX_EMBD || + n_ff > MAX_FF || n_head > MAX_HEADS || + n_head_kv > MAX_HEADS || head_dim > MAX_HEADDIM || + n_head_kv > n_head || (n_head % n_head_kv) != 0) { + char buf[320]; + std::snprintf(buf, sizeof(buf), + "gemma4 draft GGUF: hparams out of range: n_embd=%u n_layer=%u n_ff=%u " + "n_head=%u n_head_kv=%u head_dim=%u", + n_embd, n_layer, n_ff, n_head, n_head_kv, head_dim); + set_last_error(buf); + gguf_free(gctx); + return false; + } + + // ── 2. Populate GemmaDraftWeights scalars ──────────────────────────── + out.ctx = meta_ctx; + out.backend = backend; + out.n_layer = (int)n_layer; + out.n_head = (int)n_head; + out.n_head_kv = (int)n_head_kv; + out.head_dim = (int)head_dim; + out.n_embd = (int)n_embd; + out.n_ff = (int)n_ff; + out.block_size = (int)block_sz; + out.n_target_layers = (int)n_tgt_lay; + out.target_hidden = (int)target_hid; + out.mask_token_id = (int)mask_tok_id; + out.sliding_window = (int)sliding_win; + out.logit_softcap = logit_cap; + out.rope_theta = rope_theta; + + // layers [0..n_layer-2] are SWA, last layer is full attention + out.layer_is_swa.assign((size_t)n_layer, true); + out.layer_is_swa[(size_t)(n_layer - 1)] = false; + + out.layers.assign((size_t)n_layer, GemmaDraftLayer{}); + + // tok_embd is injected at runtime from the target model (same as safetensors path) + out.tok_embd = nullptr; + + // ── 3. Wire tensor pointers ────────────────────────────────────────── + auto g = [&](const char * name) -> ggml_tensor * { + return ggml_get_tensor(meta_ctx, name); + }; + + out.fc = g("dflash.fc.weight"); + out.hidden_norm = g("dflash.hidden_norm.weight"); + out.out_norm = g("output_norm.weight"); + + if (!out.fc || !out.hidden_norm || !out.out_norm) { + set_last_error("gemma4 draft GGUF: missing top-level tensors " + "(dflash.fc.weight / dflash.hidden_norm.weight / output_norm.weight)"); + gguf_free(gctx); + return false; + } + + for (int il = 0; il < out.n_layer; il++) { + char name[128]; + auto fnd = [&](const char * suffix) -> ggml_tensor * { + std::snprintf(name, sizeof(name), "blk.%d.%s", il, suffix); + return ggml_get_tensor(meta_ctx, name); + }; + GemmaDraftLayer & L = out.layers[il]; + L.attn_norm = fnd("attn_norm.weight"); + L.ffn_norm = fnd("ffn_norm.weight"); + L.wq = fnd("attn_q.weight"); + L.wk = fnd("attn_k.weight"); + L.wv = fnd("attn_v.weight"); + L.wo = fnd("attn_output.weight"); + L.q_norm = fnd("attn_q_norm.weight"); + L.k_norm = fnd("attn_k_norm.weight"); + L.w_gate = fnd("ffn_gate.weight"); + L.w_up = fnd("ffn_up.weight"); + L.w_down = fnd("ffn_down.weight"); + if (!L.attn_norm || !L.ffn_norm || !L.wq || !L.wk || !L.wv || !L.wo || + !L.q_norm || !L.k_norm || !L.w_gate || !L.w_up || !L.w_down) { + char b[128]; + std::snprintf(b, sizeof(b), + "gemma4 draft GGUF: layer %d missing tensors", il); + set_last_error(b); + gguf_free(gctx); + return false; + } + } + + // ── 4. Allocate backend buffer for all tensors ─────────────────────── + out.buf = ggml_backend_alloc_ctx_tensors(meta_ctx, backend); + if (!out.buf) { + set_last_error("gemma4 draft GGUF: ggml_backend_alloc_ctx_tensors failed"); + gguf_free(gctx); + return false; + } + + // ── 5. mmap file and copy tensor bytes to backend ──────────────────── + std::string err; + GMmap mm; + if (!mm.open_ro(path, err)) { set_last_error(err); gguf_free(gctx); return false; } + const size_t data_start = gguf_get_data_offset(gctx); + const int64_t n_tensors = gguf_get_n_tensors(gctx); + + size_t total = 0; + for (int64_t tid = 0; tid < n_tensors; tid++) { + const char * tname = gguf_get_tensor_name(gctx, tid); + ggml_tensor * t = ggml_get_tensor(meta_ctx, tname); + if (!t) continue; + const size_t off = data_start + gguf_get_tensor_offset(gctx, tid); + const size_t sz = gguf_get_tensor_size(gctx, tid); + if (off + sz > mm.len) { + set_last_error(std::string("gemma4 draft GGUF: tensor '") + + tname + "' overflows file"); + gguf_free(gctx); + return false; + } + ggml_backend_tensor_set(t, (const uint8_t *)mm.addr + off, 0, sz); + total += sz; + } + + gguf_free(gctx); + + std::fprintf(stderr, + "[gemma4 draft GGUF] loaded: n_layer=%d n_head=%d n_kv=%d " + "n_embd=%d n_ff=%d head_dim=%d target_hidden=%d (%.2f GiB on GPU)\n", + out.n_layer, out.n_head, out.n_head_kv, + out.n_embd, out.n_ff, out.head_dim, out.target_hidden, + total / (1024.0 * 1024.0 * 1024.0)); + std::fflush(stderr); + + return true; +} + +void free_gemma4_draft_weights(GemmaDraftWeights & w) { + if (w.buf) { ggml_backend_buffer_free(w.buf); w.buf = nullptr; } + if (w.ctx) { ggml_free(w.ctx); w.ctx = nullptr; } + w.layers.clear(); + w.layer_is_swa.clear(); + w.fc = nullptr; + w.hidden_norm = nullptr; + w.out_norm = nullptr; + w.tok_embd = nullptr; +} + +} // namespace dflash27b diff --git a/dflash/src/gemma4_mtp_graph.cpp b/dflash/src/gemma4_mtp_graph.cpp new file mode 100644 index 000000000..d6e95724b --- /dev/null +++ b/dflash/src/gemma4_mtp_graph.cpp @@ -0,0 +1,760 @@ +// Single-step MTP (Multi-Token Prediction) graph builder for Gemma4. +// +// Builds a ggml compute graph that, given one token id and the target's last +// full-attention hidden state h_prev, produces: +// - out_logits : F32 [n_vocab, 1] full vocabulary row +// - out_h_post : F32 [n_embd_backbone, 1] next h_prev for the γ chain +// - out_argmax : I32 [1] greedy draft token (4-byte host pull per step) +// +// Architecture (mirrors atomicbot's gemma4-assistant.cpp lines 28-256): +// 1. Token embedding from target.tok_embd, scaled by sqrt(n_embd_backbone). +// 2. Concat [tok_emb, h_prev] → pre_projection → [n_embd, 1]. +// 3. 4 transformer blocks (cross-attention into target KV): +// RMSNorm → Q proj → Q-norm → RoPE → cross-attn (reads donor K/V) → +// wo → post_attn_norm → residual → ffn_norm → GELU FFN → post_ffn_norm → +// residual → optional out_scale. +// 4. output_norm → post_projection → h_post [n_embd_backbone, 1]. +// 5. LM head: dense (tied tok_embd) or centroid-routed for ordered embeddings. +// 6. In-graph argmax. +// +// Cross-attention contract: +// - Each MTP layer reads K/V from w.layers[il].donor_target_layer in the +// target KV cache (resolved at load time as the LAST target layer whose +// SWA type matches this MTP layer). +// - V is ALWAYS read from the cache (use_k_as_v=false): per HF Gemma4 the +// V slot stores rms-normed non-rotated vectors, distinct from post-RoPE K. +// - The K/V view covers [0, attn_pos) = all committed target positions. +// attn_pos is passed in via the in_pos tensor (caller sets it to +// cache.cur_pos before each step). +// - KV mask is not needed: all committed positions ≤ attn_pos are uniformly +// admitted (step position > attn_pos, so every cell is in the causal cone). +// We pass nullptr to ggml_flash_attn_ext for the mask argument. +// +// Centroid LM head (use_ordered_embeddings=true, always active for Dense 31B): +// cent_logits = mul_mat(mtp_centroids, h_inner) +// top_k_ids = ggml_top_k(cent_logits, centroid_top_k) +// sel_ids = get_rows(token_ordering_view, top_k_ids) +// sel_logits = mul_mat(get_rows(tok_embd, flat_sel_ids), h_inner) +// full_row = scatter sel_logits into [-1e30 fill] via ggml_set_rows +// +// When use_ordered_embeddings is false (fallback, unlikely for 31B assistant): +// out_logits = mul_mat(tok_embd, h_inner) — dense tied head. + +#include "internal.h" + +#include +#include +#include +#include + +namespace dflash27b { + +static constexpr float MTP_RMS_EPS = GEMMA4_RMS_EPS; + +// ─── Helpers ───────────────────────────────────────────────────────────────── + +static ggml_tensor * mtp_rms_norm_mul(ggml_context * ctx, + ggml_tensor * x, + ggml_tensor * weight) { + ggml_tensor * n = ggml_rms_norm(ctx, x, MTP_RMS_EPS); + return ggml_mul(ctx, n, weight); +} + +// GELU FFN with SwiGLU-like gate: w_down @ (gelu(w_gate @ x) * (w_up @ x)) +static ggml_tensor * mtp_gelu_ffn(ggml_context * ctx, + ggml_tensor * cur, + const MtpLayerWeights & L) { + ggml_tensor * gate = ggml_mul_mat(ctx, L.ffn_gate, cur); + ggml_tensor * up = ggml_mul_mat(ctx, L.ffn_up, cur); + ggml_tensor * gu = ggml_geglu_split(ctx, gate, up); + return ggml_mul_mat(ctx, L.ffn_down, gu); +} + +// ─── Public graph builder ───────────────────────────────────────────────────── + +bool build_mtp_step_graph(const MtpDrafterWeights & w, + const GemmaTargetCache & target_cache, + const GemmaTargetWeights & target, + MtpStepGraph & out, + int attn_pos) { + // ── Validate prerequisites ──────────────────────────────────────────────── + if (!w.pre_projection || !w.post_projection || !w.output_norm) { + set_last_error("build_mtp_step_graph: MtpDrafterWeights missing pre/post projection or output_norm"); + return false; + } + if ((int)w.layers.size() == 0) { + set_last_error("build_mtp_step_graph: no MTP layers"); + return false; + } + if (!target.tok_embd) { + set_last_error("build_mtp_step_graph: target.tok_embd is null"); + return false; + } + if (w.n_embd == 0 || w.n_embd_backbone == 0) { + set_last_error("build_mtp_step_graph: n_embd or n_embd_backbone is 0"); + return false; + } + + const int n_embd_backbone = w.n_embd_backbone; + const int n_layer = (int)w.layers.size(); + const int n_vocab = (int)target.tok_embd->ne[1]; + + // Validate layer 0 donor KV slot (each layer validates its own in the loop). + { + const int32_t donor_il_0 = w.layers[0].donor_target_layer; + if (donor_il_0 < 0 || donor_il_0 >= (int)target_cache.layer_to_kv_idx.size()) { + set_last_error("build_mtp_step_graph: invalid donor_target_layer for MTP layer 0"); + return false; + } + const int kv_slot_0 = target_cache.layer_to_kv_idx[donor_il_0]; + const int kv_read_slot_0 = (kv_slot_0 >= 0) ? kv_slot_0 + : ((donor_il_0 < (int)target_cache.layer_to_donor_kv.size()) + ? target_cache.layer_to_donor_kv[donor_il_0] : -1); + if (kv_read_slot_0 < 0 || kv_read_slot_0 >= (int)target_cache.attn_k.size()) { + set_last_error("build_mtp_step_graph: donor KV slot unresolvable for MTP layer 0"); + return false; + } + } + + // ── Allocate ggml context ───────────────────────────────────────────────── + // Conservative tensor overhead: 3 inputs + ~80 ops per layer + outputs. + // Extras vs original: K/V casts, GQA block-broadcast views/materialization, + // Q permute/cont, explicit KQ mask, Vt materialization. + const size_t n_tensors_est = (size_t)(3 + n_layer * 80 + 20); + ggml_init_params ip{}; + ip.mem_size = n_tensors_est * ggml_tensor_overhead() + 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + ggml_context * ctx = ggml_init(ip); + if (!ctx) { + set_last_error("build_mtp_step_graph: ggml_init failed"); + return false; + } + + ggml_cgraph * gf = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false); + + // ── Input tensors ───────────────────────────────────────────────────────── + ggml_tensor * in_tok = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + ggml_set_input(in_tok); + ggml_set_name(in_tok, "mtp_in_tok"); + + // in_tok_embd: pre-dequantised token embedding supplied by caller. + // Caller must call target.embedder.embed(&tok, 1, buf) and tensor_set before compute. + // This avoids ggml_get_rows on a k-quant (Q4_K) source which the CUDA backend + // does not support in this llama.cpp revision. + ggml_tensor * in_tok_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd_backbone, 1); + ggml_set_input(in_tok_embd); + ggml_set_name(in_tok_embd, "mtp_in_tok_embd"); + + ggml_tensor * in_h_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd_backbone, 1); + ggml_set_input(in_h_prev); + ggml_set_name(in_h_prev, "mtp_in_h_prev"); + + // in_pos: absolute target position for this draft step's RoPE. + // Caller sets this to (cache.cur_pos + step_offset). + ggml_tensor * in_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + ggml_set_input(in_pos); + ggml_set_name(in_pos, "mtp_in_pos"); + + // ── 1. Token embedding from target (shared weight) ──────────────────────── + // Embedding is passed in pre-dequantised by the caller via in_tok_embd. + // This bypasses ggml_get_rows on a potentially quantised target.tok_embd + // (CUDA backend in this revision only supports F16/F32/Q8_0 for get_rows; + // Q4_K targets would abort at compute time). + ggml_tensor * tok_e = in_tok_embd; + ggml_set_name(tok_e, "mtp_tok_embd"); + + // Gemma4 scales token embeddings by sqrt(n_embd_backbone) at input pipeline + const float tok_scale = std::sqrt((float)n_embd_backbone); + tok_e = ggml_scale(ctx, tok_e, tok_scale); + ggml_set_name(tok_e, "mtp_tok_embd_scaled"); + + // ── 2. Concat [tok_e, h_prev] and project to n_embd ────────────────────── + // Both are [n_embd_backbone, 1]; concat on axis 0 → [2*n_embd_backbone, 1] + ggml_tensor * inp_cat = ggml_concat(ctx, tok_e, in_h_prev, 0); + ggml_set_name(inp_cat, "mtp_concat"); + + // pre_projection: [2*n_embd_backbone, n_embd] (ggml ne[0]=2*n_bb, ne[1]=n_embd) + // mul_mat(A, x): A->ne[0] must == x->ne[0]; output ne[0]=A->ne[1] + ggml_tensor * inpL = ggml_mul_mat(ctx, w.pre_projection, inp_cat); + ggml_set_name(inpL, "mtp_pre_proj_out"); + + // ── 3. Transformer blocks ───────────────────────────────────────────────── + // Single FA mask shared across every layer that needs one. First need-mask + // layer creates the input tensor; later layers reuse it. We require every + // need-mask layer to want the same (width, kv_seq_len) — short contexts + // satisfy this because SWA cap >= attn_pos. Divergence in long contexts + // trips an error and the builder must be extended to per-layer masks. + ggml_tensor * shared_fa_mask = nullptr; + int64_t shared_fa_mask_width = 0; + int64_t shared_fa_mask_kv_seq_len = 0; + for (int il = 0; il < n_layer; ++il) { + const MtpLayerWeights & L = w.layers[il]; + const bool is_swa = L.is_swa; + + // Resolve donor KV slot + const int32_t donor_il = L.donor_target_layer; + if (donor_il < 0 || donor_il >= (int)target_cache.layer_to_kv_idx.size()) { + set_last_error("build_mtp_step_graph: invalid donor_target_layer"); + ggml_free(ctx); + return false; + } + const int kv_slot = target_cache.layer_to_kv_idx[donor_il]; + const int kv_read_slot = (kv_slot >= 0) ? kv_slot + : ((donor_il < (int)target_cache.layer_to_donor_kv.size()) + ? target_cache.layer_to_donor_kv[donor_il] : -1); + if (kv_read_slot < 0 || kv_read_slot >= (int)target_cache.attn_k.size()) { + char buf[128]; + std::snprintf(buf, sizeof(buf), + "build_mtp_step_graph: donor KV slot unresolvable for MTP layer %d", il); + set_last_error(buf); + ggml_free(ctx); + return false; + } + ggml_tensor * cache_k = target_cache.attn_k[kv_read_slot]; + ggml_tensor * cache_v = target_cache.attn_v[kv_read_slot]; + if (!cache_k || !cache_v) { + char buf[128]; + std::snprintf(buf, sizeof(buf), + "build_mtp_step_graph: null KV cache for MTP layer %d donor slot %d", il, kv_read_slot); + set_last_error(buf); + ggml_free(ctx); + return false; + } + + // KV cache layout: [head_dim_kv, max_ctx, n_head_kv] + const int64_t head_dim_kv = cache_k->ne[0]; + const int64_t n_head_kv = cache_k->ne[2]; + // Q dimensions: derive from wq output size and attn_q_norm shape. + // wq: [n_embd, q_out_dim] where q_out_dim = n_head_norm * head_dim_norm + // attn_q_norm:[head_dim_norm] per-head norm weight from the MTP model's own hparams + // + // head_dim_norm may differ from head_dim_kv (the target KV cache head_dim). + // Dense 31B example: MTP trained with head_dim_norm=256, target K stored at 128. + // For flash_attn Q @ K^T to succeed, Q.ne[0] must equal K.ne[0]. + // Fix: norm and RoPE run at head_dim_norm; before FA, reshape Q to [head_dim_kv, ...] + // so the dot-product dimension matches K. q_out_dim is preserved throughout. + const int64_t q_out_dim = L.wq->ne[1]; + const int64_t head_dim_norm = L.attn_q_norm->ne[0]; // MTP model's per-head norm dim + const int64_t n_head_norm = q_out_dim / head_dim_norm; + // FA head_dim must match target K; use head_dim_kv (from cache_k->ne[0]). + const int64_t head_dim_fa = head_dim_kv; + const int64_t n_head_fa = q_out_dim / head_dim_fa; + + // a) RMSNorm + ggml_tensor * cur = mtp_rms_norm_mul(ctx, inpL, L.attn_norm); + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_attn_norm_%d", il); + ggml_set_name(cur, name); + } + + // b) Q projection: [n_embd, 1] → [q_out_dim, 1], reshape to [head_dim_norm, n_head_norm, 1] + ggml_tensor * Qcur = ggml_mul_mat(ctx, L.wq, cur); + Qcur = ggml_reshape_3d(ctx, Qcur, head_dim_norm, n_head_norm, 1); + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_Qcur_%d", il); + ggml_set_name(Qcur, name); + } + + // c) Q-norm: per-head RMSNorm at head_dim_norm (attn_q_norm shape: [head_dim_norm]) + Qcur = mtp_rms_norm_mul(ctx, Qcur, L.attn_q_norm); + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_Qcur_normed_%d", il); + ggml_set_name(Qcur, name); + } + + // d) RoPE on Q at head_dim_norm + // Use the target's rope_theta (SWA layers) or the full-attn layer's rope_freqs. + // For MTP cross-attention: SWA layers use rope_theta_swa, full layers use rope_theta + // (with per-layer freq_factors from the donor layer). + // We use the target's SWA/full rope parameters mirroring atomicbot. + ggml_tensor * rope_freq_factors = nullptr; + float rope_theta_val = target.rope_theta_swa; + if (!is_swa) { + rope_theta_val = target.rope_theta; + // For full-attention MTP layers: prefer assistant's OWN rope_freqs + // (top-level "rope_freqs.weight" in assistant GGUF — the assistant + // was trained with its own per-dim freq factors). Fall back to + // target's per-layer rope_freqs only if the assistant didn't ship + // one (legacy GGUFs). + if (w.rope_freqs) { + rope_freq_factors = w.rope_freqs; + } else if (donor_il >= 0 && donor_il < (int)target.layers.size()) { + rope_freq_factors = target.layers[donor_il].rope_freqs; + } + } + Qcur = ggml_rope_ext(ctx, Qcur, in_pos, + rope_freq_factors, + (int)head_dim_norm, GGML_ROPE_TYPE_NEOX, + /*n_ctx_orig=*/0, + rope_theta_val, /*freq_scale=*/1.0f, + /*ext_factor=*/0.0f, /*attn_factor=*/1.0f, + /*beta_fast=*/0.0f, /*beta_slow=*/0.0f); + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_Qcur_pos_%d", il); + ggml_set_name(Qcur, name); + } + + // e) Cross-attention (manual: Q@K^T → scale → softmax → @V) + // Sidesteps ggml_flash_attn_ext CUDA kernel shape restrictions for MTP. + // + // Make Qcur contiguous before reshape — ggml_rope_ext returns a non-contiguous + // view; ggml_reshape_3d requires a contiguous source. + Qcur = ggml_cont(ctx, Qcur); + // Reshape Q from [head_dim_norm, n_head_norm, 1] to [head_dim_fa, n_head_fa, 1] + // so Q.ne[0] == K.ne[0] == head_dim_kv. + // When head_dim_norm == head_dim_fa this is a no-op reshape. + Qcur = ggml_reshape_3d(ctx, Qcur, head_dim_fa, n_head_fa, 1); + + // K/V view from the target KV cache. + // Full-attention donors read [0, attn_pos). SWA donors use a ring buffer: + // slice only the keys admitted by atomicbot's STANDARD SWA mask for an MTP + // query at pos=attn_pos, then the remaining mask is an all-zero bias. + int64_t kv_seq_len = (int64_t)attn_pos; + int64_t kv_start_slot = 0; + bool kv_wraps = false; + int64_t kv_first_len = 0; + if (is_swa) { + const int64_t ring_len = std::min(cache_k->ne[1], cache_v->ne[1]); + const int64_t swa_prev = target.swa_window > 0 + ? std::max((int64_t)target.swa_window - 1, 0) : ring_len; + kv_seq_len = std::min((int64_t)attn_pos, std::min(swa_prev, ring_len)); + if (kv_seq_len > 0) { + const int64_t first_abs = (int64_t)attn_pos - kv_seq_len; + kv_start_slot = first_abs % ring_len; + const int64_t kv_end_slot = kv_start_slot + kv_seq_len; + kv_wraps = kv_end_slot > ring_len; + kv_first_len = kv_wraps ? (ring_len - kv_start_slot) : kv_seq_len; + } + } else if ((int64_t)attn_pos > cache_k->ne[1] || (int64_t)attn_pos > cache_v->ne[1]) { + char buf[256]; + std::snprintf(buf, sizeof(buf), + "build_mtp_step_graph: attn_pos %d exceeds donor KV cache length (K=%lld V=%lld) for MTP layer %d", + attn_pos, (long long)cache_k->ne[1], (long long)cache_v->ne[1], il); + set_last_error(buf); + ggml_free(ctx); + return false; + } + // Pad to 1 minimum to avoid zero-size tensors when attn_pos==0. + const int64_t kv_view_len = std::max(kv_seq_len, (int64_t)1); + + // For head_dim==512 with any K type, ggml_flash_attn_ext requires + // K->ne[1] % 256 == 0 for gqa_opt_applies to be true (and returns + // BEST_FATTN_KERNEL_NONE otherwise). Pad the K/V view to the next 256 + // multiple; the padding rows contain stale cache data but are masked + // out by the caller-provided fa_mask with -inf bias on those positions. + // This only applies to the non-wrap path (head_dim=512 layers are full-attn + // with monotone KV so no wrap occurs). + // FATTN_KQ_STRIDE alignment: TQ3_0 K is stored in blocks along ne[1] and + // the FA kernels (chunked + vec) iterate KV in 256-position groups; an + // unaligned ne[1] reads past the valid window into stale cache cells. We + // pad the view to 256 and exclude the tail with a -inf mask. + // This matches gemma4_target_graph.cpp:352-355's `need_256_pad` policy. + const bool kv_cache_is_tq3 = + (cache_k->type == GGML_TYPE_TQ3_0 || cache_v->type == GGML_TYPE_TQ3_0); + if (kv_wraps && + (cache_k->type == GGML_TYPE_TQ3_0 || cache_v->type == GGML_TYPE_TQ3_0)) { + char buf[256]; + std::snprintf(buf, sizeof(buf), + "build_mtp_step_graph: refusing wrapped TQ3 donor attention for MTP layer %d donor=%d; force donor KV to Q8_0", + il, donor_il); + set_last_error(buf); + ggml_free(ctx); + return false; + } + const bool needs_kv_pad = (kv_cache_is_tq3 || head_dim_fa >= 512) + && !kv_wraps && (kv_view_len % 256 != 0); + const int64_t kv_view_len_padded = needs_kv_pad + ? ((kv_view_len + 255) / 256) * 256 + : kv_view_len; + + auto view_kv = [&](ggml_tensor * cache, int64_t start, int64_t len) { + return ggml_view_3d(ctx, cache, + head_dim_kv, len, n_head_kv, + cache->nb[1], cache->nb[2], + cache->nb[1] * (size_t)start); + }; + + ggml_tensor * Kview = nullptr; + ggml_tensor * Vview = nullptr; + if (kv_wraps) { + // ggml_concat on CUDA requires F32 src. Direct TQ3_0→F32 is unsupported + // by cpy.cu (it only does TQ3_0→F16 and F16↔F32). So go via F16 first + // when the cache is TQ3, else cast directly. + auto to_f32 = [&](ggml_tensor * v) { + if (v->type == GGML_TYPE_TQ3_0) { + v = ggml_cast(ctx, v, GGML_TYPE_F16); + } + if (v->type != GGML_TYPE_F32) { + v = ggml_cast(ctx, v, GGML_TYPE_F32); + } + return v; + }; + const int64_t kv_second_len = kv_view_len - kv_first_len; + ggml_tensor * k1 = to_f32(view_kv(cache_k, kv_start_slot, kv_first_len)); + ggml_tensor * k2 = to_f32(view_kv(cache_k, 0, kv_second_len)); + ggml_tensor * v1 = to_f32(view_kv(cache_v, kv_start_slot, kv_first_len)); + ggml_tensor * v2 = to_f32(view_kv(cache_v, 0, kv_second_len)); + Kview = ggml_concat(ctx, k1, k2, 1); + Vview = ggml_concat(ctx, v1, v2, 1); + } else { + // Use padded length for the K/V view when required. + Kview = view_kv(cache_k, kv_start_slot, kv_view_len_padded); + Vview = view_kv(cache_v, kv_start_slot, kv_view_len_padded); + } + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_Kview_%d", il); + ggml_set_name(Kview, name); + std::snprintf(name, sizeof(name), "mtp_Vview_%d", il); + ggml_set_name(Vview, name); + } + + // Detect if K/V is in TQ3_0 (FWHT-domain). Graph-level FWHT keeps the + // FA backends on a single contract: pre-rotate Q for TQ3 K, inverse- + // rotate output for TQ3 V, and pass the native K/V views into FA. + const bool k_is_tq3 = (Kview->type == GGML_TYPE_TQ3_0); + const bool v_is_tq3 = (Vview->type == GGML_TYPE_TQ3_0); + const bool kv_is_tq3 = k_is_tq3 || v_is_tq3; + + // Cross-attention via ggml_flash_attn_ext. + // + // Layout for ggml_flash_attn_ext: + // Q: [head_dim, n_tokens=1, n_head_q] + // K: [head_dim, kv_len, n_head_kv] (GQA: n_head_q % n_head_kv == 0) + // V: [head_dim, kv_len, n_head_kv] + // output: [head_dim, n_tokens=1, n_head_q] (reshaped to [q_out_dim, 1]) + // + // Benefits over manual matmul attention: + // - Handles GQA directly without broadcasting K/V. + // - Graph-level FWHT correction keeps TQ3 K/V in their native cache domain. + // + // For TQ3_0 + head_dim > 256 + n_tokens=1 (decode), the CUDA dispatch + // requires a non-null mask to select the CHUNKED kernel path. We create + // an all-zero (fully-admitted) mask in that case. + // + // Permute Q from [head_dim_fa, n_head_fa, 1] → [head_dim_fa, 1, n_head_fa] + // so it matches the FA expected layout. + // ggml_turbo_wht's CUDA kernel writes dst using src strides + // (turbo-wht.cu:20-21); non-contiguous input scatters writes and + // corrupts Q. Always make Q contiguous BEFORE rotating. + ggml_tensor * Qfa = ggml_cont(ctx, ggml_permute(ctx, Qcur, 0, 2, 1, 3)); + if (k_is_tq3) { + Qfa = ggml_turbo_wht(ctx, Qfa, 0); + } + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_Qfa_%d", il); + ggml_set_name(Qfa, name); + } + + // K/V for FA: pass the original Kview/Vview (TQ3_0, Q8_0, or concat-F32) + // directly to ggml_flash_attn_ext. Graph-level FWHT correction above/below + // accounts for TQ3_0 K/V without stripping the tensor type tag. + // For the wrap case (kv_wraps=true), Kview is already F32 (from to_f32 + concat). + ggml_tensor * Kfa = Kview; // original type (TQ3_0, Q8_0, or concat-F32) + ggml_tensor * Vfa = Vview; + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_Kfa_%d", il); + ggml_set_name(Kfa, name); + std::snprintf(name, sizeof(name), "mtp_Vfa_%d", il); + ggml_set_name(Vfa, name); + } + // For head_dim==512 (any K type): the MMA dispatcher requires + // gqa_opt_applies, which requires BOTH K->ne[1] % 256 == 0 AND + // mask != nullptr. Without mask, BEST_FATTN_KERNEL_NONE → abort + // even when K is properly aligned. Always provide the mask. + // We padded K/V to kv_view_len_padded above when needs_kv_pad is true; + // when not padding, mask width == kv_view_len (all positions admitted). + // The caller fills: positions [0..kv_seq_len-1] = 0.0 (admit), + // positions [kv_seq_len..mask_width-1] = -inf (exclude padding). + // + // For head_dim==256 (SWA) with TQ3_0 K (non-wrap): VEC kernel handles it + // without mask UNLESS needs_kv_pad triggers (KV unaligned); then mask is + // needed to exclude the padding tail. + // For wrap case (F32 K/V after concat): no TQ3_0 issues, no mask needed. + const bool need_mask = head_dim_fa >= 512 || needs_kv_pad; + // Log per-layer FA types on every graph build (no static gate so subsequent + // chains are visible; need_mask read from the variable computed above). + std::printf("[mtp-fa-types] layer %d: Qfa=%s Kfa=%s Vfa=%s " + "head_dim_fa=%lld kv_is_tq3=%d need_mask=%d\n", + il, ggml_type_name(Qfa->type), ggml_type_name(Kfa->type), + ggml_type_name(Vfa->type), (long long)head_dim_fa, + (int)kv_is_tq3, (int)need_mask); + const int64_t fa_mask_width = (needs_kv_pad ? kv_view_len_padded : kv_view_len); + ggml_tensor * fa_mask = nullptr; + if (need_mask) { + if (shared_fa_mask == nullptr) { + shared_fa_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, fa_mask_width, 1); + ggml_set_name(shared_fa_mask, "mtp_fa_mask"); + ggml_set_input(shared_fa_mask); + shared_fa_mask_width = fa_mask_width; + shared_fa_mask_kv_seq_len = kv_view_len; + } else if (shared_fa_mask_width != fa_mask_width + || shared_fa_mask_kv_seq_len != kv_view_len) { + char buf[256]; + std::snprintf(buf, sizeof(buf), + "build_mtp_step_graph: per-layer FA masks diverge " + "(layer %d wants width=%lld kv_seq=%lld; existing %lld/%lld). " + "Long-context SWA cap mismatch — extend builder to per-layer masks.", + il, (long long)fa_mask_width, (long long)kv_view_len, + (long long)shared_fa_mask_width, (long long)shared_fa_mask_kv_seq_len); + set_last_error(buf); + ggml_free(ctx); + return false; + } + fa_mask = shared_fa_mask; + } + + // Gemma4 MTP: f_attention_scale = 1.0 (no pre-softmax scaling). + ggml_tensor * attn_out = ggml_flash_attn_ext(ctx, Qfa, Kfa, Vfa, fa_mask, + 1.0f, 0.0f, 0.0f); + if (v_is_tq3) { + attn_out = ggml_cont(ctx, attn_out); + attn_out = ggml_turbo_wht(ctx, attn_out, 1); + } + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_fa_out_%d", il); + ggml_set_name(attn_out, name); + } + + // FA output: [head_dim_fa, 1, n_head_fa]. Flatten to [q_out_dim, 1]. + // Flatten heads: [head_dim_fa, 1, n_head_fa] → [q_out_dim, 1] + ggml_tensor * attn = ggml_cont(ctx, attn_out); + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_attn_out_%d", il); + ggml_set_name(attn, name); + } + + // Reshape: [q_out_dim, 1] then output projection + // head_dim_fa * n_head_fa == q_out_dim == head_dim_norm * n_head_norm + attn = ggml_reshape_2d(ctx, attn, q_out_dim, 1); + cur = ggml_mul_mat(ctx, L.wo, attn); + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_attn_proj_%d", il); + ggml_set_name(cur, name); + } + + // f) Post-attention norm + cur = mtp_rms_norm_mul(ctx, cur, L.attn_post_norm); + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_attn_post_norm_%d", il); + ggml_set_name(cur, name); + } + + // g) Attention residual + ggml_tensor * attn_residual = ggml_add(ctx, cur, inpL); + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_attn_residual_%d", il); + ggml_set_name(attn_residual, name); + } + + // h) FFN norm + ggml_tensor * ffn_in = mtp_rms_norm_mul(ctx, attn_residual, L.ffn_norm); + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_ffn_norm_%d", il); + ggml_set_name(ffn_in, name); + } + + // i) GELU FFN + ggml_tensor * ffn_out = mtp_gelu_ffn(ctx, ffn_in, L); + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_ffn_out_%d", il); + ggml_set_name(ffn_out, name); + } + + // j) Post-FFN norm + ffn_out = mtp_rms_norm_mul(ctx, ffn_out, L.ffn_post_norm); + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_ffn_post_norm_%d", il); + ggml_set_name(ffn_out, name); + } + + // k) FFN residual + cur = ggml_add(ctx, ffn_out, attn_residual); + + // l) Optional per-layer output scale + if (L.out_scale) { + cur = ggml_mul(ctx, cur, L.out_scale); + { + char name[64]; std::snprintf(name, sizeof(name), "mtp_out_scaled_%d", il); + ggml_set_name(cur, name); + } + } + + inpL = cur; + } + + // ── 4. Output norm ──────────────────────────────────────────────────────── + ggml_tensor * h_inner = mtp_rms_norm_mul(ctx, inpL, w.output_norm); + ggml_set_name(h_inner, "mtp_result_norm"); + + // ── 5. Post-projection → h_post (next h_prev) ───────────────────────────── + // post_projection: [n_embd, n_embd_backbone] (ggml ne[0]=n_embd, ne[1]=n_embd_backbone) + ggml_tensor * h_post = ggml_mul_mat(ctx, w.post_projection, h_inner); + ggml_set_name(h_post, "mtp_post_proj_out"); + + // ── 6. LM head ──────────────────────────────────────────────────────────── + ggml_tensor * logits = nullptr; + + // Determine whether tok_embd supports ggml_get_rows on CUDA. + // This backend (custom llama.cpp fork) only supports F32/F16/BF16/Q4_0/Q4_1/ + // Q5_0/Q5_1/Q8_0/TQ3_0 for get_rows; K-quant types (Q4_K, Q5_K, Q6_K) are not. + // When tok_embd is a K-quant, the centroid sparse path can't use get_rows; + // fall back to dense mul_mat for logit computation instead. + const bool tok_embd_get_rows_ok = + (w.tok_embd && + (w.tok_embd->type == GGML_TYPE_F32 || + w.tok_embd->type == GGML_TYPE_F16 || + w.tok_embd->type == GGML_TYPE_BF16 || + w.tok_embd->type == GGML_TYPE_Q4_0 || + w.tok_embd->type == GGML_TYPE_Q4_1 || + w.tok_embd->type == GGML_TYPE_Q5_0 || + w.tok_embd->type == GGML_TYPE_Q5_1 || + w.tok_embd->type == GGML_TYPE_Q8_0)); + + if (w.use_ordered_embeddings && w.centroids && w.n_centroids > 0 && tok_embd_get_rows_ok) { + // Centroid-routed LM head (matches atomicbot lines 190-235). + // All mul_mat ops use h_inner [n_embd, 1] (MTP's own hidden space, n_embd=1024). + // The embedding source is the MTP model's own tok_embd [n_embd, n_vocab] (w.tok_embd), + // NOT the target's tok_embd (which is in backbone space and used only in step 1). + if (!w.tok_embd) { + set_last_error("build_mtp_step_graph: use_ordered_embeddings=true but w.tok_embd is null (token_embd.weight missing from GGUF)"); + ggml_free(ctx); + return false; + } + + const int64_t n_c = (int64_t)w.n_centroids; + const int64_t top_k = (int64_t)w.centroid_top_k; + // Validate centroid-head shape and index invariants before any arithmetic. + GGML_ASSERT(n_vocab > 0 && "centroid LM head: n_vocab must be > 0"); + GGML_ASSERT(n_c > 0 && "centroid LM head: n_centroids must be > 0"); + GGML_ASSERT(n_vocab % n_c == 0 + && "centroid LM head: n_vocab must be divisible by n_centroids"); + GGML_ASSERT(top_k > 0 && top_k <= n_c + && "centroid LM head: top_k must be in [1, n_centroids]"); + // vsc: tokens per centroid slot + const int64_t vsc = (int64_t)n_vocab / n_c; + + // centroid_logits = mul_mat(centroids, h_inner) → [n_centroids, 1] + // centroids: [n_embd, n_centroids] (ne[0]=n_embd, ne[1]=n_centroids) + ggml_tensor * centroid_logits = ggml_mul_mat(ctx, w.centroids, h_inner); + ggml_set_name(centroid_logits, "mtp_centroid_logits"); + + // top-k centroid indices + ggml_tensor * topk_idx = ggml_top_k(ctx, centroid_logits, (int)top_k); + ggml_set_name(topk_idx, "mtp_centroid_topk_idx"); + + // View token_ordering as [vsc, n_centroids] (I32) + const size_t ordering_row_bytes = ggml_row_size(GGML_TYPE_I32, vsc); + ggml_tensor * ordering = ggml_view_2d(ctx, w.token_ordering, + vsc, n_c, ordering_row_bytes, /*offset=*/0); + ggml_set_name(ordering, "mtp_token_ordering_view"); + + // Gather candidate token ids for top-k centroids: [vsc, top_k, 1] + ggml_tensor * sel_ids = ggml_get_rows(ctx, ordering, topk_idx); + ggml_set_name(sel_ids, "mtp_selected_token_ids"); + + // Flatten to 1D for embedding lookup + const int64_t n_sel = top_k * vsc; + ggml_tensor * flat_ids = ggml_reshape_1d(ctx, sel_ids, n_sel); + ggml_set_name(flat_ids, "mtp_selected_token_ids_flat"); + + // Gather embeddings for selected tokens from MTP's own tok_embd [n_embd, n_vocab]. + // get_rows selects n_sel rows → [n_embd, n_sel] + ggml_tensor * sel_emb = ggml_get_rows(ctx, w.tok_embd, flat_ids); + ggml_set_name(sel_emb, "mtp_selected_embd"); + + // Sparse logits: mul_mat(sel_emb, h_inner): + // sel_emb [n_embd, n_sel], h_inner [n_embd, 1] → [n_sel, 1] + ggml_tensor * sel_logits = ggml_mul_mat(ctx, sel_emb, h_inner); + ggml_set_name(sel_logits, "mtp_selected_logits"); + ggml_tensor * sel_logits_f32 = ggml_cast(ctx, sel_logits, GGML_TYPE_F32); + ggml_set_name(sel_logits_f32, "mtp_selected_logits_f32"); + + // Build full vocab row pre-filled with -1e30 + ggml_tensor * logits_full = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_vocab, 1); + logits_full = ggml_fill_inplace(ctx, logits_full, -1e30f); + ggml_set_name(logits_full, "mtp_logits_masked_base"); + + // Scatter selected logits into full row + ggml_tensor * scatter_dst = ggml_cont_2d(ctx, logits_full, 1, (int64_t)n_vocab); + ggml_tensor * scatter_src = ggml_cont_2d(ctx, sel_logits_f32, 1, n_sel); + logits = ggml_set_rows(ctx, scatter_dst, scatter_src, flat_ids); + logits = ggml_reshape_2d(ctx, logits, n_vocab, 1); + ggml_set_name(logits, "mtp_logits_full"); + } else if (w.use_ordered_embeddings && w.tok_embd) { + // Dense fallback for ordered-embeddings models when tok_embd type does not + // support CUDA get_rows (e.g. K-quants in this llama.cpp fork). + // mul_mat supports Q4_K/Q5_K/Q6_K on CUDA; produces exact logits + // (not the centroid approximation) which is fine for greedy/low-temp decoding. + logits = ggml_mul_mat(ctx, w.tok_embd, h_inner); + ggml_set_name(logits, "mtp_logits_dense_fallback"); + } else { + // Dense tied LM head: mul_mat(tok_embd, h_post) → [n_vocab, 1] + // For non-ordered-embeddings models (n_embd == n_embd_backbone), use h_post + // (post-projected to n_embd_backbone) so dimensions match target.tok_embd. + // Prefer w.tok_embd (MTP's own, in n_embd space) if available, else + // fall back to target.tok_embd (in n_embd_backbone space) with h_post. + if (w.tok_embd) { + // MTP has its own tied LM head in n_embd space + logits = ggml_mul_mat(ctx, w.tok_embd, h_inner); + } else { + // Fallback: use target's tok_embd against the backbone-projected hidden + logits = ggml_mul_mat(ctx, target.tok_embd, h_post); + } + ggml_set_name(logits, "mtp_logits_dense"); + } + + // Optional logit softcapping (matches target's softcap=30) + if (target.logit_softcap > 0.0f) { + logits = ggml_scale(ctx, logits, 1.0f / target.logit_softcap); + logits = ggml_tanh(ctx, logits); + logits = ggml_scale(ctx, logits, target.logit_softcap); + ggml_set_name(logits, "mtp_logits_softcapped"); + } + + // ── 7. In-graph argmax ───────────────────────────────────────────────────── + ggml_tensor * argmax = ggml_argmax(ctx, logits); + ggml_set_name(argmax, "mtp_argmax"); + + // Expand all outputs into the graph + ggml_build_forward_expand(gf, argmax); + ggml_build_forward_expand(gf, h_post); + // Note: logits is already in argmax's DAG, but mark it as output for diagnostic reads. + ggml_set_output(logits); + ggml_set_output(h_post); + ggml_set_output(argmax); + + // ── Populate output struct ──────────────────────────────────────────────── + out.ctx = ctx; + out.gf = gf; + out.in_tok = in_tok; + out.in_tok_embd = in_tok_embd; + out.in_h_prev = in_h_prev; + out.in_pos = in_pos; + out.fa_mask = shared_fa_mask; + out.fa_mask_kv_seq_len = shared_fa_mask_kv_seq_len; + out.out_logits = logits; + out.out_h_post = h_post; + out.out_argmax = argmax; + + return true; +} + +void free_mtp_step_graph(MtpStepGraph & g) { + if (g.ctx) { + ggml_free(g.ctx); + g.ctx = nullptr; + } + g.gf = nullptr; + g.in_tok = nullptr; + g.in_tok_embd = nullptr; + g.in_h_prev = nullptr; + g.in_pos = nullptr; + g.fa_mask = nullptr; + g.fa_mask_kv_seq_len = 0; + g.out_logits = nullptr; + g.out_h_post = nullptr; + g.out_argmax = nullptr; +} + +} // namespace dflash27b diff --git a/dflash/src/gemma4_target_graph.cpp b/dflash/src/gemma4_target_graph.cpp new file mode 100644 index 000000000..5f1e42be1 --- /dev/null +++ b/dflash/src/gemma4_target_graph.cpp @@ -0,0 +1,1189 @@ +// Forward pass of Gemma4 (pure attention) in pure ggml. +// +// Supports both Gemma4-31B (dense, 60 layers) and Gemma4-26B-A4B (MoE, 30 layers). +// All model dimensions are read from GGUF at load time via GemmaTargetWeights. +// No llama.cpp runtime is linked — only ggml ops. +// +// Architecture highlights: +// - ALL layers are attention (no DeltaNet/SSM) — simpler than Qwen3.5 hybrid +// - Two layer types interleaved per swa_layers[]: +// SWA (sliding window): standard RoPE (rope_theta_swa), windowed FA +// Full (global): proportional RoPE via per-layer rope_freqs, full FA +// - Attention scale = 1.0 (self.scaling = 1.0, not 1/sqrt(head_dim)) +// - Logit softcapping: output = softcap * tanh(output / softcap), softcap=30 +// - Per-Layer Embeddings (PLE): gated embedding added to residual each layer +// - Shared KV cache: some layers reuse an earlier layer's KV slot +// - MoE FFN (26B-A4B): shared_expert + routed experts (top-K) +// +// State (persisted in GemmaTargetCache across calls): +// - attn_k, attn_v : KV cache for non-shared KV layers +// - layer_to_kv_idx : maps layer index -> KV slot index (-1 = shared) +// - layer_to_donor_kv: maps layer index -> donor slot for shared layers + +#include "internal.h" +#include "kv_quant.h" + +#include +#include +#include +#include +#include + +namespace dflash27b { + +// ─── File-local constants ──────────────────────────────────────────────────── + +static constexpr float EPS = GEMMA4_RMS_EPS; + +// ─── Helpers ───────────────────────────────────────────────────────────────── + +static ggml_tensor * rms_norm_mul(ggml_context * ctx, ggml_tensor * x, + ggml_tensor * weight, float eps) { + ggml_tensor * n = ggml_rms_norm(ctx, x, eps); + return ggml_mul(ctx, n, weight); +} + +// GeGLU FFN: w_down @ (gelu(w_gate @ x) * (w_up @ x)) +static ggml_tensor * build_geglu_ffn(ggml_context * ctx, + ggml_tensor * cur, + const GemmaTargetLayer & L) { + ggml_tensor * gate = ggml_mul_mat(ctx, L.w_gate, cur); + ggml_tensor * up = ggml_mul_mat(ctx, L.w_up, cur); + ggml_tensor * gu = ggml_geglu_split(ctx, gate, up); + return ggml_mul_mat(ctx, L.w_down, gu); +} + +// MoE FFN — shared expert + softmax-gated routed experts. +// Matches Gemma4-26B-A4B architecture: +// shared_out = w_down @ (gelu(w_gate @ x) * (w_up @ x)) +// shared_out = rms_norm(shared_out) * ffn_post_norm_1 +// router_in = rms_norm(inpSA) / sqrt(n_embd) * ffn_gate_inp_s (bare rms_norm) +// logits = ffn_gate_inp @ router_in [n_expert, n_tokens] +// probs = softmax(logits) +// top_ids = argsort_top_k(probs, n_expert_used) [n_expert_used, n_tokens] i32 +// weights = get_rows(probs, top_ids) [1, n_expert_used, n_tokens] +// weights = weights / sum(weights) (normalize to 1.0) +// gate_up_out = mul_mat_id(ffn_gate_up_exps, x, top_ids) → gelu+mul → weighted +// expert_out = mul_mat_id(ffn_down_exps, act, top_ids) [n_embd, n_expert_used, n_tokens] +// expert_out = sum over expert dim [n_embd, n_tokens] +// expert_out = rms_norm(expert_out) * ffn_post_norm_2 +// result = shared_out + expert_out +static ggml_tensor * build_moe_ffn(ggml_context * ctx, + ggml_cgraph * gf, + const GemmaTargetWeights & w, + const GemmaTargetLayer & L, + ggml_tensor * cur_shared_ffn, + ggml_tensor * cur_moe_ffn, + ggml_tensor * cur_for_router, + int n_tokens) { + const int n_embd = w.n_embd; + const int n_expert_used = w.n_expert_used; + const int n_expert = w.n_expert; + const int n_ff_exp = w.n_ff_exp; + + // ── Shared expert (always active) ────────────────────────────────────────── + ggml_tensor * shared_out = nullptr; + if (L.w_gate && L.w_up && L.w_down) { + ggml_tensor * sg = ggml_mul_mat(ctx, L.w_gate, cur_shared_ffn); + ggml_tensor * su = ggml_mul_mat(ctx, L.w_up, cur_shared_ffn); + ggml_tensor * sgu = ggml_geglu_split(ctx, sg, su); + shared_out = ggml_mul_mat(ctx, L.w_down, sgu); + if (L.ffn_post_norm_1) { + shared_out = rms_norm_mul(ctx, shared_out, L.ffn_post_norm_1, EPS); + } + } + + // ── Router ───────────────────────────────────────────────────────────────── + // router_in = rms_norm(inpSA) / sqrt(n_embd) * ffn_gate_inp_s (bare rms_norm, no weight) + ggml_tensor * router_in = ggml_rms_norm(ctx, cur_for_router, EPS); + router_in = ggml_scale(ctx, router_in, 1.0f / std::sqrt((float)n_embd)); + if (L.ffn_gate_inp_s) { + router_in = ggml_mul(ctx, router_in, L.ffn_gate_inp_s); + } + // logits: [n_expert, n_tokens] + ggml_tensor * logits = ggml_mul_mat(ctx, L.ffn_gate_inp, router_in); + + // Softmax gating + ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] + + // Top-K selection — returns i32 index tensor [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); + + // Routing weights: gather probs at selected indices [1, n_expert_used, n_tokens] + ggml_tensor * probs_3d = ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens); + ggml_tensor * weights = ggml_get_rows(ctx, probs_3d, selected_experts); + // weights: [1, n_expert_used, n_tokens] → normalize to sum=1.0 + { + ggml_tensor * w2d = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); + ggml_tensor * wsum = ggml_sum_rows(ctx, w2d); + wsum = ggml_clamp(ctx, wsum, 6.103515625e-5f, INFINITY); + w2d = ggml_div(ctx, w2d, wsum); + weights = ggml_reshape_3d(ctx, w2d, 1, n_expert_used, n_tokens); + } + + // ── Routed experts via ggml_mul_mat_id ───────────────────────────────────── + ggml_tensor * expert_out = nullptr; + if (L.ffn_gate_up_exps && L.ffn_down_exps) { + // cur_moe_ffn is [n_embd, n_tokens]; mul_mat_id expects [n_embd, 1, n_tokens] + ggml_tensor * x = ggml_reshape_3d(ctx, cur_moe_ffn, n_embd, 1, n_tokens); + + // Gate+up projection: ffn_gate_up_exps [2*n_ff_exp, n_embd, n_expert] + // Result: [2*n_ff_exp, n_expert_used, n_tokens] + ggml_tensor * gate_up = ggml_mul_mat_id(ctx, L.ffn_gate_up_exps, + x, selected_experts); + + const size_t elt = ggml_element_size(gate_up); + // gate half: first n_ff_exp rows + ggml_tensor * g_half = ggml_view_3d(ctx, gate_up, + n_ff_exp, n_expert_used, n_tokens, + (size_t)n_ff_exp * 2 * elt, + (size_t)n_ff_exp * 2 * n_expert_used * elt, + 0); + // up half: second n_ff_exp rows + ggml_tensor * u_half = ggml_view_3d(ctx, gate_up, + n_ff_exp, n_expert_used, n_tokens, + (size_t)n_ff_exp * 2 * elt, + (size_t)n_ff_exp * 2 * n_expert_used * elt, + (size_t)n_ff_exp * elt); + + // GeGLU activation (views are non-contiguous; ggml_gelu requires contiguous) + g_half = ggml_cont(ctx, g_half); + u_half = ggml_cont(ctx, u_half); + ggml_tensor * activated = ggml_mul(ctx, ggml_gelu(ctx, g_half), u_half); + + // Scale by routing weights [1, n_expert_used, n_tokens] + activated = ggml_mul(ctx, activated, weights); + + // Down projection: ffn_down_exps [n_embd, n_ff_exp, n_expert] + // activated: [n_ff_exp, n_expert_used, n_tokens] + ggml_tensor * down_out = ggml_mul_mat_id(ctx, L.ffn_down_exps, + activated, selected_experts); + // down_out: [n_embd, n_expert_used, n_tokens] + + // Optional down-projection scale (ffn_down_exps_s is a per-column scale) + if (L.ffn_down_exps_s) { + down_out = ggml_mul(ctx, down_out, L.ffn_down_exps_s); + } + + // Sum over n_expert_used to get [n_embd, n_tokens]. + // down_out: [n_embd, n_expert_used, n_tokens] + // Use the proven llama.cpp pattern: ggml_build_forward_expand the full + // tensor then sum slice views with ggml_add in a loop over n_expert_used. + ggml_build_forward_expand(gf, down_out); + expert_out = ggml_view_2d(ctx, down_out, + n_embd, n_tokens, + down_out->nb[2], + 0); + ggml_build_forward_expand(gf, expert_out); + for (int ei = 1; ei < n_expert_used; ++ei) { + ggml_tensor * slice = ggml_view_2d(ctx, down_out, + n_embd, n_tokens, + down_out->nb[2], + (size_t)ei * down_out->nb[1]); + ggml_build_forward_expand(gf, slice); + expert_out = ggml_add(ctx, expert_out, slice); + ggml_build_forward_expand(gf, expert_out); + } + + if (L.ffn_post_norm_2) { + expert_out = rms_norm_mul(ctx, expert_out, L.ffn_post_norm_2, EPS); + } + } + + // ── Combine shared + routed experts ──────────────────────────────────────── + if (shared_out && expert_out) { + return ggml_add(ctx, shared_out, expert_out); + } else if (shared_out) { + return shared_out; + } else if (expert_out) { + return expert_out; + } + // Fallback: should not happen with a correctly loaded MoE model + return cur_shared_ffn; +} + +// ─── SWA view geometry helper ──────────────────────────────────────────────── +// +// Compute the (abs_win_start, effective_win_len, ring_win_start) triple for a +// chunk at position kv_start with n_tokens query tokens, given swa_window and +// the ring-buffer size (swa_ctx_alloc). This is the single source of truth for +// the K/V view passed to FA and for the host-side causal mask. +SwaView compute_swa_view(int kv_start, int n_tokens, + int swa_window, int swa_ctx_alloc) +{ + SwaView v; + v.abs_win_start = (swa_window > 0 && kv_start > swa_window) + ? (kv_start - swa_window) : 0; + // K view is ALWAYS the full ring; the host-built mask handles the + // non-monotonic ring layout via abs_pos(slot) computation. + v.effective_win_len = swa_ctx_alloc; + v.ring_win_start = 0; + return v; +} + +// Sliding-Window Attention block. +// Uses standard RoPE (rope_theta_swa) and a windowed view of the KV cache. +static ggml_tensor * build_swa_attn_block( + ggml_context * ctx, + ggml_cgraph * gf, + const GemmaTargetWeights & w, + const GemmaTargetLayer & L, + ggml_tensor * cur, + ggml_tensor * positions, + ggml_tensor * cache_k, + ggml_tensor * cache_v, + ggml_tensor * attn_mask, + int kv_start, + int n_tokens, + ggml_type kv_k_type, + ggml_type kv_v_type, + bool write_kv, + int il) +{ + // SWA layers use the SWA head_dim (may be smaller than full-attn head_dim) + const int head_dim = w.head_dim_swa; + const int n_head = w.n_head; + const int n_head_kv = (il >= 0 && il < (int)w.head_kv_per_layer.size()) + ? w.head_kv_per_layer[il] : w.n_head_kv; + const int q_dim = n_head * head_dim; + + // Q projection + ggml_tensor * Qcur = ggml_mul_mat(ctx, L.wq, cur); + Qcur = ggml_reshape_3d(ctx, Qcur, head_dim, n_head, n_tokens); + Qcur = rms_norm_mul(ctx, Qcur, L.q_norm, EPS); + + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + if (write_kv) { + Kcur = ggml_mul_mat(ctx, L.wk, cur); + Kcur = ggml_reshape_3d(ctx, Kcur, head_dim, n_head_kv, n_tokens); + + Vcur = ggml_mul_mat(ctx, L.wv, cur); + Vcur = ggml_reshape_3d(ctx, Vcur, head_dim, n_head_kv, n_tokens); + + if (L.k_norm) { + Kcur = rms_norm_mul(ctx, Kcur, L.k_norm, EPS); + } + Vcur = ggml_rms_norm(ctx, Vcur, EPS); + } + + // Standard RoPE (SWA uses rope_theta_swa, no freq_factors) + Qcur = ggml_rope_ext(ctx, Qcur, positions, /*freq_factors=*/nullptr, + head_dim, GGML_ROPE_TYPE_NEOX, /*n_ctx_orig=*/0, + w.rope_theta_swa, /*freq_scale=*/1.0f, + /*ext_factor=*/0.0f, /*attn_factor=*/1.0f, + /*beta_fast=*/0.0f, /*beta_slow=*/0.0f); + if (Kcur) { + Kcur = ggml_rope_ext(ctx, Kcur, positions, nullptr, + head_dim, GGML_ROPE_TYPE_NEOX, 0, + w.rope_theta_swa, 1.0f, + 0.0f, 1.0f, 0.0f, 0.0f); + } + + // SWA ring-buffer: derive the ring size from the tensor's actual slot count. + // When swa_ctx_alloc < max_ctx (long contexts), writes use kv_start % ring_size + // so the tensor is never exceeded. + const int ring_size = cache_k ? (int)cache_k->ne[1] : (kv_start + n_tokens); + + // Write K/V into cache using ring-buffer position. + // Split-on-wrap: when write_pos + n_tokens > ring_size the chunk straddles + // the ring boundary, so we issue two ggml_cpy ops (pre-wrap and post-wrap). + if (write_kv && cache_k && cache_v && Kcur && Vcur) { + ggml_tensor * Kcur_T = ggml_permute(ctx, Kcur, 0, 2, 1, 3); + ggml_tensor * Vcur_T = ggml_permute(ctx, Vcur, 0, 2, 1, 3); + + const int write_pos = kv_start % ring_size; + const int pre_n = std::min(n_tokens, ring_size - write_pos); + const int post_n = n_tokens - pre_n; + + // First slice: [write_pos .. write_pos+pre_n) + { + ggml_tensor * k_slot = ggml_view_3d(ctx, cache_k, + head_dim, pre_n, n_head_kv, + cache_k->nb[1], cache_k->nb[2], + cache_k->nb[1] * write_pos); + ggml_tensor * v_slot = ggml_view_3d(ctx, cache_v, + head_dim, pre_n, n_head_kv, + cache_v->nb[1], cache_v->nb[2], + cache_v->nb[1] * write_pos); + ggml_tensor * k_src = ggml_view_3d(ctx, Kcur_T, + head_dim, pre_n, n_head_kv, + Kcur_T->nb[1], Kcur_T->nb[2], 0); + ggml_tensor * v_src = ggml_view_3d(ctx, Vcur_T, + head_dim, pre_n, n_head_kv, + Vcur_T->nb[1], Vcur_T->nb[2], 0); + ggml_build_forward_expand(gf, ggml_cpy(ctx, k_src, k_slot)); + ggml_build_forward_expand(gf, ggml_cpy(ctx, v_src, v_slot)); + } + + // Second slice (wrap-around): [0 .. post_n) + if (post_n > 0) { + ggml_tensor * k_slot = ggml_view_3d(ctx, cache_k, + head_dim, post_n, n_head_kv, + cache_k->nb[1], cache_k->nb[2], + 0); + ggml_tensor * v_slot = ggml_view_3d(ctx, cache_v, + head_dim, post_n, n_head_kv, + cache_v->nb[1], cache_v->nb[2], + 0); + ggml_tensor * k_src = ggml_view_3d(ctx, Kcur_T, + head_dim, post_n, n_head_kv, + Kcur_T->nb[1], Kcur_T->nb[2], + Kcur_T->nb[1] * pre_n); + ggml_tensor * v_src = ggml_view_3d(ctx, Vcur_T, + head_dim, post_n, n_head_kv, + Vcur_T->nb[1], Vcur_T->nb[2], + Vcur_T->nb[1] * pre_n); + ggml_build_forward_expand(gf, ggml_cpy(ctx, k_src, k_slot)); + ggml_build_forward_expand(gf, ggml_cpy(ctx, v_src, v_slot)); + } + } + + // Determine window for SWA reads using the shared geometry helper. + // ring_win_start is always 0 (full-ring read); correctness comes from the + // host-built mask which uses abs_pos(slot) arithmetic for ring geometry. + const SwaView swa_view = compute_swa_view(kv_start, n_tokens, + w.swa_window, ring_size); + const int effective_win_len = swa_view.effective_win_len; + const int ring_win_start = swa_view.ring_win_start; // always 0 + + // swa_ctx_alloc is already aligned to fattn_stride (set in create_gemma4_cache), + // so win_len_padded == effective_win_len == ring_size. No further snap needed. + const bool need_256_pad = (kv_k_type == GGML_TYPE_TQ3_0 || kv_v_type == GGML_TYPE_TQ3_0 + || head_dim >= 512); + const int fattn_stride = need_256_pad ? 256 : 1; + const int win_len_padded = ((effective_win_len + fattn_stride - 1) / fattn_stride) * fattn_stride; + + const bool q_rotate = (kv_k_type == GGML_TYPE_TQ3_0); + const bool out_rotate = (kv_v_type == GGML_TYPE_TQ3_0); + // TQ3 contract: caller pre-rotates Q forward, post-rotates FA output + // backward. ggml_turbo_wht's kernel writes dst using src strides + // (turbo-wht.cu:20-21), so non-contiguous input scatters writes and + // corrupts the result — wrap with ggml_cont before rotating, never after + // permute alone. (Regression-fix vs. c15f93a; see EVIDENCE.md TQ3 thread.) + ggml_tensor * Qfa = ggml_cont(ctx, ggml_permute(ctx, Qcur, 0, 2, 1, 3)); + if (q_rotate) { + Qfa = ggml_turbo_wht(ctx, Qfa, 0); + } + + ggml_tensor * Kfa = ggml_view_3d(ctx, cache_k, + head_dim, win_len_padded, n_head_kv, + cache_k->nb[1], cache_k->nb[2], + cache_k->nb[1] * ring_win_start); + ggml_tensor * Vfa = ggml_view_3d(ctx, cache_v, + head_dim, win_len_padded, n_head_kv, + cache_v->nb[1], cache_v->nb[2], + cache_v->nb[1] * ring_win_start); + // Gemma4: attn_scale = 1.0 (self.scaling = 1.0, no 1/sqrt(head_dim)) + ggml_tensor * attn = ggml_flash_attn_ext(ctx, Qfa, Kfa, Vfa, attn_mask, + 1.0f, 0.0f, 0.0f); + + if (out_rotate) { + attn = ggml_cont(ctx, attn); + attn = ggml_turbo_wht(ctx, attn, 1); + } + + attn = ggml_reshape_2d(ctx, attn, q_dim, n_tokens); + attn = ggml_mul_mat(ctx, L.wo, attn); + return attn; +} + +// Full (Global) Attention block. +// Uses proportional RoPE via per-layer rope_freqs (freq_factors) and full context. +// When use_pflash is true, uses ggml_flash_attn_sparse (block-sparse) instead of +// ggml_flash_attn_ext for the attention computation. +static ggml_tensor * build_full_attn_block( + ggml_context * ctx, + ggml_cgraph * gf, + const GemmaTargetWeights & w, + const GemmaTargetLayer & L, + ggml_tensor * cur, + ggml_tensor * positions, + ggml_tensor * cache_k, + ggml_tensor * cache_v, + ggml_tensor * attn_mask, + int kv_start, + int n_tokens, + ggml_type kv_k_type, + ggml_type kv_v_type, + bool write_kv, + int fa_window, + int il, + bool use_pflash, + float pflash_alpha) +{ + // Full-attention layers use the full head_dim + const int head_dim = w.head_dim; + const int n_head = w.n_head; + const int n_head_kv = (il >= 0 && il < (int)w.head_kv_per_layer.size()) + ? w.head_kv_per_layer[il] : w.n_head_kv; + const int q_dim = n_head * head_dim; + + // Q projection + ggml_tensor * Qcur = ggml_mul_mat(ctx, L.wq, cur); + Qcur = ggml_reshape_3d(ctx, Qcur, head_dim, n_head, n_tokens); + Qcur = rms_norm_mul(ctx, Qcur, L.q_norm, EPS); + + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + if (write_kv) { + Kcur = ggml_mul_mat(ctx, L.wk, cur); + Kcur = ggml_reshape_3d(ctx, Kcur, head_dim, n_head_kv, n_tokens); + + // V = K (pre-norm) when wv absent, else separate projection + if (L.wv == L.wk) { + Vcur = Kcur; + } else { + Vcur = ggml_mul_mat(ctx, L.wv, cur); + Vcur = ggml_reshape_3d(ctx, Vcur, head_dim, n_head_kv, n_tokens); + } + + // K gets weighted RMSNorm, V gets bare RMSNorm (no learned weights) + if (L.k_norm) { + Kcur = rms_norm_mul(ctx, Kcur, L.k_norm, EPS); + } + Vcur = ggml_rms_norm(ctx, Vcur, EPS); + } + + // Proportional RoPE for full-attention layers (uses per-layer rope_freqs) + Qcur = ggml_rope_ext(ctx, Qcur, positions, L.rope_freqs, + head_dim, GGML_ROPE_TYPE_NEOX, /*n_ctx_orig=*/0, + w.rope_theta, /*freq_scale=*/1.0f, + /*ext_factor=*/0.0f, /*attn_factor=*/1.0f, + /*beta_fast=*/0.0f, /*beta_slow=*/0.0f); + if (Kcur) { + Kcur = ggml_rope_ext(ctx, Kcur, positions, L.rope_freqs, + head_dim, GGML_ROPE_TYPE_NEOX, 0, + w.rope_theta, 1.0f, + 0.0f, 1.0f, 0.0f, 0.0f); + } + + // Write K/V into cache + if (write_kv && cache_k && cache_v && Kcur && Vcur) { + ggml_tensor * Kcur_T = ggml_permute(ctx, Kcur, 0, 2, 1, 3); + ggml_tensor * Vcur_T = ggml_permute(ctx, Vcur, 0, 2, 1, 3); + + ggml_tensor * k_slot = ggml_view_3d(ctx, cache_k, + head_dim, n_tokens, n_head_kv, + cache_k->nb[1], cache_k->nb[2], + cache_k->nb[1] * kv_start); + ggml_tensor * v_slot = ggml_view_3d(ctx, cache_v, + head_dim, n_tokens, n_head_kv, + cache_v->nb[1], cache_v->nb[2], + cache_v->nb[1] * kv_start); + ggml_build_forward_expand(gf, ggml_cpy(ctx, Kcur_T, k_slot)); + ggml_build_forward_expand(gf, ggml_cpy(ctx, Vcur_T, v_slot)); + } + + // For full-attention layers: optional windowed FA for long-context efficiency + const int win_start = (fa_window > 0 && kv_start > fa_window) + ? (kv_start - fa_window) : 0; + const int kv_len = kv_start + n_tokens; + const int win_len = kv_len - win_start; + + const bool need_256_pad = (kv_k_type == GGML_TYPE_TQ3_0 || kv_v_type == GGML_TYPE_TQ3_0 + || head_dim >= 512); + const int fattn_stride = need_256_pad ? 256 : 1; + const int win_len_padded = ((win_len + fattn_stride - 1) / fattn_stride) * fattn_stride; + + const bool q_rotate = (kv_k_type == GGML_TYPE_TQ3_0); + const bool out_rotate = (kv_v_type == GGML_TYPE_TQ3_0); + // See SWA block above for the contiguity rationale (turbo-wht.cu:20-21). + ggml_tensor * Qfa = ggml_cont(ctx, ggml_permute(ctx, Qcur, 0, 2, 1, 3)); + if (q_rotate) { + Qfa = ggml_turbo_wht(ctx, Qfa, 0); + } + + ggml_tensor * Kfa = ggml_view_3d(ctx, cache_k, + head_dim, win_len_padded, n_head_kv, + cache_k->nb[1], cache_k->nb[2], + cache_k->nb[1] * win_start); + ggml_tensor * Vfa = ggml_view_3d(ctx, cache_v, + head_dim, win_len_padded, n_head_kv, + cache_v->nb[1], cache_v->nb[2], + cache_v->nb[1] * win_start); + + // pFlash sparse path supports F16, Q8_0, Q4_0, and (gated) TQ3_0 K/V. + // The CUDA dispatch in fattn-sparse.cu:170-197 dequantizes to F16 before + // the S<->H BF16 transpose. For TQ3_0 it routes through cpy_tq3_0_f16_kernel + // (cpy.cu:429-475) which is compressed-domain — exactly what the graph-level + // WHT contract requires (Q is pre-rotated above, O is inverse-rotated below). + // + // The TQ3 path is gated behind DFLASH_PFLASH_TQ3=1 for A/B rollout. Without + // the gate, TQ3 falls back to the dense FA chunked SGEMM driver (which works + // but is ~2.3x slower than BF16 MMA pflash on Dense 31B prefill). + static const bool s_pflash_tq3 = []() { + const char * s = std::getenv("DFLASH_PFLASH_TQ3"); + return s && (s[0] == '1' || s[0] == 't' || s[0] == 'T' || s[0] == 'y' || s[0] == 'Y'); + }(); + auto pflash_supports = [](enum ggml_type t) { + if (t == GGML_TYPE_F16 || t == GGML_TYPE_Q8_0 || t == GGML_TYPE_Q4_0) return true; + if (t == GGML_TYPE_TQ3_0 && s_pflash_tq3) return true; + return false; + }; + const bool can_pflash = use_pflash && + pflash_supports(Kfa->type) && + pflash_supports(Vfa->type); + + // Gemma4: attn_scale = 1.0 (self.scaling = 1.0, no 1/sqrt(head_dim)) + ggml_tensor * attn; + if (can_pflash) { + attn = ggml_flash_attn_sparse(ctx, Qfa, Kfa, Vfa, 1.0f, pflash_alpha); + } else { + attn = ggml_flash_attn_ext(ctx, Qfa, Kfa, Vfa, attn_mask, 1.0f, 0.0f, 0.0f); + } + + if (out_rotate) { + attn = ggml_cont(ctx, attn); + attn = ggml_turbo_wht(ctx, attn, 1); + } + + attn = ggml_reshape_2d(ctx, attn, q_dim, n_tokens); + attn = ggml_mul_mat(ctx, L.wo, attn); + return attn; +} + +// ─── GemmaTargetCache allocation ───────────────────────────────────────────── + +bool create_gemma4_cache(const GemmaTargetWeights & w, + int max_ctx, + ggml_backend_t backend, + GemmaTargetCache & out, + const std::vector & extra_q8_layers, + int target_feat_cap_hint, + bool enable_dflash_capture_overrides) { + out.backend = backend; + out.max_ctx = max_ctx; + out.cur_pos = 0; + + // Resolve KV types from environment + ggml_type kv_k_type = GGML_TYPE_Q8_0; + ggml_type kv_v_type = GGML_TYPE_Q8_0; + dflash::resolve_kv_types(kv_k_type, kv_v_type); + out.kv_k_type = kv_k_type; + out.kv_v_type = kv_v_type; + + // TQ3_0 and head_dim>=512 (CUDA FA FATTN_KQ_STRIDE) require 256-alignment + const bool need_256_align = (kv_k_type == GGML_TYPE_TQ3_0 || kv_v_type == GGML_TYPE_TQ3_0 + || w.head_dim >= 512); + const int align_stride = need_256_align ? 256 : 1; + const int max_ctx_alloc = need_256_align + ? ((max_ctx + 255) / 256) * 256 + : max_ctx; + + // SWA layers only need swa_window slots (ring-buffer). Allocate + // min(max_ctx_alloc, swa_window_padded) for SWA layers, saving ~50% VRAM + // at long contexts. swa_ctx_alloc must be strictly > swa_window so the + // decode window (win_len = swa_window + n_tokens) fits within one view. + // We pad swa_window to the same alignment stride and add one alignment + // block as headroom so contiguous views always work for n_tokens=1 decode. + const int swa_window_padded = (w.swa_window > 0) + ? ((w.swa_window + align_stride - 1) / align_stride) * align_stride + : max_ctx_alloc; + // Ring sized to hold last R = 2*swa_window keys (= 2 chunks worth, since + // chunk_size <= swa_window). Combined with a non-monotonic mask in the + // test driver's build_swa_causal_mask, this lets the K view be the full + // ring while correctness comes from the mask filtering by abs_pos. + const int swa_ring_target = 2 * swa_window_padded; + const int swa_ctx_alloc = (w.swa_window > 0) + ? std::min(max_ctx_alloc, swa_ring_target) + : max_ctx_alloc; + out.swa_ctx_alloc = swa_ctx_alloc; + + // Build layer -> KV index mappings. + // Gemma4 can share KV caches across layers. The weight loader sets wk=nullptr + // for shared layers. We detect this and point them at the most recent + // non-shared layer's KV slot. + out.layer_to_kv_idx.assign(w.n_layer, -1); + out.layer_to_donor_kv.assign(w.n_layer, -1); + + int n_kv_slots = 0; + for (int il = 0; il < w.n_layer; il++) { + if (w.layers[il].wk != nullptr) { + out.layer_to_kv_idx[il] = n_kv_slots++; + } + } + + // For shared layers, find the most recent layer that owns a KV slot + int last_kv_slot = -1; + for (int il = 0; il < w.n_layer; il++) { + if (out.layer_to_kv_idx[il] >= 0) { + last_kv_slot = out.layer_to_kv_idx[il]; + } else { + out.layer_to_donor_kv[il] = last_kv_slot; + } + } + + if (n_kv_slots == 0) { + set_last_error("create_gemma4_cache: no KV-owning layers found"); + return false; + } + + // Per-layer KV types. + // + // The upstream FA dispatch (deps/llama.cpp/.../fattn.cu:441) routes + // TQ3 + (Q->ne[0] > 256 || Q->ne[1] > 1) to the slow CHUNKED kernel. + // On Dense Gemma4 31B with full-attn head_dim=512, every chunked + // prefill / draft-verify hits this trap. + // + // Narrow workaround (Codex pattern, mirrors vLLM's kv-cache-dtype-skip-layers): + // when the DFlash draft is wired up, force Q8_0 KV on the small subset of + // full-attn layers whose hidden states are CAPTURED for the draft (the + // "target_feat" ring at gemma4_target_graph.cpp:971 — drafter consumes + // these in build_gemma4_draft_graph). This unblocks the pflash sparse + // fast path for the layers the draft actually depends on, without + // touching the other 8/10 full-attn layers (avoids the MoE regression + // we saw when forcing ALL full-attn -> Q8). + out.kv_k_type_per_layer.assign(w.n_layer, kv_k_type); + out.kv_v_type_per_layer.assign(w.n_layer, kv_v_type); + + const bool gate = (kv_k_type == GGML_TYPE_TQ3_0 || kv_v_type == GGML_TYPE_TQ3_0) + && (w.head_dim > 256) + && enable_dflash_capture_overrides + && (w.n_capture_layers > 0); + + if (gate) { + int n_overridden = 0; + for (int ci = 0; ci < w.n_capture_layers; ci++) { + const int captured_il = w.capture_layer_ids[ci]; + if (captured_il < 0 || captured_il >= w.n_layer) continue; + const bool is_swa = (captured_il < (int)w.swa_layers.size()) + && w.swa_layers[captured_il]; + if (is_swa) continue; // SWA layers don't hit the trap + if (kv_k_type == GGML_TYPE_TQ3_0) { + out.kv_k_type_per_layer[captured_il] = GGML_TYPE_Q8_0; + } + if (kv_v_type == GGML_TYPE_TQ3_0) { + out.kv_v_type_per_layer[captured_il] = GGML_TYPE_Q8_0; + } + n_overridden++; + } + // Count total full-attn layers for the log message + int n_full_attn = 0; + for (int il = 0; il < w.n_layer; il++) { + const bool is_swa = (il < (int)w.swa_layers.size()) && w.swa_layers[il]; + if (!is_swa && out.layer_to_kv_idx[il] >= 0) n_full_attn++; + } + std::fprintf(stderr, + "[cache] narrow asymmetric: forced Q8_0 on %d captured full-attn layer(s) " + "(remaining %d full-attn keep TQ3)\n", + n_overridden, n_full_attn - n_overridden); + } + + // Extra override: force Q8_0 on caller-specified layer indices (e.g. MTP donor layers). + // These layers must NOT use TQ3_0 because MTP cross-attention reads them via ggml_cast + // (no FWHT inverse applied), so TQ3_0 FWHT-domain values would corrupt attention scores. + if (!extra_q8_layers.empty() && + (kv_k_type == GGML_TYPE_TQ3_0 || kv_v_type == GGML_TYPE_TQ3_0)) { + int n_mtp_overridden = 0; + for (int il : extra_q8_layers) { + if (il < 0 || il >= w.n_layer) continue; + if (kv_k_type == GGML_TYPE_TQ3_0) out.kv_k_type_per_layer[il] = GGML_TYPE_Q8_0; + if (kv_v_type == GGML_TYPE_TQ3_0) out.kv_v_type_per_layer[il] = GGML_TYPE_Q8_0; + n_mtp_overridden++; + } + if (n_mtp_overridden > 0) { + std::fprintf(stderr, + "[cache] MTP donor override: forced Q8_0 on %d layer(s) to avoid TQ3/FWHT cross-attn mismatch\n", + n_mtp_overridden); + } + } + + // (head_dim and n_head_kv are resolved per-layer in the allocation loop below) + + const int n_capture_layers = w.n_capture_layers; + const int n_embd = w.n_embd; + + // Tensor count: 2 (K+V) per KV slot + 1 target_feat + const int n_tensors = 2 * n_kv_slots + 1; + ggml_init_params ip{}; + ip.mem_size = (size_t)(n_tensors + 16) * ggml_tensor_overhead(); + ip.mem_buffer = nullptr; + ip.no_alloc = true; + out.base_ctx = ggml_init(ip); + if (!out.base_ctx) { + set_last_error("create_gemma4_cache: ggml_init failed"); + return false; + } + + out.attn_k.assign(n_kv_slots, nullptr); + out.attn_v.assign(n_kv_slots, nullptr); + + // Create KV tensors — iterate layers to preserve name <-> layer correlation. + // Each layer's KV slot uses the head_dim and n_head_kv appropriate to its + // attention type (SWA vs full-attention may have different dimensions). + for (int il = 0; il < w.n_layer; il++) { + const int kv_idx = out.layer_to_kv_idx[il]; + if (kv_idx < 0) continue; + + const bool is_swa_layer = (il < (int)w.swa_layers.size()) && w.swa_layers[il]; + const int layer_head_dim = is_swa_layer ? w.head_dim_swa : w.head_dim; + const int layer_n_head_kv = (il < (int)w.head_kv_per_layer.size()) + ? w.head_kv_per_layer[il] : w.n_head_kv; + + // SWA layers use a ring buffer of swa_ctx_alloc slots; full-attn layers + // need the full max_ctx_alloc to cover the entire context. + const int layer_ctx_alloc = is_swa_layer ? swa_ctx_alloc : max_ctx_alloc; + + const ggml_type layer_kv_k_type = out.kv_k_type_per_layer[il]; + const ggml_type layer_kv_v_type = out.kv_v_type_per_layer[il]; + ggml_tensor * K = ggml_new_tensor_3d(out.base_ctx, layer_kv_k_type, + layer_head_dim, layer_ctx_alloc, layer_n_head_kv); + ggml_tensor * V = ggml_new_tensor_3d(out.base_ctx, layer_kv_v_type, + layer_head_dim, layer_ctx_alloc, layer_n_head_kv); + char name[64]; + std::snprintf(name, sizeof(name), "gemma4_cache_k_%d", il); + ggml_set_name(K, name); + std::snprintf(name, sizeof(name), "gemma4_cache_v_%d", il); + ggml_set_name(V, name); + out.attn_k[kv_idx] = K; + out.attn_v[kv_idx] = V; + } + + // target_feat ring buffer: [n_capture_layers * n_embd, cap] bf16 + constexpr int TARGET_FEAT_CAP_DEFAULT = 4096; + const int target_feat_cap_req = std::max(TARGET_FEAT_CAP_DEFAULT, target_feat_cap_hint); + out.target_feat_cap = std::min(max_ctx, target_feat_cap_req); + { + const int fc_in = n_capture_layers * n_embd; + out.target_feat = ggml_new_tensor_2d(out.base_ctx, GGML_TYPE_BF16, + fc_in, out.target_feat_cap); + ggml_set_name(out.target_feat, "gemma4_target_feat"); + } + + out.base_buf = ggml_backend_alloc_ctx_tensors(out.base_ctx, backend); + if (!out.base_buf) { + set_last_error("create_gemma4_cache: ggml_backend_alloc_ctx_tensors failed"); + ggml_free(out.base_ctx); + out.base_ctx = nullptr; + return false; + } + + // Count full-attn vs SWA KV-owning layers for VRAM savings log. + int n_full_kv = 0, n_swa_kv = 0; + for (int il = 0; il < w.n_layer; il++) { + if (out.layer_to_kv_idx[il] < 0) continue; + const bool is_swa = (il < (int)w.swa_layers.size()) && w.swa_layers[il]; + if (is_swa) n_swa_kv++; else n_full_kv++; + } + const float full_slots = (float)n_full_kv * max_ctx_alloc; + const float swa_slots = (float)n_swa_kv * swa_ctx_alloc; + const float old_slots = (float)(n_full_kv + n_swa_kv) * max_ctx_alloc; + const float saved_pct = old_slots > 0.0f + ? 100.0f * (1.0f - (full_slots + swa_slots) / old_slots) + : 0.0f; + // Find a representative SWA layer index and a representative full-attn layer index + // for the diagnostic log (first of each kind that owns a KV slot). + int repr_swa_il = -1, repr_full_il = -1; + for (int il = 0; il < w.n_layer; il++) { + if (out.layer_to_kv_idx[il] < 0) continue; + const bool is_swa = (il < (int)w.swa_layers.size()) && w.swa_layers[il]; + if (is_swa && repr_swa_il < 0) repr_swa_il = il; + if (!is_swa && repr_full_il < 0) repr_full_il = il; + if (repr_swa_il >= 0 && repr_full_il >= 0) break; + } + const char * swa_k_name = (repr_swa_il >= 0) + ? ggml_type_name(out.kv_k_type_per_layer[repr_swa_il]) : "n/a"; + const char * full_k_name = (repr_full_il >= 0) + ? ggml_type_name(out.kv_k_type_per_layer[repr_full_il]) : "n/a"; + std::fprintf(stderr, + "[cache] created max_ctx=%d (full_attn=%d, swa=%d), kv_layers=%d, saved %.1f%%\n", + max_ctx, max_ctx_alloc, swa_ctx_alloc, n_kv_slots, saved_pct); + std::fprintf(stderr, "[cache] kv types: SWA=%s, full=%s\n", swa_k_name, full_k_name); + + // Zero-initialize all tensors + std::vector zeros(1 * 1024 * 1024, 0); + for (ggml_tensor * t = ggml_get_first_tensor(out.base_ctx); t != nullptr; + t = ggml_get_next_tensor(out.base_ctx, t)) { + size_t nb = ggml_nbytes(t); + size_t off = 0; + while (off < nb) { + size_t chunk = std::min(nb - off, zeros.size()); + ggml_backend_tensor_set(t, zeros.data(), off, chunk); + off += chunk; + } + } + + return true; +} + +void free_gemma4_cache(GemmaTargetCache & c) { + free_draft_kv_cache(c); + if (c.base_buf) { ggml_backend_buffer_free(c.base_buf); c.base_buf = nullptr; } + if (c.base_ctx) { ggml_free(c.base_ctx); c.base_ctx = nullptr; } + c.attn_k.clear(); + c.attn_v.clear(); + c.layer_to_kv_idx.clear(); + c.layer_to_donor_kv.clear(); + c.target_feat = nullptr; + c.cur_pos = 0; + c.last_tok = -1; + c.swa_ctx_alloc = 0; +} + +void reset_gemma4_cache(GemmaTargetCache & c) { + c.cur_pos = 0; + c.last_tok = -1; + c.draft_kv_pos = 0; + std::vector zeros(1 * 1024 * 1024, 0); + if (!c.base_ctx) return; + for (ggml_tensor * t = ggml_get_first_tensor(c.base_ctx); t != nullptr; + t = ggml_get_next_tensor(c.base_ctx, t)) { + size_t nb = ggml_nbytes(t); + size_t off = 0; + while (off < nb) { + size_t chunk = std::min(nb - off, zeros.size()); + ggml_backend_tensor_set(t, zeros.data(), off, chunk); + off += chunk; + } + } +} + +// ─── Draft KV cache allocation ─────────────────────────────────────────────── + +bool create_draft_kv_cache(const GemmaDraftWeights & dw, + ggml_backend_t backend, + GemmaTargetCache & cache, + int cap_override) { + // Capacity: sliding window + one block + headroom + const int default_cap = dw.sliding_window + dw.block_size + 32; + const int draft_kv_cap = cap_override > 0 ? cap_override : default_cap; + if (draft_kv_cap < dw.block_size + 1) { + set_last_error("create_draft_kv_cache: cap_override is smaller than block_size+1"); + return false; + } + + const size_t n_tensors = (size_t)(2 * dw.n_layer); // K + V per layer + ggml_init_params ip{}; + ip.mem_size = ggml_tensor_overhead() * n_tensors + 256; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + cache.draft_kv_ctx = ggml_init(ip); + if (!cache.draft_kv_ctx) { + set_last_error("create_draft_kv_cache: ggml_init failed"); + return false; + } + + cache.draft_k.reserve((size_t)dw.n_layer); + cache.draft_v.reserve((size_t)dw.n_layer); + + for (int il = 0; il < dw.n_layer; il++) { + ggml_tensor * K = ggml_new_tensor_3d(cache.draft_kv_ctx, GGML_TYPE_F32, + dw.head_dim, dw.n_head_kv, draft_kv_cap); + ggml_tensor * V = ggml_new_tensor_3d(cache.draft_kv_ctx, GGML_TYPE_F32, + dw.head_dim, dw.n_head_kv, draft_kv_cap); + char name[64]; + std::snprintf(name, sizeof(name), "draft_k_%d", il); + ggml_set_name(K, name); + std::snprintf(name, sizeof(name), "draft_v_%d", il); + ggml_set_name(V, name); + cache.draft_k.push_back(K); + cache.draft_v.push_back(V); + } + + cache.draft_kv_buf = ggml_backend_alloc_ctx_tensors(cache.draft_kv_ctx, backend); + if (!cache.draft_kv_buf) { + set_last_error("create_draft_kv_cache: ggml_backend_alloc_ctx_tensors failed"); + ggml_free(cache.draft_kv_ctx); + cache.draft_kv_ctx = nullptr; + cache.draft_k.clear(); + cache.draft_v.clear(); + return false; + } + + cache.draft_kv_cap = draft_kv_cap; + cache.draft_kv_pos = 0; + + ggml_backend_buffer_clear(cache.draft_kv_buf, 0); + + return true; +} + +void free_draft_kv_cache(GemmaTargetCache & cache) { + if (cache.draft_kv_buf) { + ggml_backend_buffer_free(cache.draft_kv_buf); + cache.draft_kv_buf = nullptr; + } + if (cache.draft_kv_ctx) { + ggml_free(cache.draft_kv_ctx); + cache.draft_kv_ctx = nullptr; + } + cache.draft_k.clear(); + cache.draft_v.clear(); + cache.draft_kv_cap = 0; + cache.draft_kv_pos = 0; +} + +// ─── Main graph builder ─────────────────────────────────────────────────────── + +GemmaGraphOutputs build_gemma4_graph( + ggml_context * ctx, + ggml_cgraph * gf, + const GemmaTargetWeights & w, + GemmaTargetCache & cache, + const GemmaGraphInputs & in) +{ + const int n_tokens = in.n_tokens; + const int kv_start = in.kv_start; + const int n_embd = w.n_embd; + + // CUDA FA for head_dim>=512 requires a non-null mask to enable the GQA + // optimization path (gqa_opt_applies=true). Auto-create a causal mask + // when the caller did not supply one so that full-attention layers don't + // hit BEST_FATTN_KERNEL_NONE → abort. + ggml_tensor * attn_mask = in.attn_mask; + if (!attn_mask && w.head_dim >= 512) { + const int kv_len = kv_start + n_tokens; + // Pad to 256 — required by FATTN_KQ_STRIDE for TQ3 / large head_dim. + const int kv_len_padded = ((kv_len + 255) / 256) * 256; + attn_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, kv_len_padded, n_tokens); + ggml_set_name(attn_mask, "auto_causal_mask"); + ggml_set_input(attn_mask); + } + + ggml_tensor * inpL = in.inp_embed; // [n_embd, n_tokens] f32 + + // Gemma4 scales embeddings by sqrt(n_embd) (matches HF Gemma4TextScaledWordEmbedding) + inpL = ggml_scale(ctx, inpL, std::sqrt((float)n_embd)); + + for (int il = 0; il < w.n_layer; il++) { + const GemmaTargetLayer & L = w.layers[il]; + const bool is_swa = (il < (int)w.swa_layers.size()) ? w.swa_layers[il] : true; + + // ── a) Pre-attention RMSNorm ──────────────────────────────────────────── + ggml_tensor * inpSA = inpL; + ggml_tensor * cur = rms_norm_mul(ctx, inpL, L.attn_norm, EPS); + + // ── b-f) Attention (SWA or Full) ─────────────────────────────────────── + const int kv_idx = cache.layer_to_kv_idx[il]; + const bool write_kv = (kv_idx >= 0); + + // Determine which KV cache buffers to use for reading + const int read_kv_idx = write_kv ? kv_idx : cache.layer_to_donor_kv[il]; + ggml_tensor * cache_k = (read_kv_idx >= 0) ? cache.attn_k[read_kv_idx] : nullptr; + ggml_tensor * cache_v = (read_kv_idx >= 0) ? cache.attn_v[read_kv_idx] : nullptr; + + // Resolve per-layer KV types (asymmetric: TQ3 on SWA, Q8 on full-attn). + const ggml_type layer_kv_k = !cache.kv_k_type_per_layer.empty() + ? cache.kv_k_type_per_layer[il] : cache.kv_k_type; + const ggml_type layer_kv_v = !cache.kv_v_type_per_layer.empty() + ? cache.kv_v_type_per_layer[il] : cache.kv_v_type; + + if (is_swa) { + ggml_tensor * effective_mask = in.swa_mask ? in.swa_mask : attn_mask; + cur = build_swa_attn_block(ctx, gf, w, L, cur, in.positions, + cache_k, cache_v, effective_mask, + kv_start, n_tokens, + layer_kv_k, layer_kv_v, + write_kv, il); + } else { + cur = build_full_attn_block(ctx, gf, w, L, cur, in.positions, + cache_k, cache_v, attn_mask, + kv_start, n_tokens, + layer_kv_k, layer_kv_v, + write_kv, in.fa_window, il, + in.use_pflash, in.pflash_alpha); + } + + // ── g) Output projection already done inside attn block ──────────────── + + // ── h) Post-attention norm + residual ────────────────────────────────── + if (L.attn_post_norm) { + cur = rms_norm_mul(ctx, cur, L.attn_post_norm, EPS); + } + // NOTE: out_scale is applied AFTER the full layer (after FFN), not here + ggml_tensor * inpSA_post = ggml_add(ctx, cur, inpSA); + + // ── i) FFN ───────────────────────────────────────────────────────────── + ggml_tensor * ffn_residual = inpSA_post; + ggml_tensor * ffn_in = rms_norm_mul(ctx, inpSA_post, L.ffn_norm, EPS); + + ggml_tensor * ffn_out = nullptr; + if (L.ffn_gate_inp != nullptr) { + // MoE path (26B-A4B): shared expert uses ffn_norm, routed use ffn_pre_norm_2 + ggml_tensor * moe_in = L.ffn_pre_norm_2 + ? rms_norm_mul(ctx, inpSA_post, L.ffn_pre_norm_2, EPS) + : ffn_in; + ffn_out = build_moe_ffn(ctx, gf, w, L, + ffn_in, moe_in, inpSA_post, + n_tokens); + } else { + // Dense path (31B) + ffn_out = build_geglu_ffn(ctx, ffn_in, L); + } + + // Post-FFN norm + if (L.ffn_post_norm) { + ffn_out = rms_norm_mul(ctx, ffn_out, L.ffn_post_norm, EPS); + } + + cur = ggml_add(ctx, ffn_out, ffn_residual); + + // ── layer_output_scale: applied after full layer (attn + FFN residuals) ─ + // Matches HF: hidden_states = layer_scalar * (attn_residual + ffn_residual) + if (L.out_scale) { + cur = ggml_mul(ctx, cur, L.out_scale); + } + + // ── j) Per-Layer Embedding (PLE) ─────────────────────────────────────── + if (in.per_layer_inp && L.ple_inp_gate && L.ple_proj) { + // ple_inp_gate: gate projection + ggml_tensor * ple_gate = ggml_mul_mat(ctx, L.ple_inp_gate, cur); + ple_gate = ggml_gelu(ctx, ple_gate); + + // per_layer_inp is [n_embd_per_layer, n_tokens, n_layer] or similar; + // we select the slice for this layer along axis 2. + // Assuming per_layer_inp is [n_embd_per_layer, n_tokens] for this layer + // (caller pre-selects by layer index) — or it is [n_embd_per_layer, n_layer] + // shaped with the layer axis being dim 1. + // Use a view to extract the il-th column if per_layer_inp has n_layer cols. + const int n_embd_per_layer = w.n_embd_per_layer > 0 ? w.n_embd_per_layer + : (int)in.per_layer_inp->ne[0]; + ggml_tensor * ple_emb; + if (ggml_n_dims(in.per_layer_inp) >= 3 || (int)in.per_layer_inp->ne[1] == w.n_layer) { + // Shape [n_embd_per_layer, n_layer] or [n_embd_per_layer, n_tokens, n_layer] + ple_emb = ggml_view_2d(ctx, in.per_layer_inp, + n_embd_per_layer, n_tokens, + in.per_layer_inp->nb[1], + (size_t)il * n_tokens * in.per_layer_inp->nb[1]); + } else { + // Already sliced per-layer by caller + ple_emb = in.per_layer_inp; + } + + ggml_tensor * ple = ggml_mul(ctx, ple_gate, ple_emb); + ple = ggml_mul_mat(ctx, L.ple_proj, ple); + if (L.ple_post_norm) { + ple = rms_norm_mul(ctx, ple, L.ple_post_norm, EPS); + } + cur = ggml_add(ctx, cur, ple); + } + + // ── k) Target feature capture ────────────────────────────────────────── + if (in.capture_layers && cache.target_feat) { + int capture_idx = -1; + for (int k = 0; k < w.n_capture_layers; k++) { + if (w.capture_layer_ids[k] == il) { capture_idx = k; break; } + } + if (capture_idx >= 0) { + const size_t elt = ggml_element_size(cache.target_feat); + const size_t col_stride = cache.target_feat->nb[1]; + const int cap = cache.target_feat_cap; + const int slot_start = kv_start % cap; + const int pre_n = std::min(n_tokens, cap - slot_start); + const int post_n = n_tokens - pre_n; + + ggml_tensor * cur_2d = ggml_reshape_2d(ctx, cur, n_embd, n_tokens); + + // First slice: [slot_start..slot_start+pre_n) in the ring + { + const size_t offset = + (size_t)slot_start * col_stride + + (size_t)capture_idx * n_embd * elt; + ggml_tensor * slot = ggml_view_2d(ctx, cache.target_feat, + n_embd, pre_n, col_stride, offset); + ggml_tensor * src = ggml_view_2d(ctx, cur_2d, + n_embd, pre_n, cur_2d->nb[1], 0); + ggml_build_forward_expand(gf, ggml_cpy(ctx, src, slot)); + } + + // Second slice: wrap-around at [0..post_n) if needed + if (post_n > 0) { + const size_t offset = + (size_t)capture_idx * n_embd * elt; + ggml_tensor * slot = ggml_view_2d(ctx, cache.target_feat, + n_embd, post_n, col_stride, offset); + ggml_tensor * src = ggml_view_2d(ctx, cur_2d, + n_embd, post_n, cur_2d->nb[1], + (size_t)pre_n * cur_2d->nb[1]); + ggml_build_forward_expand(gf, ggml_cpy(ctx, src, slot)); + } + } + } + + // ── l) Advance residual stream ────────────────────────────────────────── + inpL = cur; + } + + // ── Final norm ───────────────────────────────────────────────────────────── + ggml_tensor * out = rms_norm_mul(ctx, inpL, w.out_norm, EPS); + + // ── MTP h_prev capture (post-output-norm, last token) ────────────────────── + // h_prev must be the backbone hidden AFTER final RMSNorm — the same vector + // fed to lm_head — so the MTP draft head sees the same representation as + // the target's token prediction. Capturing inside the layer loop (pre-norm) + // caused accept_rate=0 because the draft head was trained on post-norm hiddens. + // Source: vLLM PR #41745:569-621 + llama.cpp #22738. + if (cache.mtp_h_prev_enabled && cache.mtp_h_prev_capture_mode == 1 + && cache.mtp_h_prev_batch && n_tokens > 1) { + // Approach B: write all n_tokens rows of post-final-norm hidden into + // the first n_tokens columns of mtp_h_prev_batch. The γ>1 driver + // then picks the correct column host-side after greedy match; no + // extra re-capture forward is needed. + const int n_embd_hp = (int)cache.mtp_h_prev_batch->ne[0]; + GGML_ASSERT(n_tokens <= (int)cache.mtp_h_prev_batch->ne[1]); + ggml_tensor * src = out; // [n_embd, n_tokens] + if (src->type != GGML_TYPE_F32) { + src = ggml_cast(ctx, src, GGML_TYPE_F32); + } + // Destination view: first n_tokens columns of mtp_h_prev_batch. + ggml_tensor * dst_view = ggml_view_2d(ctx, cache.mtp_h_prev_batch, + n_embd_hp, n_tokens, + ggml_row_size(cache.mtp_h_prev_batch->type, n_embd_hp), + /* offset = */ 0); + ggml_build_forward_expand(gf, ggml_cpy(ctx, src, dst_view)); + } else if (cache.mtp_h_prev_enabled && cache.mtp_h_prev) { + const int n_embd_hp = (int)cache.mtp_h_prev->ne[0]; + // Row to capture from the [n_embd, n_tokens] tensor. Default (sentinel + // -1) is the last row, matching the γ=1 contract. For γ>1 partial + // accept, the driver sets cache.mtp_h_prev_row = accept_n - 1 so we + // capture the last *accepted* hidden, not the last *speculative* one. + // See plan: /home/peppi/.claude/plans/wild-growing-ember.md Phase 2. + const int capture_row = (cache.mtp_h_prev_row >= 0) + ? cache.mtp_h_prev_row + : (n_tokens - 1); + GGML_ASSERT(capture_row >= 0 && capture_row < n_tokens); + ggml_tensor * h_prev_src = out; + if (n_tokens > 1) { + h_prev_src = ggml_view_2d(ctx, out, + n_embd_hp, 1, + ggml_row_size(out->type, n_embd_hp), + ggml_row_size(out->type, n_embd_hp) * capture_row); + } + if (h_prev_src->type != GGML_TYPE_F32) { + h_prev_src = ggml_cast(ctx, h_prev_src, GGML_TYPE_F32); + } + h_prev_src = ggml_reshape_2d(ctx, h_prev_src, n_embd_hp, 1); + ggml_build_forward_expand(gf, ggml_cpy(ctx, h_prev_src, cache.mtp_h_prev)); + } + + // ── last_token_logits_only: slice to the final token before lm_head ──────── + // During chunked prefill we only need the last token's logits to seed decode. + // Slicing here reduces lm_head compute from O(n_tokens) to O(1) and avoids + // allocating a [vocab, n_tokens] output tensor (saves ~1 GB for chunk_size=1024). + if (in.last_token_logits_only && n_tokens > 1) { + out = ggml_view_2d(ctx, out, + n_embd, 1, + ggml_row_size(out->type, n_embd), + ggml_row_size(out->type, n_embd) * (n_tokens - 1)); + } + + // ── LM head ──────────────────────────────────────────────────────────────── + ggml_tensor * logits = ggml_mul_mat(ctx, w.output, out); + + // ── Logit softcapping: logits = softcap * tanh(logits / softcap) ────────── + if (w.logit_softcap > 0.0f) { + logits = ggml_scale(ctx, logits, 1.0f / w.logit_softcap); + logits = ggml_tanh(ctx, logits); + logits = ggml_scale(ctx, logits, w.logit_softcap); + } + + ggml_set_name(logits, "logits"); + ggml_build_forward_expand(gf, logits); + + GemmaGraphOutputs og{}; + og.logits = logits; + return og; +} + +} // namespace dflash27b diff --git a/dflash/src/gemma4_target_loader.cpp b/dflash/src/gemma4_target_loader.cpp new file mode 100644 index 000000000..d8823770d --- /dev/null +++ b/dflash/src/gemma4_target_loader.cpp @@ -0,0 +1,1200 @@ +// Loads a Gemma4 target model (31B Dense or 26B-A4B MoE) from a GGUF file into +// a GemmaTargetWeights struct backed by the supplied ggml backend (typically +// CUDA). +// +// The expected GGUF architecture string is "gemma4". The loader supports both +// the dense variant (60 layers, pure SwiGLU FFN) and the MoE variant (30 +// layers, sparse expert FFN on the "26B-A4B" config). +// +// Tensor naming follows llama.cpp's gemma4-iswa.cpp conventions: +// +// Global: +// token_embd.weight [n_embd, n_vocab] +// output_norm.weight [n_embd] +// output.weight [n_vocab, n_embd] (optional; falls back) +// +// Per-Layer Embedding (PLE, present when n_embd_per_layer > 0): +// per_layer_token_embd.weight [n_embd_per_layer * n_layer, n_vocab] +// per_layer_model_proj.weight [n_embd, n_embd_per_layer * n_layer] +// per_layer_proj_norm.weight [n_embd_per_layer] +// blk.{i}.inp_gate.weight [n_embd, n_embd_per_layer] +// blk.{i}.proj.weight [n_embd_per_layer, n_embd] +// blk.{i}.post_norm.weight [n_embd] +// +// Per-Layer Attention: +// blk.{i}.attn_norm.weight [n_embd] +// blk.{i}.attn_q.weight [n_embd, n_head * head_dim] +// blk.{i}.attn_k.weight [n_embd, n_head_kv * head_dim] (optional) +// blk.{i}.attn_v.weight [n_embd, n_head_kv * head_dim] (optional) +// blk.{i}.attn_output.weight [n_head * head_dim, n_embd] +// blk.{i}.attn_q_norm.weight [head_dim] +// blk.{i}.attn_k_norm.weight [head_dim] (optional) +// blk.{i}.attn_post_norm.weight [n_embd] +// blk.{i}.rope_freqs.weight [head_dim/2] (full-attn layers only) +// blk.{i}.out_scale.weight [1] (optional) +// +// Per-Layer FFN (SwiGLU): +// blk.{i}.ffn_norm.weight [n_embd] +// blk.{i}.ffn_gate.weight [n_embd, n_ff] +// blk.{i}.ffn_up.weight [n_embd, n_ff] +// blk.{i}.ffn_down.weight [n_ff, n_embd] +// blk.{i}.ffn_post_norm.weight [n_embd] +// +// Per-Layer MoE (26B-A4B only, present when n_expert > 0): +// blk.{i}.ffn_gate_inp.weight [n_embd, n_expert] +// blk.{i}.ffn_gate_inp.scale [n_embd] (optional) +// blk.{i}.ffn_pre_norm_2.weight [n_embd] +// blk.{i}.ffn_gate_up_exps.weight [n_embd, n_ff_exp*2, n_expert] +// blk.{i}.ffn_down_exps.weight [n_ff_exp, n_embd, n_expert] +// blk.{i}.ffn_down_exps.scale [n_expert] (optional) +// blk.{i}.ffn_post_norm_1.weight [n_embd] +// blk.{i}.ffn_post_norm_2.weight [n_embd] +// +// KV-sharing: layers with index >= (n_layer - n_kv_shared_layers) omit wk, wv, +// k_norm. Their KV is borrowed from the last non-shared layer of the same +// attention type. layer_to_kv_idx maps each layer to its KV cache slot; +// layer_to_donor_kv maps shared layers to their donor layer index. + +#include "internal.h" + +#include +#include +#include +#include +#include +#include +#include + +#if !defined(_WIN32) +#include +#include +#include +#include +#include +#endif + +namespace dflash27b { + +namespace { + +// ─── Thin mmap wrapper ─────────────────────────────────────────────────────── +// Mirrors the Mmap struct from gguf_target_loader.cpp. Ownership can be +// transferred to a CpuEmbedder via release(). + +struct Mmap { + void * addr = nullptr; + size_t len = 0; +#if defined(_WIN32) + HANDLE hFile = INVALID_HANDLE_VALUE; + HANDLE hMap = nullptr; +#else + int fd = -1; +#endif + + bool open_ro(const std::string & path, std::string & err) { +#if defined(_WIN32) + hFile = CreateFileA(path.c_str(), GENERIC_READ, FILE_SHARE_READ, + nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr); + if (hFile == INVALID_HANDLE_VALUE) { + err = "CreateFileA: " + path + ": error " + std::to_string(GetLastError()); + return false; + } + LARGE_INTEGER sz; + if (!GetFileSizeEx(hFile, &sz)) { + err = "GetFileSizeEx: error " + std::to_string(GetLastError()); + return false; + } + len = (size_t)sz.QuadPart; + hMap = CreateFileMappingA(hFile, nullptr, PAGE_READONLY, 0, 0, nullptr); + if (!hMap) { + err = "CreateFileMappingA: error " + std::to_string(GetLastError()); + return false; + } + addr = MapViewOfFile(hMap, FILE_MAP_READ, 0, 0, 0); + if (!addr) { + err = "MapViewOfFile: error " + std::to_string(GetLastError()); + return false; + } +#else + fd = ::open(path.c_str(), O_RDONLY); + if (fd < 0) { err = "open: " + path + ": " + std::strerror(errno); return false; } + struct stat st; + if (::fstat(fd, &st) < 0) { err = "fstat: " + std::string(std::strerror(errno)); return false; } + len = (size_t)st.st_size; + addr = ::mmap(nullptr, len, PROT_READ, MAP_PRIVATE, fd, 0); + if (addr == MAP_FAILED) { err = "mmap: " + std::string(std::strerror(errno)); addr = nullptr; return false; } +#endif + return true; + } + + void release() { + addr = nullptr; + len = 0; +#if defined(_WIN32) + hFile = INVALID_HANDLE_VALUE; + hMap = nullptr; +#else + fd = -1; +#endif + } + + ~Mmap() { +#if defined(_WIN32) + if (addr) UnmapViewOfFile(addr); + if (hMap) CloseHandle(hMap); + if (hFile != INVALID_HANDLE_VALUE) CloseHandle(hFile); +#else + if (addr) ::munmap(addr, len); + if (fd >= 0) ::close(fd); +#endif + } +}; + +// ─── GGUF metadata helpers ─────────────────────────────────────────────────── + +static uint32_t get_u32_or(const gguf_context * g, const char * key, uint32_t fallback) { + int64_t id = gguf_find_key(g, key); + if (id < 0) return fallback; + return gguf_get_val_u32(g, id); +} + +static float get_f32_or(const gguf_context * g, const char * key, float fallback) { + int64_t id = gguf_find_key(g, key); + if (id < 0) return fallback; + return gguf_get_val_f32(g, id); +} + +static size_t align_up(size_t x, size_t a) { + if (a == 0) return x; + const size_t r = x % a; + return r == 0 ? x : x + (a - r); +} + +// ─── Tensor selection filter ───────────────────────────────────────────────── +// +// All tensors go to GPU, including token_embd.weight which doubles as the LM +// head (tied weights in Gemma4-26B-A4B). The CPU embedder keeps its own +// read-only mmap view of tok_embd for the input embedding path, so placing +// it on GPU as well is safe and necessary for correct LM head logits. + +static bool is_gemma4_gpu_tensor(const char * name) { + (void)name; + return true; +} + +} // namespace + +// ─── load_gemma4_target_gguf ───────────────────────────────────────────────── + +bool load_gemma4_target_gguf(const std::string & path, + ggml_backend_t backend, + GemmaTargetWeights & out) { + + // ── 1. Parse GGUF metadata ──────────────────────────────────────────────── + + ggml_context * meta_ctx = nullptr; + gguf_init_params gip{}; + gip.no_alloc = true; + gip.ctx = &meta_ctx; + gguf_context * gctx = gguf_init_from_file(path.c_str(), gip); + if (!gctx) { + set_last_error("gguf_init_from_file failed: " + path); + return false; + } + + // Validate architecture string. + { + int64_t arch_id = gguf_find_key(gctx, "general.architecture"); + if (arch_id < 0) { + set_last_error("missing general.architecture"); + gguf_free(gctx); + return false; + } + const char * arch = gguf_get_val_str(gctx, arch_id); + if (std::string(arch) != "gemma4") { + set_last_error(std::string("unexpected arch: ") + arch + " (expected gemma4)"); + gguf_free(gctx); + return false; + } + } + + // Read required architecture hyperparameters. + const uint32_t n_embd = get_u32_or(gctx, "gemma4.embedding_length", 0); + const uint32_t n_layer = get_u32_or(gctx, "gemma4.block_count", 0); + const uint32_t n_ff = get_u32_or(gctx, "gemma4.feed_forward_length", 0); + const uint32_t n_head = get_u32_or(gctx, "gemma4.attention.head_count", 0); + // Fix A: head_count_kv may be a per-layer INT32 array, not a scalar + std::vector head_kv_per_layer; + uint32_t n_head_kv_max = 0; + { + int64_t kv_id = gguf_find_key(gctx, "gemma4.attention.head_count_kv"); + if (kv_id >= 0) { + enum gguf_type kv_type = gguf_get_kv_type(gctx, kv_id); + if (kv_type == GGUF_TYPE_ARRAY) { + size_t arr_n = gguf_get_arr_n(gctx, kv_id); + const int32_t * arr = (const int32_t *)gguf_get_arr_data(gctx, kv_id); + head_kv_per_layer.resize(arr_n); + for (size_t i = 0; i < arr_n; i++) { + head_kv_per_layer[i] = (int)arr[i]; + if ((uint32_t)arr[i] > n_head_kv_max) n_head_kv_max = (uint32_t)arr[i]; + } + } else { + // Scalar fallback + n_head_kv_max = gguf_get_val_u32(gctx, kv_id); + } + } + } + const uint32_t n_head_kv = n_head_kv_max; + + // Fix D: read both full-attn and SWA head dims + const uint32_t head_dim = get_u32_or(gctx, "gemma4.attention.key_length", 0); + const uint32_t head_dim_swa = get_u32_or(gctx, "gemma4.attention.key_length_swa", head_dim); + + // Fix B: vocab_size key may be absent — fall back to tokenizer array length + uint32_t n_vocab = get_u32_or(gctx, "gemma4.vocab_size", 0); + if (n_vocab == 0) { + int64_t tok_id = gguf_find_key(gctx, "tokenizer.ggml.tokens"); + if (tok_id >= 0) n_vocab = (uint32_t)gguf_get_arr_n(gctx, tok_id); + } + const uint32_t swa_win = get_u32_or(gctx, "gemma4.attention.sliding_window", 1024); + const uint32_t n_kv_shared = get_u32_or(gctx, "gemma4.attention.shared_kv_layers", 0); + const uint32_t n_embd_per_layer = get_u32_or(gctx, "gemma4.embedding_length_per_layer_input", 0); + const uint32_t n_expert = get_u32_or(gctx, "gemma4.expert_count", 0); + const uint32_t n_expert_used = get_u32_or(gctx, "gemma4.expert_used_count", 0); + const uint32_t n_ff_exp = get_u32_or(gctx, "gemma4.expert_feed_forward_length", 0); + + const float rope_theta = get_f32_or(gctx, "gemma4.rope.freq_base", 1000000.0f); + const float rope_theta_swa = get_f32_or(gctx, "gemma4.rope.freq_base_swa", 1000000.0f); + const float logit_softcap = get_f32_or(gctx, "gemma4.final_logit_softcapping", 30.0f); + + if (n_embd == 0 || n_layer == 0 || n_ff == 0 || + n_head == 0 || n_head_kv == 0 || head_dim == 0 || n_vocab == 0) { + char buf[256]; + std::snprintf(buf, sizeof(buf), + "missing or zero required hparams: n_embd=%u n_layer=%u n_ff=%u " + "n_head=%u n_head_kv=%u head_dim=%u n_vocab=%u", + n_embd, n_layer, n_ff, n_head, n_head_kv, head_dim, n_vocab); + set_last_error(buf); + gguf_free(gctx); + return false; + } + + // ── 2. Build the per-layer SWA pattern ─────────────────────────────────── + // + // swa_layers[il] != 0 → sliding-window attention; == 0 → full attention. + // The array is stored as GGUF_TYPE_ARRAY of INT32 or BOOL. If absent we + // default to alternating: odd layers use SWA, even layers use full attn + // (matches Gemma4-31B's default pattern). + + std::vector swa_layers(n_layer, false); + { + int64_t swa_arr_id = gguf_find_key(gctx, "gemma4.attention.sliding_window_pattern"); + // Fix C: sliding_window_pattern may be BOOL array (1-byte), not INT32 + if (swa_arr_id >= 0) { + size_t arr_n = gguf_get_arr_n(gctx, swa_arr_id); + enum gguf_type arr_type = gguf_get_arr_type(gctx, swa_arr_id); + const void * arr_data = gguf_get_arr_data(gctx, swa_arr_id); + for (size_t i = 0; i < arr_n && i < n_layer; i++) { + if (arr_type == GGUF_TYPE_BOOL || arr_type == GGUF_TYPE_INT8 || arr_type == GGUF_TYPE_UINT8) { + swa_layers[i] = (((const uint8_t *)arr_data)[i] != 0); + } else { + swa_layers[i] = (((const int32_t *)arr_data)[i] != 0); + } + } + } else { + // Fallback: odd-indexed layers → SWA, even → full attention. + for (uint32_t i = 0; i < n_layer; i++) { + swa_layers[i] = ((i % 2) == 1); + } + } + } + + // ── 3. Build KV-sharing maps ────────────────────────────────────────────── + // + // Layers [0, n_layer - n_kv_shared_layers) own their own KV cache slot. + // Layers [n_layer - n_kv_shared_layers, n_layer) are KV-shared: they borrow + // KV from the last non-shared layer that has the same attention type (SWA + // or full). layer_to_kv_idx[il] == -1 for shared layers. + + const int n_non_shared = (int)n_layer - (int)n_kv_shared; + if (n_non_shared < 0) { + char buf[128]; + std::snprintf(buf, sizeof(buf), + "n_kv_shared_layers=%u > n_layer=%u", n_kv_shared, n_layer); + set_last_error(buf); + gguf_free(gctx); + return false; + } + + std::vector layer_to_kv_idx((size_t)n_layer, -1); + std::vector layer_to_donor_kv((size_t)n_layer, -1); + { + int kv_slot = 0; + for (int il = 0; il < n_non_shared; il++) { + layer_to_kv_idx[il] = kv_slot++; + } + // Shared layers find their donor: the last non-shared layer with the + // same attention type. + for (int il = n_non_shared; il < (int)n_layer; il++) { + bool is_swa = swa_layers[(size_t)il]; + int donor = -1; + for (int j = n_non_shared - 1; j >= 0; j--) { + if (swa_layers[(size_t)j] == is_swa) { donor = j; break; } + } + layer_to_donor_kv[il] = donor; + // kv_idx stays -1 (no dedicated slot). + } + } + const int n_kv_slots = n_non_shared; // total distinct KV cache entries + + // ── 4. Populate struct metadata ────────────────────────────────────────── + + out.ctx = meta_ctx; + out.backend = backend; + + out.n_embd = (int)n_embd; + out.n_head = (int)n_head; + out.n_head_kv = (int)n_head_kv; + out.head_dim = (int)head_dim; + out.head_dim_swa = (int)head_dim_swa; + out.head_kv_per_layer = head_kv_per_layer; + out.n_layer = (int)n_layer; + out.n_ff = (int)n_ff; + out.n_vocab = (int)n_vocab; + out.n_embd_per_layer = (int)n_embd_per_layer; + out.swa_window = (int)swa_win; + out.swa_layers = swa_layers; + out.n_kv_shared_layers = (int)n_kv_shared; + out.n_layer_kv = n_kv_slots; + out.rope_theta = rope_theta; + out.rope_theta_swa = rope_theta_swa; + out.n_expert = (int)n_expert; + out.n_expert_used = (int)n_expert_used; + out.n_ff_exp = (int)n_ff_exp; + out.logit_softcap = logit_softcap; + + // BOS / EOS tokens (missing key → -1) + { + const uint32_t kMissing = 0xFFFFFFFFu; + const uint32_t raw_bos = get_u32_or(gctx, "tokenizer.ggml.bos_token_id", kMissing); + const uint32_t raw_eos = get_u32_or(gctx, "tokenizer.ggml.eos_token_id", kMissing); + const uint32_t raw_eot = get_u32_or(gctx, "tokenizer.ggml.eot_token_id", kMissing); + out.bos_id = (raw_bos == kMissing) ? -1 : (int32_t)raw_bos; + out.eos_id = (raw_eos == kMissing) ? -1 : (int32_t)raw_eos; + out.eos_chat_id = (raw_eot == kMissing) ? -1 : (int32_t)raw_eot; + + // Gemma4 fallback: (107) is the chat stop token. + // Many GGUFs omit eot_token_id; default to 107 when missing. + if (out.eos_chat_id < 0) { + out.eos_chat_id = 107; + } + + std::printf("[gemma4_loader] bos_id=%d eos_id=%d eos_chat_id=%d\n", + out.bos_id, out.eos_id, out.eos_chat_id); + } + + // ── 5. Compute capture_layer_ids ───────────────────────────────────────── + // + // Use hardcoded values from the DFlash draft model config.json. + // Fallback to evenly-spaced formula for unknown layer counts. + { + const int N = GEMMA4_DRAFT_N_TARGET_LAYERS; // 6 + if ((int)n_layer == 30) { + // Gemma4-26B-A4B — from z-lab/gemma-4-26B-A4B-it-DFlash config.json + const int ids[6] = {1, 6, 11, 17, 22, 27}; + for (int k = 0; k < N; k++) out.capture_layer_ids[k] = ids[k]; + } else if ((int)n_layer == 60) { + // Gemma4-31B — from z-lab/gemma-4-31B-it-DFlash config.json + const int ids[6] = {1, 12, 23, 35, 46, 57}; + for (int k = 0; k < N; k++) out.capture_layer_ids[k] = ids[k]; + } else { + // Fallback: evenly spaced + const int step = ((int)n_layer - 2) / (N - 1); + for (int k = 0; k < N; k++) out.capture_layer_ids[k] = 1 + k * step; + } + std::printf("[gemma4_loader] capture_layer_ids:"); + for (int k = 0; k < N; k++) std::printf(" %d", out.capture_layer_ids[k]); + std::printf("\n"); + } + + // ── 6. Wire tensor pointers ─────────────────────────────────────────────── + + auto g = [&](const char * name) -> ggml_tensor * { + return ggml_get_tensor(meta_ctx, name); + }; + + out.tok_embd = g("token_embd.weight"); + out.out_norm = g("output_norm.weight"); + // output.weight is optional; fall back to token_embd for tied weights. + out.output = g("output.weight"); + if (!out.output) out.output = out.tok_embd; + + if (!out.tok_embd || !out.out_norm) { + set_last_error("missing top-level tensors (token_embd.weight / output_norm.weight)"); + gguf_free(gctx); + return false; + } + + // Global PLE tensors (present only when n_embd_per_layer > 0) + if (n_embd_per_layer > 0) { + out.per_layer_tok_embd = g("per_layer_token_embd.weight"); + out.per_layer_model_proj = g("per_layer_model_proj.weight"); + out.per_layer_proj_norm = g("per_layer_proj_norm.weight"); + if (!out.per_layer_tok_embd || !out.per_layer_model_proj || !out.per_layer_proj_norm) { + set_last_error("n_embd_per_layer > 0 but PLE global tensors missing"); + gguf_free(gctx); + return false; + } + } + + // Load global rope_freqs tensor (full-attention layers use this for proportional RoPE). + // Gemma4 stores one shared rope_freqs.weight (not per-layer blk.{i}.rope_freqs.weight). + // All full-attention layers share this single tensor, matching llama.cpp's TENSOR_DUPLICATED + // pattern (llama-model.cpp:4657-4658). + ggml_tensor * global_rope_freqs = g("rope_freqs.weight"); + + // Per-layer tensors. + out.layers.assign((size_t)n_layer, GemmaTargetLayer{}); + + for (int il = 0; il < (int)n_layer; il++) { + char name[160]; + auto fnd = [&](const char * suffix) -> ggml_tensor * { + std::snprintf(name, sizeof(name), "blk.%d.%s", il, suffix); + return ggml_get_tensor(meta_ctx, name); + }; + + GemmaTargetLayer & L = out.layers[(size_t)il]; + + // ── Attention (always present) ──────────────────────────────────────── + L.attn_norm = fnd("attn_norm.weight"); + L.wq = fnd("attn_q.weight"); + L.wo = fnd("attn_output.weight"); + L.q_norm = fnd("attn_q_norm.weight"); + // This GGUF uses "post_attention_norm.weight"; fall back to legacy name + L.attn_post_norm = fnd("post_attention_norm.weight"); + if (!L.attn_post_norm) L.attn_post_norm = fnd("attn_post_norm.weight"); + + if (!L.attn_norm || !L.wq || !L.wo || !L.q_norm || !L.attn_post_norm) { + char b[128]; + std::snprintf(b, sizeof(b), "layer %d: missing required attention tensor", il); + set_last_error(b); + gguf_free(gctx); + return false; + } + + // wk, wv, k_norm — absent for KV-shared layers (il >= n_non_shared). + const bool is_kv_owner = (il < n_non_shared); + if (is_kv_owner) { + L.wk = fnd("attn_k.weight"); + L.wv = fnd("attn_v.weight"); + L.k_norm = fnd("attn_k_norm.weight"); + if (!L.wk) { + char b[128]; + std::snprintf(b, sizeof(b), "layer %d: expected wk (non-shared), missing", il); + set_last_error(b); + gguf_free(gctx); + return false; + } + // V may be absent on full-attention layers where V == K (shared K/V). + if (!L.wv) { + L.wv = L.wk; + } + // k_norm may be absent for SWA layers in some checkpoints; allow nullptr. + } + + // Optional per-layer tensors + L.rope_freqs = fnd("rope_freqs.weight"); + // Full-attention layers use proportional RoPE via rope_freqs (freq_factors). + // Gemma4 stores a single global rope_freqs.weight (no per-layer blk.{i} variant). + // Fall back to the global tensor for full-attention layers when the per-layer + // variant is absent (which is always the case for this GGUF format). + if (!L.rope_freqs && !swa_layers[(size_t)il] && global_rope_freqs) { + L.rope_freqs = global_rope_freqs; + } + // This GGUF uses "layer_output_scale.weight"; fall back to legacy name + L.out_scale = fnd("layer_output_scale.weight"); + if (!L.out_scale) L.out_scale = fnd("out_scale.weight"); + + // ── FFN (always present) ────────────────────────────────────────────── + L.ffn_norm = fnd("ffn_norm.weight"); + L.w_gate = fnd("ffn_gate.weight"); + L.w_up = fnd("ffn_up.weight"); + L.w_down = fnd("ffn_down.weight"); + // This GGUF uses "post_ffw_norm.weight"; fall back to legacy name + L.ffn_post_norm = fnd("post_ffw_norm.weight"); + if (!L.ffn_post_norm) L.ffn_post_norm = fnd("ffn_post_norm.weight"); + + if (!L.ffn_norm || !L.w_gate || !L.w_up || !L.w_down || !L.ffn_post_norm) { + char b[128]; + std::snprintf(b, sizeof(b), "layer %d: missing required FFN tensor", il); + set_last_error(b); + gguf_free(gctx); + return false; + } + + // ── MoE (26B-A4B — present when n_expert > 0) ──────────────────────── + if (n_expert > 0) { + L.ffn_gate_inp = fnd("ffn_gate_inp.weight"); + L.ffn_gate_inp_s = fnd("ffn_gate_inp.scale"); + // This GGUF uses "pre_ffw_norm_2.weight"; fall back to legacy name + L.ffn_pre_norm_2 = fnd("pre_ffw_norm_2.weight"); + if (!L.ffn_pre_norm_2) L.ffn_pre_norm_2 = fnd("ffn_pre_norm_2.weight"); + L.ffn_gate_up_exps = fnd("ffn_gate_up_exps.weight"); + L.ffn_down_exps = fnd("ffn_down_exps.weight"); + L.ffn_down_exps_s = fnd("ffn_down_exps.scale"); + // This GGUF uses "post_ffw_norm_1/2.weight"; fall back to legacy names + L.ffn_post_norm_1 = fnd("post_ffw_norm_1.weight"); + if (!L.ffn_post_norm_1) L.ffn_post_norm_1 = fnd("ffn_post_norm_1.weight"); + L.ffn_post_norm_2 = fnd("post_ffw_norm_2.weight"); + if (!L.ffn_post_norm_2) L.ffn_post_norm_2 = fnd("ffn_post_norm_2.weight"); + + if (!L.ffn_gate_inp || !L.ffn_pre_norm_2 || + !L.ffn_gate_up_exps || !L.ffn_down_exps || + !L.ffn_post_norm_1 || !L.ffn_post_norm_2) { + char b[128]; + std::snprintf(b, sizeof(b), "layer %d: MoE model but missing expert tensor", il); + set_last_error(b); + gguf_free(gctx); + return false; + } + // ffn_gate_inp_s, ffn_down_exps_s are optional quantization scales. + } + + // ── Per-Layer Embedding (PLE) ───────────────────────────────────────── + if (n_embd_per_layer > 0) { + L.ple_inp_gate = fnd("inp_gate.weight"); + L.ple_proj = fnd("proj.weight"); + L.ple_post_norm = fnd("post_norm.weight"); + if (!L.ple_inp_gate || !L.ple_proj || !L.ple_post_norm) { + char b[128]; + std::snprintf(b, sizeof(b), "layer %d: PLE model but missing per-layer embedding tensor", il); + set_last_error(b); + gguf_free(gctx); + return false; + } + } + } + + // ── 7. Allocate GPU buffer ──────────────────────────────────────────────── + // + // Walk all GGUF tensors, skip token_embd.weight (stays CPU), accumulate + // aligned sizes, allocate one contiguous backend buffer, assign each tensor. + + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + const size_t alignment = ggml_backend_buft_get_alignment(buft); + + struct TensorSlot { + ggml_tensor * tensor = nullptr; + size_t file_offset = 0; + size_t file_size = 0; + size_t buf_offset = 0; + }; + + std::vector slots; + size_t total_gpu = 0; + const int64_t n_tensors = gguf_get_n_tensors(gctx); + for (int64_t tid = 0; tid < n_tensors; tid++) { + const char * tname = gguf_get_tensor_name(gctx, tid); + if (!is_gemma4_gpu_tensor(tname)) continue; + ggml_tensor * t = ggml_get_tensor(meta_ctx, tname); + if (!t) continue; + total_gpu = align_up(total_gpu, alignment); + TensorSlot s; + s.tensor = t; + s.file_offset = gguf_get_data_offset(gctx) + gguf_get_tensor_offset(gctx, tid); + s.file_size = gguf_get_tensor_size(gctx, tid); + s.buf_offset = total_gpu; + total_gpu += ggml_backend_buft_get_alloc_size(buft, t); + slots.push_back(s); + } + + if (slots.empty()) { + set_last_error("no GPU tensors found in gemma4 GGUF"); + gguf_free(gctx); + return false; + } + + // Cleanup helper: release any GPU buffer and ggml context already assigned + // to `out` before returning false. Must be called on every failure path + // after out.buf has been (or is about to be) allocated. + auto cleanup_out = [&]() { + if (out.buf) { + ggml_backend_buffer_free(out.buf); + out.buf = nullptr; + } + // out.ctx == meta_ctx; free it so the caller doesn't leak the graph. + if (out.ctx) { + ggml_free(out.ctx); + out.ctx = nullptr; + } + out = GemmaTargetWeights{}; + }; + + out.buf = ggml_backend_alloc_buffer(backend, total_gpu); + if (!out.buf) { + set_last_error("ggml_backend_alloc_buffer failed (gemma4 target)"); + gguf_free(gctx); + cleanup_out(); + return false; + } + ggml_backend_buffer_set_usage(out.buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + + char * base = (char *)ggml_backend_buffer_get_base(out.buf); + for (const TensorSlot & s : slots) { + if (ggml_backend_tensor_alloc(out.buf, s.tensor, base + s.buf_offset) != GGML_STATUS_SUCCESS) { + set_last_error("ggml_backend_tensor_alloc failed (gemma4 target)"); + gguf_free(gctx); + cleanup_out(); + return false; + } + } + + // ── 8. mmap file, upload GPU tensors, keep tok_embd on CPU ─────────────── + + std::string err; + Mmap mm; + if (!mm.open_ro(path, err)) { + set_last_error(err); + gguf_free(gctx); + cleanup_out(); + return false; + } + + const size_t data_start = gguf_get_data_offset(gctx); + size_t gpu_bytes_uploaded = 0; + size_t tok_embd_off = 0; + size_t tok_embd_sz = 0; + ggml_type tok_embd_type = GGML_TYPE_COUNT; + + for (int64_t tid = 0; tid < n_tensors; tid++) { + const char * tname = gguf_get_tensor_name(gctx, tid); + ggml_tensor * t = ggml_get_tensor(meta_ctx, tname); + if (!t) continue; + const size_t off = data_start + gguf_get_tensor_offset(gctx, tid); + const size_t sz = gguf_get_tensor_size(gctx, tid); + if (off + sz > mm.len) { + set_last_error(std::string("tensor '") + tname + "' overflows file"); + gguf_free(gctx); + cleanup_out(); + return false; + } + if (std::strcmp(tname, "token_embd.weight") == 0) { + tok_embd_off = off; + tok_embd_sz = sz; + tok_embd_type = gguf_get_tensor_type(gctx, tid); + // fall through: also upload to GPU for LM head (tied weights) + } + ggml_backend_tensor_set(t, (const uint8_t *)mm.addr + off, 0, sz); + gpu_bytes_uploaded += sz; + } + + gguf_free(gctx); + + if (tok_embd_off == 0 || tok_embd_type == GGML_TYPE_COUNT) { + set_last_error("token_embd.weight not found or invalid type"); + cleanup_out(); + return false; + } + + // Fix 2: validate tok_embd_sz divisibility before computing row stride. + if (n_vocab == 0 || tok_embd_sz % (size_t)n_vocab != 0) { + set_last_error("malformed GGUF: tok_embd_sz=" + std::to_string(tok_embd_sz) + + " not divisible by n_vocab=" + std::to_string(n_vocab)); + cleanup_out(); + return false; + } + + // ── 9. Transfer mmap ownership to CpuEmbedder ──────────────────────────── + + out.embedder.mmap_addr = mm.addr; + out.embedder.mmap_len = mm.len; +#if defined(_WIN32) + out.embedder.mmap_hfile = mm.hFile; + out.embedder.mmap_hmap = mm.hMap; +#else + out.embedder.mmap_fd = mm.fd; +#endif + out.embedder.tok_embd_bytes = (const uint8_t *)mm.addr + tok_embd_off; + out.embedder.tok_embd_type = tok_embd_type; + out.embedder.n_embd = (int64_t)n_embd; + out.embedder.n_vocab = (int64_t)n_vocab; + out.embedder.row_bytes = tok_embd_sz / (size_t)n_vocab; + mm.release(); + + char summary[256]; + std::snprintf(summary, sizeof(summary), + "gemma4 target loaded: n_layer=%u n_embd=%u n_ff=%u n_expert=%u " + "n_kv_slots=%d n_kv_shared=%u, %zu GPU tensors %.2f GiB, " + "tok_embd %.0f MiB GPU+CPU-mmap (%s, tied LM head)", + n_layer, n_embd, n_ff, n_expert, n_kv_slots, n_kv_shared, + slots.size(), (double)gpu_bytes_uploaded / (1024.0 * 1024.0 * 1024.0), + (double)tok_embd_sz / (1024.0 * 1024.0), ggml_type_name(tok_embd_type)); + set_last_error(summary); + + return true; +} + +// ─── load_gemma4_mtp_assistant ─────────────────────────────────────────────── +// +// Loads a Gemma4 MTP assistant GGUF (gemma4_assistant architecture) into +// MtpDrafterWeights. The loader: +// 1. Reads metadata: n_embd_backbone, attention_k_eq_v, n_centroids, etc. +// 2. Reads per-MTP-layer SWA type from gemma4_assistant.attention.sliding_window_pattern. +// 3. Resolves each MTP layer's donor_target_layer = LAST target layer whose +// SWA type matches that MTP layer's SWA type, assuming Dense 31B: +// 60 target layers, alternating pattern (odd-indexed = SWA, even = full attn). +// 4. Loads all tensors into a GPU backend buffer. +// +// Tensor names follow llama.cpp's gemma4-assistant.cpp conventions: +// mtp.pre_projection.weight [2*n_bb, n_embd] +// mtp.post_projection.weight [n_embd, n_bb] +// output_norm.weight [n_embd] +// blk.{i}.attn_norm.weight [n_embd] +// blk.{i}.attn_q.weight [n_embd, n_head*head_dim] +// blk.{i}.attn_q_norm.weight [head_dim] +// blk.{i}.attn_output.weight [n_head*head_dim, n_embd] +// blk.{i}.post_attention_norm.weight [n_embd] +// blk.{i}.ffn_norm.weight [n_embd] +// blk.{i}.ffn_gate.weight [n_embd, n_ff] +// blk.{i}.ffn_up.weight [n_embd, n_ff] +// blk.{i}.ffn_down.weight [n_ff, n_embd] +// blk.{i}.post_ffw_norm.weight [n_embd] +// blk.{i}.layer_output_scale.weight [1] (optional) +// +// Metadata keys (prefix = "gemma4_assistant"): +// gemma4_assistant.n_embd_backbone u32 +// gemma4_assistant.n_centroids u32 +// gemma4_assistant.centroid_top_k u32 +// gemma4_assistant.attention.k_eq_v bool +// gemma4_assistant.use_ordered_embeddings bool +// gemma4_assistant.requires_target_arch string + +bool load_gemma4_mtp_assistant(const std::string & gguf_path, + ggml_backend_t backend, + MtpDrafterWeights & out) { + + // ── 1. Open GGUF and read metadata ──────────────────────────────────────── + + ggml_context * meta_ctx = nullptr; + gguf_init_params gip{}; + gip.no_alloc = true; + gip.ctx = &meta_ctx; + gguf_context * gctx = gguf_init_from_file(gguf_path.c_str(), gip); + if (!gctx) { + set_last_error("load_gemma4_mtp_assistant: gguf_init_from_file failed: " + gguf_path); + return false; + } + + // Validate architecture string. + { + int64_t arch_id = gguf_find_key(gctx, "general.architecture"); + if (arch_id < 0) { + set_last_error("load_gemma4_mtp_assistant: missing general.architecture"); + gguf_free(gctx); + return false; + } + const char * arch = gguf_get_val_str(gctx, arch_id); + if (std::string(arch) != "gemma4_assistant") { + set_last_error(std::string("load_gemma4_mtp_assistant: unexpected arch: ") + + arch + " (expected gemma4_assistant)"); + gguf_free(gctx); + return false; + } + } + + // Read MTP-specific metadata. + const uint32_t n_embd = get_u32_or(gctx, "gemma4_assistant.embedding_length", 0); + const uint32_t n_embd_backbone = get_u32_or(gctx, "gemma4_assistant.n_embd_backbone", 0); + const uint32_t n_centroids = get_u32_or(gctx, "gemma4_assistant.n_centroids", 0); + const uint32_t centroid_top_k = get_u32_or(gctx, "gemma4_assistant.centroid_top_k", 0); + bool attention_k_eq_v = false; + bool use_ordered_embeddings = false; + std::string requires_target_arch; + { + int64_t kid = gguf_find_key(gctx, "gemma4_assistant.attention.k_eq_v"); + if (kid >= 0) attention_k_eq_v = gguf_get_val_bool(gctx, kid); + } + { + int64_t kid = gguf_find_key(gctx, "gemma4_assistant.use_ordered_embeddings"); + if (kid >= 0) use_ordered_embeddings = gguf_get_val_bool(gctx, kid); + } + { + int64_t kid = gguf_find_key(gctx, "gemma4_assistant.requires_target_arch"); + if (kid >= 0) requires_target_arch = gguf_get_val_str(gctx, kid); + } + + // Validate n_embd_backbone. + if (n_embd_backbone == 0) { + set_last_error("load_gemma4_mtp_assistant: missing or zero gemma4_assistant.n_embd_backbone"); + gguf_free(gctx); + return false; + } + + // Validate requires_target_arch. + if (requires_target_arch != "gemma4") { + set_last_error(std::string("load_gemma4_mtp_assistant: requires_target_arch='") + + requires_target_arch + "' expected 'gemma4'"); + gguf_free(gctx); + return false; + } + + // Read MTP model's own layer count and SWA pattern. + const uint32_t n_mtp_layer = get_u32_or(gctx, "gemma4_assistant.block_count", 4); + + std::vector mtp_swa_layers(n_mtp_layer, false); + { + int64_t swa_arr_id = gguf_find_key(gctx, "gemma4_assistant.attention.sliding_window_pattern"); + if (swa_arr_id >= 0) { + size_t arr_n = gguf_get_arr_n(gctx, swa_arr_id); + enum gguf_type arr_type = gguf_get_arr_type(gctx, swa_arr_id); + const void * arr_data = gguf_get_arr_data(gctx, swa_arr_id); + for (size_t i = 0; i < arr_n && i < (size_t)n_mtp_layer; i++) { + if (arr_type == GGUF_TYPE_BOOL || arr_type == GGUF_TYPE_INT8 || arr_type == GGUF_TYPE_UINT8) { + mtp_swa_layers[i] = (((const uint8_t *)arr_data)[i] != 0); + } else { + mtp_swa_layers[i] = (((const int32_t *)arr_data)[i] != 0); + } + } + } + // If absent, default all MTP layers to non-SWA (full attention). + } + + // ── 2. Resolve donor_target_layer per MTP layer ─────────────────────────── + // + // Per atomicbot's gemma4-assistant.cpp:12-27 + 126: + // For each MTP layer il, find the LAST target layer whose SWA type == mtp_swa_layers[il]. + // We assume Dense 31B target: 60 layers, alternating (odd-indexed = SWA, even = full attn). + // This matches the fallback in load_gemma4_target_gguf when no swa pattern key is found. + + const int target_n_layer = 60; // Dense 31B + // Build target SWA pattern: odd = SWA, even = full. + std::vector target_swa(target_n_layer, false); + for (int il = 0; il < target_n_layer; il++) { + target_swa[il] = ((il % 2) == 1); + } + + std::vector donor_per_mtp_layer(n_mtp_layer, -1); + for (uint32_t mil = 0; mil < n_mtp_layer; mil++) { + bool want_swa = mtp_swa_layers[mil]; + int32_t best = -1; + for (int til = 0; til < target_n_layer; til++) { + if (target_swa[til] == want_swa) { + best = til; + } + } + donor_per_mtp_layer[mil] = best; + } + + // ── 3. Wire tensor pointers ─────────────────────────────────────────────── + + auto g = [&](const char * name) -> ggml_tensor * { + return ggml_get_tensor(meta_ctx, name); + }; + + // Global tensors. + ggml_tensor * pre_proj = g("mtp.pre_projection.weight"); + ggml_tensor * post_proj = g("mtp.post_projection.weight"); + ggml_tensor * out_norm = g("output_norm.weight"); + // Token embedding (tied LM head for the MTP model). Used by the centroid + // LM head for get_rows(tok_embd, candidate_ids) → mul_mat(·, h_inner). + // Optional: absent in stripped GGUFs; graph falls back gracefully. + ggml_tensor * tok_embd_t = g("token_embd.weight"); + // Assistant's own RoPE per-dim freq factors (top-level tensor, used for + // proportional RoPE on the full-attn MTP layer's Q rotation). The assistant + // was trained with ITS OWN rope_freqs which may differ from target's. + ggml_tensor * rope_freqs_t = g("rope_freqs.weight"); + + if (!pre_proj || !post_proj || !out_norm) { + char buf[256]; + std::snprintf(buf, sizeof(buf), + "load_gemma4_mtp_assistant: missing global tensors " + "(pre_projection=%s post_projection=%s output_norm=%s)", + pre_proj ? "ok" : "MISSING", + post_proj ? "ok" : "MISSING", + out_norm ? "ok" : "MISSING"); + set_last_error(buf); + gguf_free(gctx); + return false; + } + + // Optional centroid tensors. Load them when n_centroids > 0, regardless of + // use_ordered_embeddings flag — some GGUFs may have the flag wrong while the + // centroid tensors are present. The graph builder decides whether to use them. + ggml_tensor * centroids_t = nullptr; + ggml_tensor * token_ordering_t = nullptr; + if (n_centroids > 0) { + centroids_t = g("mtp.centroids.weight"); + token_ordering_t = g("mtp.token_ordering.weight"); + if (use_ordered_embeddings && !centroids_t) { + set_last_error("load_gemma4_mtp_assistant: use_ordered_embeddings=true but mtp.centroids.weight missing"); + gguf_free(gctx); + return false; + } + // centroids/token_ordering are optional when use_ordered_embeddings=false + // (may be present anyway for future use). + } + + // Per-layer tensors. + std::vector mtp_layers(n_mtp_layer); + for (uint32_t il = 0; il < n_mtp_layer; il++) { + char name[160]; + auto fnd = [&](const char * suffix) -> ggml_tensor * { + std::snprintf(name, sizeof(name), "blk.%u.%s", il, suffix); + return ggml_get_tensor(meta_ctx, name); + }; + + MtpLayerWeights & L = mtp_layers[il]; + L.is_swa = mtp_swa_layers[il]; + L.donor_target_layer = donor_per_mtp_layer[il]; + + L.attn_norm = fnd("attn_norm.weight"); + L.wq = fnd("attn_q.weight"); + L.attn_q_norm = fnd("attn_q_norm.weight"); + L.wo = fnd("attn_output.weight"); + L.attn_post_norm = fnd("post_attention_norm.weight"); + L.ffn_norm = fnd("ffn_norm.weight"); + L.ffn_up = fnd("ffn_up.weight"); + L.ffn_gate = fnd("ffn_gate.weight"); + L.ffn_down = fnd("ffn_down.weight"); + L.ffn_post_norm = fnd("post_ffw_norm.weight"); + L.out_scale = fnd("layer_output_scale.weight"); // optional + + // Validate required tensors. + if (!L.attn_norm || !L.wq || !L.attn_q_norm || !L.wo || !L.attn_post_norm || + !L.ffn_norm || !L.ffn_up || !L.ffn_gate || !L.ffn_down || !L.ffn_post_norm) { + char buf[256]; + std::snprintf(buf, sizeof(buf), + "load_gemma4_mtp_assistant: layer %u missing required tensor " + "(attn_norm=%s wq=%s attn_q_norm=%s wo=%s attn_post_norm=%s " + "ffn_norm=%s ffn_up=%s ffn_gate=%s ffn_down=%s ffn_post_norm=%s)", + il, + L.attn_norm ? "ok" : "MISSING", L.wq ? "ok" : "MISSING", + L.attn_q_norm ? "ok" : "MISSING", L.wo ? "ok" : "MISSING", + L.attn_post_norm ? "ok" : "MISSING", + L.ffn_norm ? "ok" : "MISSING", L.ffn_up ? "ok" : "MISSING", + L.ffn_gate ? "ok" : "MISSING", L.ffn_down ? "ok" : "MISSING", + L.ffn_post_norm ? "ok" : "MISSING"); + set_last_error(buf); + gguf_free(gctx); + return false; + } + } + + // ── 4. Allocate GPU buffer ──────────────────────────────────────────────── + + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + const size_t alignment = ggml_backend_buft_get_alignment(buft); + + struct TensorSlot { + ggml_tensor * tensor = nullptr; + size_t file_offset = 0; + size_t file_size = 0; + size_t buf_offset = 0; + }; + + std::vector slots; + size_t total_gpu = 0; + const int64_t n_tensors = gguf_get_n_tensors(gctx); + for (int64_t tid = 0; tid < n_tensors; tid++) { + const char * tname = gguf_get_tensor_name(gctx, tid); + ggml_tensor * t = ggml_get_tensor(meta_ctx, tname); + if (!t) continue; + total_gpu = align_up(total_gpu, alignment); + TensorSlot s; + s.tensor = t; + s.file_offset = gguf_get_data_offset(gctx) + gguf_get_tensor_offset(gctx, tid); + s.file_size = gguf_get_tensor_size(gctx, tid); + s.buf_offset = total_gpu; + total_gpu += ggml_backend_buft_get_alloc_size(buft, t); + slots.push_back(s); + } + + if (slots.empty()) { + set_last_error("load_gemma4_mtp_assistant: no tensors found in GGUF"); + gguf_free(gctx); + return false; + } + + auto cleanup_out = [&]() { + if (out.buffer) { ggml_backend_buffer_free(out.buffer); out.buffer = nullptr; } + if (out.ctx) { ggml_free(out.ctx); out.ctx = nullptr; } + out = MtpDrafterWeights{}; + }; + + out.buffer = ggml_backend_alloc_buffer(backend, total_gpu); + if (!out.buffer) { + set_last_error("load_gemma4_mtp_assistant: ggml_backend_alloc_buffer failed"); + gguf_free(gctx); + cleanup_out(); + return false; + } + ggml_backend_buffer_set_usage(out.buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + + char * base = (char *)ggml_backend_buffer_get_base(out.buffer); + for (const TensorSlot & s : slots) { + if (ggml_backend_tensor_alloc(out.buffer, s.tensor, base + s.buf_offset) != GGML_STATUS_SUCCESS) { + set_last_error("load_gemma4_mtp_assistant: ggml_backend_tensor_alloc failed"); + gguf_free(gctx); + cleanup_out(); + return false; + } + } + + // ── 5. mmap and upload tensors ──────────────────────────────────────────── + + std::string err; + Mmap mm; + if (!mm.open_ro(gguf_path, err)) { + set_last_error(err); + gguf_free(gctx); + cleanup_out(); + return false; + } + + const size_t data_start = gguf_get_data_offset(gctx); + for (int64_t tid = 0; tid < n_tensors; tid++) { + const char * tname = gguf_get_tensor_name(gctx, tid); + ggml_tensor * t = ggml_get_tensor(meta_ctx, tname); + if (!t) continue; + const size_t off = data_start + gguf_get_tensor_offset(gctx, tid); + const size_t sz = gguf_get_tensor_size(gctx, tid); + if (off + sz > mm.len) { + set_last_error(std::string("load_gemma4_mtp_assistant: tensor '") + tname + "' overflows file"); + gguf_free(gctx); + cleanup_out(); + return false; + } + ggml_backend_tensor_set(t, (const uint8_t *)mm.addr + off, 0, sz); + } + + gguf_free(gctx); + + // ── 6. Populate output struct ───────────────────────────────────────────── + + out.ctx = meta_ctx; + out.backend = backend; + out.pre_projection = pre_proj; + out.post_projection = post_proj; + out.output_norm = out_norm; + out.tok_embd = tok_embd_t; + out.rope_freqs = rope_freqs_t; + out.centroids = centroids_t; + out.token_ordering = token_ordering_t; + out.layers = std::move(mtp_layers); + out.n_embd = (int32_t)n_embd; + out.n_embd_backbone = (int32_t)n_embd_backbone; + out.n_centroids = (int32_t)n_centroids; + out.centroid_top_k = (int32_t)centroid_top_k; + out.use_ordered_embeddings = use_ordered_embeddings; + out.attention_k_eq_v = attention_k_eq_v; + out.requires_target_arch = requires_target_arch; + + std::printf("[mtp_loader] loaded: n_embd_backbone=%u n_mtp_layers=%u " + "attention_k_eq_v=%d n_centroids=%u use_ordered_embeddings=%d " + "requires_target_arch=%s tensors=%zu GPU %.2f MiB\n", + n_embd_backbone, n_mtp_layer, + (int)attention_k_eq_v, n_centroids, (int)use_ordered_embeddings, + requires_target_arch.c_str(), + slots.size(), + (double)total_gpu / (1024.0 * 1024.0)); + + for (uint32_t mil = 0; mil < n_mtp_layer; mil++) { + std::printf("[mtp_loader] layer[%u]: is_swa=%d donor_target_layer=%d\n", + mil, (int)out.layers[mil].is_swa, out.layers[mil].donor_target_layer); + } + + return true; +} + +// ─── free_gemma4_mtp_assistant ──────────────────────────────────────────────── + +void free_gemma4_mtp_assistant(MtpDrafterWeights & w) { + if (w.buffer) { ggml_backend_buffer_free(w.buffer); w.buffer = nullptr; } + if (w.ctx) { ggml_free(w.ctx); w.ctx = nullptr; } + w.layers.clear(); + w.pre_projection = nullptr; + w.post_projection = nullptr; + w.output_norm = nullptr; + w.tok_embd = nullptr; + w.centroids = nullptr; + w.token_ordering = nullptr; + w = MtpDrafterWeights{}; +} + +// ─── get_mtp_swa_pattern ────────────────────────────────────────────────────── + +bool get_mtp_swa_pattern(const std::string & gguf_path, + std::vector & out_mtp_swa_layers) { + ggml_context * meta_ctx = nullptr; + gguf_init_params gip{}; + gip.no_alloc = true; + gip.ctx = &meta_ctx; + gguf_context * gctx = gguf_init_from_file(gguf_path.c_str(), gip); + if (!gctx) return false; + + // Validate arch + { + int64_t aid = gguf_find_key(gctx, "general.architecture"); + if (aid < 0) { gguf_free(gctx); if (meta_ctx) ggml_free(meta_ctx); return false; } + if (std::string(gguf_get_val_str(gctx, aid)) != "gemma4_assistant") { + gguf_free(gctx); if (meta_ctx) ggml_free(meta_ctx); return false; + } + } + + const uint32_t n_mtp_layer = get_u32_or(gctx, "gemma4_assistant.block_count", 4); + out_mtp_swa_layers.assign(n_mtp_layer, false); + + int64_t swa_arr_id = gguf_find_key(gctx, "gemma4_assistant.attention.sliding_window_pattern"); + if (swa_arr_id >= 0) { + size_t arr_n = gguf_get_arr_n(gctx, swa_arr_id); + enum gguf_type arr_type = gguf_get_arr_type(gctx, swa_arr_id); + const void * arr_data = gguf_get_arr_data(gctx, swa_arr_id); + for (size_t i = 0; i < arr_n && i < (size_t)n_mtp_layer; i++) { + if (arr_type == GGUF_TYPE_BOOL || arr_type == GGUF_TYPE_INT8 || arr_type == GGUF_TYPE_UINT8) { + out_mtp_swa_layers[i] = (((const uint8_t *)arr_data)[i] != 0); + } else { + out_mtp_swa_layers[i] = (((const int32_t *)arr_data)[i] != 0); + } + } + } + + gguf_free(gctx); + if (meta_ctx) ggml_free(meta_ctx); + return true; +} + +// ─── resolve_mtp_donor_layers ───────────────────────────────────────────────── + +void resolve_mtp_donor_layers(MtpDrafterWeights & mtp, + const std::vector & target_swa_layers) { + const int n_target = (int)target_swa_layers.size(); + for (auto & L : mtp.layers) { + // Find the LAST target layer whose SWA type matches this MTP layer. + bool want_swa = L.is_swa; + int32_t best = -1; + for (int til = 0; til < n_target; ++til) { + if ((int)target_swa_layers.size() > til && target_swa_layers[(size_t)til] == want_swa) { + best = til; + } + } + L.donor_target_layer = best; + } +} + +// ─── free_gemma4_target_weights ────────────────────────────────────────────── + +void free_gemma4_target_weights(GemmaTargetWeights & w) { + if (w.buf) { ggml_backend_buffer_free(w.buf); w.buf = nullptr; } + if (w.ctx) { ggml_free(w.ctx); w.ctx = nullptr; } + // CpuEmbedder destructor handles the mmap automatically. + w.layers.clear(); + w.tok_embd = nullptr; + w.out_norm = nullptr; + w.output = nullptr; + w.per_layer_tok_embd = nullptr; + w.per_layer_model_proj = nullptr; + w.per_layer_proj_norm = nullptr; + w.swa_layers.clear(); +} + +} // namespace dflash27b diff --git a/dflash/src/internal.h b/dflash/src/internal.h index 7a65b41a3..2ae7e8e5c 100644 --- a/dflash/src/internal.h +++ b/dflash/src/internal.h @@ -23,6 +23,7 @@ #include "gguf.h" #include "dflash27b.h" +#include "gemma4.h" namespace dflash27b { @@ -487,6 +488,464 @@ ggml_tensor * build_qwen35_layer( ggml_tensor * q_tail_capture = nullptr, int q_tail_start = 0); +// ============ Gemma4 Architecture ============ + +struct GemmaTargetLayer { + // Attention (ALL layers are attention in Gemma4) + ggml_tensor * attn_norm = nullptr; + ggml_tensor * wq = nullptr; + ggml_tensor * wk = nullptr; // nullptr for KV-shared layers + ggml_tensor * wv = nullptr; // nullptr for KV-shared layers + ggml_tensor * wo = nullptr; + ggml_tensor * q_norm = nullptr; + ggml_tensor * k_norm = nullptr; // nullptr for KV-shared layers + ggml_tensor * attn_post_norm = nullptr; + + // p-RoPE freq factors (full-attention layers only) + ggml_tensor * rope_freqs = nullptr; + + ggml_tensor * out_scale = nullptr; + + // FFN (SwiGLU) + ggml_tensor * ffn_norm = nullptr; + ggml_tensor * w_gate = nullptr; + ggml_tensor * w_up = nullptr; + ggml_tensor * w_down = nullptr; + ggml_tensor * ffn_post_norm = nullptr; + + // MoE (26B-A4B only) + ggml_tensor * ffn_gate_inp = nullptr; + ggml_tensor * ffn_gate_inp_s = nullptr; + ggml_tensor * ffn_pre_norm_2 = nullptr; + ggml_tensor * ffn_gate_up_exps = nullptr; + ggml_tensor * ffn_down_exps = nullptr; + ggml_tensor * ffn_down_exps_s = nullptr; + ggml_tensor * ffn_post_norm_1 = nullptr; + ggml_tensor * ffn_post_norm_2 = nullptr; + + // Per-Layer Embedding (PLE) + ggml_tensor * ple_inp_gate = nullptr; + ggml_tensor * ple_proj = nullptr; + ggml_tensor * ple_post_norm = nullptr; +}; + +struct GemmaTargetWeights { + ggml_context * ctx = nullptr; + ggml_backend_t backend = nullptr; + ggml_backend_buffer_t buf = nullptr; + CpuEmbedder embedder; + + ggml_tensor * tok_embd = nullptr; + std::vector layers; + ggml_tensor * out_norm = nullptr; + ggml_tensor * output = nullptr; + + // Per-Layer Embedding global tensors + ggml_tensor * per_layer_tok_embd = nullptr; + ggml_tensor * per_layer_model_proj = nullptr; + ggml_tensor * per_layer_proj_norm = nullptr; + + // Architecture metadata (loaded from GGUF) + int n_embd = 4096; + int n_head = 32; + int n_head_kv = 8; // max n_head_kv across layers (used for cache alloc) + int head_dim = 128; // full-attention head dim + int head_dim_swa = 128; // SWA head dim (may differ from head_dim) + std::vector head_kv_per_layer; // per-layer n_head_kv (empty = use n_head_kv for all) + int n_layer = 60; + int n_ff = 16384; + int n_vocab = 262144; + int n_embd_per_layer = 0; + + int swa_window = 1024; + std::vector swa_layers; + + int n_kv_shared_layers = 0; + int n_layer_kv = 0; + + float rope_theta = 1000000.0f; + float rope_theta_swa = 1000000.0f; + + int n_expert = 0; + int n_expert_used = 0; + int n_ff_exp = 0; + + float logit_softcap = 30.0f; + float attn_scale = 1.0f; + + int32_t bos_id = -1; + int32_t eos_id = -1; + int32_t eos_chat_id = -1; + + int n_capture_layers = GEMMA4_DRAFT_N_TARGET_LAYERS; + int capture_layer_ids[GEMMA4_DRAFT_N_TARGET_LAYERS] = {0}; +}; + +struct GemmaTargetCache { + ggml_context * base_ctx = nullptr; + ggml_backend_buffer_t base_buf = nullptr; + ggml_context * rollback_ctx = nullptr; + ggml_backend_buffer_t rollback_buf = nullptr; + ggml_backend_t backend = nullptr; + + int max_ctx = 0; + int swa_ctx_alloc = 0; // Actual KV-slot count for SWA layers (ring-buffer size). + // Derived as min(max_ctx_alloc, swa_window_padded). + // Full-attention layers always use max_ctx_alloc. + int cur_pos = 0; + int last_tok = -1; + + ggml_type kv_k_type = GGML_TYPE_Q8_0; + ggml_type kv_v_type = GGML_TYPE_Q8_0; + + // Per-layer override: if non-empty, use these instead of kv_k_type / kv_v_type. + // Used for asymmetric KV: TQ3_0 on SWA layers, Q8_0 on full-attn layers so + // those layers can ride the pflash block-sparse fast path (which excludes TQ3). + std::vector kv_k_type_per_layer; + std::vector kv_v_type_per_layer; + + std::vector attn_k; + std::vector attn_v; + + std::vector layer_to_kv_idx; + std::vector layer_to_donor_kv; + + ggml_tensor * target_feat = nullptr; + int target_feat_cap = 0; + + // MTP h_prev: last committed token's post-block hidden state from the + // last full-attention layer. Shape [n_embd_backbone, 1] f32. + // Allocated only when MTP is enabled (mtp_h_prev_enabled flag on cache). + // Written by the target graph at the end of every decode step. + ggml_tensor * mtp_h_prev = nullptr; + bool mtp_h_prev_enabled = false; + // Index of the last full-attention layer in the target (Dense 31B = 58). + // Computed once at cache init from w.swa_layers (highest il with swa==false). + int mtp_last_full_layer = -1; + // γ>1 MTP partial-accept correctness: when set to a non-negative value < + // n_tokens, the h_prev capture slices that row instead of the default + // (n_tokens - 1). Sentinel -1 preserves the existing γ=1 behavior. + // Set by the γ>1 driver after greedy match: mtp_h_prev_row = accept_n - 1. + int mtp_h_prev_row = -1; + + // Approach B: when mtp_h_prev_capture_mode == 1, the target graph writes + // all n_tokens rows of post-final-norm hidden into mtp_h_prev_batch + // instead of slicing one row into mtp_h_prev. After verify, the γ>1 + // driver picks the column matching accept_drafts and copies it host-side + // into mtp_h_prev for the next MTP chain to read. No re-capture forward. + // Width = max gamma + 1 = 17 (matches the --gamma CLI cap of 16). + ggml_tensor * mtp_h_prev_batch = nullptr; // [n_embd_backbone, 17] + int mtp_h_prev_capture_mode = 0; // 0 = single-row, 1 = batch + + // Draft KV cache (prefix-direct: projected target features → K/V per layer) + ggml_context * draft_kv_ctx = nullptr; + ggml_backend_buffer_t draft_kv_buf = nullptr; + std::vector draft_k; // [head_dim, n_kv_heads, draft_kv_cap] f32 + std::vector draft_v; // [head_dim, n_kv_heads, draft_kv_cap] f32 + int draft_kv_cap = 0; + int draft_kv_pos = 0; +}; + +struct GemmaGraphInputs { + ggml_tensor * inp_embed = nullptr; + ggml_tensor * positions = nullptr; // [n_tokens] i32 + ggml_tensor * attn_mask = nullptr; + ggml_tensor * swa_mask = nullptr; // sliding-window causal mask (required for ANY SWA dispatch — prefill AND single-token decode) + ggml_tensor * per_layer_inp = nullptr; // PLE pre-computed embeddings + int n_tokens = 0; + int kv_start = 0; + bool capture_layers = false; + int fa_window = 0; + ggml_tensor * parent_ids = nullptr; + // pFlash: when true, full-attention layers use ggml_flash_attn_sparse + // instead of ggml_flash_attn_ext, keeping the single-graph-per-chunk + // architecture while enabling block-sparse attention during prefill. + bool use_pflash = false; + float pflash_alpha = 0.12f; + // When true, slice hidden to the last token before lm_head so the output + // tensor has shape [vocab, 1] instead of [vocab, n_tokens]. + // Only safe for prefill chunks where we discard all but the last logit. + bool last_token_logits_only = false; +}; + +struct GemmaGraphOutputs { + ggml_tensor * logits = nullptr; +}; + +// Gemma4 target loading +bool load_gemma4_target_gguf(const std::string & path, ggml_backend_t backend, + GemmaTargetWeights & out); +void free_gemma4_target_weights(GemmaTargetWeights & w); + +// Gemma4 cache +// extra_q8_layers: additional layer indices to force Q8_0 KV regardless of the +// global kv type (e.g. MTP donor layers that need to avoid the TQ3_0/FWHT mismatch). +bool create_gemma4_cache(const GemmaTargetWeights & w, int max_ctx, + ggml_backend_t backend, GemmaTargetCache & out, + const std::vector & extra_q8_layers = {}, + int target_feat_cap_hint = 0, + bool enable_dflash_capture_overrides = false); +void free_gemma4_cache(GemmaTargetCache & c); +void reset_gemma4_cache(GemmaTargetCache & c); + +// Gemma4 graph +GemmaGraphOutputs build_gemma4_graph(ggml_context * ctx, ggml_cgraph * gf, + const GemmaTargetWeights & w, + GemmaTargetCache & cache, + const GemmaGraphInputs & in); + +// SWA window geometry for a chunk at position kv_start with n_tokens query tokens. +// Returns the triple that build_swa_attn_block uses for the K/V view. +// The mask must be sized [effective_win_len, n_tokens] (both aligned) and filled +// with view-relative indices: mask[q][k_view] where abs_k = abs_win_start + k_view. +struct SwaView { + int abs_win_start; // absolute KV position of view slot 0 + int effective_win_len; // number of valid tokens in the view + int ring_win_start; // ring-buffer modular offset (for graph K view) +}; + +SwaView compute_swa_view(int kv_start, + int n_tokens, + int swa_window, + int swa_ctx_alloc /* ring size */); + + +// ─── Gemma4 Draft weights ───────────────────────────────────────── + +struct GemmaDraftLayer { + ggml_tensor * attn_norm = nullptr; + ggml_tensor * ffn_norm = nullptr; + ggml_tensor * wq = nullptr; + ggml_tensor * wk = nullptr; + ggml_tensor * wv = nullptr; + ggml_tensor * wo = nullptr; + ggml_tensor * q_norm = nullptr; + ggml_tensor * k_norm = nullptr; + ggml_tensor * w_gate = nullptr; + ggml_tensor * w_up = nullptr; + ggml_tensor * w_down = nullptr; +}; + +struct GemmaDraftWeights { + ggml_context * ctx = nullptr; + ggml_backend_t backend = nullptr; + ggml_backend_buffer_t buf = nullptr; + + ggml_tensor * fc = nullptr; // [6*target_hidden, draft_hidden] (ggml ne[0]=6*th, ne[1]=dh) + ggml_tensor * hidden_norm = nullptr; // [draft_hidden] + ggml_tensor * out_norm = nullptr; // [draft_hidden] + ggml_tensor * tok_embd = nullptr; // [draft_hidden, n_vocab] — tied lm_head + + std::vector layers; + std::vector layer_is_swa; + + int n_layer = GEMMA4_DRAFT_LAYERS; // 5 + int n_head = 0; + int n_head_kv = 0; + int head_dim = 128; + int n_embd = 0; // draft hidden size + int n_ff = 0; // draft intermediate size + int n_vocab = GEMMA4_31B_VOCAB; // 262144 + int block_size = GEMMA4_DRAFT_BLOCK_SIZE; // 16 + int n_target_layers = GEMMA4_DRAFT_N_TARGET_LAYERS; // 6 + int target_hidden = 0; // target model hidden dim (4096 for all Gemma4 variants) + float logit_softcap = GEMMA4_LOGIT_SOFTCAP; // 30.0 + float rope_theta = GEMMA4_ROPE_THETA; // 1e6 + int mask_token_id = GEMMA4_31B_DRAFT_MASK_TOKEN_ID; // 4 + int sliding_window = 2048; +}; + +// ─── Gemma4 MTP (Multi-Token Prediction) assistant weights ─────────────────── +// +// Loaded from a gemma4_assistant GGUF (e.g. gemma-4-31B-it-assistant.Q4_K_M.gguf). +// These are the 4 cross-attention transformer blocks that run after the target +// model's forward pass to predict the next speculative token. + +struct MtpLayerWeights { + // Q-only attention (no wk/wv — V is always read from the donor target KV cache; + // attention_k_eq_v=true means V stored as rms-normed non-rotated K, so MTP + // MUST read V from cache, not reuse K. use_k_as_v=false hardcoded per + // atomicbot:gemma4-assistant.cpp:134). + ggml_tensor * attn_norm = nullptr; // [n_embd] + ggml_tensor * wq = nullptr; // [n_embd, n_head * head_dim] + ggml_tensor * attn_q_norm = nullptr; // [head_dim] + ggml_tensor * wo = nullptr; // [n_head * head_dim, n_embd] + ggml_tensor * attn_post_norm = nullptr; // [n_embd] + ggml_tensor * ffn_norm = nullptr; // [n_embd] + ggml_tensor * ffn_up = nullptr; // [n_embd, n_ff] + ggml_tensor * ffn_gate = nullptr; // [n_embd, n_ff] + ggml_tensor * ffn_down = nullptr; // [n_ff, n_embd] + ggml_tensor * ffn_post_norm = nullptr; // [n_embd] + ggml_tensor * out_scale = nullptr; // [1] optional; nullptr if absent + // Donor target layer resolved per-MTP-layer: LAST target layer whose + // attention type (SWA vs full) matches this MTP layer's type. + int32_t donor_target_layer = -1; + bool is_swa = false; // this MTP layer's attention type +}; + +struct MtpDrafterWeights { + // Pre/post projection (concat tok_emb + h_prev → n_embd, and back) + ggml_tensor * pre_projection = nullptr; // [2*n_embd_backbone, n_embd] + ggml_tensor * post_projection = nullptr; // [n_embd, n_embd_backbone] + ggml_tensor * output_norm = nullptr; // [n_embd] + // Token embedding (shared / tied LM head for the MTP assistant model). + // Used ONLY in the centroid-routed LM head (get_rows + mul_mat) and in + // the dense fallback. This is the MTP model's own embedding, NOT the + // target's tok_embd (which is used only for the step-1 input embedding). + // Loaded from "token_embd.weight" in the assistant GGUF. + // nullptr if absent (some stripped GGUFs omit it; dense path then uses + // target.tok_embd projected through h_post). + ggml_tensor * tok_embd = nullptr; // [n_embd, n_vocab] + // Per-dim RoPE freq factors (assistant's own; for proportional RoPE on full-attn MTP layer). + // Loaded from "rope_freqs.weight" in the assistant GGUF (top-level, NOT per-layer). + // nullptr if absent (legacy GGUFs); MTP graph then falls back to target's per-layer rope_freqs. + ggml_tensor * rope_freqs = nullptr; // [head_dim/2] f32 + // Optional centroid head (Edge models only; nullptr for Dense 31B) + ggml_tensor * centroids = nullptr; // [n_embd, n_centroids] + ggml_tensor * token_ordering = nullptr; // [n_vocab] I32 invariant if present + // MTP transformer layers (always 4 per atomicbot spec) + std::vector layers; + // Metadata + int32_t n_embd = 0; // MTP model's own hidden size (e.g. 1024 for compressed MTP) + int32_t n_embd_backbone = 0; // target backbone hidden size (must match target's n_embd) + int32_t n_centroids = 0; + int32_t centroid_top_k = 0; + bool use_ordered_embeddings = false; + bool attention_k_eq_v = false; + std::string requires_target_arch; + // Backend that owns the tensors + ggml_backend_t backend = nullptr; + ggml_context * ctx = nullptr; + ggml_backend_buffer_t buffer = nullptr; +}; + +// Load Gemma4 MTP assistant weights from a GGUF file. +// The loader reads n_embd_backbone from GGUF metadata and resolves each MTP +// layer's donor target KV layer assuming Dense 31B (60 target layers, alternating +// SWA pattern: odd-indexed = SWA, even-indexed = full attention). +bool load_gemma4_mtp_assistant(const std::string & gguf_path, + ggml_backend_t backend, + MtpDrafterWeights & out); + +void free_gemma4_mtp_assistant(MtpDrafterWeights & w); + +// Read only the MTP SWA layer pattern from the GGUF (lightweight — no tensor loading). +// Returns false if the GGUF can't be opened or lacks the required architecture. +// out_mtp_swa_layers[il] = true if MTP layer il uses sliding-window attention. +bool get_mtp_swa_pattern(const std::string & gguf_path, + std::vector & out_mtp_swa_layers); + +// Re-resolve MTP donor layers using the actual target SWA pattern instead of the +// hardcoded alternating assumption used during loading. Call this after both the +// target model and MTP assistant are loaded, passing the target's swa_layers vector. +// Each MTP layer's donor_target_layer is updated to the LAST target layer whose +// SWA type matches the MTP layer's SWA type per the provided pattern. +void resolve_mtp_donor_layers(MtpDrafterWeights & mtp, + const std::vector & target_swa_layers); + +// ─── Gemma4 MTP step graph ──────────────────────────────────────────────────── +// +// Build a single MTP step graph that maps: +// inputs: in_tok (i32 [1]) — last token id +// in_h_prev (f32 [n_embd_backbone, 1]) — last target full-attn hidden +// in_pos (i32 [1]) — absolute target position for RoPE +// outputs: out_logits (f32 [n_vocab, 1]) — full vocab row +// out_h_post (f32 [n_embd_backbone, 1]) — next h_prev +// out_argmax (i32 [1]) — greedy token (in-graph argmax) +// +// Each MTP layer reads target K/V from w.layers[il].donor_target_layer +// (resolved at load time). V always read from cache (attention_k_eq_v quirk). +// KV mask is nullptr: all committed positions ≤ attn_pos are uniformly admitted. +// +// attn_pos is the number of committed target tokens (cache.cur_pos at call time). +// The caller passes it separately because the graph is rebuilt per-step in the +// chained γ loop (attn_pos is constant across steps, pos advances per step). +struct MtpStepGraph { + ggml_context * ctx = nullptr; + ggml_cgraph * gf = nullptr; + // Inputs (caller sets via ggml_backend_tensor_set before each step) + ggml_tensor * in_tok = nullptr; // I32 [1] — the token id (unused in graph; kept for API compat) + ggml_tensor * in_tok_embd = nullptr; // F32 [n_embd_backbone, 1] — pre-dequantised embedding + ggml_tensor * in_h_prev = nullptr; + ggml_tensor * in_pos = nullptr; + // Single FA mask shared across all MTP layers that need padding (currently + // every TQ3_0 layer with non-256-aligned kv_view_len, and every head_dim≥512 + // layer with non-256-aligned kv_view_len). The builder asserts at compile + // time that every need-mask layer wants the same `(width, kv_seq_len)`; if + // they ever diverge (e.g. SWA window cap < full-attn pos in long context) + // the assert fires and the builder must be extended to per-layer masks. + // Caller must fill before each compute: + // positions [0..fa_mask_kv_seq_len-1]: 0x0000 (F16 0.0 = admit) + // positions [fa_mask_kv_seq_len..width-1]: 0xFC00 (F16 -inf = exclude) + ggml_tensor * fa_mask = nullptr; // F16 [width, 1] or null + int64_t fa_mask_kv_seq_len = 0; + // Outputs (caller reads via ggml_backend_tensor_get after compute) + ggml_tensor * out_logits = nullptr; + ggml_tensor * out_h_post = nullptr; + ggml_tensor * out_argmax = nullptr; +}; + +// Build the MTP step graph. attn_pos = cache.cur_pos at submit time. +// Returns false and sets last_error on failure. +bool build_mtp_step_graph(const MtpDrafterWeights & w, + const GemmaTargetCache & target_cache, + const GemmaTargetWeights & target, + MtpStepGraph & out, + int attn_pos); + +// Free the ggml context owned by the graph (tensors only; backend buffers +// for KV views are owned by target_cache and must not be freed here). +void free_mtp_step_graph(MtpStepGraph & g); + +// Load Gemma4 DFlash draft weights from a directory containing safetensors shards. +bool load_gemma4_draft_safetensors(const std::string & dir_path, + ggml_backend_t backend, + GemmaDraftWeights & out); + +// Load Gemma4 DFlash draft weights from a Q8_0-quantized GGUF file. +bool load_gemma4_draft_gguf(const std::string & path, + ggml_backend_t backend, + GemmaDraftWeights & out); + +void free_gemma4_draft_weights(GemmaDraftWeights & w); + +// Allocate draft KV cache tensors on the given backend. +bool create_draft_kv_cache(const GemmaDraftWeights & dw, + ggml_backend_t backend, + GemmaTargetCache & cache, + int cap_override = 0); +void free_draft_kv_cache(GemmaTargetCache & cache); + +// Build graph that projects target features → draft KV cache (prefix-direct). +// Materializes K,V for n_tokens new positions starting at cache.draft_kv_pos. +// target_feat [6*target_hidden, n_tokens] f32 +// positions [n_tokens] i32 (absolute positions for RoPE) +ggml_tensor * build_draft_kv_prefill_graph( + ggml_context * ctx, + ggml_cgraph * gf, + const GemmaDraftWeights & w, + GemmaTargetCache & cache, + ggml_tensor * target_feat, + ggml_tensor * positions, + int n_tokens); + +// Build the Gemma4 draft model forward graph with KV cache attention. +// draft_embed [draft_hidden, n_tokens] f32 (MASK token embeddings) +// positions [n_tokens] i32 (absolute positions) +// attn_mask [kv_pad, q_pad] f16 (causal over context+block) +// kv_start = cache.draft_kv_pos (context length before this block) +// Returns logits [n_vocab, n_tokens] f32 (softcapped). +ggml_tensor * build_gemma4_draft_graph( + ggml_context * ctx, + ggml_cgraph * gf, + const GemmaDraftWeights & w, + GemmaTargetCache & cache, + ggml_tensor * draft_embed, + ggml_tensor * positions, + ggml_tensor * attn_mask, + int n_tokens, + int kv_start); + } // namespace dflash27b #if defined(GGML_USE_CUDA) && !defined(GGML_USE_HIP) diff --git a/dflash/src/pflash_ggml_adapter.cpp b/dflash/src/pflash_ggml_adapter.cpp new file mode 100644 index 000000000..4862d379b --- /dev/null +++ b/dflash/src/pflash_ggml_adapter.cpp @@ -0,0 +1,33 @@ +#include "flashprefill.h" + +// Forward-declare the registration function from ggml-cuda (defined in fattn-sparse.cu). +// No extern "C" — nvcc compiles .cu as C++ and the symbol has C++ linkage. +void ggml_cuda_flash_attn_sparse_set_kernel( + int (*fn)(const void*, const void*, const void*, void*, + int, int, int, int, int, float, float)); + +static int pflash_adapter( + const void * Q, const void * K, const void * V, void * O, + int batch, int seq_len, int n_q_heads, int n_k_heads, int head_dim, + float scale, float alpha) +{ + dflash27b::flashprefill::FlashPrefillConfig cfg; + if (alpha >= 1.0f) { + // alpha >= 1.0 means "select all blocks" — configure for dense attention + cfg.alpha = 0.0f; + cfg.attention_sink = seq_len; // all blocks are "sinks" + cfg.window = seq_len; // window covers everything + cfg.last_n_full = seq_len; // all query blocks attend fully + } else { + cfg.alpha = alpha; + } + return dflash27b::flashprefill::flash_prefill_forward_bf16( + Q, K, V, O, + batch, seq_len, n_q_heads, n_k_heads, head_dim, + scale, cfg); +} + +// Call this once at init time before running any ggml_flash_attn_sparse graphs. +void pflash_register_ggml_kernel() { + ggml_cuda_flash_attn_sparse_set_kernel(&pflash_adapter); +} diff --git a/dflash/src/pflash_ggml_adapter.h b/dflash/src/pflash_ggml_adapter.h new file mode 100644 index 000000000..f3bd5c56e --- /dev/null +++ b/dflash/src/pflash_ggml_adapter.h @@ -0,0 +1,2 @@ +#pragma once +void pflash_register_ggml_kernel(); diff --git a/dflash/test/gemma4/smoke_gemma4_draft_forward.cpp b/dflash/test/gemma4/smoke_gemma4_draft_forward.cpp new file mode 100644 index 000000000..4d963eb9c --- /dev/null +++ b/dflash/test/gemma4/smoke_gemma4_draft_forward.cpp @@ -0,0 +1,310 @@ +// Smoke test: load Gemma4 DFlash draft weights, build a forward graph with +// synthetic inputs, run on CUDA, and validate logits. +// +// Usage: smoke_gemma4_draft_forward + +#include "internal.h" +#include "gemma4.h" + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml-cuda.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace dflash27b; + +static void fail(const char * msg) { + std::fprintf(stderr, "FAIL: %s\n", msg); + std::exit(1); +} + +int main(int argc, char ** argv) { + if (argc < 3) { + std::fprintf(stderr, "usage: %s \n", argv[0]); + return 2; + } + + ggml_backend_t backend = ggml_backend_cuda_init(0); + if (!backend) { std::fprintf(stderr, "cuda init failed\n"); return 1; } + + GemmaDraftWeights dw; + if (!load_gemma4_draft_safetensors(argv[1], backend, dw)) { + std::fprintf(stderr, "load_gemma4_draft_safetensors: %s\n", dflash27b_last_error()); + ggml_backend_free(backend); + return 1; + } + + // Load target to get tok_embd (shared between target and draft for LM head). + // tok_embd is not in the draft safetensors — it must come from the target at runtime. + // The target loader keeps tok_embd CPU-side (CpuEmbedder / mmap) to avoid uploading + // ~400 MiB to VRAM for every inference. For this smoke test we upload it once. + GemmaTargetWeights tw; + if (!load_gemma4_target_gguf(argv[2], backend, tw)) { + std::fprintf(stderr, "load_gemma4_target_gguf: %s\n", dflash27b_last_error()); + free_gemma4_draft_weights(dw); + ggml_backend_free(backend); + return 1; + } + + // tw.tok_embd is metadata-only (data = nullptr); actual bytes live in tw.embedder. + // Allocate a dedicated GPU tensor for tok_embd and upload the quantized bytes. + ggml_context * tok_embd_ctx = nullptr; + ggml_backend_buffer_t tok_embd_buf = nullptr; + { + ggml_init_params ep{}; + ep.mem_size = ggml_tensor_overhead() * 2; + ep.mem_buffer = nullptr; + ep.no_alloc = true; + tok_embd_ctx = ggml_init(ep); + if (!tok_embd_ctx) { + std::fprintf(stderr, "ggml_init for tok_embd failed\n"); + free_gemma4_target_weights(tw); + free_gemma4_draft_weights(dw); + ggml_backend_free(backend); + return 1; + } + + const ggml_type emb_type = tw.embedder.tok_embd_type; + const int64_t n_embd_t = tw.embedder.n_embd; + const int64_t n_vocab_t = tw.embedder.n_vocab; + + // ggml convention: ne[0] = n_embd (fast axis), ne[1] = n_vocab + ggml_tensor * te = ggml_new_tensor_2d(tok_embd_ctx, emb_type, n_embd_t, n_vocab_t); + ggml_set_name(te, "tok_embd_gpu"); + + tok_embd_buf = ggml_backend_alloc_ctx_tensors(tok_embd_ctx, backend); + if (!tok_embd_buf) { + std::fprintf(stderr, "ggml_backend_alloc_ctx_tensors for tok_embd failed\n"); + ggml_free(tok_embd_ctx); + free_gemma4_target_weights(tw); + free_gemma4_draft_weights(dw); + ggml_backend_free(backend); + return 1; + } + + const size_t emb_bytes = (size_t)tw.embedder.row_bytes * (size_t)n_vocab_t; + ggml_backend_tensor_set(te, tw.embedder.tok_embd_bytes, 0, emb_bytes); + std::printf("[tok_embd] uploaded %.1f MiB to GPU (%s [%" PRId64 ", %" PRId64 "])\n", + (double)emb_bytes / (1024.0 * 1024.0), + ggml_type_name(emb_type), n_embd_t, n_vocab_t); + + dw.tok_embd = te; + dw.n_vocab = (int)n_vocab_t; + } + + std::printf("[draft] n_layer=%d n_head=%d n_embd=%d n_vocab=%d target_hidden=%d\n", + dw.n_layer, dw.n_head, dw.n_embd, dw.n_vocab, dw.target_hidden); + + const int n_tokens = 16; // one block + const int target_feat_w = dw.n_target_layers * dw.target_hidden; // 6*4096 = 24576 + const int draft_hidden = dw.n_embd; + const int n_vocab = dw.n_vocab; + const int kq_mask_pad = 32; + + auto align_up = [](int x, int a) { return ((x + a - 1) / a) * a; }; + + // Allocate draft KV cache + GemmaTargetCache cache; + cache.backend = backend; + if (!create_draft_kv_cache(dw, backend, cache)) { + std::fprintf(stderr, "create_draft_kv_cache failed\n"); + return 1; + } + std::printf("[draft kv] cap=%d\n", cache.draft_kv_cap); + + // ── Step 1: Prefill draft KV with synthetic target features ────── + // Simulate n_tokens context positions with random target features + { + ggml_init_params ip{}; + ip.mem_size = 256 * 1024 * 1024; + ip.no_alloc = true; + ggml_context * pctx = ggml_init(ip); + if (!pctx) { fail("ggml_init for prefill failed"); } + + ggml_tensor * pf_target_feat = ggml_new_tensor_2d(pctx, GGML_TYPE_F32, target_feat_w, n_tokens); + ggml_tensor * pf_positions = ggml_new_tensor_1d(pctx, GGML_TYPE_I32, n_tokens); + ggml_set_name(pf_target_feat, "pf_target_feat"); + ggml_set_name(pf_positions, "pf_positions"); + ggml_set_input(pf_target_feat); + ggml_set_input(pf_positions); + + ggml_cgraph * pf_gf = ggml_new_graph_custom(pctx, 4096, false); + build_draft_kv_prefill_graph(pctx, pf_gf, dw, cache, + pf_target_feat, pf_positions, n_tokens); + + ggml_gallocr_t pf_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(pf_alloc, pf_gf)) { fail("prefill alloc failed"); } + + std::mt19937 rng_pf(42); + std::uniform_real_distribution u_pf(-0.05f, 0.05f); + { + std::vector data((size_t)target_feat_w * n_tokens); + for (auto & v : data) v = u_pf(rng_pf); + ggml_backend_tensor_set(pf_target_feat, data.data(), 0, sizeof(float) * data.size()); + } + { + std::vector pos(n_tokens); + for (int i = 0; i < n_tokens; i++) pos[i] = i; + ggml_backend_tensor_set(pf_positions, pos.data(), 0, sizeof(int32_t) * n_tokens); + } + + auto st = ggml_backend_graph_compute(backend, pf_gf); + if (st != GGML_STATUS_SUCCESS) { fail("prefill compute failed"); } + cache.draft_kv_pos = n_tokens; + std::printf("[prefill] KV materialized for %d positions\n", n_tokens); + + ggml_gallocr_free(pf_alloc); + ggml_free(pctx); + } + + // ── Step 2: Draft forward with KV cache ────────────────────────── + const int kv_start = cache.draft_kv_pos; // context length = n_tokens + + ggml_init_params ip{}; + ip.mem_size = 256 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + ggml_context * gctx = ggml_init(ip); + if (!gctx) { std::fprintf(stderr, "ggml_init failed\n"); return 1; } + + ggml_tensor * draft_embed = ggml_new_tensor_2d(gctx, GGML_TYPE_F32, draft_hidden, n_tokens); + ggml_tensor * positions = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, n_tokens); + const int kv_len = kv_start + n_tokens; + const int kv_pad = align_up(kv_len, kq_mask_pad); + const int q_pad = align_up(n_tokens, kq_mask_pad); + ggml_tensor * attn_mask = ggml_new_tensor_2d(gctx, GGML_TYPE_F16, kv_pad, q_pad); + + ggml_set_name(draft_embed, "draft_embed"); + ggml_set_name(positions, "positions"); + ggml_set_name(attn_mask, "attn_mask"); + ggml_set_input(draft_embed); + ggml_set_input(positions); + ggml_set_input(attn_mask); + + ggml_cgraph * gf = ggml_new_graph_custom(gctx, 8192, false); + ggml_tensor * logits = build_gemma4_draft_graph( + gctx, gf, dw, cache, + draft_embed, positions, attn_mask, + n_tokens, kv_start); + if (!logits) { std::fprintf(stderr, "build_gemma4_draft_graph returned null\n"); return 1; } + ggml_set_output(logits); + ggml_build_forward_expand(gf, logits); + std::printf("[graph] nodes=%d\n", ggml_graph_n_nodes(gf)); + + ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(alloc, gf)) { + std::fprintf(stderr, "ggml_gallocr_alloc_graph failed\n"); + return 1; + } + + // Fill inputs with deterministic random data + std::mt19937 rng(1234); + std::uniform_real_distribution u(-0.05f, 0.05f); + + // draft_embed: [draft_hidden, 16] f32 + { + std::vector data((size_t)draft_hidden * n_tokens); + for (auto & v : data) v = u(rng); + ggml_backend_tensor_set(draft_embed, data.data(), 0, sizeof(float) * data.size()); + } + // positions: [kv_start, kv_start+1, ..., kv_start+15] + { + std::vector pos(n_tokens); + for (int i = 0; i < n_tokens; i++) pos[i] = kv_start + i; + ggml_backend_tensor_set(positions, pos.data(), 0, sizeof(int32_t) * n_tokens); + } + // attn_mask: causal over full kv_len, block queries attend to all context + causal within block + { + const ggml_fp16_t zero_h = ggml_fp32_to_fp16(0.0f); + const ggml_fp16_t ninf_h = ggml_fp32_to_fp16(-INFINITY); + std::vector mask((size_t)kv_pad * q_pad, ninf_h); + for (int q = 0; q < n_tokens; q++) { + int max_kv = kv_start + q; // attend to all context + block[0..q] + for (int k = 0; k <= max_kv; k++) { + mask[(size_t)q * kv_pad + k] = zero_h; + } + } + ggml_backend_tensor_set(attn_mask, mask.data(), 0, sizeof(ggml_fp16_t) * mask.size()); + } + + // Compute + auto status = ggml_backend_graph_compute(backend, gf); + if (status != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "ggml_backend_graph_compute failed: %d\n", (int)status); + return 1; + } + std::printf("[compute] OK\n"); + + // Validate expected output shape + if (logits->ne[0] != (int64_t)n_vocab || logits->ne[1] != (int64_t)n_tokens) { + char buf[128]; + std::snprintf(buf, sizeof(buf), + "logits shape [%" PRId64 ", %" PRId64 "] expected [%d, %d]", + logits->ne[0], logits->ne[1], n_vocab, n_tokens); + fail(buf); + } + std::printf("[logits] shape: [%" PRId64 ", %" PRId64 "]\n", + logits->ne[0], logits->ne[1]); + + // Read logits for position 0 + std::vector logit_buf((size_t)n_vocab * n_tokens); + ggml_backend_tensor_get(logits, logit_buf.data(), 0, sizeof(float) * logit_buf.size()); + + // Check for NaN and softcap bounds across all positions + int n_nan = 0, n_oob = 0; + float vmin = 1e30f, vmax = -1e30f; + for (auto v : logit_buf) { + if (std::isnan(v)) { n_nan++; continue; } + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + if (v < -30.0f || v > 30.0f) n_oob++; + } + std::printf("[logits] nan=%d oob=%d min=%.4g max=%.4g\n", + n_nan, n_oob, vmin, vmax); + + if (n_nan > 0) fail("NaN values in logits"); + if (n_oob > 0) { + char buf[64]; + std::snprintf(buf, sizeof(buf), + "%d logit values outside [-30, 30] softcap bounds", n_oob); + fail(buf); + } + + // Top-5 tokens for position 0 + const float * pos0_logits = logit_buf.data(); + std::vector> sorted; + sorted.reserve((size_t)n_vocab); + for (int i = 0; i < n_vocab; i++) sorted.emplace_back(pos0_logits[i], i); + std::partial_sort(sorted.begin(), sorted.begin() + 5, sorted.end(), + [](const auto & a, const auto & b) { return a.first > b.first; }); + std::printf("[top 5 pos=0]"); + for (int i = 0; i < 5; i++) { + std::printf(" id=%d l=%.3f", sorted[i].second, sorted[i].first); + } + std::printf("\n"); + + ggml_gallocr_free(alloc); + ggml_free(gctx); + free_draft_kv_cache(cache); + // dw.tok_embd points into tok_embd_ctx/buf — null it before freeing the draft + // so free_gemma4_draft_weights doesn't double-free or access freed memory. + dw.tok_embd = nullptr; + free_gemma4_draft_weights(dw); + // Free tok_embd GPU allocation (must outlive the compute graph). + if (tok_embd_buf) ggml_backend_buffer_free(tok_embd_buf); + if (tok_embd_ctx) ggml_free(tok_embd_ctx); + // Target weights own the mmap that backs tok_embd_bytes; free after GPU upload. + free_gemma4_target_weights(tw); + ggml_backend_free(backend); + std::printf("PASS\n"); + return 0; +} diff --git a/dflash/test/gemma4/smoke_gemma4_target_forward.cpp b/dflash/test/gemma4/smoke_gemma4_target_forward.cpp new file mode 100644 index 000000000..596cd790e --- /dev/null +++ b/dflash/test/gemma4/smoke_gemma4_target_forward.cpp @@ -0,0 +1,198 @@ +// Smoke test: load Gemma4 target, run a single-token forward pass, validate logits. +// +// Usage: smoke_gemma4_target_forward + +#include "internal.h" +#include "gemma4.h" + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml-cuda.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace dflash27b; + +static void fail(const char * msg) { + std::fprintf(stderr, "FAIL: %s\n", msg); + std::exit(1); +} + +int main(int argc, char ** argv) { + if (argc < 2) { + std::fprintf(stderr, "usage: %s \n", argv[0]); + return 2; + } + + ggml_backend_t backend = ggml_backend_cuda_init(0); + if (!backend) { std::fprintf(stderr, "cuda init failed\n"); return 1; } + + // Load target weights + GemmaTargetWeights w; + if (!load_gemma4_target_gguf(argv[1], backend, w)) { + std::fprintf(stderr, "load_gemma4_target_gguf: %s\n", dflash27b_last_error()); + ggml_backend_free(backend); + return 1; + } + std::printf("[target] n_layer=%d n_embd=%d n_vocab=%d\n", + w.n_layer, w.n_embd, w.n_vocab); + + // Create target cache + GemmaTargetCache cache; + const int max_ctx = 512; + if (!create_gemma4_cache(w, max_ctx, backend, cache)) { + std::fprintf(stderr, "create_gemma4_cache: %s\n", dflash27b_last_error()); + free_gemma4_target_weights(w); + ggml_backend_free(backend); + return 1; + } + std::printf("[cache] created max_ctx=%d kv_layers=%zu\n", + cache.max_ctx, cache.attn_k.size()); + + // Build graph context + ggml_init_params ip{}; + ip.mem_size = 512 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + ggml_context * gctx = ggml_init(ip); + if (!gctx) { std::fprintf(stderr, "ggml_init failed\n"); return 1; } + + // Input tensors for a single token at position 0 + const int n_tokens = 1; + const int hidden = w.n_embd; + const int kv_start = 0; + + // Gemma4 uses 1D positions (not M-RoPE with 4 values like Qwen) + ggml_tensor * inp_embed = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, hidden, n_tokens, 1); + ggml_tensor * positions = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, n_tokens); + ggml_set_name(inp_embed, "inp_embed"); + ggml_set_name(positions, "positions"); + ggml_set_input(inp_embed); + ggml_set_input(positions); + + // CUDA flash attention for head_dim>=512 (Gemma4-26B has head_dim=512 on full-attn + // layers) requires a non-null mask so the GQA optimisation path is taken. + // Provide a causal attention mask: shape [kv_len_padded, n_tokens], F32. + // Entries are 0.0 for positions we attend to and -INF for positions we don't. + const int kv_len = kv_start + n_tokens; // 1 + const int kv_len_padded = ((kv_len + 255) / 256) * 256; // 256 + ggml_tensor * attn_mask = ggml_new_tensor_2d(gctx, GGML_TYPE_F16, kv_len_padded, n_tokens); + ggml_set_name(attn_mask, "attn_mask"); + ggml_set_input(attn_mask); + + GemmaGraphInputs gi{}; + gi.inp_embed = inp_embed; + gi.positions = positions; + gi.attn_mask = attn_mask; + gi.n_tokens = n_tokens; + gi.kv_start = kv_start; + gi.capture_layers = true; + + // Build graph + ggml_cgraph * gf = ggml_new_graph_custom(gctx, 16384, false); + GemmaGraphOutputs go = build_gemma4_graph(gctx, gf, w, cache, gi); + if (!go.logits) { std::fprintf(stderr, "build_gemma4_graph returned null logits\n"); return 1; } + ggml_set_output(go.logits); + ggml_build_forward_expand(gf, go.logits); + std::printf("[graph] nodes=%d\n", ggml_graph_n_nodes(gf)); + + // Allocate graph memory + ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(alloc, gf)) { + std::fprintf(stderr, "ggml_gallocr_alloc_graph failed\n"); + return 1; + } + + // Fill causal attention mask (F16). + // mask[k, q] = 0.0 if k <= q (position k is visible from query q) + // = -INF otherwise (masked out / padding) + { + const ggml_fp16_t zero_h = ggml_fp32_to_fp16(0.0f); + const ggml_fp16_t ninf_h = ggml_fp32_to_fp16(-INFINITY); + std::vector mask_data((size_t)kv_len_padded * n_tokens, ninf_h); + for (int q = 0; q < n_tokens; q++) { + for (int k = 0; k <= kv_start + q; k++) { + mask_data[(size_t)q * kv_len_padded + k] = zero_h; + } + } + ggml_backend_tensor_set(attn_mask, mask_data.data(), 0, + sizeof(ggml_fp16_t) * mask_data.size()); + } + + // Embed token id=2 (BOS) using the CPU embedder + int32_t bos_id = 2; + std::vector embed_buf((size_t)hidden * n_tokens); + if (!w.embedder.embed(&bos_id, n_tokens, embed_buf.data())) { + std::fprintf(stderr, "embedder.embed failed\n"); + return 1; + } + ggml_backend_tensor_set(inp_embed, embed_buf.data(), 0, sizeof(float) * embed_buf.size()); + + // Position 0 + int32_t pos0 = 0; + ggml_backend_tensor_set(positions, &pos0, 0, sizeof(int32_t)); + + // Compute + auto status = ggml_backend_graph_compute(backend, gf); + if (status != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "ggml_backend_graph_compute failed: %d\n", (int)status); + return 1; + } + std::printf("[compute] OK\n"); + + // Read logits back + const int64_t vocab = w.n_vocab; + std::vector logits((size_t)vocab); + ggml_backend_tensor_get(go.logits, logits.data(), 0, sizeof(float) * vocab); + + // Check for NaN / Inf and validate softcap bounds + int n_nan = 0, n_inf = 0, n_oob = 0; + float vmin = 1e30f, vmax = -1e30f; + for (auto v : logits) { + if (std::isnan(v)) { n_nan++; continue; } + if (std::isinf(v)) { n_inf++; continue; } + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + // Logit softcap = 30.0 means values are in (-30, 30) + if (v < -30.0f || v > 30.0f) n_oob++; + } + std::printf("[logits] vocab=%" PRId64 " nan=%d inf=%d oob=%d min=%.4g max=%.4g\n", + vocab, n_nan, n_inf, n_oob, vmin, vmax); + + if (n_nan > 0) fail("NaN values in logits"); + if (n_inf > 0) fail("Inf values in logits"); + if (n_oob > 0) { + char buf[64]; + std::snprintf(buf, sizeof(buf), + "%d logit values outside [-30, 30] softcap bounds", n_oob); + fail(buf); + } + + // Print top-5 tokens + std::vector> sorted; + sorted.reserve((size_t)vocab); + for (int i = 0; i < (int)vocab; i++) sorted.emplace_back(logits[i], i); + std::partial_sort(sorted.begin(), sorted.begin() + 5, sorted.end(), + [](const auto & a, const auto & b) { return a.first > b.first; }); + std::printf("[top 5]"); + for (int i = 0; i < 5; i++) { + std::printf(" id=%d l=%.3f", sorted[i].second, sorted[i].first); + } + std::printf("\n"); + + ggml_gallocr_free(alloc); + ggml_free(gctx); + free_gemma4_cache(cache); + free_gemma4_target_weights(w); + ggml_backend_free(backend); + std::printf("PASS\n"); + return 0; +} diff --git a/dflash/test/gemma4/smoke_load_gemma4_draft.cpp b/dflash/test/gemma4/smoke_load_gemma4_draft.cpp new file mode 100644 index 000000000..1cc7d6389 --- /dev/null +++ b/dflash/test/gemma4/smoke_load_gemma4_draft.cpp @@ -0,0 +1,113 @@ +// Smoke test: load Gemma4 DFlash draft weights from a safetensors directory. +// +// Usage: smoke_load_gemma4_draft + +#include "internal.h" +#include "gemma4.h" + +#include "ggml.h" +#include "ggml-backend.h" +#include "ggml-cuda.h" + +#include +#include +#include +#include +#include + +using namespace dflash27b; + +static void fail(const char * msg) { + std::fprintf(stderr, "FAIL: %s\n", msg); + std::exit(1); +} + +int main(int argc, char ** argv) { + if (argc < 2) { + std::fprintf(stderr, "usage: %s \n", argv[0]); + return 2; + } + + ggml_backend_t backend = ggml_backend_cuda_init(0); + if (!backend) { + std::fprintf(stderr, "ggml_backend_cuda_init(0) failed\n"); + return 1; + } + std::printf("cuda backend: %s\n", ggml_backend_name(backend)); + + GemmaDraftWeights dw; + if (!load_gemma4_draft_safetensors(argv[1], backend, dw)) { + std::fprintf(stderr, "load_gemma4_draft_safetensors failed: %s\n", + dflash27b_last_error()); + ggml_backend_free(backend); + return 1; + } + + // Print loaded metadata + std::printf("n_layer=%d n_head=%d n_head_kv=%d head_dim=%d n_embd=%d n_ff=%d n_vocab=%d\n", + dw.n_layer, dw.n_head, dw.n_head_kv, dw.head_dim, + dw.n_embd, dw.n_ff, dw.n_vocab); + std::printf("n_target_layers=%d target_hidden=%d logit_softcap=%.1f\n", + dw.n_target_layers, dw.target_hidden, dw.logit_softcap); + + // Assert expected draft topology + if (dw.n_layer != 5) { + char buf[64]; + std::snprintf(buf, sizeof(buf), "n_layer=%d expected 5", dw.n_layer); + fail(buf); + } + if (dw.n_vocab != 262144) { + char buf[64]; + std::snprintf(buf, sizeof(buf), "n_vocab=%d expected 262144", dw.n_vocab); + fail(buf); + } + if (!dw.fc) fail("fc is null"); + + // Validate fc shape: ne[0] = 6*target_hidden (input features), ne[1] = draft_hidden (output) + // In ggml convention: ne[0] is the fast (inner) dimension of matrix multiply, + // so fc has ne[0]=6*target_hidden and ne[1]=draft_hidden. + const int64_t expected_fc_ne0 = (int64_t)dw.n_target_layers * dw.target_hidden; + std::printf("fc: ne=[%" PRId64 ", %" PRId64 "] type=%s\n", + dw.fc->ne[0], dw.fc->ne[1], + ggml_type_name(dw.fc->type)); + if (dw.fc->ne[0] != expected_fc_ne0) { + char buf[128]; + std::snprintf(buf, sizeof(buf), + "fc->ne[0]=%" PRId64 " expected %" PRId64 " (n_target_layers=%d * target_hidden=%d)", + dw.fc->ne[0], expected_fc_ne0, dw.n_target_layers, dw.target_hidden); + fail(buf); + } + + // Assert layers vector size + if ((int)dw.layers.size() != dw.n_layer) { + char buf[64]; + std::snprintf(buf, sizeof(buf), + "layers.size()=%zu expected %d", dw.layers.size(), dw.n_layer); + fail(buf); + } + + // Spot-check layer 0 key tensors + if (!dw.layers[0].wq) fail("layers[0].wq is null"); + if (!dw.layers[0].wk) fail("layers[0].wk is null"); + if (!dw.layers[0].w_gate) fail("layers[0].w_gate is null"); + + // Print layer 0 shape as spot check + std::printf("layers[0].wq: ne=[%" PRId64 ", %" PRId64 "] type=%s\n", + dw.layers[0].wq->ne[0], dw.layers[0].wq->ne[1], + ggml_type_name(dw.layers[0].wq->type)); + + // Validate hidden_norm and out_norm + if (!dw.hidden_norm) fail("hidden_norm is null"); + if (!dw.out_norm) fail("out_norm is null"); + // tok_embd is NOT loaded from the draft safetensors; it is injected at + // runtime from the target model's token embedding table. + if (dw.tok_embd) fail("tok_embd should be null after loading draft (shared with target)"); + + std::printf("hidden_norm: ne[0]=%" PRId64 " type=%s\n", + dw.hidden_norm->ne[0], ggml_type_name(dw.hidden_norm->type)); + + free_gemma4_draft_weights(dw); + ggml_backend_free(backend); + std::printf("PASS\n"); + return 0; +} diff --git a/dflash/test/gemma4/smoke_load_gemma4_target.cpp b/dflash/test/gemma4/smoke_load_gemma4_target.cpp new file mode 100644 index 000000000..fc70c1dc9 --- /dev/null +++ b/dflash/test/gemma4/smoke_load_gemma4_target.cpp @@ -0,0 +1,115 @@ +// Smoke test: load a Gemma4 target GGUF, validate metadata and tensor shapes. +// +// Usage: smoke_load_gemma4_target + +#include "internal.h" +#include "gemma4.h" + +#include "ggml.h" +#include "ggml-backend.h" +#include "ggml-cuda.h" + +#include +#include +#include +#include +#include + +using namespace dflash27b; + +static void fail(const char * msg) { + std::fprintf(stderr, "FAIL: %s\n", msg); + std::exit(1); +} + +int main(int argc, char ** argv) { + if (argc < 2) { + std::fprintf(stderr, "usage: %s \n", argv[0]); + return 2; + } + + ggml_backend_t backend = ggml_backend_cuda_init(0); + if (!backend) { std::fprintf(stderr, "cuda init failed\n"); return 1; } + + GemmaTargetWeights w; + if (!load_gemma4_target_gguf(argv[1], backend, w)) { + std::fprintf(stderr, "load_gemma4_target_gguf failed: %s\n", dflash27b_last_error()); + ggml_backend_free(backend); + return 1; + } + + // Print architecture metadata + std::printf("hparams: n_layer=%d n_embd=%d n_head=%d n_head_kv=%d head_dim=%d " + "n_vocab=%d n_ff=%d\n", + w.n_layer, w.n_embd, w.n_head, w.n_head_kv, w.head_dim, + w.n_vocab, w.n_ff); + + // Count SWA vs full-attention layers + int n_swa = 0, n_full = 0; + for (int il = 0; il < w.n_layer; il++) { + if (il < (int)w.swa_layers.size() && w.swa_layers[il]) n_swa++; + else n_full++; + } + std::printf("swa_layers: swa=%d full=%d (total=%d)\n", n_swa, n_full, w.n_layer); + + // Print KV-sharing config + std::printf("kv_sharing: n_kv_shared_layers=%d n_layer_kv=%d\n", + w.n_kv_shared_layers, w.n_layer_kv); + + // Print Per-Layer Embedding dimension + std::printf("n_embd_per_layer=%d\n", w.n_embd_per_layer); + + // Print MoE config (0 for dense) + std::printf("moe: n_expert=%d n_expert_used=%d\n", w.n_expert, w.n_expert_used); + + // Print attention config + std::printf("logit_softcap=%.2f attn_scale=%.4f rope_theta=%.0f\n", + w.logit_softcap, w.attn_scale, w.rope_theta); + + // Print captured layer IDs for the DFlash draft + std::printf("capture_layer_ids:"); + for (int i = 0; i < w.n_capture_layers; i++) { + std::printf(" %d", w.capture_layer_ids[i]); + } + std::printf("\n"); + + // Assertions + if (w.n_vocab != 262144) { + char buf[64]; + std::snprintf(buf, sizeof(buf), "n_vocab=%d expected 262144", w.n_vocab); + fail(buf); + } + if (w.logit_softcap != 30.0f) { + char buf[64]; + std::snprintf(buf, sizeof(buf), "logit_softcap=%.2f expected 30.0", w.logit_softcap); + fail(buf); + } + if (w.n_layer_kv <= 0) { + fail("n_layer_kv must be > 0"); + } + if (w.n_layer_kv > w.n_layer) { + char buf[64]; + std::snprintf(buf, sizeof(buf), "n_layer_kv=%d > n_layer=%d", w.n_layer_kv, w.n_layer); + fail(buf); + } + + // Spot-check layer 0 tensors + if (!w.layers[0].wq) fail("layers[0].wq is null"); + if (!w.layers[0].wo) fail("layers[0].wo is null"); + if (!w.layers[0].w_gate) fail("layers[0].w_gate is null"); + + // Spot-check tok_embd and output + if (!w.tok_embd) fail("tok_embd is null"); + if (!w.output) fail("output (lm_head) is null"); + if (!w.out_norm) fail("out_norm is null"); + + std::printf("tok_embd: ne=[%" PRId64 ", %" PRId64 "] type=%s nbytes=%.2f MiB\n", + w.tok_embd->ne[0], w.tok_embd->ne[1], + ggml_type_name(w.tok_embd->type), + ggml_nbytes(w.tok_embd) / (1024.0 * 1024.0)); + + free_gemma4_target_weights(w); + ggml_backend_free(backend); + std::printf("PASS\n"); + return 0; +} diff --git a/dflash/test/gemma4/test_gemma4_dflash.cpp b/dflash/test/gemma4/test_gemma4_dflash.cpp new file mode 100644 index 000000000..eb62e738b --- /dev/null +++ b/dflash/test/gemma4/test_gemma4_dflash.cpp @@ -0,0 +1,2962 @@ +// Gemma4 DFlash speculative decoding end-to-end test / benchmark driver. +// +// Pipeline: +// 1. Load target (Gemma4-31B or 26B-A4B GGUF) + draft (z-lab Gemma4-DFlash +// safetensors directory). +// 2. Prefill: chunked batched forward over prompt tokens (up to swa_window +// tokens per chunk), capture_layers=true so target_feat gets populated. +// 3. Decode loop (until n_predict): +// a. [target-only path, always active] +// Run target forward for last committed token → logits → sample next. +// b. [speculative path, active when draft is loaded] +// i. Get target_feat from cache. +// ii. Run draft model to propose a block of tokens. +// iii. Verify proposals against target in one batched forward. +// iv. Accept longest verified prefix + bonus token, advance cache. +// 4. Print generated text and timing stats. +// +// Usage: +// test_gemma4_dflash --model [--draft ] +// [--prompt ] [--n-predict ] +// [--ctx-size ] [--kv-k ] [--kv-v ] +// [--seed ] [--temp ] [--top-k ] [--top-p ] +// [--budget ] [--gpu ] [--bench] + +#include "internal.h" +#include "dflash27b.h" +#include "gemma4.h" +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml-cuda.h" +#include +#include "../src/pflash_ggml_adapter.h" + +#ifdef _WIN32 +#define setenv(name, value, overwrite) _putenv_s(name, value) +#define unsetenv(name) _putenv_s(name, "") +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#endif + +using namespace dflash27b; + +// Copy n_tokens rows of width feat_w from a BF16 ring-buffer tensor (src_bf16) +// starting at ring slot ring_slot0 into a contiguous F32 tensor (dst_f32). +// Uses ggml_cpy with ggml_view_2d for type conversion on the GPU backend — +// replaces the former dflash27b_launch_bf16_to_f32 custom kernel (f16_convert.cu), +// removed per howard0su's review (r3214289240): ggml_cpy does the same thing. +static void copy_target_feat_bf16_to_f32( + ggml_backend_t backend, + const ggml_tensor * src_bf16, // [feat_w, cap] BF16 (cache.target_feat) + ggml_tensor * dst_f32, // [feat_w, n_tokens] F32 (pkg.target_feat) + int ring_slot0, + int n_tokens, + int feat_w) { + const int cap = (int)src_bf16->ne[1]; + const int pre_n = std::min(n_tokens, cap - ring_slot0); + const int post_n = n_tokens - pre_n; + + ggml_init_params ip{}; + ip.mem_size = 256 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + ggml_context * tmp_ctx = ggml_init(ip); + + ggml_cgraph * gf = ggml_new_graph(tmp_ctx); + + // ggml_view_2d wants non-const but we promise not to mutate the source. + ggml_tensor * src_bf16_nc = const_cast(src_bf16); + // Pre-wrap segment: rows [ring_slot0 .. ring_slot0+pre_n-1] → dst rows [0..pre_n-1] + { + ggml_tensor * s = ggml_view_2d(tmp_ctx, src_bf16_nc, feat_w, pre_n, + src_bf16->nb[1], + (size_t)ring_slot0 * src_bf16->nb[1]); + ggml_tensor * d = ggml_view_2d(tmp_ctx, dst_f32, feat_w, pre_n, + dst_f32->nb[1], 0); + ggml_build_forward_expand(gf, ggml_cpy(tmp_ctx, s, d)); + } + // Post-wrap segment: rows [0..post_n-1] → dst rows [pre_n..pre_n+post_n-1] + if (post_n > 0) { + ggml_tensor * s = ggml_view_2d(tmp_ctx, src_bf16_nc, feat_w, post_n, + src_bf16->nb[1], 0); + ggml_tensor * d = ggml_view_2d(tmp_ctx, dst_f32, feat_w, post_n, + dst_f32->nb[1], + (size_t)pre_n * dst_f32->nb[1]); + ggml_build_forward_expand(gf, ggml_cpy(tmp_ctx, s, d)); + } + + ggml_backend_graph_compute(backend, gf); + ggml_free(tmp_ctx); +} + +// ─── Utilities ──────────────────────────────────────────────────────────── + +static constexpr int KQ_MASK_PAD = 32; +static constexpr uint16_t F16_ZERO = 0x0000; +static constexpr uint16_t F16_NEG_INF = 0xFC00; + +static int g_kq_stride_pad = KQ_MASK_PAD; + +static int align_up(int x, int a) { return ((x + a - 1) / a) * a; } + +static int argmax_f32(const float * x, int n) { + int best = 0; + float bv = x[0]; + for (int i = 1; i < n; i++) { + if (x[i] > bv) { bv = x[i]; best = i; } + } + return best; +} + +// ─── Sampler ────────────────────────────────────────────────────────────── + +struct SamplerCfg { + float temp = 0.0f; + float top_p = 1.0f; + int top_k = 0; + float rep_pen = 1.0f; + int rep_window = 256; + uint64_t seed = 0; +}; + +static int sample_logits(const float * logits_in, + int vocab, + const SamplerCfg & cfg, + const std::vector & history, + std::mt19937_64 & rng) { + if (cfg.temp <= 0.0f) { + return argmax_f32(logits_in, vocab); + } + + std::vector> cand(vocab); + for (int i = 0; i < vocab; i++) cand[i] = {logits_in[i], i}; + + if (cfg.rep_pen > 1.0f && !history.empty()) { + const int win = std::min((int)history.size(), cfg.rep_window); + const int from = (int)history.size() - win; + std::unordered_set seen; + for (int i = from; i < (int)history.size(); i++) seen.insert(history[i]); + for (auto & c : cand) { + if (seen.count(c.second)) { + c.first = (c.first > 0.0f) ? c.first / cfg.rep_pen + : c.first * cfg.rep_pen; + } + } + } + + if (cfg.top_k > 0 && cfg.top_k < vocab) { + std::partial_sort(cand.begin(), cand.begin() + cfg.top_k, cand.end(), + [](const auto & a, const auto & b) { return a.first > b.first; }); + cand.resize(cfg.top_k); + } else { + std::sort(cand.begin(), cand.end(), + [](const auto & a, const auto & b) { return a.first > b.first; }); + } + + const float inv_t = 1.0f / std::max(1e-3f, cfg.temp); + float maxv = cand.front().first * inv_t; + double Z = 0.0; + std::vector probs(cand.size()); + for (size_t i = 0; i < cand.size(); i++) { + probs[i] = std::exp(cand[i].first * inv_t - maxv); + Z += probs[i]; + } + for (auto & p : probs) p = (float)(p / Z); + + if (cfg.top_p > 0.0f && cfg.top_p < 1.0f) { + double cum = 0.0; + size_t cut = probs.size(); + for (size_t i = 0; i < probs.size(); i++) { + cum += probs[i]; + if (cum >= cfg.top_p) { cut = i + 1; break; } + } + probs.resize(cut); + cand.resize(cut); + double zz = 0.0; + for (auto p : probs) zz += p; + for (auto & p : probs) p = (float)(p / zz); + } + + std::uniform_real_distribution u(0.0, 1.0); + double r = u(rng); + double acc = 0.0; + for (size_t i = 0; i < probs.size(); i++) { + acc += probs[i]; + if (r <= acc) return cand[i].second; + } + return cand.back().second; +} + +// ─── Causal mask builder ────────────────────────────────────────────────── + +static void build_causal_mask(std::vector & out, + int kv_len, int n_tokens, int kv_start) { + const int kv_pad = align_up(kv_len, g_kq_stride_pad); + const int q_pad = align_up(n_tokens, KQ_MASK_PAD); + out.assign((size_t)kv_pad * q_pad, F16_NEG_INF); + for (int q = 0; q < n_tokens; q++) { + const int abs_q = kv_start + q; + for (int k = 0; k <= abs_q && k < kv_len; k++) { + out[(size_t)q * kv_pad + k] = F16_ZERO; + } + } +} + +// ─── SWA causal mask builder (for chunked batched prefill) ─────────────────── +// +// Non-monotonic ring mask. The K view is always the full ring (ring_size slots, +// ring_win_start==0). Slot k_view maps to absolute position via: +// latest_slot = (kv_end - 1) % ring_size +// offset_back = (latest_slot - k_view + ring_size) % ring_size +// abs_k = (kv_end - 1) - offset_back +// +// mask[q_idx][k_view_idx] = 0 (attend) iff: +// abs_k >= (abs_q - swa_window + 1) AND abs_k <= abs_q AND abs_k >= 0 +// else -inf. +static void build_swa_causal_mask(std::vector & out, + int kv_start, + int n_tokens, + int swa_window, + int ring_size, // = swa_view.effective_win_len = swa_ctx_alloc + int kv_end) { // = kv_start + n_tokens + const int kv_pad = align_up(ring_size, g_kq_stride_pad); + const int q_pad = align_up(n_tokens, KQ_MASK_PAD); + out.assign((size_t)kv_pad * q_pad, F16_NEG_INF); + const int latest_slot = ((kv_end - 1) % ring_size + ring_size) % ring_size; + for (int q = 0; q < n_tokens; q++) { + const int abs_q = kv_start + q; + const int q_lo = std::max(0, abs_q - swa_window + 1); + for (int k_view = 0; k_view < ring_size; k_view++) { + const int offset_back = (latest_slot - k_view + ring_size) % ring_size; + const int abs_k = (kv_end - 1) - offset_back; + const bool valid = (abs_k >= q_lo && abs_k <= abs_q && abs_k >= 0); + if (valid) { + out[(size_t)q * kv_pad + k_view] = F16_ZERO; + } + } + } +} + +// ─── Per-step graph state (rebuilt each forward pass since kv_len varies) ─ + +struct StepGraph { + ggml_context * ctx = nullptr; + ggml_cgraph * gf = nullptr; + ggml_gallocr_t alloc = nullptr; + ggml_tensor * inp_embed = nullptr; + ggml_tensor * positions = nullptr; + ggml_tensor * attn_mask = nullptr; + ggml_tensor * swa_mask = nullptr; + ggml_tensor * logits = nullptr; +}; + +static void step_graph_free(StepGraph & sg) { + if (sg.ctx) { ggml_free(sg.ctx); sg.ctx = nullptr; } + sg.gf = nullptr; + sg.inp_embed = nullptr; + sg.positions = nullptr; + sg.attn_mask = nullptr; + sg.swa_mask = nullptr; + sg.logits = nullptr; +} + +static void step_graph_destroy(StepGraph & sg) { + if (sg.alloc) { ggml_gallocr_free(sg.alloc); sg.alloc = nullptr; } + step_graph_free(sg); +} + +// ─── Draft step graph state ─────────────────────────────────────────────── + +struct DraftStepGraph { + ggml_context * ctx = nullptr; + ggml_cgraph * gf = nullptr; + ggml_gallocr_t alloc = nullptr; + ggml_tensor * draft_embed = nullptr; + ggml_tensor * positions = nullptr; + ggml_tensor * attn_mask = nullptr; + ggml_tensor * logits = nullptr; +}; + +static void draft_step_free(DraftStepGraph & dsg) { + if (dsg.ctx) { ggml_free(dsg.ctx); dsg.ctx = nullptr; } + dsg.gf = nullptr; + dsg.draft_embed = nullptr; + dsg.positions = nullptr; + dsg.attn_mask = nullptr; + dsg.logits = nullptr; +} + +static void draft_step_destroy(DraftStepGraph & dsg) { + if (dsg.alloc) { ggml_gallocr_free(dsg.alloc); dsg.alloc = nullptr; } + draft_step_free(dsg); +} + +// ─── Draft KV prefill graph state ──────────────────────────────────────────── + +struct DraftKVPrefillGraph { + ggml_context * ctx = nullptr; + ggml_cgraph * gf = nullptr; + ggml_gallocr_t alloc = nullptr; + ggml_tensor * target_feat = nullptr; // input: [6*target_hidden, n_tokens] + ggml_tensor * positions = nullptr; // input: [n_tokens] i32 +}; + +static void draft_kv_prefill_free(DraftKVPrefillGraph & pkg) { + if (pkg.ctx) { ggml_free(pkg.ctx); pkg.ctx = nullptr; } + pkg.gf = nullptr; + pkg.target_feat = nullptr; + pkg.positions = nullptr; +} + +static void draft_kv_prefill_destroy(DraftKVPrefillGraph & pkg) { + if (pkg.alloc) { ggml_gallocr_free(pkg.alloc); pkg.alloc = nullptr; } + draft_kv_prefill_free(pkg); +} + +// Build a single-step target forward graph. +// n_tokens - number of tokens in this forward (1 for decode, >1 for prefill) +// kv_start - index of the first new token in the KV cache +// with_mask - whether to allocate an attention-mask input (required for n_tokens > 1) +// capture - whether to write captured layer features to cache.target_feat +static bool build_gemma4_step(StepGraph & sg, + const GemmaTargetWeights & w, + GemmaTargetCache & cache, + ggml_backend_t backend, + int kv_start, + int n_tokens, + bool with_mask, + bool capture, + bool use_pflash = false, + float pflash_alpha = 0.12f, + int fa_window = 0, + bool last_token_logits_only = false) { + step_graph_free(sg); + + ggml_init_params ip{}; + ip.mem_size = 512 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + sg.ctx = ggml_init(ip); + if (!sg.ctx) return false; + + sg.inp_embed = ggml_new_tensor_3d(sg.ctx, GGML_TYPE_F32, w.n_embd, n_tokens, 1); + ggml_set_name(sg.inp_embed, "inp_embed"); + ggml_set_input(sg.inp_embed); + + sg.positions = ggml_new_tensor_1d(sg.ctx, GGML_TYPE_I32, n_tokens); + ggml_set_name(sg.positions, "positions"); + ggml_set_input(sg.positions); + + if (with_mask) { + const int kv_len = kv_start + n_tokens; + const int kv_pad = align_up(kv_len, g_kq_stride_pad); + const int q_pad = align_up(n_tokens, KQ_MASK_PAD); + + sg.attn_mask = ggml_new_tensor_2d(sg.ctx, GGML_TYPE_F16, kv_pad, q_pad); + ggml_set_name(sg.attn_mask, "attn_mask"); + ggml_set_input(sg.attn_mask); + ggml_set_output(sg.attn_mask); // force gallocr to allocate even if no op references it + + // SWA mask is required for every SWA dispatch — including single-token + // decode (n_tokens==1). When swa_mask is null, gemma4_target_graph falls + // back to attn_mask, which is sized for kv_len rather than the SWA window; + // the resulting dimension mismatch lets FA read past the populated cache + // region and corrupts attention. Catastrophic with TQ3_0 KV (it amplifies + // uninitialized-cache noise into a fixed-point repetition loop), benign + // but technically wrong with Q8_0 KV. + const SwaView swa_view = compute_swa_view(kv_start, n_tokens, + w.swa_window, cache.swa_ctx_alloc); + const int swa_kv_pad = align_up(swa_view.effective_win_len, g_kq_stride_pad); + sg.swa_mask = ggml_new_tensor_2d(sg.ctx, GGML_TYPE_F16, swa_kv_pad, q_pad); + ggml_set_name(sg.swa_mask, "swa_mask"); + ggml_set_input(sg.swa_mask); + ggml_set_output(sg.swa_mask); // force gallocr to allocate even if no op references it + } + + sg.gf = ggml_new_graph_custom(sg.ctx, 16384, false); + + GemmaGraphInputs gi{}; + gi.inp_embed = sg.inp_embed; + gi.positions = sg.positions; + gi.attn_mask = sg.attn_mask; + gi.swa_mask = sg.swa_mask; + gi.n_tokens = n_tokens; + gi.kv_start = kv_start; + gi.capture_layers = capture; + gi.fa_window = fa_window; + gi.use_pflash = use_pflash; + gi.pflash_alpha = pflash_alpha; + gi.last_token_logits_only = last_token_logits_only; + + GemmaGraphOutputs go = build_gemma4_graph(sg.ctx, sg.gf, w, cache, gi); + if (!go.logits) return false; + sg.logits = go.logits; + ggml_set_output(sg.logits); + ggml_build_forward_expand(sg.gf, sg.logits); + + if (!sg.alloc) { + sg.alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + } + return ggml_gallocr_alloc_graph(sg.alloc, sg.gf); +} + +// Build a draft KV prefill graph: project target features → draft KV cache. +static bool build_draft_kv_prefill(DraftKVPrefillGraph & pkg, + const GemmaDraftWeights & dw, + GemmaTargetCache & cache, + ggml_backend_t backend, + int n_tokens) { + // Free previous graph state + if (pkg.ctx) { ggml_free(pkg.ctx); pkg.ctx = nullptr; } + pkg.gf = nullptr; + pkg.target_feat = nullptr; + pkg.positions = nullptr; + + const int target_feat_w = dw.n_target_layers * dw.target_hidden; + + ggml_init_params ip{}; + ip.mem_size = 256 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + pkg.ctx = ggml_init(ip); + if (!pkg.ctx) return false; + + pkg.target_feat = ggml_new_tensor_2d(pkg.ctx, GGML_TYPE_F32, target_feat_w, n_tokens); + ggml_set_name(pkg.target_feat, "prefill_target_feat"); + ggml_set_input(pkg.target_feat); + + pkg.positions = ggml_new_tensor_1d(pkg.ctx, GGML_TYPE_I32, n_tokens); + ggml_set_name(pkg.positions, "prefill_positions"); + ggml_set_input(pkg.positions); + + pkg.gf = ggml_new_graph_custom(pkg.ctx, 4096, false); + + build_draft_kv_prefill_graph(pkg.ctx, pkg.gf, dw, cache, + pkg.target_feat, pkg.positions, n_tokens); + + if (!pkg.alloc) { + pkg.alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + } + return ggml_gallocr_alloc_graph(pkg.alloc, pkg.gf); +} + +// Build a draft model forward graph for one diffusion step. +static bool build_draft_step(DraftStepGraph & dsg, + const GemmaDraftWeights & dw, + GemmaTargetCache & cache, + ggml_backend_t backend, + int n_tokens, + int kv_start) { + draft_step_free(dsg); + + ggml_init_params ip{}; + ip.mem_size = 256 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + dsg.ctx = ggml_init(ip); + if (!dsg.ctx) return false; + + dsg.draft_embed = ggml_new_tensor_2d(dsg.ctx, GGML_TYPE_F32, dw.n_embd, n_tokens); + ggml_set_name(dsg.draft_embed, "draft_embed"); + ggml_set_input(dsg.draft_embed); + + dsg.positions = ggml_new_tensor_1d(dsg.ctx, GGML_TYPE_I32, n_tokens); + ggml_set_name(dsg.positions, "positions"); + ggml_set_input(dsg.positions); + + // Attention mask: block tokens attend to context + block (causal). + const int kv_len = kv_start + n_tokens; + const int kv_pad = align_up(kv_len, KQ_MASK_PAD); + const int q_pad = align_up(n_tokens, KQ_MASK_PAD); + dsg.attn_mask = ggml_new_tensor_2d(dsg.ctx, GGML_TYPE_F16, kv_pad, q_pad); + ggml_set_name(dsg.attn_mask, "draft_attn_mask"); + ggml_set_input(dsg.attn_mask); + + dsg.gf = ggml_new_graph_custom(dsg.ctx, 8192, false); + dsg.logits = build_gemma4_draft_graph( + dsg.ctx, dsg.gf, dw, cache, + dsg.draft_embed, dsg.positions, dsg.attn_mask, + n_tokens, kv_start); + if (!dsg.logits) return false; + ggml_set_output(dsg.logits); + ggml_build_forward_expand(dsg.gf, dsg.logits); + + if (!dsg.alloc) { + dsg.alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + } + return ggml_gallocr_alloc_graph(dsg.alloc, dsg.gf); +} + +// ─── Embed one token into the inp_embed input tensor ───────────────────── + +static bool embed_token(const GemmaTargetWeights & w, + int32_t token_id, + ggml_tensor * inp_embed, + ggml_backend_t backend) { + const int hidden = w.n_embd; + std::vector emb(hidden); + if (!w.embedder.embed(&token_id, 1, emb.data())) { + std::fprintf(stderr, "[embed] failed for token %d\n", token_id); + return false; + } + // inp_embed shape: [hidden, 1, 1] + ggml_backend_tensor_set(inp_embed, emb.data(), 0, sizeof(float) * hidden); + (void)backend; + return true; +} + +// Embed a batch of tokens (for chunked prefill). +static bool embed_tokens_batch(const GemmaTargetWeights & w, + const int32_t * ids, + int n, + ggml_tensor * inp_embed, + ggml_backend_t backend) { + const int hidden = w.n_embd; + std::vector emb((size_t)hidden * n); + if (!w.embedder.embed(ids, n, emb.data())) { + std::fprintf(stderr, "[embed_batch] failed for %d tokens\n", n); + return false; + } + ggml_backend_tensor_set(inp_embed, emb.data(), 0, sizeof(float) * hidden * n); + (void)backend; + return true; +} + +// ─── EOS check ─────────────────────────────────────────────────────────── + +static bool g_ignore_eos = false; + +#define IS_EOS_TOK(tok, w) \ + (!g_ignore_eos && \ + (((w).eos_chat_id >= 0 && (tok) == (w).eos_chat_id) || \ + ((w).eos_id >= 0 && (tok) == (w).eos_id))) + +// ─── KV type resolution helper ─────────────────────────────────────────── + +static ggml_type kv_type_from_string(const std::string & s) { + if (s == "f16") return GGML_TYPE_F16; + if (s == "q8_0") return GGML_TYPE_Q8_0; + if (s == "q4_0") return GGML_TYPE_Q4_0; + if (s == "tq3_0") return GGML_TYPE_TQ3_0; + return GGML_TYPE_Q8_0; // default +} + +// ─── Nanosecond wall clock ──────────────────────────────────────────────── + +static double now_ms() { + return std::chrono::duration( + std::chrono::steady_clock::now().time_since_epoch()).count(); +} + +// ─── Minimal tokenizer stub ────────────────────────────────────────────── +// +// A proper tokenizer (SentencePiece / tiktoken) requires linking to an +// external library. For benchmarking purposes we provide two modes: +// +// 1. Pre-tokenised input via --tokens +// Pass comma-separated integer token IDs directly. This is the +// recommended path for reproducible benchmarks. +// +// 2. Byte-fallback: each byte of the --prompt string becomes one token. +// This is NOT linguistically valid but lets the driver run without any +// tokenizer library. Override with --tokens for real evaluation. + +static std::vector tokenize_byte_fallback(const std::string & text) { + std::vector ids; + ids.reserve(text.size()); + for (unsigned char c : text) { + ids.push_back((int32_t)c); + } + return ids; +} + +static std::vector parse_token_ids(const std::string & s) { + std::vector ids; + const char * p = s.c_str(); + while (*p) { + char * end = nullptr; + long v = std::strtol(p, &end, 10); + if (end == p) break; + ids.push_back((int32_t)v); + if (*end == '\0') break; + if (*end == ',') { p = end + 1; continue; } + break; + } + return ids; +} + +// ─── Binary token file helper (daemon mode) ────────────────────────────── + +static std::vector read_int32_file(const std::string & path) { + std::ifstream f(path, std::ios::binary | std::ios::ate); + if (!f) return {}; + auto sz = (size_t)f.tellg(); + f.seekg(0); + std::vector out(sz / sizeof(int32_t)); + f.read((char *)out.data(), (std::streamsize)sz); + return out; +} + +// Parse optional " samp=temp,top_p,top_k,rep_pen[,seed]" suffix from line. +// Erases the matched suffix from line. Returns true if parsed. +static bool parse_sampler_token(std::string & line, SamplerCfg & out) { + auto pos = line.find(" samp="); + if (pos == std::string::npos) return false; + auto end = line.find(' ', pos + 1); + std::string tok = (end == std::string::npos) + ? line.substr(pos + 6) + : line.substr(pos + 6, end - (pos + 6)); + line.erase(pos, (end == std::string::npos ? std::string::npos : end - pos)); + float t = 0.0f, tp = 1.0f, rp = 1.0f; + int tk = 0; + unsigned long long sd = 0; + int n = std::sscanf(tok.c_str(), "%f,%f,%d,%f,%llu", + &t, &tp, &tk, &rp, &sd); + if (n < 1) return false; + out.temp = t; + out.top_p = tp; + out.top_k = tk; + out.rep_pen = rp; + out.seed = sd; + return true; +} + +// ─── Main ───────────────────────────────────────────────────────────────── + +static void print_usage(const char * prog) { + std::fprintf(stderr, + "usage: %s --model [options]\n" + "\n" + "Options:\n" + " --model path to Gemma4 GGUF (target, required)\n" + " --draft path to z-lab DFlash safetensors directory (optional)\n" + " --prompt input prompt text (default: \"Hello, world!\")\n" + " --tokens comma-separated prompt token IDs (overrides --prompt)\n" + " --tokens-file read comma-separated token IDs from a file (for long prompts)\n" + " --n-predict max tokens to generate (default: 128)\n" + " --ctx-size max context size (default: 4096)\n" + " --kv-k KV cache K type: f16/q8_0/q4_0/tq3_0 (default: q8_0)\n" + " --kv-v KV cache V type: f16/q8_0/q4_0/tq3_0 (default: q8_0)\n" + " --seed RNG seed (default: 0)\n" + " --temp temperature, 0 = greedy (default: 0.0)\n" + " --top-k top-k sampling, 0 = disabled (default: 0)\n" + " --top-p nucleus sampling (default: 1.0)\n" + " --budget DDTree budget for speculative decoding (default: 22)\n" + " --gpu CUDA device index (default: 0)\n" + " --bench benchmark mode: repeat generation, report statistics\n" + " --fa-window sliding attention window for full layers (0 = full, default: 0)\n" + " --pflash use pFlash prefill for prompts >= 4096 tokens\n" + " --pflash-alpha pFlash block selection threshold (default: 0.12)\n" + " --draft-max DFlash draft block cap (0 = model block_size)\n" + " --draft-max-adaptive enable rolling adaptive DFlash draft cap\n" + " --draft-kv-cap override DFlash drafter KV slots\n" + " --draft-swa-trunc enable per-layer SWA truncation in the draft graph\n" + " (also DFLASH_DRAFT_SWA_TRUNC=1; helps long-prompt decode)\n" + " --mem-diag print VRAM checkpoints around major allocations\n" + " --gamma MTP chain length (1=γ=1 correctness gate, 2-16=γ>1 path, default: 1)\n" + " γ>1 requires --draft-method mtp and --temp 0 (greedy only)\n" + " --mtp-pos-mode position_ids within an MTP chain: const|incr (default: const)\n" + " 'const' matches Google's HF reference; 'incr' is for A/B falsification\n" + "\n", + prog); +} + +// Draft method selection +enum class DraftMethod { Auto, None, Dflash, Mtp }; + +static void print_mem_diag(const char * tag) { + size_t free_bytes = 0, total_bytes = 0; + cudaMemGetInfo(&free_bytes, &total_bytes); + const double used_gb = (total_bytes - free_bytes) / (1024.0 * 1024.0 * 1024.0); + const double free_gb = free_bytes / (1024.0 * 1024.0 * 1024.0); + const double total_gb = total_bytes / (1024.0 * 1024.0 * 1024.0); + std::printf("[mem-diag] %-18s used=%.2f GB free=%.2f GB total=%.2f GB\n", + tag, used_gb, free_gb, total_gb); +} + +struct AdaptiveDraftMax { + bool enabled = false; + int current = 0; + int min_q = 1; + int max_q = 0; + int window_steps = 8; + int window_accepted = 0; + int window_capacity = 0; + int window_steps_seen = 0; + + void init(bool on, int initial, int block_size) { + enabled = on; + max_q = block_size; + current = initial > 0 ? std::min(initial, block_size) : block_size; + current = std::max(min_q, current); + } + + void observe(int accepted, int q_len, int step_no) { + if (!enabled) return; + // accept_n includes the pinned current token. Adapt on speculative + // next-token fill so dm=1 does not look artificially perfect. + window_accepted += std::max(0, accepted - 1); + window_capacity += std::max(1, q_len - 1); + window_steps_seen++; + if (window_steps_seen < window_steps || window_capacity <= 0) return; + + const double fill = (double)window_accepted / (double)window_capacity; + const int old = current; + if (fill < 0.35 && current > min_q) { + current = std::max(min_q, current / 2); + } else if (fill > 0.78 && current < max_q) { + current = std::min(max_q, current * 2); + } + if (current != old) { + std::printf("[adaptive] step=%d fill=%.2f draft_max %d -> %d\n", + step_no, fill, old, current); + } else { + std::printf("[adaptive] step=%d fill=%.2f draft_max=%d\n", + step_no, fill, current); + } + window_accepted = 0; + window_capacity = 0; + window_steps_seen = 0; + } +}; + +int main(int argc, char ** argv) { + if (argc < 2) { + print_usage(argv[0]); + return 2; + } + + // ── Parse CLI arguments ─────────────────────────────────────────────── + std::string model_path; + std::string draft_path; + std::string mtp_path; + std::string prompt_text = "Hello, world!"; + std::string token_ids_str; + std::string tokens_file; + int n_predict = 128; + int ctx_size = 4096; + std::string kv_k_str = "q8_0"; + std::string kv_v_str = "q8_0"; + int gpu = 0; + int ddtree_budget = 22; + bool bench_mode = false; + int fa_window = 0; + bool use_pflash = false; + float pflash_alpha = 0.12f; + SamplerCfg sampler; + bool daemon_mode = false; + int stream_fd = -1; + int draft_max = 0; // 0 = use model's block_size (default 16) + bool draft_max_adaptive = false; + int draft_kv_cap_override = 0; + bool mem_diag = false; + DraftMethod draft_method = DraftMethod::Auto; + int gamma = 1; // MTP chain length (1=current correctness gate, >1=Phase 2+3) + int mtp_pos_mode = 0; // 0=const (Google reference), 1=incr (A/B falsification) + + for (int i = 1; i < argc; i++) { + auto require_next = [&](const char * flag) -> const char * { + if (i + 1 >= argc) { + std::fprintf(stderr, "error: %s requires an argument\n", flag); + std::exit(2); + } + return argv[++i]; + }; + + if (std::strcmp(argv[i], "--model") == 0) model_path = require_next("--model"); + else if (std::strcmp(argv[i], "--draft") == 0) draft_path = require_next("--draft"); + else if (std::strcmp(argv[i], "--prompt") == 0) prompt_text = require_next("--prompt"); + else if (std::strcmp(argv[i], "--tokens") == 0) token_ids_str = require_next("--tokens"); + else if (std::strcmp(argv[i], "--tokens-file") == 0) tokens_file = require_next("--tokens-file"); + else if (std::strcmp(argv[i], "--n-predict") == 0) n_predict = std::atoi(require_next("--n-predict")); + else if (std::strcmp(argv[i], "--ctx-size") == 0) ctx_size = std::atoi(require_next("--ctx-size")); + else if (std::strncmp(argv[i], "--ctx-size=", 11) == 0) ctx_size = std::atoi(argv[i] + 11); + else if (std::strcmp(argv[i], "--max-ctx") == 0) ctx_size = std::atoi(require_next("--max-ctx")); + else if (std::strncmp(argv[i], "--max-ctx=", 10) == 0) ctx_size = std::atoi(argv[i] + 10); + else if (std::strcmp(argv[i], "--kv-k") == 0) kv_k_str = require_next("--kv-k"); + else if (std::strcmp(argv[i], "--kv-v") == 0) kv_v_str = require_next("--kv-v"); + else if (std::strcmp(argv[i], "--seed") == 0) sampler.seed = (uint64_t)std::atoll(require_next("--seed")); + else if (std::strcmp(argv[i], "--temp") == 0) sampler.temp = (float)std::atof(require_next("--temp")); + else if (std::strcmp(argv[i], "--ignore-eos")== 0) g_ignore_eos = true; + else if (std::strcmp(argv[i], "--top-k") == 0) sampler.top_k = std::atoi(require_next("--top-k")); + else if (std::strcmp(argv[i], "--top-p") == 0) sampler.top_p = (float)std::atof(require_next("--top-p")); + else if (std::strcmp(argv[i], "--budget") == 0) ddtree_budget = std::atoi(require_next("--budget")); + else if (std::strcmp(argv[i], "--gpu") == 0) gpu = std::atoi(require_next("--gpu")); + else if (std::strcmp(argv[i], "--fa-window") == 0) fa_window = std::atoi(require_next("--fa-window")); + else if (std::strcmp(argv[i], "--bench") == 0) bench_mode = true; + else if (std::strcmp(argv[i], "--daemon") == 0) daemon_mode = true; + else if (std::strcmp(argv[i], "--pflash") == 0) use_pflash = true; + else if (std::strcmp(argv[i], "--pflash-alpha") == 0) pflash_alpha = (float)std::atof(require_next("--pflash-alpha")); + else if (std::strcmp(argv[i], "--draft-max") == 0) draft_max = std::atoi(require_next("--draft-max")); + else if (std::strcmp(argv[i], "--draft-max-adaptive") == 0) draft_max_adaptive = true; + else if (std::strcmp(argv[i], "--draft-kv-cap") == 0) draft_kv_cap_override = std::atoi(require_next("--draft-kv-cap")); + else if (std::strcmp(argv[i], "--draft-swa-trunc") == 0) ::setenv("DFLASH_DRAFT_SWA_TRUNC", "1", 1); + else if (std::strcmp(argv[i], "--mem-diag") == 0) mem_diag = true; + else if (std::strcmp(argv[i], "--mtp") == 0) mtp_path = require_next("--mtp"); + else if (std::strcmp(argv[i], "--gamma") == 0) gamma = std::atoi(require_next("--gamma")); + else if (std::strcmp(argv[i], "--mtp-pos-mode") == 0) { + const char * m = require_next("--mtp-pos-mode"); + if (std::strcmp(m, "const") == 0) mtp_pos_mode = 0; + else if (std::strcmp(m, "incr") == 0) mtp_pos_mode = 1; + else { std::fprintf(stderr, "error: unknown --mtp-pos-mode %s (expected const|incr)\n", m); return 1; } + } + else if (std::strcmp(argv[i], "--draft-method") == 0) { + const char * m = require_next("--draft-method"); + if (std::strcmp(m, "none") == 0) draft_method = DraftMethod::None; + else if (std::strcmp(m, "dflash") == 0) draft_method = DraftMethod::Dflash; + else if (std::strcmp(m, "mtp") == 0) draft_method = DraftMethod::Mtp; + else { std::fprintf(stderr, "error: unknown --draft-method %s\n", m); return 1; } + } + else if (std::strncmp(argv[i], "--stream-fd=", 12) == 0) { + stream_fd = std::atoi(argv[i] + 12); + } + // No-op flags forwarded by server.py for Qwen3 compatibility: + else if (std::strcmp(argv[i], "--fast-rollback") == 0) { /* no-op */ } + else if (std::strcmp(argv[i], "--ddtree") == 0) { /* no-op */ } + else if (std::strncmp(argv[i], "--ddtree-budget=", 16) == 0) { /* no-op */ } + else if (std::strncmp(argv[i], "--ddtree-temp=", 14) == 0) { /* no-op */ } + else if (std::strcmp(argv[i], "--ddtree-no-chain-seed") == 0) { /* no-op */ } + else if (std::strcmp(argv[i], "--help") == 0 || + std::strcmp(argv[i], "-h") == 0) { + print_usage(argv[0]); + return 0; + } else { + std::fprintf(stderr, "warning: unknown argument: %s\n", argv[i]); + } + } + + if (model_path.empty()) { + std::fprintf(stderr, "error: --model is required\n"); + print_usage(argv[0]); + return 2; + } + + // ── Resolve Auto draft method ───────────────────────────────────────── + if (draft_method == DraftMethod::Auto) { + if (!draft_path.empty() && !mtp_path.empty()) { + std::fprintf(stderr, "error: both --draft and --mtp provided; use --draft-method to disambiguate\n"); + return 1; + } else if (!mtp_path.empty()) { + draft_method = DraftMethod::Mtp; + } else if (!draft_path.empty()) { + draft_method = DraftMethod::Dflash; + } else { + draft_method = DraftMethod::None; + } + } + if (draft_method == DraftMethod::Mtp && mtp_path.empty()) { + std::fprintf(stderr, "error: --draft-method mtp requires --mtp \n"); + return 1; + } + if (draft_method == DraftMethod::Dflash && draft_path.empty()) { + std::fprintf(stderr, "error: --draft-method dflash requires --draft \n"); + return 1; + } + + // ── γ>1 MTP plumbing (Phase 1 of wild-growing-ember plan) ──────────── + if (gamma < 1 || gamma > 16) { + std::fprintf(stderr, "error: --gamma must be in [1, 16] (got %d)\n", gamma); + return 1; + } + if (gamma > 1 && draft_method != DraftMethod::Mtp) { + std::fprintf(stderr, "error: --gamma > 1 requires --draft-method mtp\n"); + return 1; + } + if (gamma > 1 && sampler.temp != 0.0f) { + std::fprintf(stderr, "error: --gamma > 1 currently requires greedy decoding (--temp 0); stochastic γ>1 needs Leviathan rescaling and is not yet implemented\n"); + return 1; + } + + const bool have_draft = (draft_method == DraftMethod::Dflash); + const bool have_mtp = (draft_method == DraftMethod::Mtp); + + // ── Load token IDs from file if --tokens-file was specified ────────── + if (!tokens_file.empty()) { + FILE * f = fopen(tokens_file.c_str(), "r"); + if (!f) { + std::fprintf(stderr, "error: cannot open tokens file: %s\n", tokens_file.c_str()); + return 1; + } + fseek(f, 0, SEEK_END); + long sz = ftell(f); + rewind(f); + std::string content(sz, '\0'); + fread(&content[0], 1, sz, f); + fclose(f); + token_ids_str = content; + } + + // ── KV type env vars (consumed by create_gemma4_cache → resolve_kv_types) ─ + setenv("DFLASH27B_KV_K", kv_k_str.c_str(), 1); + setenv("DFLASH27B_KV_V", kv_v_str.c_str(), 1); + + // After argv parsing, the KV type may have been chosen via --kv-k tq3_0 / --kv-v tq3_0, + // which sets DFLASH27B_KV_K / DFLASH27B_KV_V env vars. Re-check for TQ3 here so + // g_kq_stride_pad matches the chunked-FA driver's align_up(kv_len, 256); otherwise the + // host-built mask is short and the kernel reads past its end. + auto kv_env_is_tq3 = [](const char * name) { + const char * s = std::getenv(name); + if (!s) return false; + std::string lc; + for (const char * p = s; *p; ++p) lc += (char)std::tolower((unsigned char)*p); + return lc.rfind("tq3", 0) == 0; + }; + // Note: also need to bump g_kq_stride_pad to 256 when head_dim >= 512 + // (Dense 31B full-attn). That check is deferred until after target + // weights are loaded — see "head_dim mask-pad gate" below. + if (kv_env_is_tq3("DFLASH27B_KV_K") || kv_env_is_tq3("DFLASH27B_KV_V")) { + g_kq_stride_pad = 256; + } + + // ── CUDA device validation ──────────────────────────────────────────── + int cuda_device_count = 0; + cudaGetDeviceCount(&cuda_device_count); + if (gpu >= cuda_device_count) { + std::fprintf(stderr, "error: --gpu %d out of range (device_count=%d)\n", + gpu, cuda_device_count); + return 2; + } + cudaSetDevice(gpu); + + // Detect <=24 GiB CUDA devices and emit a runtime warning if VMM is enabled. + // Note: GGML_CUDA_NO_VMM is compile-time only (CMake option that adds + // compile_definitions). Setting it via setenv() at runtime has no effect on + // ggml-cuda — it's not read via getenv. The real safeguard is to rebuild + // with `cmake -DGGML_CUDA_NO_VMM=ON ..`. + { + int dev_count = 0; + if (cudaGetDeviceCount(&dev_count) == cudaSuccess) { + for (int i = 0; i < dev_count; ++i) { + cudaDeviceProp prop{}; + if (cudaGetDeviceProperties(&prop, i) != cudaSuccess) continue; + const size_t gib = (size_t)(prop.totalGlobalMem / (1ull << 30)); +#ifndef GGML_CUDA_NO_VMM + if (gib <= 24) { + std::fprintf(stderr, + "[dflash] WARNING: detected CUDA device %d (%s) with %zu GiB VRAM.\n" + "[dflash] Long-context prefill on <=24 GiB cards is significantly\n" + "[dflash] slower with VMM enabled. Consider rebuilding with:\n" + "[dflash] cmake -DGGML_CUDA_NO_VMM=ON ..\n", + i, prop.name, gib); + } +#endif + } + } + } + + std::printf("[cfg] model=%s draft=%s method=%s gpu=%d ctx=%d n_predict=%d kv_k=%s kv_v=%s " + "temp=%.2f top_k=%d top_p=%.2f budget=%d bench=%d fa_window=%d " + "draft_max=%d adaptive=%d draft_kv_cap_override=%d pflash=%d pflash_alpha=%.3f\n", + model_path.c_str(), + draft_path.empty() ? "(none)" : draft_path.c_str(), + draft_method == DraftMethod::Dflash ? "dflash" : + draft_method == DraftMethod::Mtp ? "mtp" : + draft_method == DraftMethod::None ? "none" : "auto", + gpu, ctx_size, n_predict, + kv_k_str.c_str(), kv_v_str.c_str(), + sampler.temp, sampler.top_k, sampler.top_p, + ddtree_budget, (int)bench_mode, fa_window, + draft_max, (int)draft_max_adaptive, draft_kv_cap_override, + (int)use_pflash, pflash_alpha); + + // ── Backend init ────────────────────────────────────────────────────── + ggml_backend_t backend = ggml_backend_cuda_init(gpu); + if (!backend) { + std::fprintf(stderr, "error: ggml_backend_cuda_init(%d) failed\n", gpu); + return 1; + } + if (mem_diag) print_mem_diag("after-backend"); + + // Register the pFlash GGML custom kernel so ggml_flash_attn_sparse ops + // dispatched from build_gemma4_graph (full-attention layers, use_pflash=true) + // have a backend implementation available. + if (use_pflash) { + pflash_register_ggml_kernel(); + } + + // ── Load target weights ─────────────────────────────────────────────── + GemmaTargetWeights w; + { + double t0 = now_ms(); + if (!load_gemma4_target_gguf(model_path, backend, w)) { + std::fprintf(stderr, "load_gemma4_target_gguf: %s\n", dflash27b_last_error()); + return 1; + } + double t1 = now_ms(); + std::printf("[target] loaded %d layers, n_embd=%d, vocab=%d (%.1f ms)\n", + w.n_layer, w.n_embd, w.n_vocab, t1 - t0); + if (mem_diag) print_mem_diag("after-target-load"); + } + + // head_dim mask-pad gate: the target graph forces `need_256_pad` on K-view + // when head_dim >= 512 (full-attn layer in Dense 31B), regardless of KV + // dtype. Without bumping g_kq_stride_pad here, the host-built causal mask + // is padded to KQ_MASK_PAD (64) while the K view is padded to 256 — the FA + // kernel reads mask columns past the populated region, attending to + // uninitialised K slots (mask byte often 0x0000 = "attend"). Symptom under + // Q4_0/Q8_0 KV: MTP accept rate collapses to ~0.30 vs ~0.78 expected. + // TQ3 was unaffected only because the kv_env_is_tq3 gate above already + // bumped to 256. + if (w.head_dim >= 512) { + g_kq_stride_pad = 256; + } + + // ── Load draft weights (optional) ──────────────────────────────────── + // Draft state: declared in main scope so they persist across bench iterations + // and are accessible in cleanup. + GemmaDraftWeights dw; + ggml_context * tok_embd_ctx = nullptr; + ggml_backend_buffer_t tok_embd_buf = nullptr; + + if (have_draft) { + double t0 = now_ms(); + // Auto-detect: + // 1. If path ends with .gguf, use GGUF loader directly + // 2. If path is a directory containing draft-q8_0.gguf, prefer it + // (Q8 GGUF is ~2x smaller than the BF16 safetensors and avoids + // a memory-pressure perf trap on Dense + TQ3 KV that drops + // target prefill 20x; see commit notes for details) + // 3. Otherwise fall back to safetensors loader + std::string resolved_draft_path = draft_path; + bool is_gguf = (draft_path.size() >= 5 && + draft_path.compare(draft_path.size() - 5, 5, ".gguf") == 0); + if (!is_gguf) { + // Check if path is a directory with a draft-q8_0.gguf inside + const std::string candidate = draft_path + "/draft-q8_0.gguf"; + std::ifstream probe(candidate.c_str()); + if (probe.good()) { + resolved_draft_path = candidate; + is_gguf = true; + std::fprintf(stderr, + "[draft] auto-selected Q8 GGUF: %s\n" + " (%s also present; Q8 is ~2x smaller and ~20x faster on Dense+TQ3)\n", + candidate.c_str(), + (draft_path + "/model.safetensors").c_str()); + } + } + bool ok = false; + if (is_gguf) { + ok = load_gemma4_draft_gguf(resolved_draft_path, backend, dw); + if (!ok) std::fprintf(stderr, "load_gemma4_draft_gguf: %s\n", dflash27b_last_error()); + } else { + ok = load_gemma4_draft_safetensors(resolved_draft_path, backend, dw); + if (!ok) std::fprintf(stderr, "load_gemma4_draft_safetensors: %s\n", dflash27b_last_error()); + } + if (!ok) return 1; + double t1 = now_ms(); + if (mem_diag) print_mem_diag("after-draft-load"); + + // Upload tok_embd from target embedder to GPU (tied lm_head for draft). + // tw.embedder keeps the bytes CPU-side; we upload once and inject a pointer. + { + ggml_init_params ep{}; + ep.mem_size = ggml_tensor_overhead() * 2; + ep.mem_buffer = nullptr; + ep.no_alloc = true; + tok_embd_ctx = ggml_init(ep); + if (!tok_embd_ctx) { + std::fprintf(stderr, "[draft] ggml_init for tok_embd failed\n"); + return 1; + } + + const ggml_type emb_type = w.embedder.tok_embd_type; + const int64_t n_embd_t = w.embedder.n_embd; + const int64_t n_vocab_t = w.embedder.n_vocab; + + // ggml convention: ne[0] = n_embd (fast axis), ne[1] = n_vocab + ggml_tensor * te = ggml_new_tensor_2d(tok_embd_ctx, emb_type, n_embd_t, n_vocab_t); + ggml_set_name(te, "tok_embd_gpu"); + + tok_embd_buf = ggml_backend_alloc_ctx_tensors(tok_embd_ctx, backend); + if (!tok_embd_buf) { + std::fprintf(stderr, "[draft] ggml_backend_alloc_ctx_tensors for tok_embd failed\n"); + ggml_free(tok_embd_ctx); + tok_embd_ctx = nullptr; + return 1; + } + + const size_t emb_bytes = (size_t)w.embedder.row_bytes * (size_t)n_vocab_t; + ggml_backend_tensor_set(te, w.embedder.tok_embd_bytes, 0, emb_bytes); + std::printf("[tok_embd] uploaded %.1f MiB to GPU (%s [%" PRId64 ", %" PRId64 "])\n", + (double)emb_bytes / (1024.0 * 1024.0), + ggml_type_name(emb_type), n_embd_t, n_vocab_t); + + dw.tok_embd = te; + dw.n_vocab = (int)n_vocab_t; + if (mem_diag) print_mem_diag("after-tok-embd"); + } + + std::printf("[draft] loaded n_layer=%d n_head=%d n_embd=%d n_vocab=%d " + "target_hidden=%d block_size=%d (%.1f ms)\n", + dw.n_layer, dw.n_head, dw.n_embd, dw.n_vocab, + dw.target_hidden, dw.block_size, t1 - t0); + } + + // ── Load MTP weights early when enabled ────────────────────────────── + // Donor target layers must be known before target KV allocation so TQ3 + // donor caches can be forced to Q8_0 and avoid wrap-concat FWHT loss. + MtpDrafterWeights mtp_w; + MtpStepGraph mtp_g; + std::vector mtp_extra_q8_layers; + + if (have_mtp) { + double t0 = now_ms(); + if (!load_gemma4_mtp_assistant(mtp_path, backend, mtp_w)) { + std::fprintf(stderr, "load_gemma4_mtp_assistant: %s\n", dflash27b_last_error()); + return 1; + } + double t1 = now_ms(); + std::printf("[mtp] loaded n_layers=%d n_embd=%d n_embd_backbone=%d (%.1f ms)\n", + (int)mtp_w.layers.size(), mtp_w.n_embd, mtp_w.n_embd_backbone, t1 - t0); + if (mem_diag) print_mem_diag("after-mtp-load"); + + // Re-resolve donor target layers using the actual target SWA pattern. + resolve_mtp_donor_layers(mtp_w, w.swa_layers); + for (const MtpLayerWeights & L : mtp_w.layers) { + if (L.donor_target_layer >= 0 && + std::find(mtp_extra_q8_layers.begin(), mtp_extra_q8_layers.end(), + L.donor_target_layer) == mtp_extra_q8_layers.end()) { + mtp_extra_q8_layers.push_back(L.donor_target_layer); + } + } + } + + // ── Create KV cache ─────────────────────────────────────────────────── + GemmaTargetCache cache; + { + if (mem_diag) print_mem_diag("before-target-kv"); + double t0 = now_ms(); + const int draft_kv_default_cap = have_draft + ? (dw.sliding_window + dw.block_size + 32) + : 0; + const int target_feat_cap_hint = have_draft + ? std::max(draft_kv_default_cap, draft_kv_cap_override) + : 0; + if (!create_gemma4_cache(w, ctx_size, backend, cache, mtp_extra_q8_layers, + target_feat_cap_hint, + /*enable_dflash_capture_overrides=*/have_draft)) { + std::fprintf(stderr, "create_gemma4_cache: %s\n", dflash27b_last_error()); + return 1; + } + double t1 = now_ms(); + std::printf("[cache] created max_ctx=%d, kv_layers=%zu (%.1f ms)\n", + cache.max_ctx, cache.attn_k.size(), t1 - t0); + if (mem_diag) print_mem_diag("after-target-kv"); + } + + // ── Allocate draft KV cache (requires cache to already exist) ───────── + if (have_draft) { + if (mem_diag) print_mem_diag("before-draft-kv"); + if (!create_draft_kv_cache(dw, backend, cache, draft_kv_cap_override)) { + std::fprintf(stderr, "create_draft_kv_cache failed\n"); + return 1; + } + std::printf("[draft] KV cache allocated: %d slots%s\n", + cache.draft_kv_cap, + draft_kv_cap_override > 0 ? " (override)" : ""); + if (mem_diag) print_mem_diag("after-draft-kv"); + } + + // ── MTP state + step graph (optional) ──────────────────────────────── + // mtp_h_prev context/buffer: separate small allocation so base_ctx stays + // unmodified and free_gemma4_cache() doesn't double-free it. + ggml_context * mtp_h_prev_ctx = nullptr; + ggml_backend_buffer_t mtp_h_prev_buf = nullptr; + + if (have_mtp) { + // Allocate mtp_h_prev tensor: [n_embd_backbone, 1] f32, GPU-resident, + // persistent across decode steps. Separate context so free_gemma4_cache + // doesn't free it. + // Also allocate mtp_h_prev_batch [n_embd_backbone, 17] for approach B + // (batch capture of all verify rows; eliminates per-chain re-capture forward). + { + // Two tensors: mtp_h_prev [n_embd, 1] + mtp_h_prev_batch [n_embd, 17]. + const int kBatchCols = 17; + ggml_init_params ep{}; + ep.mem_size = 2 * ggml_tensor_overhead() + 512; + ep.mem_buffer = nullptr; + ep.no_alloc = true; + mtp_h_prev_ctx = ggml_init(ep); + if (!mtp_h_prev_ctx) { + std::fprintf(stderr, "[mtp] ggml_init for mtp_h_prev failed\n"); + return 1; + } + cache.mtp_h_prev = ggml_new_tensor_2d(mtp_h_prev_ctx, + GGML_TYPE_F32, + mtp_w.n_embd_backbone, 1); + ggml_set_name(cache.mtp_h_prev, "mtp_h_prev"); + cache.mtp_h_prev_batch = ggml_new_tensor_2d(mtp_h_prev_ctx, + GGML_TYPE_F32, + mtp_w.n_embd_backbone, kBatchCols); + ggml_set_name(cache.mtp_h_prev_batch, "mtp_h_prev_batch"); + mtp_h_prev_buf = ggml_backend_alloc_ctx_tensors(mtp_h_prev_ctx, backend); + if (!mtp_h_prev_buf) { + std::fprintf(stderr, "[mtp] alloc mtp_h_prev failed\n"); + ggml_free(mtp_h_prev_ctx); mtp_h_prev_ctx = nullptr; + return 1; + } + // Zero-initialize + std::vector zeros_f(mtp_w.n_embd_backbone, 0.0f); + ggml_backend_tensor_set(cache.mtp_h_prev, zeros_f.data(), 0, + sizeof(float) * mtp_w.n_embd_backbone); + } + + // Determine last full-attention layer index from swa_layers + cache.mtp_last_full_layer = -1; + for (int il = w.n_layer - 1; il >= 0; il--) { + const bool is_swa = (il < (int)w.swa_layers.size()) && w.swa_layers[il]; + if (!is_swa) { + cache.mtp_last_full_layer = il; + break; + } + } + if (cache.mtp_last_full_layer < 0) { + std::fprintf(stderr, "[mtp] error: no full-attention layer found in target\n"); + return 1; + } + std::printf("[mtp] mtp_last_full_layer=%d\n", cache.mtp_last_full_layer); + + cache.mtp_h_prev_enabled = true; + + // Build the MTP step graph (attn_pos=0 initially; will be rebuilt per step) + if (!build_mtp_step_graph(mtp_w, cache, w, mtp_g, /*attn_pos=*/0)) { + std::fprintf(stderr, "build_mtp_step_graph: %s\n", dflash27b_last_error()); + return 1; + } + std::printf("[mtp] step graph built ok\n"); + } + + // ── RNG ─────────────────────────────────────────────────────────────── + std::mt19937_64 rng(sampler.seed); + + // ── Daemon mode: stream token fd write helper ───────────────────────── + auto stream_emit = [&](int32_t tok) { + if (stream_fd < 0) return; + int32_t v = tok; +#ifdef _WIN32 + DWORD written; + WriteFile((HANDLE)(intptr_t)stream_fd, &v, sizeof(v), &written, nullptr); +#else + ssize_t n = ::write(stream_fd, &v, sizeof(v)); + (void)n; +#endif + }; + + // ── Daemon mode ─────────────────────────────────────────────────────── + if (daemon_mode) { + std::printf("[daemon] ready\n"); + std::fflush(stdout); + + StepGraph sg; + DraftStepGraph dsg; + bool daemon_first_iter = true; + std::string line; + + while (std::getline(std::cin, line)) { + // Per-request sampler (reset to CLI defaults each request). + SamplerCfg req_sampler = sampler; + parse_sampler_token(line, req_sampler); + // Always reseed per request so requests are independent. + // seed==0 means "random": use std::random_device for a fresh seed. + uint64_t actual_seed = req_sampler.seed; + if (actual_seed == 0) { + actual_seed = std::random_device{}(); + } + rng.seed(actual_seed); + + // ── Unsupported commands: emit -1 sentinel and continue ──────── + auto starts_with = [](const std::string & s, const char * pre) { + size_t n = std::strlen(pre); + return s.size() >= n && s.compare(0, n, pre) == 0; + }; + bool unsupported = (starts_with(line, "RESTORE") || + starts_with(line, "SNAPSHOT") || + starts_with(line, "FREE_SNAPSHOT") || + starts_with(line, "LIST_SLOTS") || + starts_with(line, "compress ") || + starts_with(line, "park") || + starts_with(line, "unpark") || + line == "free drafter" || + line == "drafter free"); + if (unsupported) { + std::fprintf(stderr, + "[daemon] command not supported in gemma4 daemon: %s\n", + line.c_str()); + std::fflush(stderr); + stream_emit(-1); + continue; + } + + // ── Parse: ────────────────────────── + char ppath[1024] = {0}; + int n_gen = 0; + if (std::sscanf(line.c_str(), "%1023s %d", ppath, &n_gen) != 2 || n_gen <= 0) { + std::fprintf(stderr, "[daemon] bad command line: %s\n", line.c_str()); + std::fflush(stderr); + stream_emit(-1); + continue; + } + + // Read binary prompt file (int32 LE token IDs). + std::vector prompt_ids = read_int32_file(ppath); + if (prompt_ids.empty()) { + std::fprintf(stderr, "[daemon] empty or unreadable prompt file: %s\n", ppath); + std::fflush(stderr); + stream_emit(-1); + continue; + } + std::printf("[daemon] prompt=%zu tokens n_gen=%d\n", + prompt_ids.size(), n_gen); + std::fflush(stdout); + + // Reset KV cache between requests. + if (!daemon_first_iter) { + step_graph_free(sg); + reset_gemma4_cache(cache); // also resets draft_kv_pos + if (have_draft) { + draft_step_free(dsg); + } + } + daemon_first_iter = false; + + if ((int)prompt_ids.size() + n_gen > ctx_size) { + std::fprintf(stderr, + "[daemon] prompt (%zu) + n_gen (%d) > ctx_size (%d)\n", + prompt_ids.size(), n_gen, ctx_size); + std::fflush(stderr); + stream_emit(-1); + continue; + } + + // ── Prefill ─────────────────────────────────────────────────── + int last_logit_tok = -1; + { + const int n_prompt = (int)prompt_ids.size(); + const int swa_window = w.swa_window > 0 ? w.swa_window : 1024; + const int chunk_size = std::min(n_prompt, swa_window); + + for (int cs = 0; cs < n_prompt; cs += chunk_size) { + const int chunk_n = std::min(chunk_size, n_prompt - cs); + const bool is_last = (cs + chunk_n == n_prompt); + const bool need_mask = (cs + chunk_n > 1); + + if (!build_gemma4_step(sg, w, cache, backend, + cs, chunk_n, need_mask, + /*capture=*/true, + use_pflash, pflash_alpha, + fa_window, + /*last_token_logits_only=*/true)) { + std::fprintf(stderr, "[daemon] prefill build failed at %d\n", cs); + std::fflush(stderr); + break; + } + + if (!embed_tokens_batch(w, prompt_ids.data() + cs, chunk_n, + sg.inp_embed, backend)) { + std::fprintf(stderr, "[daemon] embed_tokens_batch failed\n"); + std::fflush(stderr); + break; + } + + { + std::vector pos(chunk_n); + for (int i = 0; i < chunk_n; i++) pos[i] = cs + i; + ggml_backend_tensor_set(sg.positions, pos.data(), 0, + sizeof(int32_t) * chunk_n); + } + + if (sg.attn_mask && sg.attn_mask->buffer) { + const int kv_len = cs + chunk_n; + std::vector mask_buf; + build_causal_mask(mask_buf, kv_len, chunk_n, cs); + ggml_backend_tensor_set(sg.attn_mask, mask_buf.data(), 0, + sizeof(uint16_t) * mask_buf.size()); + } + + if (sg.swa_mask && sg.swa_mask->buffer) { + const SwaView swa_view = compute_swa_view(cs, chunk_n, + swa_window, cache.swa_ctx_alloc); + std::vector swa_buf; + build_swa_causal_mask(swa_buf, + /*kv_start*/ cs, + /*n_tokens*/ chunk_n, + /*swa_window*/ swa_window, + /*ring_size*/ swa_view.effective_win_len, + /*kv_end*/ cs + chunk_n); + ggml_backend_tensor_set(sg.swa_mask, swa_buf.data(), 0, + sizeof(uint16_t) * swa_buf.size()); + } + + auto st = ggml_backend_graph_compute(backend, sg.gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[daemon] prefill compute failed at %d\n", cs); + std::fflush(stderr); + break; + } + + // ── TQ3_0 K-cache write probe ───────────────────────────────────── + if (getenv("DFLASH_TQ3_PROBE_CACHE_WRITE") && + (cs == 0 || cs == chunk_size) && + !cache.attn_k.empty()) { + ggml_tensor * cache_k_layer0 = cache.attn_k[0]; + if (cache_k_layer0 && cache_k_layer0->type == GGML_TYPE_TQ3_0) { + // nb[1] is the stride in bytes between successive token slots + const size_t off = (size_t)cache_k_layer0->nb[1] * (size_t)cs; + uint8_t blk[14] = {}; + ggml_backend_tensor_get(cache_k_layer0, blk, off, 14); + std::fprintf(stderr, "[CACHE-WRITE-PROBE] cs=%d off=%zu bytes=", cs, off); + for (int _i = 0; _i < 14; _i++) + std::fprintf(stderr, "%02x ", blk[_i]); + std::fprintf(stderr, "\n"); + std::fflush(stderr); + } + } + // ───────────────────────────────────────────────────────────────── + + cache.cur_pos = cs + chunk_n; + + if (is_last) { + const int vocab = w.n_vocab; + std::vector logits_cpu(vocab); + ggml_backend_tensor_get(sg.logits, logits_cpu.data(), + 0, sizeof(float) * vocab); + last_logit_tok = sample_logits(logits_cpu.data(), vocab, + req_sampler, prompt_ids, rng); + cache.last_tok = last_logit_tok; + } + + step_graph_free(sg); + } + + // Draft KV prefill after target prefill. + if (have_draft && last_logit_tok >= 0) { + const int target_feat_w = dw.n_target_layers * dw.target_hidden; + const int draft_kv_cap = cache.draft_kv_cap > 0 + ? cache.draft_kv_cap + : (int)cache.draft_k[0]->ne[2]; + const int draft_prefill_n = std::min(n_prompt, draft_kv_cap); + const int draft_prefill_skip = n_prompt - draft_prefill_n; + + DraftKVPrefillGraph pkg; + if (build_draft_kv_prefill(pkg, dw, cache, backend, draft_prefill_n)) { + // Ring-buffer aware bf16→f32 conversion via ggml_cpy. + copy_target_feat_bf16_to_f32(backend, cache.target_feat, + pkg.target_feat, + draft_prefill_skip % cache.target_feat_cap, + draft_prefill_n, target_feat_w); + + std::vector pos(draft_prefill_n); + for (int pi = 0; pi < draft_prefill_n; pi++) pos[pi] = draft_prefill_skip + pi; + ggml_backend_tensor_set(pkg.positions, pos.data(), 0, + sizeof(int32_t) * draft_prefill_n); + + auto dst = ggml_backend_graph_compute(backend, pkg.gf); + if (dst != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[daemon] draft KV prefill compute failed\n"); + std::fflush(stderr); + } + cache.draft_kv_pos = draft_prefill_n; + std::fprintf(stderr, + "[daemon] draft KV prefill done: %d positions materialized " + "(skipped %d early tokens, cap=%d, target_feat_cap=%d, dkv_pos=%d)\n", + draft_prefill_n, draft_prefill_skip, draft_kv_cap, + cache.target_feat_cap, cache.draft_kv_pos); + } + draft_kv_prefill_destroy(pkg); + } + } + + if (last_logit_tok < 0) { + std::fprintf(stderr, "[daemon] prefill produced no logit token\n"); + std::fflush(stderr); + stream_emit(-1); + continue; + } + + // ── Decode loop ─────────────────────────────────────────────── + std::vector history(prompt_ids); + int committed = cache.cur_pos; + int32_t cur_tok = last_logit_tok; + int n_generated = 0; + + while (n_generated < n_gen) { + if (IS_EOS_TOK(cur_tok, w)) { + std::printf("[daemon] EOS at step %d\n", n_generated); + std::fflush(stdout); + break; + } + if (committed >= ctx_size - 1) { + std::printf("[daemon] context full\n"); + std::fflush(stdout); + break; + } + + if (!build_gemma4_step(sg, w, cache, backend, + committed, 1, + /*with_mask=*/true, + /*capture=*/false, + /*use_pflash=*/false, pflash_alpha, + fa_window)) { + std::fprintf(stderr, "[daemon] decode build failed at step %d\n", n_generated); + std::fflush(stderr); + break; + } + + if (sg.attn_mask && sg.attn_mask->buffer) { + const int kv_len = committed + 1; + std::vector mask_buf; + build_causal_mask(mask_buf, kv_len, 1, committed); + ggml_backend_tensor_set(sg.attn_mask, mask_buf.data(), 0, + sizeof(uint16_t) * mask_buf.size()); + } + if (sg.swa_mask && sg.swa_mask->buffer) { + const SwaView swa_view = compute_swa_view(committed, 1, + w.swa_window, cache.swa_ctx_alloc); + std::vector swa_buf; + build_swa_causal_mask(swa_buf, + /*kv_start*/ committed, + /*n_tokens*/ 1, + /*swa_window*/ w.swa_window, + /*ring_size*/ swa_view.effective_win_len, + /*kv_end*/ committed + 1); + ggml_backend_tensor_set(sg.swa_mask, swa_buf.data(), 0, + sizeof(uint16_t) * swa_buf.size()); + } + + if (!embed_token(w, cur_tok, sg.inp_embed, backend)) { + std::fprintf(stderr, "[daemon] embed_token failed\n"); + std::fflush(stderr); + break; + } + + int32_t pos_val = committed; + ggml_backend_tensor_set(sg.positions, &pos_val, 0, sizeof(int32_t)); + + auto st = ggml_backend_graph_compute(backend, sg.gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[daemon] decode compute failed at step %d\n", n_generated); + std::fflush(stderr); + break; + } + + committed++; + cache.cur_pos = committed; + + const int vocab = w.n_vocab; + std::vector logits_cpu(vocab); + ggml_backend_tensor_get(sg.logits, logits_cpu.data(), 0, + sizeof(float) * vocab); + + const int32_t next_tok = (int32_t)sample_logits( + logits_cpu.data(), vocab, req_sampler, history, rng); + + // Emit current token to stream fd before advancing. + stream_emit(cur_tok); + + history.push_back(cur_tok); + n_generated++; + + cur_tok = next_tok; + cache.last_tok = cur_tok; + + step_graph_free(sg); + } + + // Sentinel: end of stream. + stream_emit(-1); + std::printf("[daemon] generated %d tokens\n", n_generated); + std::fflush(stdout); + } + + // ── Daemon exit: clean up ───────────────────────────────────────── + step_graph_destroy(sg); + draft_step_destroy(dsg); + if (have_draft) { + free_draft_kv_cache(cache); + dw.tok_embd = nullptr; + free_gemma4_draft_weights(dw); + if (tok_embd_buf) ggml_backend_buffer_free(tok_embd_buf); + if (tok_embd_ctx) ggml_free(tok_embd_ctx); + } + free_gemma4_cache(cache); + free_gemma4_target_weights(w); + ggml_backend_free(backend); + return 0; + } + + // ── Non-daemon: tokenize prompt ─────────────────────────────────────── + std::vector prompt_ids; + if (!token_ids_str.empty()) { + prompt_ids = parse_token_ids(token_ids_str); + if (prompt_ids.empty()) { + std::fprintf(stderr, "error: --tokens produced no valid token IDs\n"); + return 2; + } + std::printf("[tokens] using %zu pre-tokenised IDs from --tokens\n", + prompt_ids.size()); + } else { + prompt_ids = tokenize_byte_fallback(prompt_text); + std::printf("[tokens] byte-fallback tokenisation: %zu tokens " + "(pass --tokens for real tokenisation)\n", + prompt_ids.size()); + } + + // ── Ensure BOS is prepended (Gemma4 requires BOS at position 0) ── + if (w.bos_id >= 0 && (prompt_ids.empty() || prompt_ids[0] != w.bos_id)) { + prompt_ids.insert(prompt_ids.begin(), w.bos_id); + std::printf("[tokens] prepended BOS token %d\n", w.bos_id); + } + + if ((int)prompt_ids.size() >= ctx_size) { + std::fprintf(stderr, "error: prompt (%zu tokens) >= ctx_size (%d)\n", + prompt_ids.size(), ctx_size); + return 2; + } + + // ── Benchmark loop outer container ──────────────────────────────────── + const int bench_runs = bench_mode ? 3 : 1; + std::vector bench_tok_per_sec; + + // Declared here (main scope) so step_graph_destroy(sg)/draft_step_destroy(dsg) + // in cleanup is valid. + StepGraph sg; + DraftStepGraph dsg; + + // Speculative decode stats (accumulated across bench iterations when bench_mode) + int total_draft_steps = 0; + int total_accepted = 0; + + for (int bench_iter = 0; bench_iter < bench_runs; bench_iter++) { + + if (bench_runs > 1) { + reset_gemma4_cache(cache); + // Reset draft step state for the new bench iteration + draft_step_free(dsg); + total_draft_steps = 0; + total_accepted = 0; + std::printf("[bench] run %d/%d\n", bench_iter + 1, bench_runs); + } + + // ── Prefill ─────────────────────────────────────────────────────── + // + // Chunked batched prefill: process up to swa_window tokens per chunk. + // Each chunk dispatches a single GPU graph covering all tokens in the + // chunk, which is far cheaper than one dispatch per token. + // + // For a chunk [cs, cs+chunk_n): + // 1. Embed chunk tokens → inp_embed + // 2. Set positions[i] = cs + i + // 3. Build causal mask covering [0, cs+chunk_n) for the chunk rows + // 4. Build SWA mask for sliding-window layers (when cs > 0) + // 5. Compute graph → KV + target_feat (logits discarded except last) + + std::printf("[prefill] %zu tokens ...\n", prompt_ids.size()); + double prefill_t0 = now_ms(); + int last_logit_tok = -1; + + { + const int n_prompt = (int)prompt_ids.size(); + + { + const int swa_window = w.swa_window > 0 ? w.swa_window : 1024; + const int chunk_size = std::min(n_prompt, swa_window); + + for (int cs = 0; cs < n_prompt; cs += chunk_size) { + const int chunk_n = std::min(chunk_size, n_prompt - cs); + const bool is_last = (cs + chunk_n == n_prompt); + const bool need_mask = (cs + chunk_n > 1); + + if (!build_gemma4_step(sg, w, cache, backend, + /*kv_start=*/cs, chunk_n, + need_mask, /*capture=*/true, + use_pflash, pflash_alpha, + fa_window, + /*last_token_logits_only=*/true)) { + std::fprintf(stderr, "prefill chunk build failed at offset %d\n", cs); + return 1; + } + + if (!embed_tokens_batch(w, prompt_ids.data() + cs, chunk_n, + sg.inp_embed, backend)) { + return 1; + } + + { + std::vector pos(chunk_n); + for (int i = 0; i < chunk_n; i++) pos[i] = cs + i; + ggml_backend_tensor_set(sg.positions, pos.data(), 0, + sizeof(int32_t) * chunk_n); + } + + if (sg.attn_mask && sg.attn_mask->buffer) { + const int kv_len = cs + chunk_n; + std::vector mask_buf; + build_causal_mask(mask_buf, kv_len, chunk_n, cs); + ggml_backend_tensor_set(sg.attn_mask, mask_buf.data(), 0, + sizeof(uint16_t) * mask_buf.size()); + } + + if (sg.swa_mask && sg.swa_mask->buffer) { + const SwaView swa_view = compute_swa_view(cs, chunk_n, + swa_window, cache.swa_ctx_alloc); + std::vector swa_buf; + build_swa_causal_mask(swa_buf, + /*kv_start*/ cs, + /*n_tokens*/ chunk_n, + /*swa_window*/ swa_window, + /*ring_size*/ swa_view.effective_win_len, + /*kv_end*/ cs + chunk_n); + ggml_backend_tensor_set(sg.swa_mask, swa_buf.data(), 0, + sizeof(uint16_t) * swa_buf.size()); + } + + auto st = ggml_backend_graph_compute(backend, sg.gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "prefill compute failed at chunk offset %d\n", cs); + return 1; + } + + cache.cur_pos = cs + chunk_n; + + if (is_last) { + const int vocab = w.n_vocab; + std::vector logits_cpu(vocab); + // last_token_logits_only=true → logits has shape [vocab, 1]; + // read from offset 0 instead of skipping (chunk_n-1)*vocab floats. + ggml_backend_tensor_get(sg.logits, logits_cpu.data(), + 0, + sizeof(float) * vocab); + last_logit_tok = sample_logits(logits_cpu.data(), vocab, + sampler, prompt_ids, rng); + cache.last_tok = last_logit_tok; + } + + step_graph_free(sg); + } + } + } + + double prefill_t1 = now_ms(); + { + const int n_prompt = (int)prompt_ids.size(); + const double prefill_ms = prefill_t1 - prefill_t0; + { + const int swa_window = w.swa_window > 0 ? w.swa_window : 1024; + const int chunk_size = std::min(n_prompt, swa_window); + std::printf("[prefill] %d tokens in %.1f ms (%.1f tok/s) " + "[chunked%s, chunk_size=%d] (last sampled token: %d)\n", + n_prompt, prefill_ms, + prefill_ms > 0.0 ? (double)n_prompt / (prefill_ms / 1000.0) : 0.0, + use_pflash ? "+pflash" : "", chunk_size, last_logit_tok); + } + } + + // ── Draft KV prefill: materialize draft KV for all prompt positions ─ + if (have_draft) { + const int n_prompt = (int)prompt_ids.size(); + const int target_feat_w = dw.n_target_layers * dw.target_hidden; + + // Clamp to draft KV cache capacity. When the prompt is longer than the + // draft cache, we prefill only the LAST draft_prefill_n tokens so that + // the context that matters most (closest to the first decode step) is + // represented in the draft KV cache. + const int draft_kv_cap = cache.draft_kv_cap > 0 + ? cache.draft_kv_cap + : (int)cache.draft_k[0]->ne[2]; + const int draft_prefill_n = std::min(n_prompt, draft_kv_cap); + const int draft_prefill_skip = n_prompt - draft_prefill_n; + + DraftKVPrefillGraph pkg; + if (!build_draft_kv_prefill(pkg, dw, cache, backend, draft_prefill_n)) { + std::fprintf(stderr, "[draft] KV prefill build failed\n"); + return 1; + } + + // Extract target_feat from ring buffer (bf16 → f32) via ggml_cpy. + // The ring buffer stores tokens at slot (pos % cap). + // We want the LAST draft_prefill_n hidden states (positions draft_prefill_skip + // through n_prompt-1). + copy_target_feat_bf16_to_f32(backend, cache.target_feat, + pkg.target_feat, + draft_prefill_skip % cache.target_feat_cap, + draft_prefill_n, target_feat_w); + + // Positions: [draft_prefill_skip, ..., n_prompt-1] + { + std::vector pos(draft_prefill_n); + for (int i = 0; i < draft_prefill_n; i++) pos[i] = draft_prefill_skip + i; + ggml_backend_tensor_set(pkg.positions, pos.data(), 0, sizeof(int32_t) * draft_prefill_n); + } + + auto st = ggml_backend_graph_compute(backend, pkg.gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[draft] KV prefill compute failed\n"); + draft_kv_prefill_destroy(pkg); + return 1; + } + // draft_kv_pos tracks entries written, bounded by draft_kv_cap. + cache.draft_kv_pos = draft_prefill_n; + + draft_kv_prefill_destroy(pkg); + std::printf("[draft] KV prefill done: %d positions materialized " + "(skipped %d early tokens, cap=%d, target_feat_cap=%d, dkv_pos=%d)\n", + draft_prefill_n, draft_prefill_skip, draft_kv_cap, + cache.target_feat_cap, cache.draft_kv_pos); + } + + // ── Decode loop ─────────────────────────────────────────────────── + + std::vector generated; + generated.reserve(n_predict); + std::vector history(prompt_ids); + + int committed = cache.cur_pos; + int32_t cur_tok = last_logit_tok; + + double decode_t0 = now_ms(); + double first_token_ms = -1.0; + + if (have_draft) { + // ── SPECULATIVE DECODE LOOP ─────────────────────────────────── + // + // Each iteration proposes a block of q_len tokens via the draft + // model, then verifies with a single batched target forward. + // Accepted prefix tokens are committed; the loop advances by + // accept_n tokens per target call instead of 1. + // + // Gemma4 is pure attention (no SSM/conv state), so rollback is + // trivially: just don't advance committed past accepted tokens. + // Stale KV at positions [committed+commit_n..committed+q_len-1] + // will be overwritten by the next verify pass. + + AdaptiveDraftMax adaptive; + adaptive.init(draft_max_adaptive, draft_max, dw.block_size); + if (draft_max_adaptive) { + std::printf("[adaptive] enabled initial=%d max=%d window=%d\n", + adaptive.current, adaptive.max_q, adaptive.window_steps); + } + const int mask_tok = dw.mask_token_id; // 4 + const int target_feat_w = dw.n_target_layers * dw.target_hidden; + const int vocab = w.n_vocab; + const int dkv_cap = cache.draft_kv_cap > 0 + ? cache.draft_kv_cap + : (int)cache.draft_k[0]->ne[2]; + + std::vector noise_ids(dw.block_size); + std::vector noise_embed_buf((size_t)dw.n_embd * dw.block_size); + std::vector draft_tok(dw.block_size); + std::vector target_tok(dw.block_size); + std::vector draft_logits_buf((size_t)vocab * dw.block_size); + std::vector verify_logits_buf((size_t)vocab * dw.block_size); + + while ((int)generated.size() < n_predict) { + int q_len = adaptive.enabled + ? adaptive.current + : ((draft_max > 0 && draft_max < dw.block_size) + ? draft_max : dw.block_size); + q_len = std::min(q_len, std::max(1, ctx_size - committed - 1)); + + if (IS_EOS_TOK(cur_tok, w)) { + std::printf("\n[decode] EOS token %d\n", cur_tok); + break; + } + if (committed >= ctx_size - 1) { + std::printf("\n[decode] context full\n"); + break; + } + + // Not enough context for target_feat extraction yet: + // fall back to single-token target-only decode. + if (committed < q_len) { + if (!build_gemma4_step(sg, w, cache, backend, + committed, /*n_tokens=*/1, + /*with_mask=*/true, + /*capture=*/true, + /*use_pflash=*/false, pflash_alpha, + fa_window)) { + std::fprintf(stderr, "[decode] warmup build failed at step %zu\n", + generated.size()); + return 1; + } + + if (sg.attn_mask && sg.attn_mask->buffer) { + const int kv_len = committed + 1; + std::vector mask_buf; + build_causal_mask(mask_buf, kv_len, 1, committed); + ggml_backend_tensor_set(sg.attn_mask, mask_buf.data(), 0, + sizeof(uint16_t) * mask_buf.size()); + } + if (sg.swa_mask && sg.swa_mask->buffer) { + const SwaView swa_view = compute_swa_view(committed, 1, + w.swa_window, cache.swa_ctx_alloc); + std::vector swa_buf; + build_swa_causal_mask(swa_buf, + /*kv_start*/ committed, + /*n_tokens*/ 1, + /*swa_window*/ w.swa_window, + /*ring_size*/ swa_view.effective_win_len, + /*kv_end*/ committed + 1); + ggml_backend_tensor_set(sg.swa_mask, swa_buf.data(), 0, + sizeof(uint16_t) * swa_buf.size()); + } + + if (!embed_token(w, cur_tok, sg.inp_embed, backend)) return 1; + + int32_t pos_val = committed; + ggml_backend_tensor_set(sg.positions, &pos_val, 0, sizeof(int32_t)); + + double step_t0 = now_ms(); + auto st = ggml_backend_graph_compute(backend, sg.gf); + double step_t1 = now_ms(); + + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[decode] warmup compute failed at step %zu\n", + generated.size()); + return 1; + } + + committed++; + cache.cur_pos = committed; + + // Draft KV prefill for this warmup token (position committed-1). + { + const int warmup_pos = committed - 1; + const int target_feat_w_w = dw.n_target_layers * dw.target_hidden; + DraftKVPrefillGraph wpkg; + if (!build_draft_kv_prefill(wpkg, dw, cache, backend, 1)) { + std::fprintf(stderr, "[decode] warmup draft KV prefill build failed\n"); + return 1; + } + copy_target_feat_bf16_to_f32(backend, cache.target_feat, + wpkg.target_feat, + warmup_pos % cache.target_feat_cap, + 1, target_feat_w_w); + { + int32_t p = warmup_pos; + ggml_backend_tensor_set(wpkg.positions, &p, 0, sizeof(int32_t)); + } + auto wst = ggml_backend_graph_compute(backend, wpkg.gf); + if (wst != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[decode] warmup draft KV prefill compute failed\n"); + draft_kv_prefill_destroy(wpkg); + return 1; + } + cache.draft_kv_pos = std::min(dkv_cap, cache.draft_kv_pos + 1); + draft_kv_prefill_destroy(wpkg); + } + + const int vocab_inner = w.n_vocab; + std::vector logits_cpu(vocab_inner); + ggml_backend_tensor_get(sg.logits, logits_cpu.data(), 0, + sizeof(float) * vocab_inner); + + const int32_t next_tok = (int32_t)sample_logits( + logits_cpu.data(), vocab_inner, sampler, history, rng); + + generated.push_back(cur_tok); + history.push_back(cur_tok); + + if (first_token_ms < 0.0) { + first_token_ms = step_t1 - step_t0; + } + + std::printf("%d ", cur_tok); + std::fflush(stdout); + + cur_tok = next_tok; + cache.last_tok = cur_tok; + + step_graph_free(sg); + continue; + } + + // ── 1. Build noise block: [cur_tok, MASK, MASK, ..., MASK] + noise_ids[0] = cur_tok; + for (int i = 1; i < q_len; i++) noise_ids[i] = mask_tok; + if (!w.embedder.embed(noise_ids.data(), q_len, noise_embed_buf.data())) { + std::fprintf(stderr, "[spec] embed noise_ids failed\n"); + return 1; + } + + // ── 2. Build draft graph (KV-cached, no target_feat input) + // The draft model operates in its own KV address space bounded by + // draft_kv_cap. Use cache.draft_kv_pos (number of entries written into + // the draft KV cache) as kv_start, NOT the absolute committed position. + double refill_ms = 0.0; + if (cache.draft_kv_pos + q_len > dkv_cap) { + // Sliding-window re-prefill: instead of wiping all draft KV context, + // keep the most recent (dkv_cap - q_len) committed tokens by + // re-projecting their target_feat into the beginning of the draft + // KV cache. This preserves the drafter's context continuity across + // ring-buffer wrap points, which is the root cause of acceptance + // collapsing from ~10/16 at 32K to ~1/16 at 64K. + const int keep = dkv_cap - q_len; + if (keep > 0 && committed >= keep) { + // Absolute positions of the (keep) tokens we want to retain: + // [committed - keep, committed). + const int refill_start = committed - keep; + + // Reset draft_kv_pos to 0 so build_draft_kv_prefill_graph writes + // to slot [0, keep) — the ASSERT inside the graph builder requires + // draft_kv_pos + n_tokens <= ne[2]. + cache.draft_kv_pos = 0; + + const double refill_t0 = now_ms(); + DraftKVPrefillGraph rpkg; + if (!build_draft_kv_prefill(rpkg, dw, cache, backend, keep)) { + std::fprintf(stderr, "[spec] draft KV re-prefill build failed\n"); + return 1; + } + + // Copy target_feat for [refill_start, refill_start+keep) from the + // ring buffer (bf16) into rpkg.target_feat (f32) via ggml_cpy. + copy_target_feat_bf16_to_f32(backend, cache.target_feat, + rpkg.target_feat, + refill_start % cache.target_feat_cap, + keep, target_feat_w); + + // Absolute positions for RoPE — must match training. + { + std::vector rpos(keep); + for (int i = 0; i < keep; i++) rpos[i] = refill_start + i; + ggml_backend_tensor_set(rpkg.positions, rpos.data(), 0, + sizeof(int32_t) * keep); + } + + auto rst = ggml_backend_graph_compute(backend, rpkg.gf); + if (rst != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[spec] draft KV re-prefill compute failed\n"); + draft_kv_prefill_destroy(rpkg); + return 1; + } + cache.draft_kv_pos = keep; + draft_kv_prefill_destroy(rpkg); + refill_ms = now_ms() - refill_t0; + + std::fprintf(stderr, + "[spec] draft KV sliding re-prefill: kept %d tokens " + "(positions %d..%d), dkv_cap=%d\n", + keep, refill_start, committed - 1, dkv_cap); + } else { + // Not enough committed history to re-prefill — hard reset. + // This only happens at the very beginning of decode (committed < keep). + cache.draft_kv_pos = 0; + } + } + if (!build_draft_step(dsg, dw, cache, backend, q_len, cache.draft_kv_pos)) { + std::fprintf(stderr, "[spec] draft build failed\n"); + return 1; + } + + // ── 3. Set draft inputs + + // draft_embed: noise embeddings [n_embd, q_len] f32 + ggml_backend_tensor_set(dsg.draft_embed, noise_embed_buf.data(), 0, + sizeof(float) * (size_t)dw.n_embd * q_len); + + // positions: absolute [committed, committed+1, ..., committed+q_len-1] + // (absolute positions are used for RoPE — they must match training) + { + std::vector pos(q_len); + for (int i = 0; i < q_len; i++) pos[i] = committed + i; + ggml_backend_tensor_set(dsg.positions, pos.data(), 0, sizeof(int32_t) * q_len); + } + + // Causal mask: block token i attends to all draft KV context + // [0..draft_kv_pos-1] plus block tokens [0..i]. + // Use draft_kv_pos (draft KV address space), not committed. + if (dsg.attn_mask && dsg.attn_mask->buffer) { + const int dkv_ctx = cache.draft_kv_pos; + const int kv_len = dkv_ctx + q_len; + const int kv_pad = align_up(kv_len, KQ_MASK_PAD); + const int q_pad = align_up(q_len, KQ_MASK_PAD); + std::vector mask((size_t)kv_pad * q_pad, F16_NEG_INF); + for (int q = 0; q < q_len; q++) { + const int max_k = dkv_ctx + q; + for (int k = 0; k <= max_k; k++) { + mask[(size_t)q * kv_pad + k] = F16_ZERO; + } + } + ggml_backend_tensor_set(dsg.attn_mask, mask.data(), 0, + sizeof(uint16_t) * mask.size()); + } + + // ── 4. Draft compute + const double draft_t0 = now_ms(); + { + auto st = ggml_backend_graph_compute(backend, dsg.gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[spec] draft compute failed: %d\n", (int)st); + return 1; + } + } + const double draft_t1 = now_ms(); + + // ── 5. Read draft logits and argmax + ggml_backend_tensor_get(dsg.logits, draft_logits_buf.data(), 0, + sizeof(float) * (size_t)vocab * q_len); + for (int i = 0; i < q_len; i++) { + draft_tok[i] = argmax_f32(draft_logits_buf.data() + (size_t)i * vocab, vocab); + } + draft_tok[0] = cur_tok; // pin first token (it was cur_tok, not a prediction) + + // ── 6. Target verify: batched forward on draft_tok[0..q_len-1] + if (!build_gemma4_step(sg, w, cache, backend, + committed, q_len, + /*with_mask=*/true, /*capture=*/true, + use_pflash, pflash_alpha, fa_window)) { + std::fprintf(stderr, "[spec] verify build failed\n"); + return 1; + } + + if (!embed_tokens_batch(w, draft_tok.data(), q_len, sg.inp_embed, backend)) { + return 1; + } + + // Target positions: [committed, committed+1, ..., committed+q_len-1] + { + std::vector pos(q_len); + for (int i = 0; i < q_len; i++) pos[i] = committed + i; + ggml_backend_tensor_set(sg.positions, pos.data(), 0, sizeof(int32_t) * q_len); + } + + // Causal mask for target verify + if (sg.attn_mask && sg.attn_mask->buffer) { + const int kv_len = committed + q_len; + std::vector mask_buf; + build_causal_mask(mask_buf, kv_len, q_len, committed); + ggml_backend_tensor_set(sg.attn_mask, mask_buf.data(), 0, + sizeof(uint16_t) * mask_buf.size()); + } + + // SWA mask for target verify (required when n_tokens > 1) + if (sg.swa_mask && sg.swa_mask->buffer) { + const SwaView swa_view = compute_swa_view(committed, q_len, + w.swa_window, cache.swa_ctx_alloc); + std::vector swa_buf; + build_swa_causal_mask(swa_buf, + /*kv_start*/ committed, + /*n_tokens*/ q_len, + /*swa_window*/ w.swa_window, + /*ring_size*/ swa_view.effective_win_len, + /*kv_end*/ committed + q_len); + ggml_backend_tensor_set(sg.swa_mask, swa_buf.data(), 0, + sizeof(uint16_t) * swa_buf.size()); + } + + const double verify_t0 = now_ms(); + { + auto st = ggml_backend_graph_compute(backend, sg.gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[spec] verify compute failed: %d\n", (int)st); + return 1; + } + } + const double verify_t1 = now_ms(); + + // ── 7. Read target logits and argmax + ggml_backend_tensor_get(sg.logits, verify_logits_buf.data(), 0, + sizeof(float) * (size_t)vocab * q_len); + for (int i = 0; i < q_len; i++) { + target_tok[i] = argmax_f32(verify_logits_buf.data() + (size_t)i * vocab, vocab); + } + + // ── 8. Acceptance: longest prefix match + // draft_tok[0] = cur_tok (accepted unconditionally as the current token) + // target_tok[i] = target's prediction for position committed+i+1 + // Check: draft_tok[i+1] == target_tok[i] (draft proposed the right next token) + int accept_n = 1; + for (int i = 0; i < q_len - 1; i++) { + if (draft_tok[i + 1] == target_tok[i]) accept_n++; + else break; + } + int commit_n = accept_n; + if (commit_n > n_predict - (int)generated.size()) { + commit_n = n_predict - (int)generated.size(); + } + + // ── 9. Commit accepted tokens + bool hit_eos = false; + for (int i = 0; i < commit_n; i++) { + generated.push_back(draft_tok[i]); + history.push_back(draft_tok[i]); + std::printf("%d ", draft_tok[i]); + std::fflush(stdout); + if (IS_EOS_TOK(draft_tok[i], w)) { hit_eos = true; break; } + } + + // ── 10. Draft KV prefill for the committed positions, then advance state. + // The target verify pass (step 6) captured target_feat for positions + // [committed..committed+q_len-1]. We prefill draft KV for the accepted + // prefix [committed..committed+commit_n-1] before advancing committed. + const double commit_t0 = now_ms(); + { + DraftKVPrefillGraph cpkg; + if (!build_draft_kv_prefill(cpkg, dw, cache, backend, commit_n)) { + std::fprintf(stderr, "[spec] draft KV prefill build failed\n"); + return 1; + } + + // Extract target_feat for positions [committed..committed+commit_n-1] + // from the ring buffer (bf16 → f32) via ggml_cpy. + copy_target_feat_bf16_to_f32(backend, cache.target_feat, + cpkg.target_feat, + committed % cache.target_feat_cap, + commit_n, target_feat_w); + + { + std::vector pos(commit_n); + for (int i = 0; i < commit_n; i++) pos[i] = committed + i; + ggml_backend_tensor_set(cpkg.positions, pos.data(), 0, + sizeof(int32_t) * commit_n); + } + + auto cst = ggml_backend_graph_compute(backend, cpkg.gf); + if (cst != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[spec] draft KV prefill compute failed\n"); + draft_kv_prefill_destroy(cpkg); + return 1; + } + cache.draft_kv_pos = std::min(dkv_cap, cache.draft_kv_pos + commit_n); + draft_kv_prefill_destroy(cpkg); + } + const double commit_t1 = now_ms(); + + // Gemma4 is pure attention — no SSM/conv rollback needed. + // Stale KV at positions [committed+commit_n..committed+q_len-1] + // will be overwritten by the next verify pass. + committed += commit_n; + cache.cur_pos = committed; + cur_tok = target_tok[commit_n - 1]; + cache.last_tok = cur_tok; + + total_draft_steps++; + total_accepted += commit_n; + + if (first_token_ms < 0.0) { + first_token_ms = now_ms() - decode_t0; + } + + double avg_accept = (total_draft_steps > 0) + ? (double)total_accepted / total_draft_steps : 0.0; + std::printf("[step %d] accept=%d/%d avg=%.1f " + "draft_ms=%.2f verify_ms=%.2f kv_ms=%.2f refill_ms=%.2f\n", + total_draft_steps, accept_n, q_len, avg_accept, + draft_t1 - draft_t0, verify_t1 - verify_t0, + commit_t1 - commit_t0, refill_ms); + adaptive.observe(accept_n, q_len, total_draft_steps); + + if (hit_eos) break; + + step_graph_free(sg); + draft_step_free(dsg); + } + + } else if (have_mtp) { + + if (gamma > 1) { + // ── γ>1 MTP SPECULATIVE DECODE LOOP ────────────────────────── + // + // Phase 3 of wild-growing-ember plan: chain generation with hoisted + // allocator, batched target verify, greedy longest-prefix accept. + // + // Per-chain flow: + // 1. Rebuild mtp_g ONCE per chain (hoisted outside k-loop). + // 2. K MTP steps: feed (seed_tok, h_prev) → draft[k], chain h_post. + // 3. Batched target verify: [cur_tok, draft[0..K-1]] = K+1 tokens. + // 4. Greedy longest-prefix match → accept_drafts + bonus token. + // 5. Commit tokens, advance state. + // 6. If accept_drafts < K: 1-token re-capture to refresh mtp_h_prev + // at the correct row (approach A from plan). + // + // Pack convention: + // verify_in[0] = cur_tok at position committed + // verify_in[i+1] = draft[i] at position committed+i+1, i in [0,K) + // target_tok[i] = target's prediction for position committed+i+1 + // accept if draft[i] == target_tok[i] (0-based comparison over K) + // bonus = target_tok[accept_drafts] + // emit_count = accept_drafts + 1 + // new committed = old_committed + accept_drafts + 1 + // + // mtp_h_prev refresh (approach A): + // verify is run with mtp_h_prev_row = -1 (sentinel = last row = K). + // After match, if accept_drafts < K, one extra 1-token target forward + // at position old_committed+accept_drafts refreshes the hidden to the + // correct row. + + // Stats counters + int mtp_gt1_chains = 0; + int mtp_gt1_accepted = 0; // total drafted tokens accepted + int mtp_gt1_total = 0; // total drafted positions evaluated + + // Allocate a persistent mtp_galloc for the chain loop. + // build_mtp_step_graph needs a fresh ggml context per chain, but we + // reuse the same ggml_gallocr_t to avoid repeated VRAM alloc/free. + ggml_gallocr_t mtp_galloc = ggml_gallocr_new( + ggml_backend_get_default_buffer_type(backend)); + + const int K = gamma; + const int vocab = w.n_vocab; + + while ((int)generated.size() < n_predict) { + + if (IS_EOS_TOK(cur_tok, w)) { + std::printf("\n[mtp-gt1] EOS token %d at step %zu\n", + cur_tok, generated.size()); + break; + } + if (committed >= ctx_size - (K + 2)) { + std::printf("\n[mtp-gt1] context nearly full at step %zu\n", + generated.size()); + break; + } + + // ── Phase 3a: Build mtp_g ONCE for this chain ────────────── + // attn_pos = committed for all K steps (const mode, Google ref). + // incr mode: in_pos is updated per step inside the k-loop below. + free_mtp_step_graph(mtp_g); + if (!build_mtp_step_graph(mtp_w, cache, w, mtp_g, committed)) { + std::fprintf(stderr, "[mtp-gt1] build_mtp_step_graph failed: %s\n", + dflash27b_last_error()); + ggml_gallocr_free(mtp_galloc); + return 1; + } + if (!ggml_gallocr_alloc_graph(mtp_galloc, mtp_g.gf)) { + std::fprintf(stderr, "[mtp-gt1] gallocr_alloc_graph failed\n"); + ggml_gallocr_free(mtp_galloc); + return 1; + } + + // ── Phase 3a: Chain generation (K steps) ────────────────── + std::vector draft(K); + + for (int k = 0; k < K; ++k) { + // Seed token for step k + const int32_t seed_tok = (k == 0) ? cur_tok : draft[k - 1]; + + // in_tok_embd: pre-dequantised F32 embedding of seed_tok + if (!embed_token(w, seed_tok, mtp_g.in_tok_embd, backend)) { + std::fprintf(stderr, "[mtp-gt1] embed_token failed for tok=%d k=%d\n", + seed_tok, k); + ggml_gallocr_free(mtp_galloc); + return 1; + } + + // in_h_prev: at k=0 use target's captured hidden; at k>0 chain from prev step + if (k == 0) { + ggml_backend_tensor_copy(cache.mtp_h_prev, mtp_g.in_h_prev); + } else { + ggml_backend_tensor_copy(mtp_g.out_h_post, mtp_g.in_h_prev); + } + + // in_pos: const=committed for all k (Google ref), incr=committed+k (A/B) + { + int32_t p = (mtp_pos_mode == 0) ? committed : (committed + k); + ggml_backend_tensor_set(mtp_g.in_pos, &p, 0, sizeof(int32_t)); + } + + // FA mask for TQ3_0 / head_dim>=512 layers + if (mtp_g.fa_mask && mtp_g.fa_mask->buffer) { + const int64_t mask_n = mtp_g.fa_mask->ne[0]; + const int64_t kv_seq = mtp_g.fa_mask_kv_seq_len; + std::vector mask_buf(mask_n); + for (int64_t i = 0; i < mask_n; i++) { + mask_buf[i] = (i < kv_seq) ? 0x0000u : 0xFC00u; + } + ggml_backend_tensor_set(mtp_g.fa_mask, mask_buf.data(), 0, + sizeof(uint16_t) * mask_n); + } + + // Compute + { + auto st = ggml_backend_graph_compute(backend, mtp_g.gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[mtp-gt1] MTP compute failed at k=%d\n", k); + ggml_gallocr_free(mtp_galloc); + return 1; + } + } + + // Read draft token from in-graph argmax + int32_t tok_out = 0; + ggml_backend_tensor_get(mtp_g.out_argmax, &tok_out, 0, sizeof(int32_t)); + draft[k] = tok_out; + } + + // ── Phase 3b: Batched target verify ──────────────────────── + // Pack: verify_in = [cur_tok, draft[0..K-1]] = K+1 tokens + // at positions [committed .. committed+K]. + std::vector verify_in; + verify_in.reserve(K + 1); + verify_in.push_back(cur_tok); + for (int i = 0; i < K; ++i) verify_in.push_back(draft[i]); + + const int verify_n = K + 1; + const int old_committed = committed; + + // Approach B: capture all K+1 rows in the verify pass so we + // can pick the right one host-side after greedy match. + cache.mtp_h_prev_row = -1; // unused in batch mode + cache.mtp_h_prev_capture_mode = 1; // enable batch capture + + if (!build_gemma4_step(sg, w, cache, backend, + committed, verify_n, + /*with_mask=*/true, + /*capture=*/false, // no target_feat needed for MTP path + /*use_pflash=*/false, pflash_alpha, + fa_window)) { + std::fprintf(stderr, "[mtp-gt1] verify build failed at step %zu\n", + generated.size()); + ggml_gallocr_free(mtp_galloc); + return 1; + } + + // Embed verify_in batch + if (!embed_tokens_batch(w, verify_in.data(), verify_n, sg.inp_embed, backend)) { + ggml_gallocr_free(mtp_galloc); + return 1; + } + + // Positions: [committed .. committed+K] + { + std::vector pos(verify_n); + for (int i = 0; i < verify_n; ++i) pos[i] = committed + i; + ggml_backend_tensor_set(sg.positions, pos.data(), 0, + sizeof(int32_t) * verify_n); + } + + // Causal mask for batched verify + if (sg.attn_mask && sg.attn_mask->buffer) { + const int kv_len = committed + verify_n; + std::vector mask_buf; + build_causal_mask(mask_buf, kv_len, verify_n, committed); + ggml_backend_tensor_set(sg.attn_mask, mask_buf.data(), 0, + sizeof(uint16_t) * mask_buf.size()); + } + + // SWA mask for batched verify + if (sg.swa_mask && sg.swa_mask->buffer) { + const SwaView swa_view = compute_swa_view(committed, verify_n, + w.swa_window, cache.swa_ctx_alloc); + std::vector swa_buf; + build_swa_causal_mask(swa_buf, + /*kv_start*/ committed, + /*n_tokens*/ verify_n, + /*swa_window*/ w.swa_window, + /*ring_size*/ swa_view.effective_win_len, + /*kv_end*/ committed + verify_n); + ggml_backend_tensor_set(sg.swa_mask, swa_buf.data(), 0, + sizeof(uint16_t) * swa_buf.size()); + } + + { + auto st = ggml_backend_graph_compute(backend, sg.gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[mtp-gt1] verify compute failed\n"); + ggml_gallocr_free(mtp_galloc); + return 1; + } + } + + // Read [vocab, verify_n] logits → target_tok[0..K] + std::vector verify_logits_buf((size_t)vocab * verify_n); + ggml_backend_tensor_get(sg.logits, verify_logits_buf.data(), 0, + sizeof(float) * (size_t)vocab * verify_n); + + std::vector target_tok(verify_n); + for (int i = 0; i < verify_n; ++i) { + target_tok[i] = (int32_t)argmax_f32( + verify_logits_buf.data() + (size_t)i * vocab, vocab); + } + + step_graph_free(sg); + + // ── Phase 3c: Greedy longest-prefix accept + commit ───────── + // draft[i] == target_tok[i] means the MTP chain correctly + // predicted what target would say at position committed+i+1. + int accept_drafts = 0; + for (int i = 0; i < K; ++i) { + if (draft[i] == target_tok[i]) accept_drafts++; + else break; + } + + // Bonus token: target's prediction at the first mismatch position + // (or free prediction after full match). + const int32_t bonus = target_tok[accept_drafts]; + + // Emit accepted draft tokens then bonus + bool hit_eos = false; + for (int i = 0; i < accept_drafts && (int)generated.size() < n_predict; ++i) { + generated.push_back(draft[i]); + history.push_back(draft[i]); + std::printf("%d ", draft[i]); + std::fflush(stdout); + if (IS_EOS_TOK(draft[i], w)) { hit_eos = true; break; } + } + if (!hit_eos && (int)generated.size() < n_predict) { + generated.push_back(bonus); + history.push_back(bonus); + std::printf("%d ", bonus); + std::fflush(stdout); + if (IS_EOS_TOK(bonus, w)) hit_eos = true; + } + + committed = old_committed + accept_drafts + 1; + cache.cur_pos = committed; + cur_tok = bonus; + cache.last_tok = cur_tok; + + if (first_token_ms < 0.0) { + first_token_ms = now_ms() - decode_t0; + } + + // ── mtp_h_prev refresh (approach B) ─────────────────────── + // The verify ran with mtp_h_prev_capture_mode=1, so the target + // graph wrote all K+1 rows into mtp_h_prev_batch. We pick the + // column at accept_drafts host-side (21 KB staging copy) and + // write it into mtp_h_prev. No extra GPU forward needed. + { + const size_t n_embd_hp = (size_t)cache.mtp_h_prev->ne[0]; + const size_t col_bytes = n_embd_hp * sizeof(float); + std::vector staging(n_embd_hp); + ggml_backend_tensor_get(cache.mtp_h_prev_batch, staging.data(), + /* offset = */ (size_t)accept_drafts * col_bytes, col_bytes); + ggml_backend_tensor_set(cache.mtp_h_prev, staging.data(), 0, col_bytes); + } + cache.mtp_h_prev_capture_mode = 0; // reset for safety + + // ── Stats ────────────────────────────────────────────────── + mtp_gt1_chains++; + mtp_gt1_accepted += accept_drafts; + mtp_gt1_total += K; + + std::printf("[mtp-gt1] chain k=%d accepted=%d bonus=%d " + "total_acc=%d pos_mode=%s\n", + mtp_gt1_chains, accept_drafts, bonus, + mtp_gt1_accepted, + mtp_pos_mode == 0 ? "const" : "incr"); + std::fflush(stdout); + + if (hit_eos) break; + + } // while generated < n_predict + + ggml_gallocr_free(mtp_galloc); + + if (mtp_gt1_chains > 0) { + const double mean_accept = (double)mtp_gt1_accepted / mtp_gt1_chains; + const double accept_rate = (double)mtp_gt1_accepted / mtp_gt1_total; + std::printf("\n[mtp-gt1] chains=%d total_accepted=%d mean_accept=%.2f " + "accept_rate=%.3f gamma=%d pos_mode=%s\n", + mtp_gt1_chains, mtp_gt1_accepted, mean_accept, + accept_rate, K, + mtp_pos_mode == 0 ? "const" : "incr"); + } + + } else { // gamma == 1 + // ── MTP SPECULATIVE DECODE LOOP (γ=1 v1) ───────────────────── + // + // Each iteration: + // 1. Run target forward for cur_tok at position `committed`, + // capturing mtp_h_prev from the last full-attention layer. + // 2. Rebuild MTP step graph with current attn_pos = committed+1. + // 3. Feed (cur_tok, mtp_h_prev) into MTP graph → draft_tok. + // 4. Run target verify forward for draft_tok at position committed+1. + // 5. Accept draft_tok if target agrees; otherwise accept target's + // token instead (standard single-draft acceptance). + // γ=1: one MTP draft per step. Correctness gate before γ>1. + + int mtp_steps = 0; + int mtp_accepted = 0; + + while ((int)generated.size() < n_predict) { + + if (IS_EOS_TOK(cur_tok, w)) { + std::printf("\n[mtp] EOS token %d at step %zu\n", + cur_tok, generated.size()); + break; + } + if (committed >= ctx_size - 2) { + std::printf("\n[mtp] context full at step %zu\n", + generated.size()); + break; + } + + // ── 1. Target forward for cur_tok (captures mtp_h_prev) ── + if (!build_gemma4_step(sg, w, cache, backend, + committed, /*n_tokens=*/1, + /*with_mask=*/true, + /*capture=*/false, + /*use_pflash=*/false, pflash_alpha, + fa_window)) { + std::fprintf(stderr, "[mtp] target build failed at step %zu\n", + generated.size()); + return 1; + } + + if (sg.attn_mask && sg.attn_mask->buffer) { + const int kv_len = committed + 1; + std::vector mask_buf; + build_causal_mask(mask_buf, kv_len, 1, committed); + ggml_backend_tensor_set(sg.attn_mask, mask_buf.data(), 0, + sizeof(uint16_t) * mask_buf.size()); + } + if (sg.swa_mask && sg.swa_mask->buffer) { + const SwaView swa_view = compute_swa_view(committed, 1, + w.swa_window, cache.swa_ctx_alloc); + std::vector swa_buf; + build_swa_causal_mask(swa_buf, + /*kv_start*/ committed, + /*n_tokens*/ 1, + /*swa_window*/ w.swa_window, + /*ring_size*/ swa_view.effective_win_len, + /*kv_end*/ committed + 1); + ggml_backend_tensor_set(sg.swa_mask, swa_buf.data(), 0, + sizeof(uint16_t) * swa_buf.size()); + } + if (!embed_token(w, cur_tok, sg.inp_embed, backend)) return 1; + { + int32_t pos_val = committed; + ggml_backend_tensor_set(sg.positions, &pos_val, 0, sizeof(int32_t)); + } + { + auto st = ggml_backend_graph_compute(backend, sg.gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[mtp] target compute failed\n"); + return 1; + } + } + committed++; + cache.cur_pos = committed; + + // Read target logits to get target's own prediction at position committed-1 + const int vocab = w.n_vocab; + std::vector logits_cpu(vocab); + ggml_backend_tensor_get(sg.logits, logits_cpu.data(), 0, + sizeof(float) * vocab); + const int32_t target_next = (int32_t)sample_logits( + logits_cpu.data(), vocab, sampler, history, rng); + + step_graph_free(sg); + + // ── 2. Rebuild MTP step graph with attn_pos = committed ── + free_mtp_step_graph(mtp_g); + if (!build_mtp_step_graph(mtp_w, cache, w, mtp_g, committed)) { + std::fprintf(stderr, "[mtp] build_mtp_step_graph failed: %s\n", + dflash27b_last_error()); + return 1; + } + + // Allocate MTP graph (needs gallocr; build_mtp_step_graph creates + // the ggml context but not the backend buffers) + ggml_gallocr_t mtp_alloc = ggml_gallocr_new( + ggml_backend_get_default_buffer_type(backend)); + bool mtp_alloc_ok = ggml_gallocr_alloc_graph(mtp_alloc, mtp_g.gf); + if (!mtp_alloc_ok) { + std::fprintf(stderr, "[mtp] gallocr_alloc_graph failed\n"); + ggml_gallocr_free(mtp_alloc); + return 1; + } + + // ── 3. Set MTP inputs and compute ──────────────────────── + // in_tok_embd: pre-dequantised F32 embedding of cur_tok. + // embed_token dequantises via w.embedder.embed() on CPU, avoiding + // ggml_get_rows on a Q4_K source (unsupported in CUDA get_rows). + if (!embed_token(w, cur_tok, mtp_g.in_tok_embd, backend)) { + std::fprintf(stderr, "[mtp] embed_token failed for tok=%d\n", cur_tok); + ggml_gallocr_free(mtp_alloc); + return 1; + } + // in_h_prev: captured by target graph into cache.mtp_h_prev + ggml_backend_tensor_copy(cache.mtp_h_prev, mtp_g.in_h_prev); + // in_pos: position of the draft token (= committed, 0-based) + { + int32_t p = committed; + ggml_backend_tensor_set(mtp_g.in_pos, &p, 0, sizeof(int32_t)); + } + + // Fill the FA mask for TQ3_0 + head_dim>=512 cross-attention layers. + // Real positions [0..kv_seq_len-1]: 0x0000 (F16 0.0 = admit). + // Padding positions [kv_seq_len..mask_width-1]: 0xFC00 (F16 -inf = exclude). + if (mtp_g.fa_mask && mtp_g.fa_mask->buffer) { + const int64_t mask_n = mtp_g.fa_mask->ne[0]; // total mask width + const int64_t kv_seq = mtp_g.fa_mask_kv_seq_len; // admitted positions + std::vector mask_buf(mask_n); + for (int64_t i = 0; i < mask_n; i++) { + mask_buf[i] = (i < kv_seq) ? 0x0000u : 0xFC00u; + } + ggml_backend_tensor_set(mtp_g.fa_mask, mask_buf.data(), 0, + sizeof(uint16_t) * mask_n); + } + + { + auto st = ggml_backend_graph_compute(backend, mtp_g.gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[mtp] MTP compute failed\n"); + ggml_gallocr_free(mtp_alloc); + return 1; + } + } + + // Read draft token from in-graph argmax + int32_t draft_tok = -1; + ggml_backend_tensor_get(mtp_g.out_argmax, &draft_tok, 0, sizeof(int32_t)); + + ggml_gallocr_free(mtp_alloc); + + // Emit the current token (already committed by target step above) + generated.push_back(cur_tok); + history.push_back(cur_tok); + std::printf("%d ", cur_tok); + std::fflush(stdout); + + if (first_token_ms < 0.0) { + first_token_ms = now_ms() - decode_t0; + } + + mtp_steps++; + + // ── 4+5. Check if draft matches target's greedy token ─── + if (mtp_steps <= 8) { + std::printf("[mtp-dbg] step=%d draft=%d target=%d %s\n", + mtp_steps, draft_tok, target_next, + draft_tok == target_next ? "MATCH" : "miss"); + std::fflush(stdout); + } + if (draft_tok == target_next) { + // MTP was right: accept draft token as next cur_tok + mtp_accepted++; + cur_tok = draft_tok; + } else { + // MTP was wrong: use target's token + cur_tok = target_next; + } + cache.last_tok = cur_tok; + + if ((int)generated.size() % 8 == 0) { + std::printf("[mtp-step %d] accept_rate=%.2f\n", + mtp_steps, + mtp_steps > 0 ? (float)mtp_accepted / mtp_steps : 0.0f); + } + + if (IS_EOS_TOK(cur_tok, w)) { + std::printf("\n[mtp] EOS token %d\n", cur_tok); + break; + } + } + + if (mtp_steps > 0) { + std::printf("\n[mtp] steps=%d accepted=%d accept_rate=%.2f\n", + mtp_steps, mtp_accepted, + (float)mtp_accepted / mtp_steps); + } + + } // end gamma == 1 + + } else { + // ── TARGET-ONLY DECODE LOOP ─────────────────────────────────── + // + // Single-token autoregressive path. + // Each iteration: + // 1. Feed `cur_tok` through the target at position `committed`. + // 2. Sample the next token from logits. + // 3. Append to generated sequence. + // 4. Stop if EOS or n_predict reached. + + while ((int)generated.size() < n_predict) { + + if (IS_EOS_TOK(cur_tok, w)) { + std::printf("\n[decode] EOS token %d at step %zu\n", + cur_tok, generated.size()); + break; + } + + if (committed >= ctx_size - 1) { + std::printf("\n[decode] context full at step %zu\n", + generated.size()); + break; + } + + // Build single-token decode graph + if (!build_gemma4_step(sg, w, cache, backend, + committed, /*n_tokens=*/1, + /*with_mask=*/true, + /*capture=*/false, + /*use_pflash=*/false, pflash_alpha, + fa_window)) { + std::fprintf(stderr, "[decode] build failed at step %zu\n", + generated.size()); + return 1; + } + + if (sg.attn_mask && sg.attn_mask->buffer) { + const int kv_len = committed + 1; + std::vector mask_buf; + build_causal_mask(mask_buf, kv_len, 1, committed); + ggml_backend_tensor_set(sg.attn_mask, mask_buf.data(), 0, + sizeof(uint16_t) * mask_buf.size()); + } + if (sg.swa_mask && sg.swa_mask->buffer) { + const SwaView swa_view = compute_swa_view(committed, 1, + w.swa_window, cache.swa_ctx_alloc); + std::vector swa_buf; + build_swa_causal_mask(swa_buf, + /*kv_start*/ committed, + /*n_tokens*/ 1, + /*swa_window*/ w.swa_window, + /*ring_size*/ swa_view.effective_win_len, + /*kv_end*/ committed + 1); + ggml_backend_tensor_set(sg.swa_mask, swa_buf.data(), 0, + sizeof(uint16_t) * swa_buf.size()); + } + + if (!embed_token(w, cur_tok, sg.inp_embed, backend)) return 1; + + int32_t pos_val = committed; + ggml_backend_tensor_set(sg.positions, &pos_val, 0, sizeof(int32_t)); + + double step_t0 = now_ms(); + auto st = ggml_backend_graph_compute(backend, sg.gf); + double step_t1 = now_ms(); + + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[decode] compute failed at step %zu\n", + generated.size()); + return 1; + } + + committed++; + cache.cur_pos = committed; + + // Fetch logits and sample + const int vocab = w.n_vocab; + std::vector logits_cpu(vocab); + ggml_backend_tensor_get(sg.logits, logits_cpu.data(), 0, + sizeof(float) * vocab); + + const int32_t next_tok = (int32_t)sample_logits( + logits_cpu.data(), vocab, sampler, history, rng); + + // Debug: check logits on first decode step + if (generated.empty()) { + float maxl = logits_cpu[0]; int maxi = 0; + for (int i = 1; i < vocab; i++) { + if (logits_cpu[i] > maxl) { maxl = logits_cpu[i]; maxi = i; } + } + std::printf("[tgt-only-dbg] logits[0..3]: %.3f %.3f %.3f %.3f max=%.3f@%d next=%d\n", + logits_cpu[0], logits_cpu[1], logits_cpu[2], logits_cpu[3], maxl, maxi, next_tok); + std::fflush(stdout); + } + + generated.push_back(cur_tok); + history.push_back(cur_tok); + + if (first_token_ms < 0.0 && !generated.empty()) { + first_token_ms = step_t1 - step_t0; + } + + // Print token id (a proper decoder would map id -> string here) + std::printf("%d ", cur_tok); + std::fflush(stdout); + + cur_tok = next_tok; + cache.last_tok = cur_tok; + + step_graph_free(sg); + } + } + + double decode_t1 = now_ms(); + const double decode_ms = decode_t1 - decode_t0; + const int n_gen = (int)generated.size(); + const double tps = (decode_ms > 0.0 && n_gen > 0) + ? n_gen / (decode_ms / 1000.0) + : 0.0; + + bench_tok_per_sec.push_back(tps); + + std::printf("\n"); + std::printf("[stats] generated=%d decode_ms=%.1f tok/s=%.2f " + "first_tok_ms=%.2f\n", + n_gen, decode_ms, tps, first_token_ms); + std::printf("[stats] prefill=%zu tokens context_used=%d/%d\n", + prompt_ids.size(), committed, ctx_size); + + if (have_draft && total_draft_steps > 0) { + std::printf("[spec] draft_steps=%d total_accepted=%d avg_accept=%.2f\n", + total_draft_steps, total_accepted, + (double)total_accepted / total_draft_steps); + } + + // ── Memory stats ────────────────────────────────────────────────── + { + size_t free_bytes = 0, total_bytes = 0; + cudaMemGetInfo(&free_bytes, &total_bytes); + const double used_gb = (total_bytes - free_bytes) / (1024.0 * 1024.0 * 1024.0); + const double total_gb = total_bytes / (1024.0 * 1024.0 * 1024.0); + std::printf("[mem] VRAM used=%.2f GB total=%.2f GB\n", + used_gb, total_gb); + } + + } // bench loop + + // ── Benchmark summary ───────────────────────────────────────────────── + if (bench_mode && bench_tok_per_sec.size() > 1) { + std::sort(bench_tok_per_sec.begin(), bench_tok_per_sec.end()); + const double median = bench_tok_per_sec[bench_tok_per_sec.size() / 2]; + const double best = bench_tok_per_sec.back(); + std::printf("\n[bench] median=%.2f tok/s best=%.2f tok/s runs=%zu\n", + median, best, bench_tok_per_sec.size()); + } + + // ── Cleanup ─────────────────────────────────────────────────────────── + step_graph_destroy(sg); + draft_step_destroy(dsg); + if (have_draft) { + free_draft_kv_cache(cache); + dw.tok_embd = nullptr; // prevent double-free (tok_embd lives in tok_embd_buf) + free_gemma4_draft_weights(dw); + if (tok_embd_buf) ggml_backend_buffer_free(tok_embd_buf); + if (tok_embd_ctx) ggml_free(tok_embd_ctx); + } + if (have_mtp) { + free_mtp_step_graph(mtp_g); + free_gemma4_mtp_assistant(mtp_w); + // mtp_h_prev lives in mtp_h_prev_buf/ctx (not base_ctx). + // Null out the pointer in cache before free_gemma4_cache to avoid + // dangling reference (cache struct is stack-allocated; the pointer + // would otherwise reference freed memory). + cache.mtp_h_prev = nullptr; + cache.mtp_h_prev_batch = nullptr; + cache.mtp_h_prev_enabled = false; + cache.mtp_h_prev_capture_mode = 0; + if (mtp_h_prev_buf) { ggml_backend_buffer_free(mtp_h_prev_buf); mtp_h_prev_buf = nullptr; } + if (mtp_h_prev_ctx) { ggml_free(mtp_h_prev_ctx); mtp_h_prev_ctx = nullptr; } + } + free_gemma4_cache(cache); + free_gemma4_target_weights(w); + ggml_backend_free(backend); + + return 0; +} diff --git a/dflash/test/gemma4/test_gemma4_kv_tq3.cpp b/dflash/test/gemma4/test_gemma4_kv_tq3.cpp new file mode 100644 index 000000000..495fc869b --- /dev/null +++ b/dflash/test/gemma4/test_gemma4_kv_tq3.cpp @@ -0,0 +1,179 @@ +// Smoke test: create a Gemma4 KV cache with TQ3_0 quantization and validate +// the resulting cache structure, alignment, and layer-to-KV-index mappings. +// +// Usage: test_gemma4_kv_tq3 + +#include "internal.h" +#include "gemma4.h" + +#include "ggml.h" +#include "ggml-backend.h" +#include "ggml-cuda.h" + +#include +#include +#include +#include +#include + +#ifdef _WIN32 +# define setenv(name, value, overwrite) _putenv_s(name, value) +#endif + +using namespace dflash27b; + +static void fail(const char * msg) { + std::fprintf(stderr, "FAIL: %s\n", msg); + std::exit(1); +} + +int main(int argc, char ** argv) { + if (argc < 2) { + std::fprintf(stderr, "usage: %s \n", argv[0]); + return 2; + } + + ggml_backend_t backend = ggml_backend_cuda_init(0); + if (!backend) { std::fprintf(stderr, "cuda init failed\n"); return 1; } + + GemmaTargetWeights w; + if (!load_gemma4_target_gguf(argv[1], backend, w)) { + std::fprintf(stderr, "load_gemma4_target_gguf: %s\n", dflash27b_last_error()); + ggml_backend_free(backend); + return 1; + } + std::printf("[target] n_layer=%d n_embd=%d n_head_kv=%d head_dim=%d " + "n_layer_kv=%d n_capture_layers=%d\n", + w.n_layer, w.n_embd, w.n_head_kv, w.head_dim, + w.n_layer_kv, w.n_capture_layers); + + // Set KV type environment variables to tq3_0 before cache creation + setenv("DFLASH27B_KV_K", "tq3_0", 1); + setenv("DFLASH27B_KV_V", "tq3_0", 1); + + const int max_ctx = 1024; + GemmaTargetCache cache; + if (!create_gemma4_cache(w, max_ctx, backend, cache)) { + std::fprintf(stderr, "create_gemma4_cache: %s\n", dflash27b_last_error()); + free_gemma4_target_weights(w); + ggml_backend_free(backend); + return 1; + } + std::printf("[cache] created max_ctx=%d kv_slots=%zu\n", + cache.max_ctx, cache.attn_k.size()); + + // Assert KV types resolved correctly + if (cache.kv_k_type != GGML_TYPE_TQ3_0) { + char buf[64]; + std::snprintf(buf, sizeof(buf), + "kv_k_type=%s expected tq3_0", ggml_type_name(cache.kv_k_type)); + fail(buf); + } + if (cache.kv_v_type != GGML_TYPE_TQ3_0) { + char buf[64]; + std::snprintf(buf, sizeof(buf), + "kv_v_type=%s expected tq3_0", ggml_type_name(cache.kv_v_type)); + fail(buf); + } + std::printf("[types] kv_k=%s kv_v=%s OK\n", + ggml_type_name(cache.kv_k_type), + ggml_type_name(cache.kv_v_type)); + + // Validate layer_to_kv_idx mapping + if ((int)cache.layer_to_kv_idx.size() != w.n_layer) { + char buf[64]; + std::snprintf(buf, sizeof(buf), + "layer_to_kv_idx.size()=%zu expected %d", + cache.layer_to_kv_idx.size(), w.n_layer); + fail(buf); + } + + const int n_kv_slots = (int)cache.attn_k.size(); + int n_shared_layers = 0; + for (int il = 0; il < w.n_layer; il++) { + const int idx = cache.layer_to_kv_idx[il]; + if (idx == -1) { + n_shared_layers++; + } else if (idx < 0 || idx >= n_kv_slots) { + char buf[128]; + std::snprintf(buf, sizeof(buf), + "layer_to_kv_idx[%d]=%d out of range [0, %d)", + il, idx, n_kv_slots); + fail(buf); + } + } + std::printf("[kv_idx] n_kv_slots=%d n_shared_layers=%d n_layer=%d\n", + n_kv_slots, n_shared_layers, w.n_layer); + + // Validate layer_to_donor_kv: shared layers must have a valid donor + if ((int)cache.layer_to_donor_kv.size() != w.n_layer) { + fail("layer_to_donor_kv.size() != n_layer"); + } + for (int il = 0; il < w.n_layer; il++) { + if (cache.layer_to_kv_idx[il] == -1) { + // This is a shared layer — must have a valid donor + const int donor = cache.layer_to_donor_kv[il]; + if (donor < 0 || donor >= n_kv_slots) { + char buf[128]; + std::snprintf(buf, sizeof(buf), + "layer_to_donor_kv[%d]=%d invalid for shared layer (n_kv_slots=%d)", + il, donor, n_kv_slots); + fail(buf); + } + } + } + std::printf("[donor_kv] all shared layers have valid donors OK\n"); + + // Validate TQ3_0 alignment: for TQ3_0, KV tensors must have ne[1] % 256 == 0 + // (create_gemma4_cache rounds max_ctx_alloc up to a multiple of 256 for TQ3_0). + for (int i = 0; i < n_kv_slots; i++) { + const ggml_tensor * K = cache.attn_k[i]; + const ggml_tensor * V = cache.attn_v[i]; + if (!K) { char buf[32]; std::snprintf(buf, sizeof(buf), "attn_k[%d] is null", i); fail(buf); } + if (!V) { char buf[32]; std::snprintf(buf, sizeof(buf), "attn_v[%d] is null", i); fail(buf); } + if (K->ne[1] % 256 != 0) { + char buf[128]; + std::snprintf(buf, sizeof(buf), + "attn_k[%d]: ne[1]=%" PRId64 " not a multiple of 256 (TQ3_0 alignment)", + i, K->ne[1]); + fail(buf); + } + if (V->ne[1] % 256 != 0) { + char buf[128]; + std::snprintf(buf, sizeof(buf), + "attn_v[%d]: ne[1]=%" PRId64 " not a multiple of 256 (TQ3_0 alignment)", + i, V->ne[1]); + fail(buf); + } + } + std::printf("[alignment] all %d KV tensors are 256-aligned OK\n", n_kv_slots); + + // Validate target_feat tensor + if (!cache.target_feat) fail("target_feat is null"); + const int64_t expected_feat_ne0 = (int64_t)w.n_capture_layers * w.n_embd; + if (cache.target_feat->ne[0] != expected_feat_ne0) { + char buf[128]; + std::snprintf(buf, sizeof(buf), + "target_feat->ne[0]=%" PRId64 " expected %" PRId64 + " (n_capture_layers=%d * n_embd=%d)", + cache.target_feat->ne[0], expected_feat_ne0, + w.n_capture_layers, w.n_embd); + fail(buf); + } + std::printf("[target_feat] ne=[%" PRId64 ", %" PRId64 "] type=%s OK\n", + cache.target_feat->ne[0], cache.target_feat->ne[1], + ggml_type_name(cache.target_feat->type)); + + // Print cache stats + std::printf("[stats] n_kv_slots=%d max_ctx=%d kv_seq_dim=%" PRId64 + " target_feat_cap=%d\n", + n_kv_slots, cache.max_ctx, + cache.attn_k[0]->ne[1], + cache.target_feat_cap); + + free_gemma4_cache(cache); + free_gemma4_target_weights(w); + ggml_backend_free(backend); + std::printf("PASS\n"); + return 0; +} diff --git a/dflash/test/gemma4/test_mtp_graph_shapes.cpp b/dflash/test/gemma4/test_mtp_graph_shapes.cpp new file mode 100644 index 000000000..b0da76baa --- /dev/null +++ b/dflash/test/gemma4/test_mtp_graph_shapes.cpp @@ -0,0 +1,298 @@ +// Phase 3a shape test: MTP step graph builds without crash and output tensor +// shapes match the contract: +// out_logits : F32 [n_vocab, 1] +// out_h_post : F32 [n_embd_backbone, 1] +// out_argmax : I32 [1] +// +// We stub GemmaTargetCache and GemmaTargetWeights with zero-initialised tensors +// of the correct shapes. No actual inference is performed — this is a graph +// construction smoke test only. +// +// Run: +// MTP_GGUF=/path/to/gemma-4-31B-it-assistant.Q4_K_M.gguf \ +// ./build/test_mtp_graph_shapes +// +// Requires MTP_GGUF to be set; exits 77 (autotools skip) if absent. + +#include "../src/internal.h" +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml-cuda.h" + +#include +#include +#include +#include +#include +#include + +using namespace dflash27b; + +static int fail(const char * msg) { + std::fprintf(stderr, "[FAIL] %s\n", msg); + return 1; +} + +// Build a minimal stub GemmaTargetWeights with tok_embd of the right shape. +// The stub does NOT allocate GPU memory for embedding data; graph construction +// only needs the tensor *metadata* (ne[], type), not data. +static bool build_stub_target_weights(ggml_backend_t backend, + int n_vocab, + int n_embd_backbone, + int n_layer, + const std::vector & swa_layers, + GemmaTargetWeights & out) { + // Minimal tensor count: tok_embd + per-layer rope_freqs (optional) + out_norm + output + const size_t n_tensors_est = (size_t)(n_layer + 8); + ggml_init_params ip{}; + ip.mem_size = n_tensors_est * ggml_tensor_overhead() + 4096; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + out.ctx = ggml_init(ip); + if (!out.ctx) return false; + + // tok_embd: [n_embd_backbone, n_vocab] (ggml ne[0]=embedding_dim, ne[1]=n_vocab) + out.tok_embd = ggml_new_tensor_2d(out.ctx, GGML_TYPE_F32, n_embd_backbone, n_vocab); + ggml_set_name(out.tok_embd, "stub_tok_embd"); + + // Populate fields needed by build_mtp_step_graph + // Dense 31B: head_dim=256 (from GGUF "gemma4.attention.key_length") + out.n_embd = n_embd_backbone; + out.n_head = 32; + out.n_head_kv = 8; + out.head_dim = 256; + out.head_dim_swa = 256; + out.n_layer = n_layer; + out.rope_theta = 1000000.0f; + out.rope_theta_swa = 1000000.0f; + out.attn_scale = 1.0f; + out.logit_softcap = 30.0f; + out.swa_layers = swa_layers; + + // Populate minimal per-layer structs (only rope_freqs is accessed by MTP graph + // for full-attention donor layers) + out.layers.resize((size_t)n_layer); + // Leave rope_freqs nullptr for all layers (proportional RoPE freq_factors are + // optional; nullptr → falls back to base rope_theta scaling). + + out.backend = backend; + out.buf = ggml_backend_alloc_ctx_tensors(out.ctx, backend); + if (!out.buf) { ggml_free(out.ctx); out.ctx = nullptr; return false; } + + // Zero-init the tok_embd (so GPU tensor is valid even though we won't run compute) + ggml_backend_tensor_set(out.tok_embd, nullptr, 0, 0); // no-op; buffer already zeroed + + return true; +} + +// Build a minimal stub GemmaTargetCache with KV tensors of the right shapes. +// attn_k[i]: [head_dim_kv, max_ctx, n_head_kv] +// attn_v[i]: [head_dim_kv, max_ctx, n_head_kv] +// head_dim_kv_swa and head_dim_kv_full allow different head_dims per attention type. +static bool build_stub_target_cache(ggml_backend_t backend, + int n_layer, + int n_kv_per_layer, // n_head_kv for KV cache + int head_dim_kv_swa, // head_dim for SWA layers + int head_dim_kv_full, // head_dim for full-attn layers + int max_ctx, + const std::vector & swa_layers, + GemmaTargetCache & out) { + // Count KV-owning layers (non-shared). For stub, all layers own a KV slot. + const int n_kv_slots = n_layer; // stub: one per layer (no sharing) + + const size_t n_tensors_est = (size_t)(2 * n_kv_slots + 4); + ggml_init_params ip{}; + ip.mem_size = n_tensors_est * ggml_tensor_overhead() + 4096; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + out.base_ctx = ggml_init(ip); + if (!out.base_ctx) return false; + + out.layer_to_kv_idx.resize((size_t)n_layer); + out.layer_to_donor_kv.resize((size_t)n_layer, -1); + out.attn_k.resize((size_t)n_kv_slots, nullptr); + out.attn_v.resize((size_t)n_kv_slots, nullptr); + + for (int il = 0; il < n_layer; il++) { + out.layer_to_kv_idx[il] = il; // one-to-one for stub + + // Use different head_dim per attention type + const bool is_swa = (il < (int)swa_layers.size()) && swa_layers[il]; + const int layer_head_dim = is_swa ? head_dim_kv_swa : head_dim_kv_full; + ggml_tensor * K = ggml_new_tensor_3d(out.base_ctx, GGML_TYPE_F16, + layer_head_dim, max_ctx, n_kv_per_layer); + ggml_tensor * V = ggml_new_tensor_3d(out.base_ctx, GGML_TYPE_F16, + layer_head_dim, max_ctx, n_kv_per_layer); + char name[64]; + std::snprintf(name, sizeof(name), "stub_k_%d", il); + ggml_set_name(K, name); + std::snprintf(name, sizeof(name), "stub_v_%d", il); + ggml_set_name(V, name); + out.attn_k[il] = K; + out.attn_v[il] = V; + } + + out.backend = backend; + out.max_ctx = max_ctx; + out.cur_pos = 16; // pretend we have 16 committed tokens + out.swa_ctx_alloc = max_ctx; + (void)swa_layers; + + out.base_buf = ggml_backend_alloc_ctx_tensors(out.base_ctx, backend); + if (!out.base_buf) { ggml_free(out.base_ctx); out.base_ctx = nullptr; return false; } + + // Zero-init (backend buffer is already zeroed by alloc; explicit set skipped for perf) + + return true; +} + +int main() { + const char * mtp_path = std::getenv("MTP_GGUF"); + if (!mtp_path) { + std::fprintf(stderr, "[skip] MTP_GGUF not set; skipping test_mtp_graph_shapes\n"); + return 77; // autotools skip code + } + + ggml_backend_t backend = ggml_backend_cuda_init(0); + if (!backend) { + return fail("ggml_backend_cuda_init(0) failed"); + } + + // ── Load MTP weights ───────────────────────────────────────────────────── + MtpDrafterWeights mtp{}; + if (!load_gemma4_mtp_assistant(std::string(mtp_path), backend, mtp)) { + std::fprintf(stderr, " loader error: %s\n", gemma4_last_error()); + ggml_backend_free(backend); + return fail("load_gemma4_mtp_assistant failed"); + } + + const int n_embd_backbone = mtp.n_embd_backbone; // e.g. 5376 + const int n_vocab = 262144; // Dense 31B vocab + const int n_target_layers = 60; // Dense 31B + const int max_ctx = 64; // small stub context + + // Dense 31B SWA pattern: odd-indexed = SWA, even = full attention + std::vector target_swa(n_target_layers, false); + for (int il = 0; il < n_target_layers; il++) { + target_swa[il] = ((il % 2) == 1); + } + + // ── Build stub target structures ───────────────────────────────────────── + GemmaTargetWeights stub_target{}; + if (!build_stub_target_weights(backend, n_vocab, n_embd_backbone, + n_target_layers, target_swa, stub_target)) { + ggml_backend_free(backend); + return fail("build_stub_target_weights failed"); + } + + GemmaTargetCache stub_cache{}; + // KV head_dim: derived from MTP weight shapes (attn_q_norm->ne[0] gives per-head Q dim, + // which must equal the target KV head_dim for flash_attn_ext to work). + // Dense 31B: SWA layers use head_dim=256, full-attn layers use head_dim=512 + // (derived from mtp.layers[0].attn_q_norm->ne[0]=256 for SWA, [3].attn_q_norm->ne[0]=512 for full). + const int head_dim_swa_stub = (int)mtp.layers[0].attn_q_norm->ne[0]; // SWA layers 0-2 + const int head_dim_full_stub = (int)mtp.layers[3].attn_q_norm->ne[0]; // Full-attn layer 3 + std::fprintf(stderr, "[shape_test] MTP Q head_dim: SWA=%d, full=%d\n", + head_dim_swa_stub, head_dim_full_stub); + if (!build_stub_target_cache(backend, n_target_layers, + /*n_kv_per_layer=*/8, + head_dim_swa_stub, head_dim_full_stub, + max_ctx, target_swa, stub_cache)) { + free_gemma4_target_weights(stub_target); + ggml_backend_free(backend); + return fail("build_stub_target_cache failed"); + } + + // ── Build MTP step graph ───────────────────────────────────────────────── + MtpStepGraph graph{}; + const int attn_pos = stub_cache.cur_pos; // = 16 + + if (!build_mtp_step_graph(mtp, stub_cache, stub_target, graph, attn_pos)) { + std::fprintf(stderr, " build error: %s\n", gemma4_last_error()); + free_gemma4_target_weights(stub_target); + // Note: stub_cache KV tensors point into base_ctx; free manually: + if (stub_cache.base_buf) ggml_backend_buffer_free(stub_cache.base_buf); + if (stub_cache.base_ctx) ggml_free(stub_cache.base_ctx); + ggml_backend_free(backend); + return fail("build_mtp_step_graph failed"); + } + + // ── Shape assertions ───────────────────────────────────────────────────── + + // 1. Input shapes + if (!graph.in_tok || graph.in_tok->ne[0] != 1 || + graph.in_tok->type != GGML_TYPE_I32) { + ggml_backend_free(backend); + return fail("in_tok shape/type mismatch: expected I32[1]"); + } + + if (!graph.in_h_prev || + graph.in_h_prev->ne[0] != (int64_t)n_embd_backbone || + graph.in_h_prev->ne[1] != 1 || + graph.in_h_prev->type != GGML_TYPE_F32) { + std::fprintf(stderr, " in_h_prev->ne = [%lld, %lld]\n", + (long long)(graph.in_h_prev ? graph.in_h_prev->ne[0] : -1), + (long long)(graph.in_h_prev ? graph.in_h_prev->ne[1] : -1)); + ggml_backend_free(backend); + return fail("in_h_prev shape/type mismatch: expected F32[n_embd_backbone, 1]"); + } + + if (!graph.in_pos || graph.in_pos->ne[0] != 1 || + graph.in_pos->type != GGML_TYPE_I32) { + ggml_backend_free(backend); + return fail("in_pos shape/type mismatch: expected I32[1]"); + } + + // 2. out_h_post: F32 [n_embd_backbone, 1] + if (!graph.out_h_post || + graph.out_h_post->ne[0] != (int64_t)n_embd_backbone || + graph.out_h_post->ne[1] != 1 || + graph.out_h_post->type != GGML_TYPE_F32) { + std::fprintf(stderr, " out_h_post->ne = [%lld, %lld], type=%s\n", + (long long)(graph.out_h_post ? graph.out_h_post->ne[0] : -1), + (long long)(graph.out_h_post ? graph.out_h_post->ne[1] : -1), + graph.out_h_post ? ggml_type_name(graph.out_h_post->type) : "null"); + ggml_backend_free(backend); + return fail("out_h_post shape mismatch: expected F32[n_embd_backbone, 1]"); + } + + // 3. out_logits: F32 [n_vocab, 1] + if (!graph.out_logits || + graph.out_logits->ne[0] != (int64_t)n_vocab || + graph.out_logits->ne[1] != 1 || + graph.out_logits->type != GGML_TYPE_F32) { + std::fprintf(stderr, " out_logits->ne = [%lld, %lld], type=%s\n", + (long long)(graph.out_logits ? graph.out_logits->ne[0] : -1), + (long long)(graph.out_logits ? graph.out_logits->ne[1] : -1), + graph.out_logits ? ggml_type_name(graph.out_logits->type) : "null"); + ggml_backend_free(backend); + return fail("out_logits shape mismatch: expected F32[n_vocab, 1]"); + } + + // 4. out_argmax: I32 [1] + if (!graph.out_argmax || + graph.out_argmax->ne[0] != 1 || + graph.out_argmax->type != GGML_TYPE_I32) { + std::fprintf(stderr, " out_argmax->ne[0]=%lld type=%s\n", + (long long)(graph.out_argmax ? graph.out_argmax->ne[0] : -1), + graph.out_argmax ? ggml_type_name(graph.out_argmax->type) : "null"); + ggml_backend_free(backend); + return fail("out_argmax shape/type mismatch: expected I32[1]"); + } + + std::fprintf(stderr, "[PASS] all shape assertions passed for MTP step graph\n"); + std::fprintf(stderr, " n_embd_backbone=%d, n_vocab=%d, n_layers=%zu, attn_pos=%d\n", + n_embd_backbone, n_vocab, mtp.layers.size(), attn_pos); + + // Cleanup + free_mtp_step_graph(graph); + // Stub cache: manual teardown since we bypassed create_gemma4_cache + if (stub_cache.base_buf) ggml_backend_buffer_free(stub_cache.base_buf); + if (stub_cache.base_ctx) ggml_free(stub_cache.base_ctx); + free_gemma4_target_weights(stub_target); + free_gemma4_mtp_assistant(mtp); + ggml_backend_free(backend); + + return 0; +} diff --git a/dflash/test/gemma4/test_mtp_loader.cpp b/dflash/test/gemma4/test_mtp_loader.cpp new file mode 100644 index 000000000..ea44ca947 --- /dev/null +++ b/dflash/test/gemma4/test_mtp_loader.cpp @@ -0,0 +1,128 @@ +// Phase 2 RED test: Gemma4 MTP loader (load_gemma4_mtp_assistant) +// +// Should NOT compile today — MtpDrafterWeights and load_gemma4_mtp_assistant +// do not yet exist in internal.h. Once Phase 2 GREEN lands, the test compiles +// and 7 assertions verify the loader contract per +// .sisyphus/notes/mtp-spike-2026-05-09.md (sections "Contract — Phase 2"). +// +// Run: +// cd dflash && cmake --build build --target test_mtp_loader && \ +// MTP_GGUF=$ROOT/models/gemma4-mtp-31B/gemma-4-31B-it-assistant.Q4_K_M.gguf \ +// ./build/test_mtp_loader + +#include "../src/internal.h" +#include "ggml.h" +#include "ggml-backend.h" +#include "ggml-cuda.h" + +#include +#include +#include +#include +#include + +using namespace dflash27b; + +static int fail(const char *msg) { + std::fprintf(stderr, "[red] FAIL: %s\n", msg); + return 1; +} + +int main() { + const char *p = std::getenv("MTP_GGUF"); + if (!p) { + std::fprintf(stderr, "[skip] MTP_GGUF env not set; expected:\n"); + std::fprintf(stderr, " /home/peppi/Dev/lucebox-hub/models/gemma4-mtp-31B/gemma-4-31B-it-assistant.Q4_K_M.gguf\n"); + return 77; // autotools skip + } + + // Backend init (reuse the pattern from test_gemma4_dflash.cpp) + ggml_backend_t backend = ggml_backend_cuda_init(0); + if (!backend) { + return fail("ggml_backend_cuda_init failed"); + } + + // The function under test (Phase 2 GREEN must define this) + MtpDrafterWeights mtp; + bool ok = load_gemma4_mtp_assistant(std::string(p), backend, mtp); + if (!ok) { + ggml_backend_free(backend); + return fail("load_gemma4_mtp_assistant returned false"); + } + + // Assertion 1: n_embd_backbone matches target hidden (Dense 31B = 5376) + if (mtp.n_embd_backbone != 5376) { + std::fprintf(stderr, " n_embd_backbone=%d expected 5376\n", mtp.n_embd_backbone); + ggml_backend_free(backend); + return fail("n_embd_backbone mismatch"); + } + + // Assertion 2: requires_target_arch == "gemma4" (vLLM #41789 guard) + if (mtp.requires_target_arch != "gemma4") { + std::fprintf(stderr, " requires_target_arch=\"%s\" expected \"gemma4\"\n", + mtp.requires_target_arch.c_str()); + ggml_backend_free(backend); + return fail("requires_target_arch mismatch"); + } + + // Assertion 3: 4 MTP transformer blocks (per MTP.md spec) + if (mtp.layers.size() != 4) { + std::fprintf(stderr, " layers.size()=%zu expected 4\n", mtp.layers.size()); + ggml_backend_free(backend); + return fail("MTP block count mismatch"); + } + + // Assertion 4: attention_k_eq_v=true (Gemma4 quirk; V always read from cache) + if (!mtp.attention_k_eq_v) { + ggml_backend_free(backend); + return fail("attention_k_eq_v should be true for Gemma4"); + } + + // Assertion 5: pre_projection tensor shape [2*n_embd_backbone, n_embd_mtp] + // pre_projection concatenates [tok_embd(n_embd_backbone) + h_prev(n_embd_backbone)] + // and projects to MTP's own hidden size n_embd. + // ne[0] = 2*n_embd_backbone = 10752, ne[1] = mtp.n_embd (the MTP model's hidden size) + if (!mtp.pre_projection || + mtp.pre_projection->ne[0] != 2 * (int64_t)mtp.n_embd_backbone) { + std::fprintf(stderr, " pre_projection->ne[0]=%lld expected %d\n", + (long long)(mtp.pre_projection ? mtp.pre_projection->ne[0] : -1), + 2 * mtp.n_embd_backbone); + ggml_backend_free(backend); + return fail("pre_projection shape mismatch (ne[0] != 2*n_embd_backbone)"); + } + + // Assertion 6: post_projection tensor shape [n_embd_mtp, n_embd_backbone] + // Projects MTP hidden back to target backbone dimension. + // ne[0] = mtp.n_embd, ne[1] = n_embd_backbone = 5376 + if (!mtp.post_projection || + mtp.post_projection->ne[1] != (int64_t)mtp.n_embd_backbone) { + std::fprintf(stderr, " post_projection->ne[1]=%lld expected %d\n", + (long long)(mtp.post_projection ? mtp.post_projection->ne[1] : -1), + mtp.n_embd_backbone); + ggml_backend_free(backend); + return fail("post_projection shape mismatch (ne[1] != n_embd_backbone)"); + } + + // Assertion 7: per-MTP-layer donor KV resolution (NOT global pair). + // For Dense 31B (60 target layers, SWA pattern from gemma4_target_graph): + // even-indexed target layers = full attention, last = 58 + // odd-indexed target layers = SWA attention, last = 59 + // Each MTP layer's donor_target_layer must be exactly 58 (full) or 59 (SWA) + // depending on that layer's own attention type. A bounds-only check would + // accept any value in [0, 60), which misses wrong-type assignments. + for (size_t il = 0; il < mtp.layers.size(); ++il) { + const int32_t got = mtp.layers[il].donor_target_layer; + const int32_t want = mtp.layers[il].is_swa ? 59 : 58; // last SWA / last full-attn + if (got != want) { + std::fprintf(stderr, + " layer %zu is_swa=%d donor_target_layer=%d expected %d\n", + il, (int)mtp.layers[il].is_swa, got, want); + ggml_backend_free(backend); + return fail("donor_target_layer does not point to last matching-type target layer"); + } + } + + ggml_backend_free(backend); + std::fprintf(stderr, "[red->green] all 7 assertions PASS\n"); + return 0; +} diff --git a/dflash/test/test_flash_attn_sparse.cpp b/dflash/test/test_flash_attn_sparse.cpp new file mode 100644 index 000000000..f38c470f0 --- /dev/null +++ b/dflash/test/test_flash_attn_sparse.cpp @@ -0,0 +1,201 @@ +#include "ggml.h" +#include "ggml-cuda.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "../src/pflash_ggml_adapter.h" +#include +#include +#include +#include +#include + +// Compare dense FA output vs sparse FA output +// At alpha=1.0 (select all blocks), sparse should match dense exactly. +static bool test_sparse_matches_dense(ggml_backend_t backend, int S, int H, int Hk, int D) { + // Use no_alloc=true so tensors are NOT pre-allocated in CPU memory. + // The gallocr will allocate them in the CUDA backend buffer instead, + // which is required for ggml_backend_tensor_set/get to work. + const size_t ctx_size = 256 * 1024 * 1024; + ggml_init_params params = { ctx_size, nullptr, /*no_alloc=*/true }; + ggml_context * ctx = ggml_init(params); + + // Q must be F32: the CUDA FA kernel asserts Q->type == GGML_TYPE_F32 + // K and V can be F16; the kernel converts them internally if needed + // ggml FA convention: ne[0]=D, ne[1]=S, ne[2]=H + ggml_tensor * Q = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, D, S, H); + // K [D, S, Hk] + ggml_tensor * K = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, D, S, Hk); + // V [D, S, Hk] + ggml_tensor * V = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, D, S, Hk); + + // Mark Q, K, V as graph inputs so gallocr allocates persistent backend buffers for them + ggml_set_input(Q); + ggml_set_input(K); + ggml_set_input(V); + + // Causal mask for dense FA: ne[0]=KV_len, ne[1]=Q_len. + // The kernel indexes it as mask[q * ne[0] + kv], so mask[q][kv] = (kv <= q) ? 0 : -inf. + // pFlash applies causal masking at block granularity, so we give dense FA the same mask. + ggml_tensor * mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, S, S); + ggml_set_input(mask); + + // Dense FA + ggml_tensor * dense_out = ggml_flash_attn_ext(ctx, Q, K, V, mask, 1.0f/sqrtf((float)D), 0.0f, 0.0f); + + // Sparse FA (alpha=1.0 = select all blocks = should match dense) + ggml_tensor * sparse_out = ggml_flash_attn_sparse(ctx, Q, K, V, 1.0f/sqrtf((float)D), 1.0f); + + // Mark outputs so gallocr never frees/overwrites them before readback + ggml_set_output(dense_out); + ggml_set_output(sparse_out); + + ggml_cgraph * gf = ggml_new_graph(ctx); + ggml_build_forward_expand(gf, dense_out); + ggml_build_forward_expand(gf, sparse_out); + + ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_alloc_graph(alloc, gf); + + // Fill Q (F32), K (F16), V (F16) with random data + srand(42); + + std::vector q_buf(D * S * H); + for (auto & x : q_buf) x = (float)(rand() % 1000 - 500) / 500.0f; + ggml_backend_tensor_set(Q, q_buf.data(), 0, ggml_nbytes(Q)); + + std::vector buf(D * S * Hk); + for (auto & x : buf) x = ggml_fp32_to_fp16((float)(rand() % 1000 - 500) / 500.0f); + ggml_backend_tensor_set(K, buf.data(), 0, ggml_nbytes(K)); + + buf.resize(D * S * Hk); + for (auto & x : buf) x = ggml_fp32_to_fp16((float)(rand() % 1000 - 500) / 500.0f); + ggml_backend_tensor_set(V, buf.data(), 0, ggml_nbytes(V)); + + // Fill causal mask: mask[q * S + kv] = (kv <= q) ? 0.0f : -INFINITY + { + std::vector mask_data(S * S); + for (int q = 0; q < S; q++) { + for (int kv = 0; kv < S; kv++) { + float val = (kv <= q) ? 0.0f : -INFINITY; + mask_data[q * S + kv] = ggml_fp32_to_fp16(val); + } + } + ggml_backend_tensor_set(mask, mask_data.data(), 0, S * S * sizeof(ggml_fp16_t)); + } + + ggml_backend_graph_compute(backend, gf); + + // Compare outputs (dense_out is GGML_TYPE_F32, use ggml_nelements for element count) + const size_t n_elems = ggml_nelements(dense_out); + const size_t out_bytes = n_elems * sizeof(float); + std::vector dense_data(n_elems); + std::vector sparse_data(n_elems); + ggml_backend_tensor_get(dense_out, dense_data.data(), 0, out_bytes); + ggml_backend_tensor_get(sparse_out, sparse_data.data(), 0, out_bytes); + + float max_diff = 0.0f; + bool any_nonfinite = false; + for (size_t i = 0; i < dense_data.size(); i++) { + if (!std::isfinite(sparse_data[i]) || !std::isfinite(dense_data[i])) { + any_nonfinite = true; + break; + } + float diff = fabsf(dense_data[i] - sparse_data[i]); + if (diff > max_diff) max_diff = diff; + } + + printf("[test] S=%d H=%d Hk=%d D=%d max_diff=%.6f nonfinite=%s %s\n", + S, H, Hk, D, max_diff, + any_nonfinite ? "YES" : "no", + (max_diff < 1.0f && !any_nonfinite) ? "PASS" : "FAIL"); + + ggml_gallocr_free(alloc); + ggml_free(ctx); + return max_diff < 1.0f && !any_nonfinite; +} + +// Sanity-check sparse attention at alpha < 1.0: +// The output should not be all zeros (basic liveness check). +// With alpha < 1.0 outputs will differ from dense FA — that is expected and not tested here. +static bool test_sparse_alpha(ggml_backend_t backend, int S, int H, int Hk, int D, float alpha) { + const size_t ctx_size = 256 * 1024 * 1024; + ggml_init_params params = { ctx_size, nullptr, /*no_alloc=*/true }; + ggml_context * ctx = ggml_init(params); + + // ggml FA convention: ne[0]=D, ne[1]=S, ne[2]=H + ggml_tensor * Q = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, D, S, H); + ggml_tensor * K = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, D, S, Hk); + ggml_tensor * V = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, D, S, Hk); + + ggml_set_input(Q); + ggml_set_input(K); + ggml_set_input(V); + + ggml_tensor * sparse_out = ggml_flash_attn_sparse(ctx, Q, K, V, 1.0f/sqrtf((float)D), alpha); + ggml_set_output(sparse_out); + + ggml_cgraph * gf = ggml_new_graph(ctx); + ggml_build_forward_expand(gf, sparse_out); + + ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_alloc_graph(alloc, gf); + + srand(42); + + std::vector q_buf(D * S * H); + for (auto & x : q_buf) x = (float)(rand() % 1000 - 500) / 500.0f; + ggml_backend_tensor_set(Q, q_buf.data(), 0, ggml_nbytes(Q)); + + std::vector buf(D * S * Hk); + for (auto & x : buf) x = ggml_fp32_to_fp16((float)(rand() % 1000 - 500) / 500.0f); + ggml_backend_tensor_set(K, buf.data(), 0, ggml_nbytes(K)); + + buf.resize(D * S * Hk); + for (auto & x : buf) x = ggml_fp32_to_fp16((float)(rand() % 1000 - 500) / 500.0f); + ggml_backend_tensor_set(V, buf.data(), 0, ggml_nbytes(V)); + + ggml_backend_graph_compute(backend, gf); + + const size_t n_elems = ggml_nelements(sparse_out); + const size_t out_bytes = n_elems * sizeof(float); + std::vector out_data(n_elems); + ggml_backend_tensor_get(sparse_out, out_data.data(), 0, out_bytes); + + // Basic sanity: output must not be all zeros + float max_abs = 0.0f; + for (size_t i = 0; i < out_data.size(); i++) { + float v = fabsf(out_data[i]); + if (v > max_abs) max_abs = v; + } + + bool pass = max_abs > 1e-6f; + printf("[test_sparse_alpha] alpha=%.2f S=%d H=%d Hk=%d D=%d max_abs=%.6f %s\n", + alpha, S, H, Hk, D, max_abs, pass ? "PASS" : "FAIL (all zeros)"); + + ggml_gallocr_free(alloc); + ggml_free(ctx); + return pass; +} + +int main() { + ggml_backend_t backend = ggml_backend_cuda_init(0); + if (!backend) { + fprintf(stderr, "CUDA backend not available\n"); + return 1; + } + + pflash_register_ggml_kernel(); + + bool ok = true; + ok &= test_sparse_matches_dense(backend, 256, 16, 8, 128); // small + ok &= test_sparse_matches_dense(backend, 1024, 16, 8, 128); // medium + ok &= test_sparse_matches_dense(backend, 4096, 16, 8, 128); // large + + // Alpha < 1.0: pFlash kernel with moderate and aggressive sparsity + ok &= test_sparse_alpha(backend, 1024, 16, 8, 128, 0.5f); // moderate sparsity + ok &= test_sparse_alpha(backend, 4096, 16, 8, 128, 0.12f); // aggressive sparsity (default alpha) + + ggml_backend_free(backend); + printf("\n%s\n", ok ? "ALL TESTS PASSED" : "SOME TESTS FAILED"); + return ok ? 0 : 1; +}