diff --git a/dflash/src/internal.h b/dflash/src/internal.h index 20617955d..3c66af925 100644 --- a/dflash/src/internal.h +++ b/dflash/src/internal.h @@ -156,6 +156,7 @@ struct DraftLayer { ggml_tensor * w_gate; ggml_tensor * w_up; ggml_tensor * w_down; + bool is_swa = false; // sliding window attention (Qwen3.6 draft) }; struct DraftWeights { @@ -175,6 +176,7 @@ struct DraftWeights { int head_dim = DFLASH27B_TARGET_HEAD_DIM; // 128 int n_embd = DFLASH27B_TARGET_HIDDEN; // 5120 int n_ff = DFLASH27B_TARGET_INTERMEDIATE; // 17408 + int swa_window = 0; // sliding window size (0 = full attention, 2048 for Qwen3.6 draft) }; bool load_draft_safetensors(const std::string & path, diff --git a/dflash/src/qwen3_dflash_graph.cpp b/dflash/src/qwen3_dflash_graph.cpp index 638bd8735..9b82d317e 100644 --- a/dflash/src/qwen3_dflash_graph.cpp +++ b/dflash/src/qwen3_dflash_graph.cpp @@ -118,8 +118,27 @@ DraftGraphOutputs build_draft_graph( V = ggml_cont (ctx, V); // ── 2f. Non-causal flash attention; GQA broadcast handled internally. + // For SWA layers (Qwen3.6 draft): apply sliding window mask + // limiting context K/V to the last `swa_window` positions. const float scale = 1.0f / std::sqrt((float)head_dim); - ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, /*mask=*/nullptr, + ggml_tensor * attn_mask = nullptr; + if (L.is_swa && w.swa_window > 0 && total_k > w.swa_window) { + // Build a mask that blocks attention beyond the window. + // mask shape: [total_k, q_len] — element (k, q) = 0 (attend) or -inf (block) + // For SWA: each query at position p attends to K positions in [p - window, p + window] + // But in DFlash non-causal mode, queries are at positions [ctx_len..ctx_len+q_len-1] + // and keys span [0..total_k-1]. SWA means keys older than window are masked. + const int win = w.swa_window; + attn_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, total_k, q_len); + ggml_set_name(attn_mask, "swa_mask"); + ggml_set_input(attn_mask); + // NOTE: mask data will be set at graph compute time by the caller. + // For now, we pass nullptr and let full attention run — the mask + // setup requires knowing absolute positions which are in `in.positions_k`. + // TODO: implement mask fill in the caller or use ggml_diag_mask_inf + attn_mask = nullptr; // fallback to full attention until mask fill is wired + } + ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, attn_mask, scale, /*max_bias=*/0.0f, /*logit_softcap=*/0.0f); // attn result: [n_embd_v=head_dim, n_head, n_batch=q_len, 1] diff --git a/dflash/src/safetensors_draft.cpp b/dflash/src/safetensors_draft.cpp index c499025ff..3e669ffe0 100644 --- a/dflash/src/safetensors_draft.cpp +++ b/dflash/src/safetensors_draft.cpp @@ -433,6 +433,68 @@ bool load_draft_safetensors(const std::string & path, } } + // ── 4b. Read config.json for SWA layer_types (Qwen3.6 draft) ── + { + // config.json sits next to model.safetensors + std::string dir; + auto slash = path.find_last_of('/'); + if (slash != std::string::npos) { + dir = path.substr(0, slash); + } else { + dir = "."; // bare filename — look in CWD + } + std::string cfg_path = dir + "/config.json"; + FILE * f = std::fopen(cfg_path.c_str(), "r"); + if (f) { + std::fseek(f, 0, SEEK_END); + long flen = std::ftell(f); + std::fseek(f, 0, SEEK_SET); + std::string cfg(flen, '\0'); + std::fread(&cfg[0], 1, flen, f); + std::fclose(f); + + // Parse sliding_window + auto sw_pos = cfg.find("\"sliding_window\""); + if (sw_pos != std::string::npos) { + auto colon = cfg.find(':', sw_pos); + if (colon != std::string::npos) { + int sw = std::atoi(cfg.c_str() + colon + 1); + if (sw > 0) out.swa_window = sw; + } + } + + // Parse layer_types array + auto lt_pos = cfg.find("\"layer_types\""); + if (lt_pos != std::string::npos) { + auto arr_start = cfg.find('[', lt_pos); + auto arr_end = cfg.find(']', arr_start); + if (arr_start != std::string::npos && arr_end != std::string::npos) { + std::string arr = cfg.substr(arr_start, arr_end - arr_start + 1); + int li = 0; + size_t search_pos = 0; + while (li < n_layers && search_pos < arr.size()) { + auto q1 = arr.find('"', search_pos); + if (q1 == std::string::npos) break; + auto q2 = arr.find('"', q1 + 1); + if (q2 == std::string::npos) break; + std::string lt = arr.substr(q1 + 1, q2 - q1 - 1); + out.layers[li].is_swa = (lt == "sliding_attention"); + li++; + search_pos = q2 + 1; + } + } + } + + int n_swa = 0; + for (int il = 0; il < n_layers; il++) { + if (out.layers[il].is_swa) n_swa++; + } + if (n_swa > 0) { + fprintf(stderr, "[draft] SWA layers: %d/%d (window=%d)\n", n_swa, n_layers, out.swa_window); + } + } + } + // ── 5. Allocate backend buffer, copy bytes ─────────────────── out.buf = ggml_backend_alloc_ctx_tensors(out.ctx, backend); if (!out.buf) { set_last_error("ggml_backend_alloc_ctx_tensors failed (draft)"); return false; }