From a5da2f52d3180833a916e6a5353a75e16dee5dcf Mon Sep 17 00:00:00 2001 From: Javier Pazo Date: Sat, 9 May 2026 12:37:09 +0200 Subject: [PATCH 1/2] feat(dflash): wire caller-provided SWA mask through draft graph MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the Qwen3.6 draft graph is built for a context that exceeds the SWA window, the caller-provided sliding-window attention mask must reach `ggml_flash_attn_ext` on the SWA layers. Previously the mask was constructed and then nullified post-construction in qwen3_dflash_graph, so SWA layers ran without the intended visibility constraint at long contexts. This change makes the wiring explicit and pinable as a contract: * `dflash_graph.h` — `DraftGraphInputs` gains an optional `attn_mask` field. Documented as caller-owned, type F16, with shape `[kv_len, q_len]` (or padded `[kv_pad, q_pad]`), values `0` for visible positions and `-inf` for masked positions. Two helpers added so callers do not reimplement the same logic: bool draft_graph_needs_swa_mask(const DraftWeights & w, int ctx_len); void build_draft_swa_mask(std::vector & out, int ctx_len, int q_len, int swa_window); `lm_head` is normalised to a default-null member at the same time (small consistency fix; layout unchanged). * `qwen3_dflash_graph.cpp` — when `total_k > swa_window` on a SWA layer, the graph wires the caller-provided mask through to `ggml_flash_attn_ext` and stops nullifying it. Layers that are not SWA still ignore the mask, as before. * `smoke_draft_graph.cpp` and `test_vs_oracle.cpp` — small alignment so the existing tests can build / fill the draft SWA mask when `ctx_len + q_len > swa_window`. No new test scaffolding is added in this commit; the focused regression test lives in a separate PR (`test(dflash): contract test for draft SWA mask wiring`) so each PR keeps one concern. Validation: * Built and ran `smoke_draft_graph` and `test_vs_oracle` on RTX 6000 Ada (sm_89), Heretic Q4_K_M target, FP16 safetensors drafter, FA_WINDOW=0. Both tests pass before and after; the behaviour at ctx_len <= swa_window is unchanged (mask not needed and not consumed). * At long context the SWA layers now respect the caller mask. Verification vs existing community PRs: COMP-COMPL with PR #94 (open, "support Qwen3.6-27B-DFlash draft (SWA layers)", Quitetall) and PR #129 (open Draft, "sliding window attention for Qwen3.6 draft model", howard0su). * PR #94 wires SWA via masks (same family as this PR). * PR #129 wires SWA via per-layer K/V truncation instead. The interface added here (caller-mask field + helpers) is small enough that it can survive either approach landing first. If PR #94 lands first, this commit should rebase cleanly because it formalises the mask path #94 already needs internally; if PR #129 lands first, the mask path here remains useful for callers that prefer mask semantics. Maintainers, happy to coordinate ordering. Author: Javier Pazo --- dflash/src/dflash_graph.h | 15 ++++++- dflash/src/qwen3_dflash_graph.cpp | 66 ++++++++++++++++++++++++++++++- dflash/test/smoke_draft_graph.cpp | 12 ++++++ dflash/test/test_vs_oracle.cpp | 12 ++++++ 4 files changed, 103 insertions(+), 2 deletions(-) diff --git a/dflash/src/dflash_graph.h b/dflash/src/dflash_graph.h index 304ff8e39..5b5204b6b 100644 --- a/dflash/src/dflash_graph.h +++ b/dflash/src/dflash_graph.h @@ -1,6 +1,9 @@ // Shared inputs/outputs for the DFlash draft graph builder. #pragma once +#include +#include + #include "ggml.h" namespace dflash27b { @@ -13,11 +16,15 @@ struct DraftGraphInputs { ggml_tensor * target_hidden_cat;// [5*hidden, ctx_len, 1] f32 ggml_tensor * positions_q; // [q_len] i32 values [ctx_len..ctx_len+q_len-1] ggml_tensor * positions_k; // [ctx_len+q_len] i32 values [0..ctx_len+q_len-1] + // Optional SWA mask for long-context sliding-attention layers. + // Shape [kv_len, q_len] or padded [kv_pad, q_pad], type F16, values + // 0 for visible positions and -inf for masked positions. + ggml_tensor * attn_mask = nullptr; // Optional: if non-null, the graph projects final hidden states through // this LM head (shape [hidden, vocab]) and returns logits instead of // hidden states. Used for DFlash integration where the draft shares the // target's lm_head. - ggml_tensor * lm_head; + ggml_tensor * lm_head = nullptr; }; struct DraftGraphOutputs { @@ -30,4 +37,10 @@ DraftGraphOutputs build_draft_graph( const DraftWeights & w, const DraftGraphInputs & in); +bool draft_graph_needs_swa_mask(const DraftWeights & w, int ctx_len); +void build_draft_swa_mask(std::vector & out, + int ctx_len, + int q_len, + int swa_window); + } // namespace dflash27b diff --git a/dflash/src/qwen3_dflash_graph.cpp b/dflash/src/qwen3_dflash_graph.cpp index 638bd8735..2060d00ee 100644 --- a/dflash/src/qwen3_dflash_graph.cpp +++ b/dflash/src/qwen3_dflash_graph.cpp @@ -31,10 +31,46 @@ #include "internal.h" #include "dflash_graph.h" +#include #include +#include namespace dflash27b { +bool draft_graph_needs_swa_mask(const DraftWeights & w, int ctx_len) { + if (w.swa_window <= 0) { + return false; + } + const int total_k = ctx_len + DFLASH27B_DRAFT_BLOCK_SIZE; + if (total_k <= w.swa_window) { + return false; + } + for (int il = 0; il < w.n_layer; ++il) { + if (w.layers[il].is_swa) { + return true; + } + } + return false; +} + +void build_draft_swa_mask(std::vector & out, + int ctx_len, + int q_len, + int swa_window) { + static constexpr uint16_t F16_ZERO = 0x0000; + static constexpr uint16_t F16_NEG_INF = 0xFC00; + + const int total_k = ctx_len + q_len; + out.assign((size_t)total_k * q_len, F16_NEG_INF); + for (int q = 0; q < q_len; ++q) { + const int abs_q = ctx_len + q; + const int min_k = std::max(0, abs_q - swa_window); + for (int k = min_k; k < total_k; ++k) { + out[(size_t)q * total_k + k] = F16_ZERO; + } + } +} + DraftGraphOutputs build_draft_graph( ggml_context * ctx, const DraftWeights & w, @@ -118,8 +154,36 @@ DraftGraphOutputs build_draft_graph( V = ggml_cont (ctx, V); // ── 2f. Non-causal flash attention; GQA broadcast handled internally. + // For SWA layers (Qwen3.6 draft): apply sliding window mask + // limiting context K/V to the last `swa_window` positions. const float scale = 1.0f / std::sqrt((float)head_dim); - ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, /*mask=*/nullptr, + ggml_tensor * attn_mask = nullptr; + if (L.is_swa && w.swa_window > 0 && total_k > w.swa_window) { + if (!in.attn_mask) { + set_last_error("build_draft_graph: SWA layer requires a non-null attn_mask"); + return {}; + } + if (in.attn_mask->type != GGML_TYPE_F16) { + char buf[128]; + std::snprintf(buf, sizeof(buf), + "build_draft_graph: SWA attn_mask must be F16, got %s", + ggml_type_name(in.attn_mask->type)); + set_last_error(buf); + return {}; + } + if (in.attn_mask->ne[0] < total_k || in.attn_mask->ne[1] < q_len) { + char buf[160]; + std::snprintf(buf, sizeof(buf), + "build_draft_graph: SWA attn_mask too small (%lld x %lld, need >= %d x %d)", + (long long)in.attn_mask->ne[0], + (long long)in.attn_mask->ne[1], + total_k, q_len); + set_last_error(buf); + return {}; + } + attn_mask = in.attn_mask; + } + ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, attn_mask, scale, /*max_bias=*/0.0f, /*logit_softcap=*/0.0f); // attn result: [n_embd_v=head_dim, n_head, n_batch=q_len, 1] diff --git a/dflash/test/smoke_draft_graph.cpp b/dflash/test/smoke_draft_graph.cpp index 166722168..8a1fa02e5 100644 --- a/dflash/test/smoke_draft_graph.cpp +++ b/dflash/test/smoke_draft_graph.cpp @@ -85,6 +85,12 @@ int main(int argc, char ** argv) { ggml_tensor * target_hid = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, fc_in, ctx_len, 1); ggml_tensor * pos_q = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, q_len); ggml_tensor * pos_k = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, ctx_len + q_len); + ggml_tensor * attn_mask = nullptr; + if (draft_graph_needs_swa_mask(w, ctx_len)) { + attn_mask = ggml_new_tensor_2d(gctx, GGML_TYPE_F16, ctx_len + q_len, q_len); + ggml_set_name(attn_mask, "draft_swa_mask"); + ggml_set_input(attn_mask); + } ggml_set_name(noise_embed, "noise_embed"); ggml_set_name(target_hid, "target_hidden_cat"); ggml_set_name(pos_q, "positions_q"); @@ -101,6 +107,7 @@ int main(int argc, char ** argv) { gi.target_hidden_cat = target_hid; gi.positions_q = pos_q; gi.positions_k = pos_k; + gi.attn_mask = attn_mask; DraftGraphOutputs go = build_draft_graph(gctx, w, gi); if (!go.hidden_states) { std::fprintf(stderr, "build_draft_graph returned null\n"); return 1; } @@ -141,6 +148,11 @@ int main(int argc, char ** argv) { for (int i = 0; i < ctx_len + q_len; i++) pk[i] = i; ggml_backend_tensor_set(pos_k, pk.data(), 0, sizeof(int32_t) * pk.size()); } + if (attn_mask) { + std::vector mask; + build_draft_swa_mask(mask, ctx_len, q_len, w.swa_window); + ggml_backend_tensor_set(attn_mask, mask.data(), 0, sizeof(uint16_t) * mask.size()); + } // ── 7. Compute auto status = ggml_backend_graph_compute(backend, gf); diff --git a/dflash/test/test_vs_oracle.cpp b/dflash/test/test_vs_oracle.cpp index b0c247d94..352139bc9 100644 --- a/dflash/test/test_vs_oracle.cpp +++ b/dflash/test/test_vs_oracle.cpp @@ -117,6 +117,12 @@ int main(int argc, char ** argv) { ggml_tensor * target_hid = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, m.fc_in, m.ctx_len, 1); ggml_tensor * pos_q = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, m.q_len); ggml_tensor * pos_k = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, m.ctx_len + m.q_len); + ggml_tensor * attn_mask = nullptr; + if (draft_graph_needs_swa_mask(w, m.ctx_len)) { + attn_mask = ggml_new_tensor_2d(gctx, GGML_TYPE_F16, m.ctx_len + m.q_len, m.q_len); + ggml_set_name(attn_mask, "draft_swa_mask"); + ggml_set_input(attn_mask); + } ggml_set_name(noise_embed, "noise_embed"); ggml_set_name(target_hid, "target_hidden_cat"); ggml_set_name(pos_q, "positions_q"); @@ -132,6 +138,7 @@ int main(int argc, char ** argv) { gi.target_hidden_cat = target_hid; gi.positions_q = pos_q; gi.positions_k = pos_k; + gi.attn_mask = attn_mask; DraftGraphOutputs go = build_draft_graph(gctx, w, gi); if (!go.hidden_states) return 1; ggml_set_output(go.hidden_states); @@ -154,6 +161,11 @@ int main(int argc, char ** argv) { for (int i = 0; i < m.ctx_len + m.q_len; i++) pk[i] = i; ggml_backend_tensor_set(pos_q, pq.data(), 0, sizeof(int32_t) * pq.size()); ggml_backend_tensor_set(pos_k, pk.data(), 0, sizeof(int32_t) * pk.size()); + if (attn_mask) { + std::vector mask; + build_draft_swa_mask(mask, m.ctx_len, m.q_len, w.swa_window); + ggml_backend_tensor_set(attn_mask, mask.data(), 0, sizeof(uint16_t) * mask.size()); + } // Compute auto status = ggml_backend_graph_compute(backend, gf); From 8614d0bab68cd720639eae1d8fd3ec7149aa78c5 Mon Sep 17 00:00:00 2001 From: Javier Pazo Date: Sat, 9 May 2026 12:37:58 +0200 Subject: [PATCH 2/2] test(dflash): contract test for draft SWA mask wiring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a focused regression test that pins the draft-side SWA mask wiring as a contract, independent of which SWA implementation eventually lands. What it tests: * When `draft_graph_needs_swa_mask(weights, ctx_len)` returns true, the graph reads `attn_mask` from `DraftGraphInputs` and propagates it to `ggml_flash_attn_ext` on SWA layers. * When SWA is not active, or `total_k <= swa_window`, the mask is not consumed. * `build_draft_swa_mask` produces the documented shape and values (`0.0` for visible positions and `-inf` for masked positions, F16 storage). Why a separate PR: * Keeps "one concern per PR" per CONTRIBUTING — the wiring is feature, this is test. * The test stays useful regardless of the approach that lands upstream: - PR #94 (mask-style) — exercises exactly the wiring this test pins. - PR #129 (per-layer K/V truncation) — the contract still prevents regressions on the mask code path that callers rely on. Depends on the companion `feat(dflash): wire caller-provided SWA mask through draft graph` (PR to be opened in parallel). The test needs the helpers exported in that PR's `dflash_graph.h`. Build registration: Adds the test to `dflash/CMakeLists.txt` next to `test_vs_oracle` with the same `EXISTS`-guard pattern, so a `cmake --build ... --target test_draft_swa_mask_contract` builds and runs it on a clean checkout. No effect on other targets. Validation: * `cmake --build dflash/build/Release --target test_draft_swa_mask_contract` succeeds on RTX 6000 Ada (sm_89), Windows MSVC + CUDA 12.x. * All assertions green: SWA-active + long-ctx case consumes the mask; non-SWA / short-ctx cases do not; mask helper output matches the documented shape and values. * No regressions on `smoke_draft_graph` or `test_vs_oracle` (untouched here). Author: Javier Pazo --- dflash/CMakeLists.txt | 5 + dflash/test/test_draft_swa_mask_contract.cpp | 177 +++++++++++++++++++ 2 files changed, 182 insertions(+) create mode 100644 dflash/test/test_draft_swa_mask_contract.cpp diff --git a/dflash/CMakeLists.txt b/dflash/CMakeLists.txt index 8ac910092..00d13e576 100644 --- a/dflash/CMakeLists.txt +++ b/dflash/CMakeLists.txt @@ -277,6 +277,11 @@ if(DFLASH27B_TESTS) target_include_directories(test_vs_oracle PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) target_link_libraries(test_vs_oracle PRIVATE dflash27b ggml ggml-cuda) endif() + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_draft_swa_mask_contract.cpp") + add_executable(test_draft_swa_mask_contract test/test_draft_swa_mask_contract.cpp) + target_include_directories(test_draft_swa_mask_contract PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) + target_link_libraries(test_draft_swa_mask_contract PRIVATE dflash27b ggml ggml-cuda) + endif() if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/smoke_load_target.cpp") add_executable(smoke_load_target test/smoke_load_target.cpp) target_include_directories(smoke_load_target PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) diff --git a/dflash/test/test_draft_swa_mask_contract.cpp b/dflash/test/test_draft_swa_mask_contract.cpp new file mode 100644 index 000000000..460ba35b3 --- /dev/null +++ b/dflash/test/test_draft_swa_mask_contract.cpp @@ -0,0 +1,177 @@ +#include "dflash_graph.h" +#include "internal.h" + +#include "ggml.h" + +#include +#include + +using namespace dflash27b; + +namespace { + +struct GraphCase { + bool is_swa = false; + int swa_window = 0; + int ctx_len = 0; + bool provide_mask = false; + bool expect_mask = false; + const char * label = ""; +}; + +ggml_tensor * new_vec(ggml_context * ctx, int64_t n) { + return ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n); +} + +ggml_tensor * new_mat(ggml_context * ctx, int64_t ne0, int64_t ne1) { + return ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne0, ne1); +} + +bool run_case(const GraphCase & tc) { + ggml_init_params ip{}; + ip.mem_size = 2 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + ggml_context * ctx = ggml_init(ip); + if (!ctx) { + std::fprintf(stderr, "FAIL %s: ggml_init failed\n", tc.label); + return false; + } + + constexpr int hidden = 8; + constexpr int n_head = 2; + constexpr int n_kv = 1; + constexpr int head_dim = 4; + constexpr int q_len = DFLASH27B_DRAFT_BLOCK_SIZE; + constexpr int inter = 12; + constexpr int fc_in = 5 * hidden; + const int total_k = tc.ctx_len + q_len; + + DraftWeights w{}; + w.n_layer = 1; + w.n_head = n_head; + w.n_head_kv = n_kv; + w.head_dim = head_dim; + w.swa_window = tc.swa_window; + w.layers.resize(1); + + w.fc = new_mat(ctx, fc_in, hidden); + w.hidden_norm = new_vec(ctx, hidden); + w.out_norm = new_vec(ctx, hidden); + + DraftLayer & layer = w.layers[0]; + layer.attn_norm = new_vec(ctx, hidden); + layer.ffn_norm = new_vec(ctx, hidden); + layer.wq = new_mat(ctx, hidden, n_head * head_dim); + layer.wk = new_mat(ctx, hidden, n_kv * head_dim); + layer.wv = new_mat(ctx, hidden, n_kv * head_dim); + layer.wo = new_mat(ctx, n_head * head_dim, hidden); + layer.q_norm = new_vec(ctx, head_dim); + layer.k_norm = new_vec(ctx, head_dim); + layer.w_gate = new_mat(ctx, hidden, inter); + layer.w_up = new_mat(ctx, hidden, inter); + layer.w_down = new_mat(ctx, inter, hidden); + layer.is_swa = tc.is_swa; + + ggml_tensor * noise_embed = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hidden, q_len, 1); + ggml_tensor * target_hidden_cat = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, fc_in, tc.ctx_len, 1); + ggml_tensor * positions_q = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, q_len); + ggml_tensor * positions_k = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, total_k); + ggml_tensor * attn_mask = nullptr; + if (tc.provide_mask) { + attn_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, total_k, q_len); + ggml_set_name(attn_mask, "draft_swa_mask"); + ggml_set_input(attn_mask); + } + + ggml_set_input(noise_embed); + ggml_set_input(target_hidden_cat); + ggml_set_input(positions_q); + ggml_set_input(positions_k); + + DraftGraphInputs in{}; + in.ctx_len = tc.ctx_len; + in.noise_embed = noise_embed; + in.target_hidden_cat = target_hidden_cat; + in.positions_q = positions_q; + in.positions_k = positions_k; + in.attn_mask = attn_mask; + + DraftGraphOutputs out = build_draft_graph(ctx, w, in); + if (!out.hidden_states) { + std::fprintf(stderr, "FAIL %s: build_draft_graph failed: %s\n", + tc.label, dflash27b_last_error()); + ggml_free(ctx); + return false; + } + + ggml_cgraph * gf = ggml_new_graph_custom(ctx, 256, false); + ggml_build_forward_expand(gf, out.hidden_states); + + ggml_tensor * flash = nullptr; + for (int i = 0; i < ggml_graph_n_nodes(gf); ++i) { + ggml_tensor * node = ggml_graph_node(gf, i); + if (node && node->op == GGML_OP_FLASH_ATTN_EXT) { + flash = node; + break; + } + } + + if (!flash) { + std::fprintf(stderr, "FAIL %s: no flash_attn_ext node found\n", tc.label); + ggml_free(ctx); + return false; + } + + const bool got_mask = flash->src[3] != nullptr; + if (got_mask != tc.expect_mask) { + std::fprintf(stderr, "FAIL %s: expected mask=%d got mask=%d\n", + tc.label, tc.expect_mask ? 1 : 0, got_mask ? 1 : 0); + ggml_free(ctx); + return false; + } + if (tc.expect_mask && flash->src[3] != attn_mask) { + std::fprintf(stderr, "FAIL %s: flash_attn_ext did not use caller mask tensor\n", tc.label); + ggml_free(ctx); + return false; + } + + std::printf("PASS %s\n", tc.label); + ggml_free(ctx); + return true; +} + +} // namespace + +int main() { + std::vector cases(3); + cases[0].is_swa = true; + cases[0].swa_window = 8; + cases[0].ctx_len = 12; + cases[0].provide_mask = true; + cases[0].expect_mask = true; + cases[0].label = "swa-long-context-wires-mask"; + + cases[1].is_swa = false; + cases[1].swa_window = 8; + cases[1].ctx_len = 12; + cases[1].provide_mask = true; + cases[1].expect_mask = false; + cases[1].label = "non-swa-layer-ignores-mask"; + + cases[2].is_swa = true; + cases[2].swa_window = 64; + cases[2].ctx_len = 12; + cases[2].provide_mask = true; + cases[2].expect_mask = false; + cases[2].label = "short-context-keeps-full-attn"; + + int failed = 0; + for (const GraphCase & tc : cases) { + if (!run_case(tc)) { + ++failed; + } + } + + return failed == 0 ? 0 : 1; +}