diff --git a/README.md b/README.md index 4b7c69ec..e9e8c9ba 100644 --- a/README.md +++ b/README.md @@ -680,6 +680,34 @@ support the CPU backend for reference/debug use and share the same KV session and snapshot format as Metal and CUDA, but normal inference should use Metal or CUDA. +### Metal TurboQuant KV cache + +The Metal graph can store compressed attention KV rows with the TurboQuant +PolarQuant/WHT formats instead of the default FP8 rows: + +```sh +DS4_KV_TURBO=4 ./ds4-server --ctx 384000 ... +DS4_KV_TURBO=3 ./ds4-server --ctx 384000 ... +``` + +`DS4_KV_TURBO=4` is the preferred quality/speed mode. `DS4_KV_TURBO=3` +uses a smaller 3-bit cache and is mainly useful when memory pressure is the +first constraint. Unset `DS4_KV_TURBO`, or set it to `0`, to keep the FP8 +path. + +For ratio-4 indexed decode, the default compressed-row top-k remains 512. +Fast non-quality runs can lower it with `DS4_METAL_DECODE_INDEXER_TOP_K`; +values are capped at 512 and rounded down to a power of two: + +```sh +DS4_KV_TURBO=4 DS4_METAL_DECODE_INDEXER_TOP_K=128 ./ds4-server --ctx 384000 ... +``` + +`--quality` keeps the 512-row path. The diagnostic switches +`DS4_METAL_DISABLE_TURBO_DIRECT_ATTN=1` and +`DS4_METAL_DISABLE_TURBO_SELECTED_F16=1` restore the older materialized +attention paths for comparisons. + ## Steering This project supports steering with single-vector activation directions; see the diff --git a/ds4.c b/ds4.c index 51410e33..6a081028 100644 --- a/ds4.c +++ b/ds4.c @@ -8078,6 +8078,9 @@ typedef struct { * the row counters whenever a checkpoint is saved or partially rewound. */ ds4_gpu_tensor *layer_raw_cache[DS4_N_LAYER]; ds4_gpu_tensor *layer_attn_comp_cache[DS4_N_LAYER]; + ds4_gpu_tensor *layer_attn_turbo_cache[DS4_N_LAYER]; /* turbo-compressed comp cache */ + ds4_gpu_tensor *attn_comp_scratch; /* shared f32 dequant/output scratch in turbo mode */ + ds4_gpu_tensor *attn_comp_selected_f16; /* shared f16 selected-row scratch in turbo decode */ ds4_gpu_tensor *layer_attn_state_kv[DS4_N_LAYER]; ds4_gpu_tensor *layer_attn_state_score[DS4_N_LAYER]; ds4_gpu_tensor *layer_index_comp_cache[DS4_N_LAYER]; @@ -8110,6 +8113,7 @@ typedef struct { * layer compression ratio instead of pessimistically using the ratio-4 cap * for every ratio-128 layer. */ uint32_t layer_comp_cap[DS4_N_LAYER]; + int kv_quant_type; /* DS4_KV_QUANT_FP8=0, DS4_KV_QUANT_TURBO3=1, DS4_KV_QUANT_TURBO4=2 */ /* Per-layer work tensors. They are reused in place by every layer instead * of allocating a generic graph arena. This is why the code is verbose but @@ -8121,6 +8125,7 @@ typedef struct { ds4_gpu_tensor *indexer_scores; ds4_gpu_tensor *comp_mask; ds4_gpu_tensor *comp_selected; + ds4_gpu_tensor *comp_selected_identity; ds4_gpu_tensor *heads; ds4_gpu_tensor *attn_low; ds4_gpu_tensor *attn_out; @@ -8300,6 +8305,7 @@ static void metal_graph_free(ds4_gpu_graph *g) { ds4_gpu_tensor_free(g->comp_kv_cur); ds4_gpu_tensor_free(g->comp_mask); ds4_gpu_tensor_free(g->comp_selected); + ds4_gpu_tensor_free(g->comp_selected_identity); ds4_gpu_tensor_free(g->indexer_scores); ds4_gpu_tensor_free(g->indexer_weights); ds4_gpu_tensor_free(g->indexer_q); @@ -8309,6 +8315,11 @@ static void metal_graph_free(ds4_gpu_graph *g) { for (uint32_t il = 0; il < DS4_N_LAYER; il++) { ds4_gpu_tensor_free(g->layer_attn_comp_cache[il]); } + for (uint32_t il = 0; il < DS4_N_LAYER; il++) { + ds4_gpu_tensor_free(g->layer_attn_turbo_cache[il]); + } + ds4_gpu_tensor_free(g->attn_comp_scratch); + ds4_gpu_tensor_free(g->attn_comp_selected_f16); for (uint32_t il = 0; il < DS4_N_LAYER; il++) { ds4_gpu_tensor_free(g->layer_attn_state_kv[il]); } @@ -8550,6 +8561,26 @@ static bool metal_graph_ensure_batch_ffn_out(ds4_gpu_graph *g) { return g->batch_ffn_out != NULL; } +static bool metal_graph_turbo_enabled(const ds4_gpu_graph *g); +static uint64_t metal_graph_turbo_row_bytes(const ds4_gpu_graph *g); + +static int metal_graph_requested_kv_quant_type(bool warn_non_metal) { + const char *turbo_env = getenv("DS4_KV_TURBO"); + int type = DS4_KV_QUANT_FP8; +#ifdef __APPLE__ + (void)warn_non_metal; + if (turbo_env) { + if (!strcmp(turbo_env, "3")) type = DS4_KV_QUANT_TURBO3; + else if (!strcmp(turbo_env, "4")) type = DS4_KV_QUANT_TURBO4; + } +#else + if (warn_non_metal && turbo_env && turbo_env[0] && strcmp(turbo_env, "0") != 0) { + fprintf(stderr, "ds4: DS4_KV_TURBO is Metal-only; using FP8 KV cache on this build\n"); + } +#endif + return type; +} + /* ========================================================================= * Metal Release Graph Allocation. * ========================================================================= */ @@ -8566,6 +8597,10 @@ static bool metal_graph_alloc_raw_cap( bool enable_mtp) { memset(g, 0, sizeof(*g)); g->mtp_enabled = enable_mtp; + + /* DS4_KV_TURBO: "3" = turbo3_0, "4" = turbo4_0, unset/"0" = FP8. */ + g->kv_quant_type = metal_graph_requested_kv_quant_type(true); + if (raw_cap == 0) raw_cap = 1; if (ctx_size == 0) ctx_size = raw_cap; if (prefill_cap == 0) prefill_cap = 1; @@ -8629,6 +8664,11 @@ static bool metal_graph_alloc_raw_cap( g->q = ds4_gpu_tensor_alloc(q_dim * sizeof(float)); g->kv_raw = ds4_gpu_tensor_alloc((uint64_t)DS4_N_HEAD_DIM * sizeof(float)); g->kv = ds4_gpu_tensor_alloc((uint64_t)DS4_N_HEAD_DIM * sizeof(float)); + if (metal_graph_turbo_enabled(g)) { + g->attn_comp_scratch = ds4_gpu_tensor_alloc((uint64_t)g->comp_cap * DS4_N_HEAD_DIM * sizeof(float)); + g->attn_comp_selected_f16 = ds4_gpu_tensor_alloc((uint64_t)DS4_N_INDEXER_TOP_K * + DS4_N_HEAD_DIM * sizeof(uint16_t)); + } bool state_init_ok = true; for (uint32_t il = 0; il < DS4_N_LAYER; il++) { g->layer_raw_cache[il] = ds4_gpu_tensor_alloc((uint64_t)raw_cap * DS4_N_HEAD_DIM * sizeof(float)); @@ -8637,7 +8677,12 @@ static bool metal_graph_alloc_raw_cap( const uint32_t coff = ratio == 4 ? 2u : 1u; const uint64_t attn_width = (uint64_t)coff * DS4_N_HEAD_DIM; const uint64_t attn_rows = (uint64_t)coff * ratio; - g->layer_attn_comp_cache[il] = ds4_gpu_tensor_alloc((uint64_t)g->layer_comp_cap[il] * DS4_N_HEAD_DIM * sizeof(float)); + if (metal_graph_turbo_enabled(g)) { + const uint64_t row_bytes = metal_graph_turbo_row_bytes(g); + g->layer_attn_turbo_cache[il] = ds4_gpu_tensor_alloc((uint64_t)g->layer_comp_cap[il] * row_bytes); + } else { + g->layer_attn_comp_cache[il] = ds4_gpu_tensor_alloc((uint64_t)g->layer_comp_cap[il] * DS4_N_HEAD_DIM * sizeof(float)); + } g->layer_attn_state_kv[il] = ds4_gpu_tensor_alloc(attn_width * attn_rows * sizeof(float)); g->layer_attn_state_score[il] = ds4_gpu_tensor_alloc(attn_width * attn_rows * sizeof(float)); if (enable_mtp) { @@ -8686,6 +8731,8 @@ static bool metal_graph_alloc_raw_cap( g->comp_mask = ds4_gpu_tensor_alloc((uint64_t)g->comp_cap * pc * sizeof(float)); g->comp_selected = ds4_gpu_tensor_alloc((uint64_t)(DS4_N_INDEXER_TOP_K ? DS4_N_INDEXER_TOP_K : 1u) * pc * sizeof(uint32_t)); + g->comp_selected_identity = ds4_gpu_tensor_alloc((uint64_t)(DS4_N_INDEXER_TOP_K ? DS4_N_INDEXER_TOP_K : 1u) * + pc * sizeof(uint32_t)); g->heads = ds4_gpu_tensor_alloc(q_dim * sizeof(float)); g->attn_low = ds4_gpu_tensor_alloc(low_dim * sizeof(float)); g->attn_out = ds4_gpu_tensor_alloc((uint64_t)DS4_N_EMBD * sizeof(float)); @@ -8776,7 +8823,12 @@ static bool metal_graph_alloc_raw_cap( layer_cache_ok = g->layer_raw_cache[il] != NULL; const uint32_t ratio = ds4_layer_compress_ratio(il); if (layer_cache_ok && ratio != 0) { - layer_cache_ok = g->layer_attn_comp_cache[il] != NULL && + const bool have_attn_cache = metal_graph_turbo_enabled(g) + ? (g->attn_comp_scratch != NULL && + g->attn_comp_selected_f16 != NULL && + g->layer_attn_turbo_cache[il] != NULL) + : (g->layer_attn_comp_cache[il] != NULL); + layer_cache_ok = have_attn_cache && g->layer_attn_state_kv[il] != NULL && g->layer_attn_state_score[il] != NULL && (!enable_mtp || @@ -8804,7 +8856,7 @@ static bool metal_graph_alloc_raw_cap( g->q && g->kv_raw && g->kv && g->comp_kv_cur && g->comp_sc_cur && g->indexer_q && g->indexer_weights && g->indexer_scores && - g->comp_mask && g->comp_selected && + g->comp_mask && g->comp_selected && g->comp_selected_identity && g->heads && g->attn_low && g->attn_out && g->after_attn_hc && g->ffn_cur && g->ffn_norm && g->shared_gate && g->shared_up && g->shared_mid && @@ -8876,6 +8928,179 @@ static uint32_t metal_graph_raw_start_for_span( return first_raw_pos % g->raw_cap; } +static bool metal_graph_turbo_enabled(const ds4_gpu_graph *g) { +#ifdef __APPLE__ + return g && (g->kv_quant_type == DS4_KV_QUANT_TURBO3 || + g->kv_quant_type == DS4_KV_QUANT_TURBO4); +#else + (void)g; + return false; +#endif +} + +static bool metal_graph_kv_quant_type_is_turbo(int type) { +#ifdef __APPLE__ + return type == DS4_KV_QUANT_TURBO3 || type == DS4_KV_QUANT_TURBO4; +#else + (void)type; + return false; +#endif +} + +static uint32_t metal_graph_turbo_n_blocks(void) { + const uint32_t n_nope = DS4_N_HEAD_DIM - DS4_N_ROT; + return n_nope / 32u; +} + +static uint64_t metal_graph_turbo_row_bytes_for_type(int kv_quant_type) { + if (!metal_graph_kv_quant_type_is_turbo(kv_quant_type)) return 0; + const uint64_t block_bytes = kv_quant_type == DS4_KV_QUANT_TURBO3 + ? DS4_TURBO3_BLOCK_BYTES + : DS4_TURBO4_BLOCK_BYTES; + return (uint64_t)metal_graph_turbo_n_blocks() * block_bytes + + (uint64_t)DS4_N_ROT * sizeof(float); +} + +static uint64_t metal_graph_turbo_row_bytes(const ds4_gpu_graph *g) { + if (!metal_graph_turbo_enabled(g)) return 0; + return metal_graph_turbo_row_bytes_for_type(g->kv_quant_type); +} + +static ds4_gpu_tensor *metal_graph_attn_comp_cache(ds4_gpu_graph *g, uint32_t il) { + if (!g) return NULL; + return metal_graph_turbo_enabled(g) ? g->attn_comp_scratch : g->layer_attn_comp_cache[il]; +} + +static bool metal_graph_turbo_quantize_rows( + ds4_gpu_graph *g, + uint32_t il, + const ds4_gpu_tensor *src, + uint32_t dst_row, + uint32_t n_rows) { + if (!metal_graph_turbo_enabled(g) || n_rows == 0) return true; +#ifdef __APPLE__ + if (!src || !g->layer_attn_turbo_cache[il]) return false; + if (((DS4_N_HEAD_DIM - DS4_N_ROT) % 32u) != 0) return false; + + const uint64_t row_bytes = metal_graph_turbo_row_bytes(g); + ds4_gpu_tensor *dst = ds4_gpu_tensor_view(g->layer_attn_turbo_cache[il], + (uint64_t)dst_row * row_bytes, + (uint64_t)n_rows * row_bytes); + if (!dst) return false; + const int ok = g->kv_quant_type == DS4_KV_QUANT_TURBO3 + ? ds4_gpu_turbo3_kv_quantize_tensor(dst, src, n_rows, 1, DS4_N_HEAD_DIM, DS4_N_ROT) + : ds4_gpu_turbo4_kv_quantize_tensor(dst, src, n_rows, 1, DS4_N_HEAD_DIM, DS4_N_ROT); + ds4_gpu_tensor_free(dst); + return ok != 0; +#else + (void)il; + (void)src; + (void)dst_row; + return false; +#endif +} + +static bool metal_graph_turbo_dequant_cache(ds4_gpu_graph *g, uint32_t il, uint32_t n_rows) { + if (!metal_graph_turbo_enabled(g) || n_rows == 0) return true; +#ifdef __APPLE__ + ds4_gpu_tensor *dst = metal_graph_attn_comp_cache(g, il); + if (!dst || !g->layer_attn_turbo_cache[il]) return false; + const uint32_t n_blocks = metal_graph_turbo_n_blocks(); + const int ok = g->kv_quant_type == DS4_KV_QUANT_TURBO3 + ? ds4_gpu_turbo3_dequant_f32_tensor(dst, g->layer_attn_turbo_cache[il], + n_blocks, DS4_N_ROT, n_rows) + : ds4_gpu_turbo4_dequant_f32_tensor(dst, g->layer_attn_turbo_cache[il], + n_blocks, DS4_N_ROT, n_rows); + return ok != 0; +#else + (void)il; + return false; +#endif +} + +static bool metal_graph_turbo_direct_attention_enabled(const ds4_gpu_graph *g) { + return metal_graph_turbo_enabled(g) && + getenv("DS4_METAL_DISABLE_TURBO_DIRECT_ATTN") == NULL; +} + +static bool metal_graph_turbo_selected_f16_enabled(const ds4_gpu_graph *g) { + return metal_graph_turbo_direct_attention_enabled(g) && + getenv("DS4_METAL_DISABLE_TURBO_SELECTED_F16") == NULL; +} + +static bool metal_graph_turbo_selected_dequant( + ds4_gpu_graph *g, + uint32_t il, + const ds4_gpu_tensor *topk, + uint32_t n_comp, + uint32_t top_k, + uint32_t n_tokens, + bool f16_dst) { + if (!metal_graph_turbo_enabled(g) || + !g->comp_selected_identity || !g->layer_attn_turbo_cache[il] || + !topk || n_comp == 0 || top_k == 0 || n_tokens == 0 || top_k > n_comp) { + return false; + } + if (f16_dst) { + if (!g->attn_comp_selected_f16 || n_tokens != 1u) return false; + } else { + if (!g->attn_comp_scratch) return false; + } +#ifdef __APPLE__ + const uint32_t n_blocks = metal_graph_turbo_n_blocks(); + int ok = 0; + if (f16_dst) { + ok = g->kv_quant_type == DS4_KV_QUANT_TURBO3 + ? ds4_gpu_turbo3_dequant_selected_f16_tensor(g->attn_comp_selected_f16, + g->comp_selected_identity, + g->layer_attn_turbo_cache[il], + topk, + n_blocks, + DS4_N_ROT, + n_comp, + top_k, + n_tokens) + : ds4_gpu_turbo4_dequant_selected_f16_tensor(g->attn_comp_selected_f16, + g->comp_selected_identity, + g->layer_attn_turbo_cache[il], + topk, + n_blocks, + DS4_N_ROT, + n_comp, + top_k, + n_tokens); + } else { + ok = g->kv_quant_type == DS4_KV_QUANT_TURBO3 + ? ds4_gpu_turbo3_dequant_selected_f32_tensor(g->attn_comp_scratch, + g->comp_selected_identity, + g->layer_attn_turbo_cache[il], + topk, + n_blocks, + DS4_N_ROT, + n_comp, + top_k, + n_tokens) + : ds4_gpu_turbo4_dequant_selected_f32_tensor(g->attn_comp_scratch, + g->comp_selected_identity, + g->layer_attn_turbo_cache[il], + topk, + n_blocks, + DS4_N_ROT, + n_comp, + top_k, + n_tokens); + } + return ok != 0; +#else + (void)il; + (void)n_comp; + (void)top_k; + (void)n_tokens; + (void)f16_dst; + return false; +#endif +} + /* Capture the verifier prefix after the first speculative token. * * Exact MTP speculation is only profitable if partial accepts are cheap. The @@ -8911,8 +9136,39 @@ static bool metal_graph_capture_prefix1_index_state(ds4_gpu_graph *g, uint32_t i } static uint32_t metal_graph_decode_indexer_top_k(const ds4_gpu_graph *g) { - (void)g; - return DS4_N_INDEXER_TOP_K; + if (g && g->quality) return DS4_N_INDEXER_TOP_K; + + static int initialized; + static uint32_t cached; + if (!initialized) { + cached = DS4_N_INDEXER_TOP_K; + const char *env = getenv("DS4_METAL_DECODE_INDEXER_TOP_K"); + if (env && env[0]) { + char *end = NULL; + errno = 0; + unsigned long requested = strtoul(env, &end, 10); + if (errno == 0 && end != env && requested != 0) { + uint32_t k = requested > DS4_N_INDEXER_TOP_K + ? DS4_N_INDEXER_TOP_K + : (uint32_t)requested; + uint32_t rounded = 1u; + while ((rounded << 1u) != 0 && (rounded << 1u) <= k) { + rounded <<= 1u; + } + cached = rounded; + if ((unsigned long)cached != requested) { + fprintf(stderr, + "ds4: DS4_METAL_DECODE_INDEXER_TOP_K=%lu using %u " + "(power-of-two cap %u)\n", + requested, + cached, + (uint32_t)DS4_N_INDEXER_TOP_K); + } + } + } + initialized = 1; + } + return cached; } /* ========================================================================= @@ -9281,11 +9537,12 @@ static bool metal_graph_encode_decode_layer( g->attn_norm, 1) != 0; } const uint32_t comp_row = g->layer_n_comp[il]; + ds4_gpu_tensor *attn_comp_cache = metal_graph_attn_comp_cache(g, il); if (ok) ok = ds4_gpu_compressor_update_tensor(g->comp_kv_cur, g->comp_sc_cur, g->layer_attn_state_kv[il], g->layer_attn_state_score[il], - g->layer_attn_comp_cache[il], + attn_comp_cache, model->map, model->size, layer->attn_compressor_ape->abs_offset, @@ -9307,13 +9564,15 @@ static bool metal_graph_encode_decode_layer( DS4_RMS_EPS) != 0; if (ok && emit) { ds4_gpu_tensor *comp_row_view = ds4_gpu_tensor_view( - g->layer_attn_comp_cache[il], + attn_comp_cache, (uint64_t)comp_row * DS4_N_HEAD_DIM * sizeof(float), (uint64_t)DS4_N_HEAD_DIM * sizeof(float)); if (!comp_row_view) { ok = false; } else { - ok = ds4_gpu_dsv4_fp8_kv_quantize_tensor(comp_row_view, 1, DS4_N_HEAD_DIM, DS4_N_ROT) != 0; + ok = metal_graph_turbo_enabled(g) + ? metal_graph_turbo_quantize_rows(g, il, comp_row_view, comp_row, 1) + : ds4_gpu_dsv4_fp8_kv_quantize_tensor(comp_row_view, 1, DS4_N_HEAD_DIM, DS4_N_ROT) != 0; if (ok) { metal_graph_debug_dump_tensor("KVcompress", comp_row_view, DS4_N_HEAD_DIM, il, pos); } @@ -9480,33 +9739,93 @@ static bool metal_graph_encode_decode_layer( } n_comp = g->layer_n_comp[il]; - comp_cache = g->layer_attn_comp_cache[il]; + comp_cache = attn_comp_cache; + const bool turbo_selected_indexed = + metal_graph_turbo_direct_attention_enabled(g) && + !g->quality && + comp_selected != NULL && + n_selected != 0; + if (ok && !turbo_selected_indexed) ok = metal_graph_turbo_dequant_cache(g, il, n_comp); } DS4_METAL_PROFILE_DECODE_STAGE("compressor_indexer"); if (ok) { const uint32_t raw_start = metal_graph_raw_start_for_span(g, pos, n_raw); if (n_comp != 0 && comp_selected != NULL && n_selected != 0) { - ok = ds4_gpu_attention_indexed_mixed_batch_heads_tensor( - g->heads, - model->map, - model->size, - layer->attn_sinks->abs_offset, - g->q, - raw_cache, - comp_cache, - comp_selected, - 1, - pos, - n_raw, - raw_cap, - raw_start, - n_comp, - n_selected, - g->raw_window, - ds4_layer_compress_ratio(il), - DS4_N_HEAD, - DS4_N_HEAD_DIM) != 0; + if (metal_graph_turbo_direct_attention_enabled(g) && !g->quality) { + const bool use_selected_f16 = metal_graph_turbo_selected_f16_enabled(g); + ok = metal_graph_turbo_selected_dequant(g, + il, + comp_selected, + n_comp, + n_selected, + 1, + use_selected_f16); + if (ok && use_selected_f16) { + ok = ds4_gpu_attention_indexed_mixed_comp_f16_batch_heads_tensor( + g->heads, + model->map, + model->size, + layer->attn_sinks->abs_offset, + g->q, + raw_cache, + g->attn_comp_selected_f16, + g->comp_selected_identity, + 1, + pos, + n_raw, + raw_cap, + raw_start, + n_selected, + n_selected, + g->raw_window, + ds4_layer_compress_ratio(il), + DS4_N_HEAD, + DS4_N_HEAD_DIM) != 0; + } else if (ok) { + ok = ds4_gpu_attention_indexed_mixed_batch_heads_tensor( + g->heads, + model->map, + model->size, + layer->attn_sinks->abs_offset, + g->q, + raw_cache, + g->attn_comp_scratch, + g->comp_selected_identity, + 1, + pos, + n_raw, + raw_cap, + raw_start, + n_selected, + n_selected, + g->raw_window, + ds4_layer_compress_ratio(il), + DS4_N_HEAD, + DS4_N_HEAD_DIM) != 0; + } + } else { + ok = ds4_gpu_attention_indexed_mixed_batch_heads_tensor( + g->heads, + model->map, + model->size, + layer->attn_sinks->abs_offset, + g->q, + raw_cache, + comp_cache, + comp_selected, + 1, + pos, + n_raw, + raw_cap, + raw_start, + n_comp, + n_selected, + g->raw_window, + ds4_layer_compress_ratio(il), + DS4_N_HEAD, + DS4_N_HEAD_DIM) != 0; + } if (ok && decode_index_stage_profile) { ok = metal_graph_indexer_stage_profile_boundary("decode_attention", il, @@ -11015,6 +11334,7 @@ static bool metal_graph_encode_layer_attention_batch( uint32_t *comp_counts = compressed ? xcalloc(n_tokens, sizeof(comp_counts[0])) : NULL; uint32_t *index_counts = ratio == 4 ? xcalloc(n_tokens, sizeof(index_counts[0])) : NULL; const bool qkv_rms_fused = !metal_graph_use_reference_qkv_norm(); + ds4_gpu_tensor *attn_comp_cache = compressed ? metal_graph_attn_comp_cache(g, il) : NULL; ds4_gpu_tensor *hc_mix_view = ds4_gpu_tensor_view( g->batch_hc_mix, 0, (uint64_t)n_tokens * mix_hc * sizeof(float)); ds4_gpu_tensor *hc_split_view = ds4_gpu_tensor_view( @@ -11023,7 +11343,8 @@ static bool metal_graph_encode_layer_attention_batch( g->batch_attn_cur, 0, (uint64_t)n_tokens * DS4_N_EMBD * sizeof(float)); ds4_gpu_tensor *after_attn_hc_view = ds4_gpu_tensor_view( g->batch_after_attn_hc, 0, (uint64_t)n_tokens * hc_dim * sizeof(float)); - bool ok = hc_mix_view && hc_split_view && attn_cur_view && after_attn_hc_view; + bool ok = hc_mix_view && hc_split_view && attn_cur_view && after_attn_hc_view && + (!compressed || attn_comp_cache != NULL); if (ok) ok = ds4_gpu_rms_norm_plain_rows_tensor(g->batch_flat_hc, g->batch_cur_hc, (uint32_t)hc_dim, @@ -11354,7 +11675,7 @@ static bool metal_graph_encode_layer_attention_batch( ok = false; } if (ok) { - ok = ds4_gpu_compressor_prefill_tensor(g->layer_attn_comp_cache[il], + ok = ds4_gpu_compressor_prefill_tensor(attn_comp_cache, g->layer_attn_state_kv[il], g->layer_attn_state_score[il], g->batch_comp_kv, @@ -11371,7 +11692,7 @@ static bool metal_graph_encode_layer_attention_batch( n_tokens, DS4_N_ROT, compressed ? (uint32_t)DS4_ROPE_ORIG_CTX : 0, - true, + (g->kv_quant_type == DS4_KV_QUANT_FP8), /* skip FP8 when turbo */ freq_base, freq_scale, ext_factor, @@ -11379,6 +11700,8 @@ static bool metal_graph_encode_layer_attention_batch( DS4_ROPE_YARN_BETA_FAST, DS4_ROPE_YARN_BETA_SLOW, DS4_RMS_EPS) != 0; + if (ok) ok = metal_graph_turbo_quantize_rows(g, il, attn_comp_cache, 0, n_comp); + if (ok) ok = metal_graph_turbo_dequant_cache(g, il, n_comp); if (ok && ratio == 4) { ok = metal_graph_refresh_ratio4_compressor_state(g, model, @@ -11398,9 +11721,9 @@ static bool metal_graph_encode_layer_attention_batch( for (uint32_t t = 0; t < n_tokens; t++) { comp_counts[t] = (pos0 + t + 1u) / ratio; } - if (n_comp != 0) { - metal_graph_debug_dump_tensor("KVcompress", - g->layer_attn_comp_cache[il], + if (n_comp != 0) { + metal_graph_debug_dump_tensor("KVcompress", + attn_comp_cache, (uint64_t)n_comp * DS4_N_HEAD_DIM, il, pos0); @@ -11427,7 +11750,7 @@ static bool metal_graph_encode_layer_attention_batch( } ds4_gpu_tensor *comp_view = NULL; if (ok) { - comp_view = ds4_gpu_tensor_view(g->layer_attn_comp_cache[il], + comp_view = ds4_gpu_tensor_view(attn_comp_cache, (uint64_t)comp_before * DS4_N_HEAD_DIM * sizeof(float), (uint64_t)comp_chunk * DS4_N_HEAD_DIM * sizeof(float)); ok = comp_view != NULL; @@ -11447,11 +11770,11 @@ static bool metal_graph_encode_layer_attention_batch( layer->attn_compressor_norm->type, DS4_N_HEAD_DIM, pos0, - n_tokens, - DS4_N_ROT, - compressed ? (uint32_t)DS4_ROPE_ORIG_CTX : 0, - true, - freq_base, + n_tokens, + DS4_N_ROT, + compressed ? (uint32_t)DS4_ROPE_ORIG_CTX : 0, + !metal_graph_turbo_enabled(g), + freq_base, freq_scale, ext_factor, attn_factor, @@ -11474,11 +11797,11 @@ static bool metal_graph_encode_layer_attention_batch( DS4_N_HEAD_DIM, ratio, pos0, - n_tokens, - DS4_N_ROT, - compressed ? (uint32_t)DS4_ROPE_ORIG_CTX : 0, - true, - freq_base, + n_tokens, + DS4_N_ROT, + compressed ? (uint32_t)DS4_ROPE_ORIG_CTX : 0, + !metal_graph_turbo_enabled(g), + freq_base, freq_scale, ext_factor, attn_factor, @@ -11486,8 +11809,8 @@ static bool metal_graph_encode_layer_attention_batch( DS4_ROPE_YARN_BETA_SLOW, DS4_RMS_EPS) != 0; } - if (ok && ratio == 4) { - ok = metal_graph_refresh_ratio4_compressor_state(g, + if (ok && ratio == 4) { + ok = metal_graph_refresh_ratio4_compressor_state(g, model, g->layer_attn_state_kv[il], g->layer_attn_state_score[il], @@ -11498,8 +11821,10 @@ static bool metal_graph_encode_layer_attention_batch( comp_width, pos0, n_tokens); - } - if (ok) { + } + if (ok) ok = metal_graph_turbo_quantize_rows(g, il, comp_view, comp_before, comp_chunk); + if (ok) ok = metal_graph_turbo_dequant_cache(g, il, comp_before + comp_chunk); + if (ok) { g->layer_n_comp[il] = comp_before + comp_chunk; if (comp_counts) { for (uint32_t t = 0; t < n_tokens; t++) { @@ -11538,9 +11863,9 @@ static bool metal_graph_encode_layer_attention_batch( ok = kv_view && sc_view && ds4_gpu_compressor_update_tensor(kv_view, sc_view, - g->layer_attn_state_kv[il], - g->layer_attn_state_score[il], - g->layer_attn_comp_cache[il], + g->layer_attn_state_kv[il], + g->layer_attn_state_score[il], + attn_comp_cache, model->map, model->size, layer->attn_compressor_ape->abs_offset, @@ -11560,16 +11885,18 @@ static bool metal_graph_encode_layer_attention_batch( DS4_ROPE_YARN_BETA_FAST, DS4_ROPE_YARN_BETA_SLOW, DS4_RMS_EPS) != 0; - if (ok && emit) { - ds4_gpu_tensor *comp_row_view = ds4_gpu_tensor_view( - g->layer_attn_comp_cache[il], - (uint64_t)comp_row * DS4_N_HEAD_DIM * sizeof(float), - (uint64_t)DS4_N_HEAD_DIM * sizeof(float)); - ok = comp_row_view && - ds4_gpu_dsv4_fp8_kv_quantize_tensor(comp_row_view, - 1, - DS4_N_HEAD_DIM, - DS4_N_ROT) != 0; + if (ok && emit) { + ds4_gpu_tensor *comp_row_view = ds4_gpu_tensor_view( + attn_comp_cache, + (uint64_t)comp_row * DS4_N_HEAD_DIM * sizeof(float), + (uint64_t)DS4_N_HEAD_DIM * sizeof(float)); + ok = comp_row_view && + (metal_graph_turbo_enabled(g) + ? metal_graph_turbo_quantize_rows(g, il, comp_row_view, comp_row, 1) + : ds4_gpu_dsv4_fp8_kv_quantize_tensor(comp_row_view, + 1, + DS4_N_HEAD_DIM, + DS4_N_ROT) != 0); if (ok) { metal_graph_debug_dump_tensor("KVcompress", comp_row_view, @@ -11584,10 +11911,11 @@ static bool metal_graph_encode_layer_attention_batch( if (ok && t == 0) ok = metal_graph_capture_prefix1_attn_state(g, il); ds4_gpu_tensor_free(sc_view); ds4_gpu_tensor_free(kv_view); - } - } - n_comp = g->layer_n_comp[il]; - } + } + } + n_comp = g->layer_n_comp[il]; + if (ok) ok = metal_graph_turbo_dequant_cache(g, il, n_comp); + } DS4_METAL_PROFILE_ATTN_STAGE("compressor"); if (ok && ratio == 4) { @@ -11936,10 +12264,10 @@ static bool metal_graph_encode_layer_attention_batch( ok = ds4_gpu_attention_indexed_mixed_batch_heads_tensor(g->batch_heads, model->map, model->size, - layer->attn_sinks->abs_offset, - g->batch_q, - g->layer_raw_cache[il], - g->layer_attn_comp_cache[il], + layer->attn_sinks->abs_offset, + g->batch_q, + g->layer_raw_cache[il], + attn_comp_cache, g->comp_selected, n_tokens, pos0, @@ -11964,10 +12292,10 @@ static bool metal_graph_encode_layer_attention_batch( ok = ds4_gpu_attention_decode_mixed_batch_heads_tensor(g->batch_heads, model->map, model->size, - layer->attn_sinks->abs_offset, - g->batch_q, - g->layer_raw_cache[il], - g->layer_attn_comp_cache[il], + layer->attn_sinks->abs_offset, + g->batch_q, + g->layer_raw_cache[il], + attn_comp_cache, use_comp_mask ? g->comp_mask : NULL, use_comp_mask, n_tokens, @@ -12048,10 +12376,10 @@ static bool metal_graph_encode_layer_attention_batch( ok = ds4_gpu_attention_indexed_mixed_batch_heads_tensor(g->batch_heads, model->map, model->size, - layer->attn_sinks->abs_offset, - g->batch_q, - g->layer_raw_cache[il], - g->layer_attn_comp_cache[il], + layer->attn_sinks->abs_offset, + g->batch_q, + g->layer_raw_cache[il], + attn_comp_cache, g->comp_selected, n_tokens, pos0, @@ -12079,10 +12407,10 @@ static bool metal_graph_encode_layer_attention_batch( ok = ds4_gpu_attention_prefill_static_mixed_heads_tensor(g->batch_heads, model->map, model->size, - layer->attn_sinks->abs_offset, - g->batch_q, - g->batch_kv, - g->layer_attn_comp_cache[il], + layer->attn_sinks->abs_offset, + g->batch_q, + g->batch_kv, + attn_comp_cache, n_tokens, n_comp, g->raw_window, @@ -12196,10 +12524,10 @@ static bool metal_graph_encode_layer_attention_batch( layer->attn_sinks->abs_offset, q_view, g->layer_raw_cache[il], - n_raw, - g->raw_cap, - raw_start, - cur_comp ? g->layer_attn_comp_cache[il] : NULL, + n_raw, + g->raw_cap, + raw_start, + cur_comp ? attn_comp_cache : NULL, cur_comp, comp_mask, n_selected, @@ -13810,6 +14138,8 @@ ds4_context_memory ds4_context_memory_estimate(ds4_backend backend, int ctx_size m.comp_cap = ctx / min_ratio + 2u; if (m.comp_cap < 2u) m.comp_cap = 2u; + const int kv_quant_type = metal_graph_requested_kv_quant_type(false); + const uint64_t turbo_row_bytes = metal_graph_turbo_row_bytes_for_type(kv_quant_type); m.raw_bytes = (uint64_t)DS4_N_LAYER * m.raw_cap * DS4_N_HEAD_DIM * @@ -13818,9 +14148,13 @@ ds4_context_memory ds4_context_memory_estimate(ds4_backend backend, int ctx_size const uint32_t ratio = ds4_layer_compress_ratio(il); if (ratio == 0) continue; const uint32_t layer_comp_cap = ctx / ratio + 2u; - m.compressed_bytes += (uint64_t)layer_comp_cap * - DS4_N_HEAD_DIM * - sizeof(float); + if (turbo_row_bytes != 0) { + m.compressed_bytes += (uint64_t)layer_comp_cap * turbo_row_bytes; + } else { + m.compressed_bytes += (uint64_t)layer_comp_cap * + DS4_N_HEAD_DIM * + sizeof(float); + } if (ratio == 4) { m.compressed_bytes += (uint64_t)layer_comp_cap * DS4_N_INDEXER_HEAD_DIM * @@ -13831,6 +14165,9 @@ ds4_context_memory ds4_context_memory_estimate(ds4_backend backend, int ctx_size m.comp_cap * m.prefill_cap * sizeof(float); + if (turbo_row_bytes != 0) { + m.scratch_bytes += (uint64_t)m.comp_cap * DS4_N_HEAD_DIM * sizeof(float); + } } else { m.raw_cap = ds4_default_raw_cap(ctx); m.raw_bytes = (uint64_t)DS4_N_LAYER * @@ -13959,11 +14296,13 @@ static int metal_graph_prompt_logits_test( free(gpu_raw_phys); } - const uint32_t n_comp = cpu_cache.layer[il].n_comp; - if (n_comp == 0) continue; - const uint64_t n = (uint64_t)n_comp * DS4_N_HEAD_DIM; - float *gpu_comp = xmalloc((size_t)n * sizeof(float)); - if (ds4_gpu_tensor_read(g.layer_attn_comp_cache[il], 0, gpu_comp, n * sizeof(float)) != 0) { + const uint32_t n_comp = cpu_cache.layer[il].n_comp; + if (n_comp == 0) continue; + const uint64_t n = (uint64_t)n_comp * DS4_N_HEAD_DIM; + float *gpu_comp = xmalloc((size_t)n * sizeof(float)); + ds4_gpu_tensor *attn_comp_cache = metal_graph_attn_comp_cache(&g, il); + if (metal_graph_turbo_dequant_cache(&g, il, n_comp) && + ds4_gpu_tensor_read(attn_comp_cache, 0, gpu_comp, n * sizeof(float)) != 0) { fprintf(stderr, "ds4: comp trace layer %u n=%u attn_max=%g attn_rms=%g\n", il, n_comp, @@ -16049,14 +16388,22 @@ int ds4_session_save_payload(ds4_session *s, FILE *fp, char *err, size_t errlen) errlen); } const uint32_t ratio = ds4_layer_compress_ratio(il); - if (rc != 0 || ratio == 0) continue; - /* Compressed rows are append-only from row zero, so the live prefix is - * contiguous. The two compressor state tensors hold the partial window - * that will become the next compressed row. */ - rc = payload_write_tensor_span(fp, - g->layer_attn_comp_cache[il], - 0, - (uint64_t)g->layer_n_comp[il] * DS4_N_HEAD_DIM * sizeof(float), + if (rc != 0 || ratio == 0) continue; + /* Compressed rows are append-only from row zero, so the live prefix is + * contiguous. The two compressor state tensors hold the partial window + * that will become the next compressed row. */ + ds4_gpu_tensor *attn_payload_cache = metal_graph_attn_comp_cache(g, il); + if (metal_graph_turbo_enabled(g) && g->layer_n_comp[il] != 0 && + !metal_graph_turbo_dequant_cache(g, il, g->layer_n_comp[il])) + { + payload_set_err(err, errlen, "failed to dequantize Metal turbo cache for session payload"); + rc = 1; + continue; + } + rc = payload_write_tensor_span(fp, + attn_payload_cache, + 0, + (uint64_t)g->layer_n_comp[il] * DS4_N_HEAD_DIM * sizeof(float), buf, DS4_SESSION_IO_CHUNK, err, @@ -16358,18 +16705,25 @@ int ds4_session_load_payload(ds4_session *s, FILE *fp, uint64_t payload_bytes, c err, errlen); } - const uint32_t ratio = ds4_layer_compress_ratio(il); - if (rc != 0 || ratio == 0) continue; - rc = payload_read_tensor_span(fp, - g->layer_attn_comp_cache[il], - 0, - (uint64_t)n_comp[il] * DS4_N_HEAD_DIM * sizeof(float), - buf, - DS4_SESSION_IO_CHUNK, - &remaining, - err, - errlen); - if (rc == 0) rc = payload_read_tensor_span(fp, + const uint32_t ratio = ds4_layer_compress_ratio(il); + if (rc != 0 || ratio == 0) continue; + ds4_gpu_tensor *attn_payload_cache = metal_graph_attn_comp_cache(g, il); + rc = payload_read_tensor_span(fp, + attn_payload_cache, + 0, + (uint64_t)n_comp[il] * DS4_N_HEAD_DIM * sizeof(float), + buf, + DS4_SESSION_IO_CHUNK, + &remaining, + err, + errlen); + if (rc == 0 && metal_graph_turbo_enabled(g) && n_comp[il] != 0 && + !metal_graph_turbo_quantize_rows(g, il, attn_payload_cache, 0, n_comp[il])) + { + payload_set_err(err, errlen, "failed to rebuild Metal turbo cache from session payload"); + rc = 1; + } + if (rc == 0) rc = payload_read_tensor_span(fp, g->layer_attn_state_kv[il], 0, layer_attn_state_bytes(ratio), diff --git a/ds4_gpu.h b/ds4_gpu.h index 2d16c9c9..15f0bb77 100644 --- a/ds4_gpu.h +++ b/ds4_gpu.h @@ -248,6 +248,104 @@ int ds4_gpu_dsv4_fp8_kv_quantize_tensor( uint32_t head_dim, uint32_t n_rot); +/* ========================================================================= + * TurboQuant KV Cache Compression. + * ========================================================================= + * + * TurboQuant (arXiv 2504.19874) compresses KV cache rows with PolarQuant + + * Walsh-Hadamard rotation. 3-bit (turbo3) and 4-bit (turbo4) formats are + * supported. The Metal integration stores compressed rows in per-layer turbo + * caches and dequantizes them back to float32 scratch before passing them to + * the existing attention kernels. + */ + +/* TurboQuant types for kv_quant_type configuration. */ +#define DS4_KV_QUANT_FP8 0 +#define DS4_KV_QUANT_TURBO3 1 +#define DS4_KV_QUANT_TURBO4 2 + +/* Block sizes (bytes) for buffer allocation. */ +#define DS4_TURBO3_BLOCK_BYTES 14 /* sizeof(block_turbo3_0): qs[8] + signs[4] + norm(2) */ +#define DS4_TURBO4_BLOCK_BYTES 18 /* sizeof(block_turbo4_0): qs[16] + norm(2) */ + +/* Quantize a float32 KV tensor to TurboQuant blocks. The input tensor is laid + * out as [n_tok, n_head, head_dim]. The output receives packed blocks for the + * non-RoPE prefix, followed by the RoPE tail as raw float32. */ +int ds4_gpu_turbo3_kv_quantize_tensor( + ds4_gpu_tensor *out, + const ds4_gpu_tensor *x, + uint32_t n_tok, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot); + +int ds4_gpu_turbo4_kv_quantize_tensor( + ds4_gpu_tensor *out, + const ds4_gpu_tensor *x, + uint32_t n_tok, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot); + +/* Dequantize TurboQuant KV blocks to float32 for flash attention. */ +int ds4_gpu_turbo3_dequant_f32_tensor( + ds4_gpu_tensor *out, + const ds4_gpu_tensor *x, + uint32_t n_blocks, + uint32_t n_rot, + uint32_t n_rows); + +int ds4_gpu_turbo4_dequant_f32_tensor( + ds4_gpu_tensor *out, + const ds4_gpu_tensor *x, + uint32_t n_blocks, + uint32_t n_rot, + uint32_t n_rows); + +int ds4_gpu_turbo3_dequant_selected_f32_tensor( + ds4_gpu_tensor *out, + ds4_gpu_tensor *identity_topk, + const ds4_gpu_tensor *x, + const ds4_gpu_tensor *topk, + uint32_t n_blocks, + uint32_t n_rot, + uint32_t n_comp, + uint32_t top_k, + uint32_t n_tokens); + +int ds4_gpu_turbo4_dequant_selected_f32_tensor( + ds4_gpu_tensor *out, + ds4_gpu_tensor *identity_topk, + const ds4_gpu_tensor *x, + const ds4_gpu_tensor *topk, + uint32_t n_blocks, + uint32_t n_rot, + uint32_t n_comp, + uint32_t top_k, + uint32_t n_tokens); + +int ds4_gpu_turbo3_dequant_selected_f16_tensor( + ds4_gpu_tensor *out, + ds4_gpu_tensor *identity_topk, + const ds4_gpu_tensor *x, + const ds4_gpu_tensor *topk, + uint32_t n_blocks, + uint32_t n_rot, + uint32_t n_comp, + uint32_t top_k, + uint32_t n_tokens); + +int ds4_gpu_turbo4_dequant_selected_f16_tensor( + ds4_gpu_tensor *out, + ds4_gpu_tensor *identity_topk, + const ds4_gpu_tensor *x, + const ds4_gpu_tensor *topk, + uint32_t n_blocks, + uint32_t n_rot, + uint32_t n_comp, + uint32_t top_k, + uint32_t n_tokens); + int ds4_gpu_rope_tail_tensor( ds4_gpu_tensor *x, uint32_t n_tok, @@ -494,6 +592,51 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( uint32_t n_head, uint32_t head_dim); +int ds4_gpu_attention_indexed_mixed_comp_f16_batch_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv, + const ds4_gpu_tensor *comp_kv, + const ds4_gpu_tensor *topk, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t top_k, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim); + +int ds4_gpu_attention_indexed_mixed_turbo_batch_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv, + const ds4_gpu_tensor *comp_turbo_kv, + const ds4_gpu_tensor *topk, + uint32_t kv_quant_type, + uint32_t n_blocks, + uint32_t n_rot, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t top_k, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim); + int ds4_gpu_attention_prefill_static_mixed_heads_tensor( ds4_gpu_tensor *heads, const void *model_map, diff --git a/ds4_metal.m b/ds4_metal.m index 70349ca4..dc8529c2 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -83,6 +83,10 @@ static id g_dsv4_fp8_kv_quantize_pipeline; static id g_dsv4_kv_fp8_store_pipeline; static id g_dsv4_ratio4_shift_pipeline; +static id g_turbo3_quantize_pipeline; +static id g_turbo4_quantize_pipeline; +static id g_turbo3_dequant_f32_pipeline; +static id g_turbo4_dequant_f32_pipeline; static id g_dsv4_softmax_pool_pipeline; static id g_soft_max_f32_pipeline; static id g_soft_max_f32_4_pipeline; @@ -1231,6 +1235,7 @@ void ds4_gpu_set_quality(bool quality) { @[@"DS4_METAL_NORM_SOURCE", @"metal/norm.metal"], @[@"DS4_METAL_BIN_SOURCE", @"metal/bin.metal"], @[@"DS4_METAL_SET_ROWS_SOURCE", @"metal/set_rows.metal"], + @[@"DS4_METAL_TURBO_QUANT_SOURCE", @"metal/turbo_quant.metal"], ]; NSMutableString *source = [NSMutableString stringWithString:base]; @@ -2487,6 +2492,41 @@ static int ds4_gpu_encode_rope_tail_inplace( uint32_t ape_type; } ds4_gpu_dsv4_compressor_store_one_args; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + int32_t n_rot; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ds4_metal_args_turbo_quantize; + +typedef struct { + int32_t n_blocks; + int32_t n_rot; + int32_t n_rows; + uint64_t src_stride; + uint64_t dst_stride; +} ds4_metal_args_turbo_dequant; + +typedef struct { + int32_t n_blocks; + int32_t n_rot; + int32_t n_comp; + int32_t top_k; + int32_t n_tokens; + uint64_t src_stride; + uint64_t dst_stride; + uint64_t topk_token_stride; +} ds4_metal_args_turbo_select_dequant; + typedef struct { int64_t ne00; int64_t ne01; @@ -2634,6 +2674,8 @@ static int ds4_gpu_encode_rope_tail_inplace( uint64_t dst_token_stride; uint64_t dst_head_stride; float scale; + uint32_t turbo_n_blocks; + uint32_t turbo_n_rot; } ds4_gpu_dsv4_indexed_attention_args; typedef struct { @@ -2917,6 +2959,70 @@ int ds4_gpu_init(void) { return 0; } + fn = [library newFunctionWithName:@"kernel_turbo3_quantize_f32"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_turbo3_quantize_f32 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + g_turbo3_quantize_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_turbo3_quantize_pipeline) { + fprintf(stderr, "ds4: Metal kernel_turbo3_quantize_f32 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + + fn = [library newFunctionWithName:@"kernel_turbo4_quantize_f32"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_turbo4_quantize_f32 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + g_turbo4_quantize_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_turbo4_quantize_pipeline) { + fprintf(stderr, "ds4: Metal kernel_turbo4_quantize_f32 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + + fn = [library newFunctionWithName:@"kernel_turbo3_dequant_f32"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_turbo3_dequant_f32 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + g_turbo3_dequant_f32_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_turbo3_dequant_f32_pipeline) { + fprintf(stderr, "ds4: Metal kernel_turbo3_dequant_f32 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + + fn = [library newFunctionWithName:@"kernel_turbo4_dequant_f32"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_turbo4_dequant_f32 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + g_turbo4_dequant_f32_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_turbo4_dequant_f32_pipeline) { + fprintf(stderr, "ds4: Metal kernel_turbo4_dequant_f32 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + fn = [library newFunctionWithName:@"kernel_swiglu_f32"]; if (!fn) { fprintf(stderr, "ds4: Metal kernel_swiglu_f32 function not found\n"); @@ -4014,6 +4120,10 @@ void ds4_gpu_cleanup(void) { g_dsv4_fp8_kv_quantize_pipeline = nil; g_dsv4_kv_fp8_store_pipeline = nil; g_dsv4_ratio4_shift_pipeline = nil; + g_turbo3_quantize_pipeline = nil; + g_turbo4_quantize_pipeline = nil; + g_turbo3_dequant_f32_pipeline = nil; + g_turbo4_dequant_f32_pipeline = nil; g_dsv4_softmax_pool_pipeline = nil; g_soft_max_f32_pipeline = nil; g_soft_max_f32_4_pipeline = nil; @@ -5817,6 +5927,412 @@ int ds4_gpu_dsv4_fp8_kv_quantize_tensor( return 1; } +/* ========================================================================= + * TurboQuant dispatch functions. + * ========================================================================= */ + +int ds4_gpu_turbo3_kv_quantize_tensor( + ds4_gpu_tensor *out, + const ds4_gpu_tensor *x, + uint32_t n_tok, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if (!x || !out || n_tok == 0 || n_head == 0 || head_dim == 0 || n_rot >= head_dim) return 0; + + const int32_t n_nope = (int32_t)head_dim - (int32_t)n_rot; + const int32_t n_blocks = n_nope / 32; + if ((n_nope % 32) != 0) return 0; + if (n_blocks == 0) return 1; + + @autoreleasepool { + id srcbuf = ds4_gpu_tensor_buffer(x); + id dstbuf = ds4_gpu_tensor_buffer(out); + const uint64_t src_bytes = (uint64_t)n_tok * n_head * head_dim * sizeof(float); + const uint64_t dst_bytes = (uint64_t)n_tok * n_head * + ((uint64_t)n_blocks * DS4_TURBO3_BLOCK_BYTES + (uint64_t)n_rot * sizeof(float)); + if (!srcbuf || ds4_gpu_tensor_bytes(x) < src_bytes) { + fprintf(stderr, "ds4: Metal turbo3 quantize undersized source buffer\n"); + return 0; + } + if (!dstbuf || ds4_gpu_tensor_bytes(out) < dst_bytes) { + fprintf(stderr, "ds4: Metal turbo3 quantize undersized dest buffer (%llu < %llu)\n", + dstbuf ? ds4_gpu_tensor_bytes(out) : 0ULL, dst_bytes); + return 0; + } + + ds4_metal_args_turbo_quantize args = { + .ne00 = (int32_t)head_dim, + .ne01 = (int32_t)n_tok, + .ne02 = (int32_t)n_head, + .ne03 = 1, + .n_rot = (int32_t)n_rot, + .nb00 = sizeof(float), + .nb01 = (uint64_t)n_head * head_dim * sizeof(float), + .nb02 = (uint64_t)head_dim * sizeof(float), + .nb03 = (uint64_t)n_tok * n_head * head_dim * sizeof(float), + .nb0 = DS4_TURBO3_BLOCK_BYTES, + .nb1 = (uint64_t)n_blocks * DS4_TURBO3_BLOCK_BYTES + (uint64_t)n_rot * sizeof(float), + .nb2 = (uint64_t)n_tok * ((uint64_t)n_blocks * DS4_TURBO3_BLOCK_BYTES + (uint64_t)n_rot * sizeof(float)), + .nb3 = (uint64_t)n_tok * n_head * ((uint64_t)n_blocks * DS4_TURBO3_BLOCK_BYTES + (uint64_t)n_rot * sizeof(float)), + }; + + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:g_turbo3_quantize_pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:srcbuf offset:ds4_gpu_tensor_offset(x) atIndex:1]; + [enc setBuffer:dstbuf offset:ds4_gpu_tensor_offset(out) atIndex:2]; + [enc setThreadgroupMemoryLength:32u * sizeof(float) atIndex:0]; + + const NSUInteger total_blocks = (NSUInteger)n_tok * n_head * n_blocks; + [enc dispatchThreadgroups:MTLSizeMake(total_blocks, 1, 1) + threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + + if (!ds4_gpu_finish_command_buffer(cb, owned, "Turbo3 KV quantize")) return 0; + } + return 1; +} + +int ds4_gpu_turbo4_kv_quantize_tensor( + ds4_gpu_tensor *out, + const ds4_gpu_tensor *x, + uint32_t n_tok, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if (!x || !out || n_tok == 0 || n_head == 0 || head_dim == 0 || n_rot >= head_dim) return 0; + + const int32_t n_nope = (int32_t)head_dim - (int32_t)n_rot; + const int32_t n_blocks = n_nope / 32; + if ((n_nope % 32) != 0) return 0; + if (n_blocks == 0) return 1; + + @autoreleasepool { + id srcbuf = ds4_gpu_tensor_buffer(x); + id dstbuf = ds4_gpu_tensor_buffer(out); + const uint64_t src_bytes = (uint64_t)n_tok * n_head * head_dim * sizeof(float); + const uint64_t dst_bytes = (uint64_t)n_tok * n_head * + ((uint64_t)n_blocks * DS4_TURBO4_BLOCK_BYTES + (uint64_t)n_rot * sizeof(float)); + if (!srcbuf || ds4_gpu_tensor_bytes(x) < src_bytes) { + fprintf(stderr, "ds4: Metal turbo4 quantize undersized source buffer\n"); + return 0; + } + if (!dstbuf || ds4_gpu_tensor_bytes(out) < dst_bytes) { + fprintf(stderr, "ds4: Metal turbo4 quantize undersized dest buffer\n"); + return 0; + } + + ds4_metal_args_turbo_quantize args = { + .ne00 = (int32_t)head_dim, + .ne01 = (int32_t)n_tok, + .ne02 = (int32_t)n_head, + .ne03 = 1, + .n_rot = (int32_t)n_rot, + .nb00 = sizeof(float), + .nb01 = (uint64_t)n_head * head_dim * sizeof(float), + .nb02 = (uint64_t)head_dim * sizeof(float), + .nb03 = (uint64_t)n_tok * n_head * head_dim * sizeof(float), + .nb0 = DS4_TURBO4_BLOCK_BYTES, + .nb1 = (uint64_t)n_blocks * DS4_TURBO4_BLOCK_BYTES + (uint64_t)n_rot * sizeof(float), + .nb2 = (uint64_t)n_tok * ((uint64_t)n_blocks * DS4_TURBO4_BLOCK_BYTES + (uint64_t)n_rot * sizeof(float)), + .nb3 = (uint64_t)n_tok * n_head * ((uint64_t)n_blocks * DS4_TURBO4_BLOCK_BYTES + (uint64_t)n_rot * sizeof(float)), + }; + + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:g_turbo4_quantize_pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:srcbuf offset:ds4_gpu_tensor_offset(x) atIndex:1]; + [enc setBuffer:dstbuf offset:ds4_gpu_tensor_offset(out) atIndex:2]; + [enc setThreadgroupMemoryLength:32u * sizeof(float) atIndex:0]; + + const NSUInteger total_blocks = (NSUInteger)n_tok * n_head * n_blocks; + [enc dispatchThreadgroups:MTLSizeMake(total_blocks, 1, 1) + threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + + if (!ds4_gpu_finish_command_buffer(cb, owned, "Turbo4 KV quantize")) return 0; + } + return 1; +} + +int ds4_gpu_turbo3_dequant_f32_tensor( + ds4_gpu_tensor *out, + const ds4_gpu_tensor *x, + uint32_t n_blocks, + uint32_t n_rot, + uint32_t n_rows) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if (!x || !out || n_rows == 0 || n_blocks == 0) return 0; + + @autoreleasepool { + id srcbuf = ds4_gpu_tensor_buffer(x); + id dstbuf = ds4_gpu_tensor_buffer(out); + const uint64_t src_row_bytes = (uint64_t)n_blocks * DS4_TURBO3_BLOCK_BYTES + (uint64_t)n_rot * sizeof(float); + const uint64_t dst_row_bytes = (uint64_t)(n_blocks * 32 + n_rot) * sizeof(float); + const uint64_t src_bytes = (uint64_t)n_rows * src_row_bytes; + const uint64_t dst_bytes = (uint64_t)n_rows * dst_row_bytes; + if (!srcbuf || ds4_gpu_tensor_bytes(x) < src_bytes) return 0; + if (!dstbuf || ds4_gpu_tensor_bytes(out) < dst_bytes) return 0; + + ds4_metal_args_turbo_dequant args = { + .n_blocks = (int32_t)n_blocks, + .n_rot = (int32_t)n_rot, + .n_rows = (int32_t)n_rows, + .src_stride = src_row_bytes, + .dst_stride = dst_row_bytes, + }; + + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:g_turbo3_dequant_f32_pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:srcbuf offset:ds4_gpu_tensor_offset(x) atIndex:1]; + [enc setBuffer:dstbuf offset:ds4_gpu_tensor_offset(out) atIndex:2]; + [enc setThreadgroupMemoryLength:32u * sizeof(float) atIndex:0]; + + const NSUInteger total_blocks = (NSUInteger)n_rows * n_blocks; + [enc dispatchThreadgroups:MTLSizeMake(total_blocks, 1, 1) + threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + + if (!ds4_gpu_finish_command_buffer(cb, owned, "Turbo3 dequant f32")) return 0; + } + return 1; +} + +int ds4_gpu_turbo4_dequant_f32_tensor( + ds4_gpu_tensor *out, + const ds4_gpu_tensor *x, + uint32_t n_blocks, + uint32_t n_rot, + uint32_t n_rows) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if (!x || !out || n_rows == 0 || n_blocks == 0) return 0; + + @autoreleasepool { + id srcbuf = ds4_gpu_tensor_buffer(x); + id dstbuf = ds4_gpu_tensor_buffer(out); + const uint64_t src_row_bytes = (uint64_t)n_blocks * DS4_TURBO4_BLOCK_BYTES + (uint64_t)n_rot * sizeof(float); + const uint64_t dst_row_bytes = (uint64_t)(n_blocks * 32 + n_rot) * sizeof(float); + const uint64_t src_bytes = (uint64_t)n_rows * src_row_bytes; + const uint64_t dst_bytes = (uint64_t)n_rows * dst_row_bytes; + if (!srcbuf || ds4_gpu_tensor_bytes(x) < src_bytes) return 0; + if (!dstbuf || ds4_gpu_tensor_bytes(out) < dst_bytes) return 0; + + ds4_metal_args_turbo_dequant args = { + .n_blocks = (int32_t)n_blocks, + .n_rot = (int32_t)n_rot, + .n_rows = (int32_t)n_rows, + .src_stride = src_row_bytes, + .dst_stride = dst_row_bytes, + }; + + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:g_turbo4_dequant_f32_pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:srcbuf offset:ds4_gpu_tensor_offset(x) atIndex:1]; + [enc setBuffer:dstbuf offset:ds4_gpu_tensor_offset(out) atIndex:2]; + [enc setThreadgroupMemoryLength:32u * sizeof(float) atIndex:0]; + + const NSUInteger total_blocks = (NSUInteger)n_rows * n_blocks; + [enc dispatchThreadgroups:MTLSizeMake(total_blocks, 1, 1) + threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + + if (!ds4_gpu_finish_command_buffer(cb, owned, "Turbo4 dequant f32")) return 0; + } + return 1; +} + +static int ds4_gpu_turbo_dequant_selected_tensor( + const char *kernel_name, + uint32_t block_bytes, + uint32_t dst_value_bytes, + ds4_gpu_tensor *out, + ds4_gpu_tensor *identity_topk, + const ds4_gpu_tensor *x, + const ds4_gpu_tensor *topk, + uint32_t n_blocks, + uint32_t n_rot, + uint32_t n_comp, + uint32_t top_k, + uint32_t n_tokens) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if (!out || !identity_topk || !x || !topk || + n_blocks == 0 || top_k == 0 || n_tokens == 0 || n_comp < top_k) { + return 0; + } + + @autoreleasepool { + id srcbuf = ds4_gpu_tensor_buffer(x); + id topkbuf = ds4_gpu_tensor_buffer(topk); + id dstbuf = ds4_gpu_tensor_buffer(out); + id idbuf = ds4_gpu_tensor_buffer(identity_topk); + const uint64_t src_row_bytes = (uint64_t)n_blocks * block_bytes + (uint64_t)n_rot * sizeof(float); + const uint64_t dst_row_bytes = (uint64_t)(n_blocks * 32u + n_rot) * dst_value_bytes; + const uint64_t src_bytes = (uint64_t)n_comp * src_row_bytes; + const uint64_t dst_bytes = (uint64_t)n_tokens * top_k * dst_row_bytes; + const uint64_t topk_bytes = (uint64_t)n_tokens * top_k * sizeof(int32_t); + if (!srcbuf || !topkbuf || !dstbuf || !idbuf || + ds4_gpu_tensor_bytes(x) < src_bytes || + ds4_gpu_tensor_bytes(topk) < topk_bytes || + ds4_gpu_tensor_bytes(out) < dst_bytes || + ds4_gpu_tensor_bytes(identity_topk) < topk_bytes) { + fprintf(stderr, "ds4: Metal turbo selected dequant received undersized buffers\n"); + return 0; + } + + id pipeline = ds4_gpu_get_pipeline(kernel_name); + if (!pipeline) return 0; + + ds4_metal_args_turbo_select_dequant args = { + .n_blocks = (int32_t)n_blocks, + .n_rot = (int32_t)n_rot, + .n_comp = (int32_t)n_comp, + .top_k = (int32_t)top_k, + .n_tokens = (int32_t)n_tokens, + .src_stride = src_row_bytes, + .dst_stride = dst_row_bytes, + .topk_token_stride = (uint64_t)top_k * sizeof(int32_t), + }; + + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:srcbuf offset:ds4_gpu_tensor_offset(x) atIndex:1]; + [enc setBuffer:topkbuf offset:ds4_gpu_tensor_offset(topk) atIndex:2]; + [enc setBuffer:dstbuf offset:ds4_gpu_tensor_offset(out) atIndex:3]; + [enc setBuffer:idbuf offset:ds4_gpu_tensor_offset(identity_topk) atIndex:4]; + [enc setThreadgroupMemoryLength:32u * sizeof(float) atIndex:0]; + + const NSUInteger total_blocks = (NSUInteger)n_tokens * top_k * n_blocks; + [enc dispatchThreadgroups:MTLSizeMake(total_blocks, 1, 1) + threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + + if (!ds4_gpu_finish_command_buffer(cb, owned, kernel_name)) return 0; + } + return 1; +} + +int ds4_gpu_turbo3_dequant_selected_f32_tensor( + ds4_gpu_tensor *out, + ds4_gpu_tensor *identity_topk, + const ds4_gpu_tensor *x, + const ds4_gpu_tensor *topk, + uint32_t n_blocks, + uint32_t n_rot, + uint32_t n_comp, + uint32_t top_k, + uint32_t n_tokens) { + return ds4_gpu_turbo_dequant_selected_tensor("kernel_turbo3_dequant_selected_f32", + DS4_TURBO3_BLOCK_BYTES, + sizeof(float), + out, + identity_topk, + x, + topk, + n_blocks, + n_rot, + n_comp, + top_k, + n_tokens); +} + +int ds4_gpu_turbo4_dequant_selected_f32_tensor( + ds4_gpu_tensor *out, + ds4_gpu_tensor *identity_topk, + const ds4_gpu_tensor *x, + const ds4_gpu_tensor *topk, + uint32_t n_blocks, + uint32_t n_rot, + uint32_t n_comp, + uint32_t top_k, + uint32_t n_tokens) { + return ds4_gpu_turbo_dequant_selected_tensor("kernel_turbo4_dequant_selected_f32", + DS4_TURBO4_BLOCK_BYTES, + sizeof(float), + out, + identity_topk, + x, + topk, + n_blocks, + n_rot, + n_comp, + top_k, + n_tokens); +} + +int ds4_gpu_turbo3_dequant_selected_f16_tensor( + ds4_gpu_tensor *out, + ds4_gpu_tensor *identity_topk, + const ds4_gpu_tensor *x, + const ds4_gpu_tensor *topk, + uint32_t n_blocks, + uint32_t n_rot, + uint32_t n_comp, + uint32_t top_k, + uint32_t n_tokens) { + return ds4_gpu_turbo_dequant_selected_tensor("kernel_turbo3_dequant_selected_f16", + DS4_TURBO3_BLOCK_BYTES, + sizeof(uint16_t), + out, + identity_topk, + x, + topk, + n_blocks, + n_rot, + n_comp, + top_k, + n_tokens); +} + +int ds4_gpu_turbo4_dequant_selected_f16_tensor( + ds4_gpu_tensor *out, + ds4_gpu_tensor *identity_topk, + const ds4_gpu_tensor *x, + const ds4_gpu_tensor *topk, + uint32_t n_blocks, + uint32_t n_rot, + uint32_t n_comp, + uint32_t top_k, + uint32_t n_tokens) { + return ds4_gpu_turbo_dequant_selected_tensor("kernel_turbo4_dequant_selected_f16", + DS4_TURBO4_BLOCK_BYTES, + sizeof(uint16_t), + out, + identity_topk, + x, + topk, + n_blocks, + n_rot, + n_comp, + top_k, + n_tokens); +} + static void ds4_gpu_set_rows_thread_shape( uint32_t width, NSUInteger *nth_out, @@ -10758,7 +11274,7 @@ int ds4_gpu_attention_decode_mixed_batch_heads_tensor( return 1; } -int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( +static int ds4_gpu_attention_indexed_mixed_batch_heads_impl( ds4_gpu_tensor *heads, const void *model_map, uint64_t model_size, @@ -10777,7 +11293,8 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( uint32_t window, uint32_t ratio, uint32_t n_head, - uint32_t head_dim) { + uint32_t head_dim, + bool comp_f16) { if (!g_initialized && !ds4_gpu_init()) return 0; if (!heads || !model_map || !q || !raw_kv || !comp_kv || !topk || n_tokens == 0 || n_raw == 0 || raw_cap < n_raw || raw_start >= raw_cap || @@ -10793,9 +11310,12 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( } const uint64_t row_bytes = (uint64_t)head_dim * sizeof(float); + const uint64_t comp_row_bytes = comp_f16 + ? (uint64_t)head_dim * sizeof(uint16_t) + : row_bytes; const uint64_t q_bytes = (uint64_t)n_tokens * n_head * row_bytes; const uint64_t raw_bytes = (uint64_t)raw_cap * row_bytes; - const uint64_t comp_bytes = (uint64_t)n_comp * row_bytes; + const uint64_t comp_bytes = (uint64_t)n_comp * comp_row_bytes; const uint64_t topk_bytes = (uint64_t)top_k * n_tokens * sizeof(int32_t); id qbuf = ds4_gpu_tensor_buffer(q); id rawbuf = ds4_gpu_tensor_buffer(raw_kv); @@ -10823,12 +11343,18 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( ds4_gpu_hot_pipeline(g_dsv4_sort_i32_rows_asc_pipeline, "kernel_dsv4_sort_i32_rows_asc"); const bool decode_one_token = n_tokens == 1u; - id attn_pipeline = - decode_one_token ? - ds4_gpu_hot_pipeline(g_dsv4_indexed_attention_heads8_rb4_pipeline, - "kernel_dsv4_indexed_mixed_attention_heads8_rb4") : - ds4_gpu_hot_pipeline(g_dsv4_indexed_attention_heads8_pipeline, - "kernel_dsv4_indexed_mixed_attention_heads8"); + id attn_pipeline = nil; + if (comp_f16) { + attn_pipeline = ds4_gpu_get_pipeline(decode_one_token + ? "kernel_dsv4_indexed_mixed_attention_heads8_rb4_comp_f16" + : "kernel_dsv4_indexed_mixed_attention_heads8_comp_f16"); + } else { + attn_pipeline = decode_one_token ? + ds4_gpu_hot_pipeline(g_dsv4_indexed_attention_heads8_rb4_pipeline, + "kernel_dsv4_indexed_mixed_attention_heads8_rb4") : + ds4_gpu_hot_pipeline(g_dsv4_indexed_attention_heads8_pipeline, + "kernel_dsv4_indexed_mixed_attention_heads8"); + } if (!sort_pipeline || !attn_pipeline) return 0; if ((NSUInteger)top_k > sort_pipeline.maxTotalThreadsPerThreadgroup) { fprintf(stderr, "ds4: Metal indexed attention top-k exceeds sort threadgroup limit\n"); @@ -10872,7 +11398,7 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( .q_token_stride = (uint64_t)n_head * row_bytes, .q_head_stride = row_bytes, .raw_row_stride = row_bytes, - .comp_row_stride = row_bytes, + .comp_row_stride = comp_row_bytes, .topk_token_stride = (uint64_t)top_k * sizeof(int32_t), .dst_token_stride = (uint64_t)n_head * row_bytes, .dst_head_stride = row_bytes, @@ -10913,7 +11439,269 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( threadsPerThreadgroup:MTLSizeMake(32, 8, 1)]; ds4_gpu_end_compute_encoder(cb, enc); - if (!ds4_gpu_finish_command_buffer(cb, owned, "graph indexed mixed attention heads")) return 0; + if (!ds4_gpu_finish_command_buffer(cb, owned, comp_f16 + ? "graph indexed mixed comp-f16 attention heads" + : "graph indexed mixed attention heads")) { + return 0; + } + } + + return 1; +} + +int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv, + const ds4_gpu_tensor *comp_kv, + const ds4_gpu_tensor *topk, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t top_k, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + return ds4_gpu_attention_indexed_mixed_batch_heads_impl(heads, + model_map, + model_size, + sinks_offset, + q, + raw_kv, + comp_kv, + topk, + n_tokens, + pos0, + n_raw, + raw_cap, + raw_start, + n_comp, + top_k, + window, + ratio, + n_head, + head_dim, + false); +} + +int ds4_gpu_attention_indexed_mixed_comp_f16_batch_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv, + const ds4_gpu_tensor *comp_kv, + const ds4_gpu_tensor *topk, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t top_k, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + return ds4_gpu_attention_indexed_mixed_batch_heads_impl(heads, + model_map, + model_size, + sinks_offset, + q, + raw_kv, + comp_kv, + topk, + n_tokens, + pos0, + n_raw, + raw_cap, + raw_start, + n_comp, + top_k, + window, + ratio, + n_head, + head_dim, + true); +} + +int ds4_gpu_attention_indexed_mixed_turbo_batch_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv, + const ds4_gpu_tensor *comp_turbo_kv, + const ds4_gpu_tensor *topk, + uint32_t kv_quant_type, + uint32_t n_blocks, + uint32_t n_rot, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t top_k, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if (!heads || !model_map || !q || !raw_kv || !comp_turbo_kv || !topk || + n_tokens == 0 || n_raw == 0 || raw_cap < n_raw || raw_start >= raw_cap || + n_comp == 0 || top_k == 0 || top_k > n_comp || (top_k & (top_k - 1u)) != 0 || + ratio == 0 || n_head == 0 || head_dim != 512 || n_rot >= head_dim || + n_blocks == 0 || n_blocks * 32u + n_rot != head_dim || + (kv_quant_type != DS4_KV_QUANT_TURBO3 && kv_quant_type != DS4_KV_QUANT_TURBO4)) { + return 0; + } + + @autoreleasepool { + if (sinks_offset > model_size || (uint64_t)n_head * sizeof(float) > model_size - sinks_offset) { + fprintf(stderr, "ds4: Metal indexed turbo attention sinks range is outside the mapped model\n"); + return 0; + } + + const uint64_t row_bytes = (uint64_t)head_dim * sizeof(float); + const uint64_t turbo_block_bytes = kv_quant_type == DS4_KV_QUANT_TURBO3 + ? DS4_TURBO3_BLOCK_BYTES + : DS4_TURBO4_BLOCK_BYTES; + const uint64_t turbo_row_bytes = + (uint64_t)n_blocks * turbo_block_bytes + (uint64_t)n_rot * sizeof(float); + const uint64_t q_bytes = (uint64_t)n_tokens * n_head * row_bytes; + const uint64_t raw_bytes = (uint64_t)raw_cap * row_bytes; + const uint64_t comp_bytes = (uint64_t)n_comp * turbo_row_bytes; + const uint64_t topk_bytes = (uint64_t)top_k * n_tokens * sizeof(int32_t); + id qbuf = ds4_gpu_tensor_buffer(q); + id rawbuf = ds4_gpu_tensor_buffer(raw_kv); + id compbuf = ds4_gpu_tensor_buffer(comp_turbo_kv); + id topkbuf = ds4_gpu_tensor_buffer(topk); + id headsbuf = ds4_gpu_tensor_buffer(heads); + if (!qbuf || !rawbuf || !compbuf || !topkbuf || !headsbuf || + ds4_gpu_tensor_bytes(q) < q_bytes || + ds4_gpu_tensor_bytes(raw_kv) < raw_bytes || + ds4_gpu_tensor_bytes(comp_turbo_kv) < comp_bytes || + ds4_gpu_tensor_bytes(topk) < topk_bytes || + ds4_gpu_tensor_bytes(heads) < q_bytes) { + fprintf(stderr, "ds4: Metal indexed mixed turbo attention received undersized buffers\n"); + return 0; + } + + uint64_t sinks_inner = 0; + id sinks_buf = ds4_gpu_wrap_model_range(model_map, model_size, + sinks_offset, + (uint64_t)n_head * sizeof(float), + &sinks_inner); + if (!sinks_buf) return 0; + + id sort_pipeline = + ds4_gpu_hot_pipeline(g_dsv4_sort_i32_rows_asc_pipeline, + "kernel_dsv4_sort_i32_rows_asc"); + const bool decode_one_token = n_tokens == 1u; + const bool turbo4 = kv_quant_type == DS4_KV_QUANT_TURBO4; + const char *attn_name = decode_one_token + ? (turbo4 + ? "kernel_dsv4_indexed_mixed_attention_heads8_rb4_turbo4" + : "kernel_dsv4_indexed_mixed_attention_heads8_rb4_turbo3") + : (turbo4 + ? "kernel_dsv4_indexed_mixed_attention_heads8_turbo4" + : "kernel_dsv4_indexed_mixed_attention_heads8_turbo3"); + id attn_pipeline = ds4_gpu_get_pipeline(attn_name); + if (!sort_pipeline || !attn_pipeline) return 0; + if ((NSUInteger)top_k > sort_pipeline.maxTotalThreadsPerThreadgroup) { + fprintf(stderr, "ds4: Metal indexed turbo attention top-k exceeds sort threadgroup limit\n"); + return 0; + } + + const bool skip_decode_sort = !g_quality_mode && decode_one_token; + if (!skip_decode_sort && + !ds4_gpu_ensure_scratch_buffer(&g_indexed_topk_buffer, + &g_indexed_topk_bytes, + (NSUInteger)topk_bytes, + "ds4_indexed_topk_sorted")) { + return 0; + } + + ds4_gpu_dsv4_topk_mask_args sort_args = { + .ne00 = (int64_t)top_k, + .ne01 = (int64_t)n_tokens, + .nb00 = sizeof(int32_t), + .nb01 = (uint64_t)top_k * sizeof(int32_t), + .ne0 = (int64_t)top_k, + .ne1 = (int64_t)n_tokens, + .nb0 = sizeof(int32_t), + .nb1 = (uint64_t)top_k * sizeof(int32_t), + }; + ds4_gpu_dsv4_indexed_attention_args attn_args = { + .n_tokens = n_tokens, + .n_head = n_head, + .n_raw = n_raw, + .raw_cap = raw_cap, + .raw_start = raw_start, + .n_comp = n_comp, + .top_k = top_k, + .pos0 = pos0, + .window = window, + .ratio = ratio, + .q_token_stride = (uint64_t)n_head * row_bytes, + .q_head_stride = row_bytes, + .raw_row_stride = row_bytes, + .comp_row_stride = turbo_row_bytes, + .topk_token_stride = (uint64_t)top_k * sizeof(int32_t), + .dst_token_stride = (uint64_t)n_head * row_bytes, + .dst_head_stride = row_bytes, + .scale = 1.0f / sqrtf((float)head_dim), + .turbo_n_blocks = n_blocks, + .turbo_n_rot = n_rot, + }; + + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + + id enc = nil; + if (!skip_decode_sort) { + enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:sort_pipeline]; + [enc setBytes:&sort_args length:sizeof(sort_args) atIndex:0]; + [enc setBuffer:topkbuf offset:ds4_gpu_tensor_offset(topk) atIndex:1]; + [enc setBuffer:g_indexed_topk_buffer offset:0 atIndex:2]; + [enc setThreadgroupMemoryLength:(NSUInteger)top_k * sizeof(int32_t) atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(n_tokens, 1, 1) + threadsPerThreadgroup:MTLSizeMake(top_k, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + } + + const NSUInteger staged_rows = decode_one_token ? 4u : 1u; + const NSUInteger shared_bytes = + staged_rows * 128u * 4u * sizeof(float) + 8u * 32u * sizeof(float); + enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:attn_pipeline]; + [enc setBytes:&attn_args length:sizeof(attn_args) atIndex:0]; + [enc setBuffer:qbuf offset:ds4_gpu_tensor_offset(q) atIndex:1]; + [enc setBuffer:rawbuf offset:ds4_gpu_tensor_offset(raw_kv) atIndex:2]; + [enc setBuffer:compbuf offset:ds4_gpu_tensor_offset(comp_turbo_kv) atIndex:3]; + [enc setBuffer:skip_decode_sort ? topkbuf : g_indexed_topk_buffer + offset:skip_decode_sort ? ds4_gpu_tensor_offset(topk) : 0 + atIndex:4]; + [enc setBuffer:sinks_buf offset:(NSUInteger)sinks_inner atIndex:5]; + [enc setBuffer:headsbuf offset:ds4_gpu_tensor_offset(heads) atIndex:6]; + [enc setThreadgroupMemoryLength:shared_bytes atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake((NSUInteger)n_tokens, ((NSUInteger)n_head + 7u) / 8u, 1) + threadsPerThreadgroup:MTLSizeMake(32, 8, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + + if (!ds4_gpu_finish_command_buffer(cb, owned, "graph indexed mixed turbo attention heads")) return 0; } return 1; diff --git a/metal/dsv4_misc.metal b/metal/dsv4_misc.metal index b06d29d3..575f0aef 100644 --- a/metal/dsv4_misc.metal +++ b/metal/dsv4_misc.metal @@ -62,6 +62,8 @@ struct ds4_metal_args_dsv4_indexed_attention { uint64_t dst_token_stride; uint64_t dst_head_stride; float scale; + uint32_t turbo_n_blocks; + uint32_t turbo_n_rot; }; struct ds4_metal_args_dsv4_indexer_scores_fused { diff --git a/metal/turbo_quant.metal b/metal/turbo_quant.metal new file mode 100644 index 00000000..5496dfe2 --- /dev/null +++ b/metal/turbo_quant.metal @@ -0,0 +1,1194 @@ +// TurboQuant Metal shaders for DS4 +// KV cache compression via PolarQuant + Walsh-Hadamard rotation +// Based on: arXiv 2504.19874 (ICLR 2026) +// Reference: github.com/TheTom/llama-cpp-turboquant +// +// 3-bit (turbo3_0): 8 centroids, 14 bytes per 32-element block +// 4-bit (turbo4_0): 16 centroids, 18 bytes per 32-element block +// Per-block FWHT rotation (32-element butterfly, 160 ops vs 1024 for dense) +// +// NOTE: This file is concatenated into the single Metal library by ds4_metal.m. +// All content must be self-contained (no #include of other metal headers). + +// --- Block type definitions --- + +#define QK_TURBO 32 + +struct block_turbo3_0 { + // 2 low bits per element: 32 * 2 = 64 bits = 8 bytes + uchar qs[QK_TURBO / 4]; + // 1 high bit per element: 32 * 1 = 32 bits = 4 bytes + uchar signs[QK_TURBO / 8]; + // per-block norm (fp16) + half norm; +}; + +struct block_turbo4_0 { + // 4 bits (nibble) per element: 32 * 4 = 128 bits = 16 bytes + uchar qs[QK_TURBO / 2]; + // per-block norm (fp16) + half norm; +}; + +// --- FWHT-32 sign array --- +constant float turbo_signs_32[32] = { + +1.0f, -1.0f, +1.0f, -1.0f, +1.0f, +1.0f, -1.0f, +1.0f, + -1.0f, -1.0f, +1.0f, -1.0f, +1.0f, +1.0f, -1.0f, +1.0f, + -1.0f, -1.0f, +1.0f, -1.0f, +1.0f, -1.0f, -1.0f, +1.0f, + -1.0f, +1.0f, +1.0f, -1.0f, +1.0f, -1.0f, -1.0f, +1.0f, +}; + +// --- 3-bit centroid tables --- +constant float turbo_centroids_3bit[8] = { + -0.190685f, -0.117832f, -0.065717f, -0.021460f, + 0.021460f, 0.065717f, 0.117832f, 0.190685f +}; + +constant float turbo_mid_3bit[7] = { + -0.154259f, -0.091775f, -0.043589f, 0.0f, 0.043589f, 0.091775f, 0.154259f +}; + +// --- 4-bit centroid tables --- +constant float turbo_centroids_4bit[16] = { + -0.173926f, -0.117195f, -0.089527f, -0.068756f, + -0.051262f, -0.035597f, -0.020989f, -0.006938f, + 0.006938f, 0.020989f, 0.035597f, 0.051262f, + 0.068756f, 0.089527f, 0.117195f, 0.173926f +}; + +constant float turbo_mid_4bit[15] = { + -0.145560f, -0.103361f, -0.079142f, -0.060009f, + -0.043430f, -0.028293f, -0.013963f, 0.000000f, + 0.013963f, 0.028293f, 0.043430f, 0.060009f, + 0.079142f, 0.103361f, 0.145560f +}; + +// --- Argument structs --- + +struct ds4_metal_args_turbo_quantize { + int32_t ne00; // head_dim + int32_t ne01; // n_tok + int32_t ne02; // n_head + int32_t ne03; + int32_t n_rot; // RoPE dimension + uint64_t nb00; // src element stride + uint64_t nb01; // src row stride + uint64_t nb02; // src head stride + uint64_t nb03; + uint64_t nb0; // dst element stride (turbo block stride) + uint64_t nb1; // dst row stride + uint64_t nb2; // dst head stride + uint64_t nb3; +}; + +struct ds4_metal_args_turbo_dequant { + int32_t n_blocks; // number of turbo blocks + int32_t n_rot; // RoPE tail elements + int32_t n_rows; // total rows + uint64_t src_stride; // bytes per row in source (turbo blocks + rope tail) + uint64_t dst_stride; // bytes per row in destination (half values) +}; + +struct ds4_metal_args_turbo_select_dequant { + int32_t n_blocks; + int32_t n_rot; + int32_t n_comp; + int32_t top_k; + int32_t n_tokens; + uint64_t src_stride; + uint64_t dst_stride; + uint64_t topk_token_stride; +}; + +// --- FWHT-32: in-place butterfly + normalize (thread-safe, parallelized) --- +// Each thread handles its element for sign application, then cooperates on butterfly +// stages with barriers between each stage to avoid race conditions on threadgroup memory. +static void turbo_fwht_32(threadgroup float * x, uint tid) { + const float inv_sqrt_32 = 0.17677669529663688f; + + // signs1 (per-element, no race) + x[tid] *= turbo_signs_32[tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // butterfly stages with barriers + for (int h = 1; h < 32; h *= 2) { + int i = (int)tid; + int block = i / (h * 2); + int offset = i % (h * 2); + if (offset < h) { + int idx_a = block * h * 2 + offset; + int idx_b = idx_a + h; + float a = x[idx_a]; + float b = x[idx_b]; + x[idx_a] = a + b; + x[idx_b] = a - b; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // normalize + signs2 (sign array is self-inverse since +/-1) + x[tid] *= inv_sqrt_32 * turbo_signs_32[tid]; +} + +// --- Nearest centroid helpers --- +static int nearest_centroid_3bit(float val) { + if (val < turbo_mid_3bit[0]) return 0; + if (val < turbo_mid_3bit[1]) return 1; + if (val < turbo_mid_3bit[2]) return 2; + if (val < turbo_mid_3bit[3]) return 3; + if (val < turbo_mid_3bit[4]) return 4; + if (val < turbo_mid_3bit[5]) return 5; + if (val < turbo_mid_3bit[6]) return 6; + return 7; +} + +static int nearest_centroid_4bit(float val) { + if (val < turbo_mid_4bit[ 0]) return 0; + if (val < turbo_mid_4bit[ 1]) return 1; + if (val < turbo_mid_4bit[ 2]) return 2; + if (val < turbo_mid_4bit[ 3]) return 3; + if (val < turbo_mid_4bit[ 4]) return 4; + if (val < turbo_mid_4bit[ 5]) return 5; + if (val < turbo_mid_4bit[ 6]) return 6; + if (val < turbo_mid_4bit[ 7]) return 7; + if (val < turbo_mid_4bit[ 8]) return 8; + if (val < turbo_mid_4bit[ 9]) return 9; + if (val < turbo_mid_4bit[10]) return 10; + if (val < turbo_mid_4bit[11]) return 11; + if (val < turbo_mid_4bit[12]) return 12; + if (val < turbo_mid_4bit[13]) return 13; + if (val < turbo_mid_4bit[14]) return 14; + return 15; +} + +// ============================================================================ +// 3-bit quantize kernel +// ============================================================================ +// Grid: (n_tok * n_head, 1, 1), Threadgroup: (32, 1, 1) +// Each threadgroup processes one 32-element block. +// One row (head_dim=512, n_rot=64) = 14 blocks. RoPE tail stays in source. +kernel void kernel_turbo3_quantize_f32( + constant ds4_metal_args_turbo_quantize & args, + device const float * src, + device block_turbo3_0 * dst, + threadgroup float * scratch [[threadgroup(0)]], + uint block_row [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]]) { + + const int32_t head_dim = args.ne00; + const int32_t n_rot = args.n_rot; + const int32_t n_nope = head_dim - n_rot; + const int32_t n_blocks = n_nope / QK_TURBO; + + const int64_t n_rows = args.ne01 * args.ne02 * args.ne03; + const int64_t row = block_row / n_blocks; + const int64_t blk = block_row % n_blocks; + + if (row >= n_rows || blk >= n_blocks) return; + + const int64_t i1 = (row / args.ne02) % args.ne01; + const int64_t i2 = row % args.ne02; + const int64_t i3 = row / (args.ne01 * args.ne02); + + device const float * src_row = (device const float *)((device const char *)src + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device block_turbo3_0 * dst_blk = (device block_turbo3_0 *)((device char *)dst + row*args.nb1 + blk*sizeof(block_turbo3_0)); + + // Step 1: Load 32 elements into scratch, compute L2 norm + float v = src_row[blk * QK_TURBO + tid]; + scratch[tid] = v; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float norm_sq = 0.0f; + for (int i = 0; i < QK_TURBO; i++) { + norm_sq += scratch[i] * scratch[i]; + } + float blk_norm = sqrt(norm_sq); + float inv_norm = (blk_norm > 1e-10f) ? 1.0f / blk_norm : 0.0f; + + // Step 2: Normalize + scratch[tid] *= inv_norm; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Step 3: FWHT-32 rotation + turbo_fwht_32(scratch, tid); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Step 4: Quantize, then pack with a single writer per output byte. + float v_rot = scratch[tid]; + int idx = nearest_centroid_3bit(v_rot); + scratch[tid] = (float)idx; + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tid < QK_TURBO / 4) { + uchar packed = 0; + for (uint j = 0; j < 4; j++) { + uint elem = tid * 4 + j; + uint qidx = (uint)scratch[elem]; + packed |= (uchar)((qidx & 0x3u) << (j * 2)); + } + dst_blk->qs[tid] = packed; + } + if (tid < QK_TURBO / 8) { + uchar packed = 0; + for (uint j = 0; j < 8; j++) { + uint elem = tid * 8 + j; + uint qidx = (uint)scratch[elem]; + packed |= (uchar)(((qidx >> 2) & 0x1u) << j); + } + dst_blk->signs[tid] = packed; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Accumulate recon norm squared + scratch[tid] = turbo_centroids_3bit[idx] * turbo_centroids_3bit[idx]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Reduction for block recon norm + for (uint s = 16; s > 0; s >>= 1) { + if (tid < s) scratch[tid] += scratch[tid + s]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Step 5: Store corrected norm + if (tid == 0) { + float recon_norm = sqrt(scratch[0]); + float corrected = (recon_norm > 1e-10f) ? blk_norm / recon_norm : blk_norm; + dst_blk->norm = (half)corrected; + } + + if (blk == 0) { + device float * rope_dst = (device float *)((device char *)dst + row * args.nb1 + + n_blocks * sizeof(block_turbo3_0)); + for (int i = tid; i < n_rot; i += QK_TURBO) { + rope_dst[i] = src_row[n_nope + i]; + } + } +} + +// ============================================================================ +// 4-bit quantize kernel +// ============================================================================ +kernel void kernel_turbo4_quantize_f32( + constant ds4_metal_args_turbo_quantize & args, + device const float * src, + device block_turbo4_0 * dst, + threadgroup float * scratch [[threadgroup(0)]], + uint block_row [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]]) { + + const int32_t head_dim = args.ne00; + const int32_t n_rot = args.n_rot; + const int32_t n_nope = head_dim - n_rot; + const int32_t n_blocks = n_nope / QK_TURBO; + + const int64_t n_rows = args.ne01 * args.ne02 * args.ne03; + const int64_t row = block_row / n_blocks; + const int64_t blk = block_row % n_blocks; + + if (row >= n_rows || blk >= n_blocks) return; + + const int64_t i1 = (row / args.ne02) % args.ne01; + const int64_t i2 = row % args.ne02; + const int64_t i3 = row / (args.ne01 * args.ne02); + + device const float * src_row = (device const float *)((device const char *)src + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device block_turbo4_0 * dst_blk = (device block_turbo4_0 *)((device char *)dst + row*args.nb1 + blk*sizeof(block_turbo4_0)); + + // Step 1: Load, compute norm + float v = src_row[blk * QK_TURBO + tid]; + scratch[tid] = v; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float norm_sq = 0.0f; + for (int i = 0; i < QK_TURBO; i++) norm_sq += scratch[i] * scratch[i]; + float blk_norm = sqrt(norm_sq); + float inv_norm = (blk_norm > 1e-10f) ? 1.0f / blk_norm : 0.0f; + + // Step 2: Normalize + scratch[tid] *= inv_norm; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Step 3: FWHT-32 rotation + turbo_fwht_32(scratch, tid); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Step 4: 4-bit quantize and nibble pack + float v_rot = scratch[tid]; + int idx = nearest_centroid_4bit(v_rot); + + scratch[tid] = (float)idx; + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tid < QK_TURBO / 2) { + uint lo = (uint)scratch[tid * 2 + 0]; + uint hi = (uint)scratch[tid * 2 + 1]; + dst_blk->qs[tid] = (uchar)((lo & 0xFu) | ((hi & 0xFu) << 4)); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + scratch[tid] = turbo_centroids_4bit[idx] * turbo_centroids_4bit[idx]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = 16; s > 0; s >>= 1) { + if (tid < s) scratch[tid] += scratch[tid + s]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Step 5: Store corrected norm + if (tid == 0) { + float recon_norm = sqrt(scratch[0]); + float corrected = (recon_norm > 1e-10f) ? blk_norm / recon_norm : blk_norm; + dst_blk->norm = (half)corrected; + } + + if (blk == 0) { + device float * rope_dst = (device float *)((device char *)dst + row * args.nb1 + + n_blocks * sizeof(block_turbo4_0)); + for (int i = tid; i < n_rot; i += QK_TURBO) { + rope_dst[i] = src_row[n_nope + i]; + } + } +} + +// ============================================================================ +// Dequantize-to-f32 kernels (for feeding existing flash attention) +// ============================================================================ +// Output format is float32 in the original KV basis to match the existing +// attention dispatch. Grid: (n_rows * n_blocks, 1, 1), Threadgroup: (32, 1, 1). + +kernel void kernel_turbo3_dequant_f32( + constant ds4_metal_args_turbo_dequant & args, + device const block_turbo3_0 * src, + device float * dst, + threadgroup float * scratch [[threadgroup(0)]], + uint gid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]]) { + + const uint row = gid / (uint)args.n_blocks; + const uint blk = gid % (uint)args.n_blocks; + if ((int)row >= args.n_rows) return; + + device const block_turbo3_0 * src_row = (device const block_turbo3_0 *)((device const char *)src + row * args.src_stride); + device float * dst_row = (device float *)((device char *)dst + row * args.dst_stride); + + device const block_turbo3_0 * blk_ptr = &src_row[blk]; + float norm = (float)blk_ptr->norm; + + uchar q_byte = blk_ptr->qs[tid / 4]; + uchar s_byte = blk_ptr->signs[tid / 8]; + uchar low2 = (q_byte >> ((tid % 4) * 2)) & 0x3; + uchar hi1 = (s_byte >> (tid % 8)) & 0x1; + uchar idx = low2 | (hi1 << 2); + scratch[tid] = turbo_centroids_3bit[idx] * norm; + threadgroup_barrier(mem_flags::mem_threadgroup); + + turbo_fwht_32(scratch, tid); + threadgroup_barrier(mem_flags::mem_threadgroup); + dst_row[blk * QK_TURBO + tid] = scratch[tid]; + + if (blk == 0) { + device const float * rope_src = (device const float *)(src_row + args.n_blocks); + for (int i = tid; i < args.n_rot; i += QK_TURBO) { + dst_row[args.n_blocks * QK_TURBO + i] = rope_src[i]; + } + } +} + +kernel void kernel_turbo4_dequant_f32( + constant ds4_metal_args_turbo_dequant & args, + device const block_turbo4_0 * src, + device float * dst, + threadgroup float * scratch [[threadgroup(0)]], + uint gid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]]) { + + const uint row = gid / (uint)args.n_blocks; + const uint blk = gid % (uint)args.n_blocks; + if ((int)row >= args.n_rows) return; + + device const block_turbo4_0 * src_row = (device const block_turbo4_0 *)((device const char *)src + row * args.src_stride); + device float * dst_row = (device float *)((device char *)dst + row * args.dst_stride); + + device const block_turbo4_0 * blk_ptr = &src_row[blk]; + float norm = (float)blk_ptr->norm; + + uchar nibble = (blk_ptr->qs[tid / 2] >> ((tid % 2) * 4)) & 0xF; + scratch[tid] = turbo_centroids_4bit[nibble] * norm; + threadgroup_barrier(mem_flags::mem_threadgroup); + + turbo_fwht_32(scratch, tid); + threadgroup_barrier(mem_flags::mem_threadgroup); + dst_row[blk * QK_TURBO + tid] = scratch[tid]; + + if (blk == 0) { + device const float * rope_src = (device const float *)(src_row + args.n_blocks); + for (int i = tid; i < args.n_rot; i += QK_TURBO) { + dst_row[args.n_blocks * QK_TURBO + i] = rope_src[i]; + } + } +} + +kernel void kernel_turbo3_dequant_selected_f32( + constant ds4_metal_args_turbo_select_dequant & args, + device const block_turbo3_0 * src, + device const int32_t * topk, + device float * dst, + device int32_t * identity_topk, + threadgroup float * scratch [[threadgroup(0)]], + uint gid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]]) { + const uint block_count = (uint)args.n_blocks; + const uint sel = (gid / block_count) % (uint)args.top_k; + const uint token = gid / (block_count * (uint)args.top_k); + const uint blk = gid % block_count; + if ((int)token >= args.n_tokens) return; + + device const int32_t *row_topk = (device const int32_t *)((device const char *)topk + + (uint64_t)token * args.topk_token_stride); + const int32_t src_idx = row_topk[sel]; + if (src_idx < 0 || src_idx >= args.n_comp) return; + if (blk == 0 && tid == 0) { + identity_topk[(uint64_t)token * (uint)args.top_k + sel] = (int32_t)sel; + } + + device const block_turbo3_0 * src_row = + (device const block_turbo3_0 *)((device const char *)src + (uint64_t)(uint)src_idx * args.src_stride); + device float * dst_row = (device float *)((device char *)dst + + ((uint64_t)token * (uint)args.top_k + sel) * args.dst_stride); + + device const block_turbo3_0 * blk_ptr = &src_row[blk]; + float norm = (float)blk_ptr->norm; + + uchar q_byte = blk_ptr->qs[tid / 4]; + uchar s_byte = blk_ptr->signs[tid / 8]; + uchar low2 = (q_byte >> ((tid % 4) * 2)) & 0x3; + uchar hi1 = (s_byte >> (tid % 8)) & 0x1; + uchar idx = low2 | (hi1 << 2); + scratch[tid] = turbo_centroids_3bit[idx] * norm; + threadgroup_barrier(mem_flags::mem_threadgroup); + + turbo_fwht_32(scratch, tid); + threadgroup_barrier(mem_flags::mem_threadgroup); + dst_row[blk * QK_TURBO + tid] = scratch[tid]; + + if (blk == 0) { + device const float * rope_src = (device const float *)(src_row + args.n_blocks); + for (int i = tid; i < args.n_rot; i += QK_TURBO) { + dst_row[args.n_blocks * QK_TURBO + i] = rope_src[i]; + } + } +} + +kernel void kernel_turbo4_dequant_selected_f32( + constant ds4_metal_args_turbo_select_dequant & args, + device const block_turbo4_0 * src, + device const int32_t * topk, + device float * dst, + device int32_t * identity_topk, + threadgroup float * scratch [[threadgroup(0)]], + uint gid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]]) { + const uint block_count = (uint)args.n_blocks; + const uint sel = (gid / block_count) % (uint)args.top_k; + const uint token = gid / (block_count * (uint)args.top_k); + const uint blk = gid % block_count; + if ((int)token >= args.n_tokens) return; + + device const int32_t *row_topk = (device const int32_t *)((device const char *)topk + + (uint64_t)token * args.topk_token_stride); + const int32_t src_idx = row_topk[sel]; + if (src_idx < 0 || src_idx >= args.n_comp) return; + if (blk == 0 && tid == 0) { + identity_topk[(uint64_t)token * (uint)args.top_k + sel] = (int32_t)sel; + } + + device const block_turbo4_0 * src_row = + (device const block_turbo4_0 *)((device const char *)src + (uint64_t)(uint)src_idx * args.src_stride); + device float * dst_row = (device float *)((device char *)dst + + ((uint64_t)token * (uint)args.top_k + sel) * args.dst_stride); + + device const block_turbo4_0 * blk_ptr = &src_row[blk]; + float norm = (float)blk_ptr->norm; + + uchar nibble = (blk_ptr->qs[tid / 2] >> ((tid % 2) * 4)) & 0xF; + scratch[tid] = turbo_centroids_4bit[nibble] * norm; + threadgroup_barrier(mem_flags::mem_threadgroup); + + turbo_fwht_32(scratch, tid); + threadgroup_barrier(mem_flags::mem_threadgroup); + dst_row[blk * QK_TURBO + tid] = scratch[tid]; + + if (blk == 0) { + device const float * rope_src = (device const float *)(src_row + args.n_blocks); + for (int i = tid; i < args.n_rot; i += QK_TURBO) { + dst_row[args.n_blocks * QK_TURBO + i] = rope_src[i]; + } + } +} + +kernel void kernel_turbo3_dequant_selected_f16( + constant ds4_metal_args_turbo_select_dequant & args, + device const block_turbo3_0 * src, + device const int32_t * topk, + device half * dst, + device int32_t * identity_topk, + threadgroup float * scratch [[threadgroup(0)]], + uint gid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]]) { + const uint block_count = (uint)args.n_blocks; + const uint sel = (gid / block_count) % (uint)args.top_k; + const uint token = gid / (block_count * (uint)args.top_k); + const uint blk = gid % block_count; + if ((int)token >= args.n_tokens) return; + + device const int32_t *row_topk = (device const int32_t *)((device const char *)topk + + (uint64_t)token * args.topk_token_stride); + const int32_t src_idx = row_topk[sel]; + if (src_idx < 0 || src_idx >= args.n_comp) return; + if (blk == 0 && tid == 0) { + identity_topk[(uint64_t)token * (uint)args.top_k + sel] = (int32_t)sel; + } + + device const block_turbo3_0 * src_row = + (device const block_turbo3_0 *)((device const char *)src + (uint64_t)(uint)src_idx * args.src_stride); + device half * dst_row = (device half *)((device char *)dst + + ((uint64_t)token * (uint)args.top_k + sel) * args.dst_stride); + + device const block_turbo3_0 * blk_ptr = &src_row[blk]; + float norm = (float)blk_ptr->norm; + + uchar q_byte = blk_ptr->qs[tid / 4]; + uchar s_byte = blk_ptr->signs[tid / 8]; + uchar low2 = (q_byte >> ((tid % 4) * 2)) & 0x3; + uchar hi1 = (s_byte >> (tid % 8)) & 0x1; + uchar idx = low2 | (hi1 << 2); + scratch[tid] = turbo_centroids_3bit[idx] * norm; + threadgroup_barrier(mem_flags::mem_threadgroup); + + turbo_fwht_32(scratch, tid); + threadgroup_barrier(mem_flags::mem_threadgroup); + dst_row[blk * QK_TURBO + tid] = (half)scratch[tid]; + + if (blk == 0) { + device const float * rope_src = (device const float *)(src_row + args.n_blocks); + for (int i = tid; i < args.n_rot; i += QK_TURBO) { + dst_row[args.n_blocks * QK_TURBO + i] = (half)rope_src[i]; + } + } +} + +kernel void kernel_turbo4_dequant_selected_f16( + constant ds4_metal_args_turbo_select_dequant & args, + device const block_turbo4_0 * src, + device const int32_t * topk, + device half * dst, + device int32_t * identity_topk, + threadgroup float * scratch [[threadgroup(0)]], + uint gid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]]) { + const uint block_count = (uint)args.n_blocks; + const uint sel = (gid / block_count) % (uint)args.top_k; + const uint token = gid / (block_count * (uint)args.top_k); + const uint blk = gid % block_count; + if ((int)token >= args.n_tokens) return; + + device const int32_t *row_topk = (device const int32_t *)((device const char *)topk + + (uint64_t)token * args.topk_token_stride); + const int32_t src_idx = row_topk[sel]; + if (src_idx < 0 || src_idx >= args.n_comp) return; + if (blk == 0 && tid == 0) { + identity_topk[(uint64_t)token * (uint)args.top_k + sel] = (int32_t)sel; + } + + device const block_turbo4_0 * src_row = + (device const block_turbo4_0 *)((device const char *)src + (uint64_t)(uint)src_idx * args.src_stride); + device half * dst_row = (device half *)((device char *)dst + + ((uint64_t)token * (uint)args.top_k + sel) * args.dst_stride); + + device const block_turbo4_0 * blk_ptr = &src_row[blk]; + float norm = (float)blk_ptr->norm; + + uchar nibble = (blk_ptr->qs[tid / 2] >> ((tid % 2) * 4)) & 0xF; + scratch[tid] = turbo_centroids_4bit[nibble] * norm; + threadgroup_barrier(mem_flags::mem_threadgroup); + + turbo_fwht_32(scratch, tid); + threadgroup_barrier(mem_flags::mem_threadgroup); + dst_row[blk * QK_TURBO + tid] = (half)scratch[tid]; + + if (blk == 0) { + device const float * rope_src = (device const float *)(src_row + args.n_blocks); + for (int i = tid; i < args.n_rot; i += QK_TURBO) { + dst_row[args.n_blocks * QK_TURBO + i] = (half)rope_src[i]; + } + } +} + +// ============================================================================ +// Fused indexed attention row staging +// ============================================================================ +// These kernels keep the existing PolarQuant/WHT cache format but avoid the +// decode-time full-cache dequantization pass. Only the selected top-k rows are +// expanded into the indexed attention threadgroup tile. + +static inline void dsv4_indexed_mixed_attention_comp_f16_impl( + constant ds4_metal_args_dsv4_indexed_attention & args, + device const char *q, + device const char *raw_kv, + device const char *comp_kv, + device const char *topk, + device const char *sinks, + device char *dst, + threadgroup float4 *kv_shared, + uint2 tgpig, + ushort tid, + ushort lane, + ushort sg, + bool rb4) { + const uint token = tgpig.x; + const uint head = tgpig.y * 8u + (uint)sg; + if (token >= args.n_tokens || head >= args.n_head) { + return; + } + + device const float4 *q4 = (device const float4 *)(q + + (uint64_t)token * args.q_token_stride + + (uint64_t)head * args.q_head_stride); + const half4 q0 = (half4)q4[lane + 0]; + const half4 q1 = (half4)q4[lane + 32]; + const half4 q2 = (half4)q4[lane + 64]; + const half4 q3 = (half4)q4[lane + 96]; + + float M = -FLT_MAX/2.0f; + float S = 0.0f; + float4 o0 = 0.0f; + float4 o1 = 0.0f; + float4 o2 = 0.0f; + float4 o3 = 0.0f; + + const uint qpos = args.pos0 + token; + const uint last_pos = args.pos0 + args.n_tokens - 1u; + const uint first_raw_pos = last_pos + 1u - args.n_raw; + const uint raw_last_pos = first_raw_pos + args.n_raw - 1u; + const uint window_first = (args.window != 0u && qpos + 1u > args.window) ? + qpos + 1u - args.window : 0u; + uint first = max(first_raw_pos, window_first); + uint last = min(qpos, raw_last_pos); + + if (first <= last) { + if (rb4) { + for (uint pos0 = first; pos0 <= last; pos0 += 4u) { + const uint n_rows = min(4u, last - pos0 + 1u); + for (uint off = (uint)tid; off < n_rows * 128u; off += 256u) { + const uint r = off >> 7; + const uint c = off & 127u; + const uint logical = pos0 + r - first_raw_pos; + const uint row = (args.raw_start + logical) % args.raw_cap; + device const float4 *src = (device const float4 *)(raw_kv + + (uint64_t)row * args.raw_row_stride); + kv_shared[off] = src[c]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint r = 0; r < n_rows; r++) { + dsv4_attend_shared_f32_row_as_f16_at(kv_shared, + r, + q0, q1, q2, q3, + args.scale, + lane, + M, S, + o0, o1, o2, o3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } else { + for (uint pos = first; pos <= last; pos++) { + const uint logical = pos - first_raw_pos; + const uint row = (args.raw_start + logical) % args.raw_cap; + device const float4 *src = (device const float4 *)(raw_kv + + (uint64_t)row * args.raw_row_stride); + if (tid < 128) kv_shared[tid] = src[tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + dsv4_attend_shared_f32_row_as_f16(kv_shared, + q0, q1, q2, q3, + args.scale, + lane, + M, S, + o0, o1, o2, o3); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + } + + uint visible = (qpos + 1u) / args.ratio; + visible = min(visible, args.n_comp); + device const int32_t *row_topk = (device const int32_t *)(topk + + (uint64_t)token * args.topk_token_stride); + if (rb4) { + bool stop = false; + for (uint i = 0; i < args.top_k && !stop; i += 4u) { + uint rows[4]; + uint n_rows = 0; + for (uint j = 0; j < 4u && i + j < args.top_k; j++) { + const int32_t idx = row_topk[i + j]; + if (idx < 0) { + continue; + } + if ((uint)idx >= visible) { + stop = true; + break; + } + rows[n_rows++] = (uint)idx; + } + if (n_rows == 0) { + continue; + } + for (uint off = (uint)tid; off < n_rows * 128u; off += 256u) { + const uint r = off >> 7; + const uint c = off & 127u; + device const half4 *src = (device const half4 *)(comp_kv + + (uint64_t)rows[r] * args.comp_row_stride); + kv_shared[off] = (float4)src[c]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint r = 0; r < n_rows; r++) { + dsv4_attend_shared_f32_row_as_f16_at(kv_shared, + r, + q0, q1, q2, q3, + args.scale, + lane, + M, S, + o0, o1, o2, o3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } else { + for (uint i = 0; i < args.top_k; i++) { + const int32_t idx = row_topk[i]; + if (idx < 0) { + continue; + } + if ((uint)idx >= visible) { + break; + } + device const half4 *src = (device const half4 *)(comp_kv + + (uint64_t)(uint)idx * args.comp_row_stride); + if (tid < 128) kv_shared[tid] = (float4)src[tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + dsv4_attend_shared_f32_row_as_f16(kv_shared, + q0, q1, q2, q3, + args.scale, + lane, + M, S, + o0, o1, o2, o3); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + dsv4_attend_sink(((device const float *)sinks)[head], M, S, o0, o1, o2, o3); + + const float inv_s = S == 0.0f ? 0.0f : 1.0f/S; + device float4 *dst4 = (device float4 *)(dst + + (uint64_t)token * args.dst_token_stride + + (uint64_t)head * args.dst_head_stride); + dst4[lane + 0] = o0 * inv_s; + dst4[lane + 32] = o1 * inv_s; + dst4[lane + 64] = o2 * inv_s; + dst4[lane + 96] = o3 * inv_s; +} + +kernel void kernel_dsv4_indexed_mixed_attention_heads8_comp_f16( + constant ds4_metal_args_dsv4_indexed_attention & args, + device const char *q, + device const char *raw_kv, + device const char *comp_kv, + device const char *topk, + device const char *sinks, + device char *dst, + threadgroup float4 *kv_shared [[threadgroup(0)]], + uint2 tgpig [[threadgroup_position_in_grid]], + ushort tid [[thread_index_in_threadgroup]], + ushort lane [[thread_index_in_simdgroup]], + ushort sg [[simdgroup_index_in_threadgroup]]) { + dsv4_indexed_mixed_attention_comp_f16_impl(args, q, raw_kv, comp_kv, + topk, sinks, dst, kv_shared, + tgpig, tid, lane, sg, false); +} + +kernel void kernel_dsv4_indexed_mixed_attention_heads8_rb4_comp_f16( + constant ds4_metal_args_dsv4_indexed_attention & args, + device const char *q, + device const char *raw_kv, + device const char *comp_kv, + device const char *topk, + device const char *sinks, + device char *dst, + threadgroup float4 *kv_shared [[threadgroup(0)]], + uint2 tgpig [[threadgroup_position_in_grid]], + ushort tid [[thread_index_in_threadgroup]], + ushort lane [[thread_index_in_simdgroup]], + ushort sg [[simdgroup_index_in_threadgroup]]) { + dsv4_indexed_mixed_attention_comp_f16_impl(args, q, raw_kv, comp_kv, + topk, sinks, dst, kv_shared, + tgpig, tid, lane, sg, true); +} + +static void turbo_fwht_32_simd(threadgroup float * x, uint tid) { + const float inv_sqrt_32 = 0.17677669529663688f; + + x[tid] *= turbo_signs_32[tid]; + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (int h = 1; h < 32; h *= 2) { + int i = (int)tid; + int block = i / (h * 2); + int offset = i % (h * 2); + if (offset < h) { + int idx_a = block * h * 2 + offset; + int idx_b = idx_a + h; + float a = x[idx_a]; + float b = x[idx_b]; + x[idx_a] = a + b; + x[idx_b] = a - b; + } + simdgroup_barrier(mem_flags::mem_threadgroup); + } + + x[tid] *= inv_sqrt_32 * turbo_signs_32[tid]; +} + +static inline void dsv4_turbo3_dequant_shared_row( + constant ds4_metal_args_dsv4_indexed_attention & args, + device const char *comp_kv, + uint row, + threadgroup float4 *dst4, + threadgroup float *scratch, + ushort lane, + ushort sg) { + const uint n_blocks = args.turbo_n_blocks; + const uint n_rot = args.turbo_n_rot; + const uint n_nope = n_blocks * QK_TURBO; + device const block_turbo3_0 *src_row = + (device const block_turbo3_0 *)(comp_kv + (uint64_t)row * args.comp_row_stride); + threadgroup float *dst = (threadgroup float *)dst4; + threadgroup float *s = scratch + (uint)sg * QK_TURBO; + + for (uint blk = (uint)sg; blk < n_blocks; blk += 8u) { + device const block_turbo3_0 *blk_ptr = &src_row[blk]; + const uint t = (uint)lane; + uchar q_byte = blk_ptr->qs[t / 4u]; + uchar s_byte = blk_ptr->signs[t / 8u]; + uchar low2 = (q_byte >> ((t % 4u) * 2u)) & 0x3; + uchar hi1 = (s_byte >> (t % 8u)) & 0x1; + uchar idx = low2 | (hi1 << 2); + s[t] = turbo_centroids_3bit[idx] * (float)blk_ptr->norm; + simdgroup_barrier(mem_flags::mem_threadgroup); + + turbo_fwht_32_simd(s, t); + simdgroup_barrier(mem_flags::mem_threadgroup); + dst[blk * QK_TURBO + t] = s[t]; + } + + device const float *rope_src = (device const float *)(src_row + n_blocks); + for (uint i = (uint)sg * QK_TURBO + (uint)lane; i < n_rot; i += 8u * QK_TURBO) { + dst[n_nope + i] = rope_src[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); +} + +static inline void dsv4_turbo4_dequant_shared_row( + constant ds4_metal_args_dsv4_indexed_attention & args, + device const char *comp_kv, + uint row, + threadgroup float4 *dst4, + threadgroup float *scratch, + ushort lane, + ushort sg) { + const uint n_blocks = args.turbo_n_blocks; + const uint n_rot = args.turbo_n_rot; + const uint n_nope = n_blocks * QK_TURBO; + device const block_turbo4_0 *src_row = + (device const block_turbo4_0 *)(comp_kv + (uint64_t)row * args.comp_row_stride); + threadgroup float *dst = (threadgroup float *)dst4; + threadgroup float *s = scratch + (uint)sg * QK_TURBO; + + for (uint blk = (uint)sg; blk < n_blocks; blk += 8u) { + device const block_turbo4_0 *blk_ptr = &src_row[blk]; + const uint t = (uint)lane; + uchar nibble = (blk_ptr->qs[t / 2u] >> ((t % 2u) * 4u)) & 0xF; + s[t] = turbo_centroids_4bit[nibble] * (float)blk_ptr->norm; + simdgroup_barrier(mem_flags::mem_threadgroup); + + turbo_fwht_32_simd(s, t); + simdgroup_barrier(mem_flags::mem_threadgroup); + dst[blk * QK_TURBO + t] = s[t]; + } + + device const float *rope_src = (device const float *)(src_row + n_blocks); + for (uint i = (uint)sg * QK_TURBO + (uint)lane; i < n_rot; i += 8u * QK_TURBO) { + dst[n_nope + i] = rope_src[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); +} + +static inline void dsv4_turbo_dequant_shared_row( + constant ds4_metal_args_dsv4_indexed_attention & args, + device const char *comp_kv, + uint row, + threadgroup float4 *dst4, + threadgroup float *scratch, + ushort lane, + ushort sg, + bool turbo4) { + if (turbo4) { + dsv4_turbo4_dequant_shared_row(args, comp_kv, row, dst4, scratch, lane, sg); + } else { + dsv4_turbo3_dequant_shared_row(args, comp_kv, row, dst4, scratch, lane, sg); + } +} + +static inline void dsv4_indexed_mixed_attention_turbo_impl( + constant ds4_metal_args_dsv4_indexed_attention & args, + device const char *q, + device const char *raw_kv, + device const char *comp_kv, + device const char *topk, + device const char *sinks, + device char *dst, + threadgroup float4 *kv_shared, + uint2 tgpig, + ushort tid, + ushort lane, + ushort sg, + bool turbo4, + bool rb4) { + const uint token = tgpig.x; + const uint head = tgpig.y * 8u + (uint)sg; + if (token >= args.n_tokens || head >= args.n_head) { + return; + } + + const uint shared_rows = rb4 ? 4u : 1u; + threadgroup float *scratch = (threadgroup float *)(kv_shared + shared_rows * 128u); + + device const float4 *q4 = (device const float4 *)(q + + (uint64_t)token * args.q_token_stride + + (uint64_t)head * args.q_head_stride); + const half4 q0 = (half4)q4[lane + 0]; + const half4 q1 = (half4)q4[lane + 32]; + const half4 q2 = (half4)q4[lane + 64]; + const half4 q3 = (half4)q4[lane + 96]; + + float M = -FLT_MAX/2.0f; + float S = 0.0f; + float4 o0 = 0.0f; + float4 o1 = 0.0f; + float4 o2 = 0.0f; + float4 o3 = 0.0f; + + const uint qpos = args.pos0 + token; + const uint last_pos = args.pos0 + args.n_tokens - 1u; + const uint first_raw_pos = last_pos + 1u - args.n_raw; + const uint raw_last_pos = first_raw_pos + args.n_raw - 1u; + const uint window_first = (args.window != 0u && qpos + 1u > args.window) ? + qpos + 1u - args.window : 0u; + uint first = max(first_raw_pos, window_first); + uint last = min(qpos, raw_last_pos); + + if (first <= last) { + if (rb4) { + for (uint pos0 = first; pos0 <= last; pos0 += 4u) { + const uint n_rows = min(4u, last - pos0 + 1u); + for (uint off = (uint)tid; off < n_rows * 128u; off += 256u) { + const uint r = off >> 7; + const uint c = off & 127u; + const uint logical = pos0 + r - first_raw_pos; + const uint row = (args.raw_start + logical) % args.raw_cap; + device const float4 *src = (device const float4 *)(raw_kv + + (uint64_t)row * args.raw_row_stride); + kv_shared[off] = src[c]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint r = 0; r < n_rows; r++) { + dsv4_attend_shared_f32_row_as_f16_at(kv_shared, + r, + q0, q1, q2, q3, + args.scale, + lane, + M, S, + o0, o1, o2, o3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } else { + for (uint pos = first; pos <= last; pos++) { + const uint logical = pos - first_raw_pos; + const uint row = (args.raw_start + logical) % args.raw_cap; + device const float4 *src = (device const float4 *)(raw_kv + + (uint64_t)row * args.raw_row_stride); + if (tid < 128) kv_shared[tid] = src[tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + dsv4_attend_shared_f32_row_as_f16(kv_shared, + q0, q1, q2, q3, + args.scale, + lane, + M, S, + o0, o1, o2, o3); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + } + + uint visible = (qpos + 1u) / args.ratio; + visible = min(visible, args.n_comp); + device const int32_t *row_topk = (device const int32_t *)(topk + + (uint64_t)token * args.topk_token_stride); + if (rb4) { + bool stop = false; + for (uint i = 0; i < args.top_k && !stop; i += 4u) { + uint rows[4]; + uint n_rows = 0; + for (uint j = 0; j < 4u && i + j < args.top_k; j++) { + const int32_t idx = row_topk[i + j]; + if (idx < 0) { + continue; + } + if ((uint)idx >= visible) { + stop = true; + break; + } + rows[n_rows++] = (uint)idx; + } + if (n_rows == 0) { + continue; + } + for (uint r = 0; r < n_rows; r++) { + dsv4_turbo_dequant_shared_row(args, + comp_kv, + rows[r], + kv_shared + r * 128u, + scratch, + lane, + sg, + turbo4); + } + for (uint r = 0; r < n_rows; r++) { + dsv4_attend_shared_f32_row_as_f16_at(kv_shared, + r, + q0, q1, q2, q3, + args.scale, + lane, + M, S, + o0, o1, o2, o3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } else { + for (uint i = 0; i < args.top_k; i++) { + const int32_t idx = row_topk[i]; + if (idx < 0) { + continue; + } + if ((uint)idx >= visible) { + break; + } + dsv4_turbo_dequant_shared_row(args, + comp_kv, + (uint)idx, + kv_shared, + scratch, + lane, + sg, + turbo4); + dsv4_attend_shared_f32_row_as_f16(kv_shared, + q0, q1, q2, q3, + args.scale, + lane, + M, S, + o0, o1, o2, o3); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + dsv4_attend_sink(((device const float *)sinks)[head], M, S, o0, o1, o2, o3); + + const float inv_s = S == 0.0f ? 0.0f : 1.0f/S; + device float4 *dst4 = (device float4 *)(dst + + (uint64_t)token * args.dst_token_stride + + (uint64_t)head * args.dst_head_stride); + dst4[lane + 0] = o0 * inv_s; + dst4[lane + 32] = o1 * inv_s; + dst4[lane + 64] = o2 * inv_s; + dst4[lane + 96] = o3 * inv_s; +} + +kernel void kernel_dsv4_indexed_mixed_attention_heads8_turbo3( + constant ds4_metal_args_dsv4_indexed_attention & args, + device const char *q, + device const char *raw_kv, + device const char *comp_kv, + device const char *topk, + device const char *sinks, + device char *dst, + threadgroup float4 *kv_shared [[threadgroup(0)]], + uint2 tgpig [[threadgroup_position_in_grid]], + ushort tid [[thread_index_in_threadgroup]], + ushort lane [[thread_index_in_simdgroup]], + ushort sg [[simdgroup_index_in_threadgroup]]) { + dsv4_indexed_mixed_attention_turbo_impl(args, q, raw_kv, comp_kv, topk, + sinks, dst, kv_shared, tgpig, + tid, lane, sg, false, false); +} + +kernel void kernel_dsv4_indexed_mixed_attention_heads8_turbo4( + constant ds4_metal_args_dsv4_indexed_attention & args, + device const char *q, + device const char *raw_kv, + device const char *comp_kv, + device const char *topk, + device const char *sinks, + device char *dst, + threadgroup float4 *kv_shared [[threadgroup(0)]], + uint2 tgpig [[threadgroup_position_in_grid]], + ushort tid [[thread_index_in_threadgroup]], + ushort lane [[thread_index_in_simdgroup]], + ushort sg [[simdgroup_index_in_threadgroup]]) { + dsv4_indexed_mixed_attention_turbo_impl(args, q, raw_kv, comp_kv, topk, + sinks, dst, kv_shared, tgpig, + tid, lane, sg, true, false); +} + +kernel void kernel_dsv4_indexed_mixed_attention_heads8_rb4_turbo3( + constant ds4_metal_args_dsv4_indexed_attention & args, + device const char *q, + device const char *raw_kv, + device const char *comp_kv, + device const char *topk, + device const char *sinks, + device char *dst, + threadgroup float4 *kv_shared [[threadgroup(0)]], + uint2 tgpig [[threadgroup_position_in_grid]], + ushort tid [[thread_index_in_threadgroup]], + ushort lane [[thread_index_in_simdgroup]], + ushort sg [[simdgroup_index_in_threadgroup]]) { + dsv4_indexed_mixed_attention_turbo_impl(args, q, raw_kv, comp_kv, topk, + sinks, dst, kv_shared, tgpig, + tid, lane, sg, false, true); +} + +kernel void kernel_dsv4_indexed_mixed_attention_heads8_rb4_turbo4( + constant ds4_metal_args_dsv4_indexed_attention & args, + device const char *q, + device const char *raw_kv, + device const char *comp_kv, + device const char *topk, + device const char *sinks, + device char *dst, + threadgroup float4 *kv_shared [[threadgroup(0)]], + uint2 tgpig [[threadgroup_position_in_grid]], + ushort tid [[thread_index_in_threadgroup]], + ushort lane [[thread_index_in_simdgroup]], + ushort sg [[simdgroup_index_in_threadgroup]]) { + dsv4_indexed_mixed_attention_turbo_impl(args, q, raw_kv, comp_kv, topk, + sinks, dst, kv_shared, tgpig, + tid, lane, sg, true, true); +} diff --git a/tests/ds4_test.c b/tests/ds4_test.c index 959367c2..d6e8dff2 100644 --- a/tests/ds4_test.c +++ b/tests/ds4_test.c @@ -150,6 +150,314 @@ static void test_metal_f16_matvec_fast_nr0_4(void) { free(weights_raw); } +static void test_metal_turbo_quant_roundtrip_one(bool turbo4) { + const uint32_t n_tok = 3; + const uint32_t n_head = 2; + const uint32_t n_rows = n_tok * n_head; + const uint32_t head_dim = 96; + const uint32_t n_rot = 32; + const uint32_t n_nope = head_dim - n_rot; + const uint32_t n_blocks = n_nope / 32; + const uint64_t block_bytes = turbo4 ? DS4_TURBO4_BLOCK_BYTES : DS4_TURBO3_BLOCK_BYTES; + const uint64_t row_bytes = (uint64_t)n_blocks * block_bytes + (uint64_t)n_rot * sizeof(float); + const uint64_t src_bytes = (uint64_t)n_rows * head_dim * sizeof(float); + const uint64_t q_bytes = (uint64_t)n_rows * row_bytes; + + ds4_gpu_tensor *src = ds4_gpu_tensor_alloc(src_bytes); + ds4_gpu_tensor *q = ds4_gpu_tensor_alloc(q_bytes); + ds4_gpu_tensor *out = ds4_gpu_tensor_alloc(src_bytes); + TEST_ASSERT(src != NULL); + TEST_ASSERT(q != NULL); + TEST_ASSERT(out != NULL); + if (!src || !q || !out) { + ds4_gpu_tensor_free(out); + ds4_gpu_tensor_free(q); + ds4_gpu_tensor_free(src); + return; + } + + float *src_host = malloc((size_t)src_bytes); + float *out_host = malloc((size_t)src_bytes); + TEST_ASSERT(src_host != NULL); + TEST_ASSERT(out_host != NULL); + if (!src_host || !out_host) { + free(out_host); + free(src_host); + ds4_gpu_tensor_free(out); + ds4_gpu_tensor_free(q); + ds4_gpu_tensor_free(src); + return; + } + + for (uint32_t r = 0; r < n_rows; r++) { + for (uint32_t i = 0; i < head_dim; i++) { + float v = (float)((int)((r * 17u + i * 13u) % 41u) - 20) / 21.0f; + if (i >= n_nope) v = 1000.0f + (float)(r * 100u + i); + src_host[(uint64_t)r * head_dim + i] = v; + } + } + + TEST_ASSERT(ds4_gpu_tensor_write(src, 0, src_host, src_bytes) != 0); + if (turbo4) { + TEST_ASSERT(ds4_gpu_turbo4_kv_quantize_tensor(q, src, n_tok, n_head, head_dim, n_rot) != 0); + TEST_ASSERT(ds4_gpu_turbo4_dequant_f32_tensor(out, q, n_blocks, n_rot, n_rows) != 0); + } else { + TEST_ASSERT(ds4_gpu_turbo3_kv_quantize_tensor(q, src, n_tok, n_head, head_dim, n_rot) != 0); + TEST_ASSERT(ds4_gpu_turbo3_dequant_f32_tensor(out, q, n_blocks, n_rot, n_rows) != 0); + } + TEST_ASSERT(ds4_gpu_tensor_read(out, 0, out_host, src_bytes) != 0); + + double prefix_ss = 0.0; + double err_ss = 0.0; + float max_tail = 0.0f; + for (uint32_t r = 0; r < n_rows; r++) { + for (uint32_t i = 0; i < head_dim; i++) { + const uint64_t off = (uint64_t)r * head_dim + i; + TEST_ASSERT(isfinite(out_host[off])); + if (i < n_nope) { + const double d = (double)out_host[off] - (double)src_host[off]; + err_ss += d * d; + prefix_ss += (double)src_host[off] * (double)src_host[off]; + } else { + const float err = fabsf(out_host[off] - src_host[off]); + if (err > max_tail) max_tail = err; + } + } + } + const double rel_rms = sqrt(err_ss / prefix_ss); + TEST_ASSERT(max_tail == 0.0f); + TEST_ASSERT(rel_rms < (turbo4 ? 0.55 : 0.75)); + + free(out_host); + free(src_host); + ds4_gpu_tensor_free(out); + ds4_gpu_tensor_free(q); + ds4_gpu_tensor_free(src); +} + +static void test_metal_turbo_quant_roundtrip(void) { + test_metal_turbo_quant_roundtrip_one(false); + test_metal_turbo_quant_roundtrip_one(true); +} + +static void test_metal_turbo_indexed_attention_one(bool turbo4) { + const uint32_t n_tokens = 1; + const uint32_t n_head = 8; + const uint32_t head_dim = 512; + const uint32_t n_rot = 64; + const uint32_t n_blocks = (head_dim - n_rot) / 32; + const uint32_t n_raw = 4; + const uint32_t raw_cap = 4; + const uint32_t n_comp = 8; + const uint32_t top_k = 4; + const uint64_t row_bytes = (uint64_t)head_dim * sizeof(float); + const uint64_t q_bytes = (uint64_t)n_tokens * n_head * row_bytes; + const uint64_t raw_bytes = (uint64_t)raw_cap * row_bytes; + const uint64_t comp_bytes = (uint64_t)n_comp * row_bytes; + const uint64_t turbo_block_bytes = turbo4 ? DS4_TURBO4_BLOCK_BYTES : DS4_TURBO3_BLOCK_BYTES; + const uint64_t turbo_row_bytes = (uint64_t)n_blocks * turbo_block_bytes + (uint64_t)n_rot * sizeof(float); + const uint64_t turbo_bytes = (uint64_t)n_comp * turbo_row_bytes; + const uint64_t selected_f16_bytes = (uint64_t)top_k * head_dim * sizeof(uint16_t); + const uint64_t heads_bytes = q_bytes; + const uint64_t sinks_alloc = test_round_up_u64((uint64_t)n_head * sizeof(float), + (uint64_t)getpagesize()); + + ds4_gpu_tensor *q = ds4_gpu_tensor_alloc(q_bytes); + ds4_gpu_tensor *raw = ds4_gpu_tensor_alloc(raw_bytes); + ds4_gpu_tensor *comp_src = ds4_gpu_tensor_alloc(comp_bytes); + ds4_gpu_tensor *comp_deq = ds4_gpu_tensor_alloc(comp_bytes); + ds4_gpu_tensor *comp_turbo = ds4_gpu_tensor_alloc(turbo_bytes); + ds4_gpu_tensor *comp_selected_deq = ds4_gpu_tensor_alloc((uint64_t)top_k * row_bytes); + ds4_gpu_tensor *comp_selected_deq_f16 = ds4_gpu_tensor_alloc(selected_f16_bytes); + ds4_gpu_tensor *topk = ds4_gpu_tensor_alloc((uint64_t)top_k * sizeof(int32_t)); + ds4_gpu_tensor *identity_topk = ds4_gpu_tensor_alloc((uint64_t)top_k * sizeof(int32_t)); + ds4_gpu_tensor *heads_ref = ds4_gpu_tensor_alloc(heads_bytes); + ds4_gpu_tensor *heads_turbo = ds4_gpu_tensor_alloc(heads_bytes); + ds4_gpu_tensor *heads_selected = ds4_gpu_tensor_alloc(heads_bytes); + ds4_gpu_tensor *heads_selected_f16 = ds4_gpu_tensor_alloc(heads_bytes); + TEST_ASSERT(q && raw && comp_src && comp_deq && comp_turbo && comp_selected_deq && + comp_selected_deq_f16 && topk && identity_topk && heads_ref && + heads_turbo && heads_selected && heads_selected_f16); + + float *q_host = malloc((size_t)q_bytes); + float *raw_host = malloc((size_t)raw_bytes); + float *comp_host = malloc((size_t)comp_bytes); + float *ref_host = malloc((size_t)heads_bytes); + float *turbo_host = malloc((size_t)heads_bytes); + void *sinks_raw = NULL; + TEST_ASSERT(q_host && raw_host && comp_host && ref_host && turbo_host); + TEST_ASSERT(posix_memalign(&sinks_raw, (size_t)getpagesize(), (size_t)sinks_alloc) == 0); + TEST_ASSERT(sinks_raw != NULL); + + if (q && raw && comp_src && comp_deq && comp_turbo && comp_selected_deq && + comp_selected_deq_f16 && topk && identity_topk && heads_ref && + heads_turbo && heads_selected && heads_selected_f16 && + q_host && raw_host && comp_host && ref_host && turbo_host && sinks_raw) { + for (uint32_t h = 0; h < n_head; h++) { + for (uint32_t i = 0; i < head_dim; i++) { + q_host[(uint64_t)h * head_dim + i] = + 0.25f * sinf((float)((h + 1u) * (i + 3u)) * 0.013f); + } + } + for (uint32_t r = 0; r < raw_cap; r++) { + for (uint32_t i = 0; i < head_dim; i++) { + raw_host[(uint64_t)r * head_dim + i] = + 0.20f * cosf((float)((r + 2u) * (i + 5u)) * 0.011f); + } + } + for (uint32_t r = 0; r < n_comp; r++) { + for (uint32_t i = 0; i < head_dim; i++) { + comp_host[(uint64_t)r * head_dim + i] = + 0.18f * sinf((float)((r + 4u) * (i + 7u)) * 0.009f); + } + } + float *sinks = (float *)sinks_raw; + memset(sinks, 0, (size_t)sinks_alloc); + for (uint32_t h = 0; h < n_head; h++) { + sinks[h] = -0.35f + 0.01f * (float)h; + } + const int32_t topk_host[4] = { 7, 3, 1, 5 }; + + TEST_ASSERT(ds4_gpu_tensor_write(q, 0, q_host, q_bytes) != 0); + TEST_ASSERT(ds4_gpu_tensor_write(raw, 0, raw_host, raw_bytes) != 0); + TEST_ASSERT(ds4_gpu_tensor_write(comp_src, 0, comp_host, comp_bytes) != 0); + TEST_ASSERT(ds4_gpu_tensor_write(topk, 0, topk_host, sizeof(topk_host)) != 0); + TEST_ASSERT(ds4_gpu_set_model_map(sinks_raw, sinks_alloc) != 0); + ds4_gpu_set_quality(false); + + if (turbo4) { + TEST_ASSERT(ds4_gpu_turbo4_kv_quantize_tensor(comp_turbo, comp_src, + n_comp, 1, head_dim, n_rot) != 0); + TEST_ASSERT(ds4_gpu_turbo4_dequant_f32_tensor(comp_deq, comp_turbo, + n_blocks, n_rot, n_comp) != 0); + TEST_ASSERT(ds4_gpu_turbo4_dequant_selected_f32_tensor(comp_selected_deq, + identity_topk, + comp_turbo, + topk, + n_blocks, + n_rot, + n_comp, + top_k, + n_tokens) != 0); + TEST_ASSERT(ds4_gpu_turbo4_dequant_selected_f16_tensor(comp_selected_deq_f16, + identity_topk, + comp_turbo, + topk, + n_blocks, + n_rot, + n_comp, + top_k, + n_tokens) != 0); + } else { + TEST_ASSERT(ds4_gpu_turbo3_kv_quantize_tensor(comp_turbo, comp_src, + n_comp, 1, head_dim, n_rot) != 0); + TEST_ASSERT(ds4_gpu_turbo3_dequant_f32_tensor(comp_deq, comp_turbo, + n_blocks, n_rot, n_comp) != 0); + TEST_ASSERT(ds4_gpu_turbo3_dequant_selected_f32_tensor(comp_selected_deq, + identity_topk, + comp_turbo, + topk, + n_blocks, + n_rot, + n_comp, + top_k, + n_tokens) != 0); + TEST_ASSERT(ds4_gpu_turbo3_dequant_selected_f16_tensor(comp_selected_deq_f16, + identity_topk, + comp_turbo, + topk, + n_blocks, + n_rot, + n_comp, + top_k, + n_tokens) != 0); + } + + TEST_ASSERT(ds4_gpu_attention_indexed_mixed_batch_heads_tensor( + heads_ref, sinks_raw, sinks_alloc, 0, + q, raw, comp_deq, topk, + n_tokens, 31, n_raw, raw_cap, 0, n_comp, top_k, 128, 4, + n_head, head_dim) != 0); + TEST_ASSERT(ds4_gpu_attention_indexed_mixed_turbo_batch_heads_tensor( + heads_turbo, sinks_raw, sinks_alloc, 0, + q, raw, comp_turbo, topk, + turbo4 ? DS4_KV_QUANT_TURBO4 : DS4_KV_QUANT_TURBO3, + n_blocks, n_rot, + n_tokens, 31, n_raw, raw_cap, 0, n_comp, top_k, 128, 4, + n_head, head_dim) != 0); + TEST_ASSERT(ds4_gpu_attention_indexed_mixed_batch_heads_tensor( + heads_selected, sinks_raw, sinks_alloc, 0, + q, raw, comp_selected_deq, identity_topk, + n_tokens, 31, n_raw, raw_cap, 0, top_k, top_k, 128, 4, + n_head, head_dim) != 0); + TEST_ASSERT(ds4_gpu_attention_indexed_mixed_comp_f16_batch_heads_tensor( + heads_selected_f16, sinks_raw, sinks_alloc, 0, + q, raw, comp_selected_deq_f16, identity_topk, + n_tokens, 31, n_raw, raw_cap, 0, top_k, top_k, 128, 4, + n_head, head_dim) != 0); + TEST_ASSERT(ds4_gpu_tensor_read(heads_ref, 0, ref_host, heads_bytes) != 0); + TEST_ASSERT(ds4_gpu_tensor_read(heads_turbo, 0, turbo_host, heads_bytes) != 0); + + float max_abs = 0.0f; + for (uint64_t i = 0; i < heads_bytes / sizeof(float); i++) { + TEST_ASSERT(isfinite(ref_host[i])); + TEST_ASSERT(isfinite(turbo_host[i])); + const float err = fabsf(ref_host[i] - turbo_host[i]); + if (err > max_abs) max_abs = err; + } + TEST_ASSERT(max_abs < 2.0e-4f); + + TEST_ASSERT(ds4_gpu_tensor_read(heads_selected, 0, turbo_host, heads_bytes) != 0); + max_abs = 0.0f; + for (uint64_t i = 0; i < heads_bytes / sizeof(float); i++) { + TEST_ASSERT(isfinite(turbo_host[i])); + const float err = fabsf(ref_host[i] - turbo_host[i]); + if (err > max_abs) max_abs = err; + } + TEST_ASSERT(max_abs < 2.0e-4f); + + TEST_ASSERT(ds4_gpu_tensor_read(heads_selected_f16, 0, turbo_host, heads_bytes) != 0); + max_abs = 0.0f; + for (uint64_t i = 0; i < heads_bytes / sizeof(float); i++) { + TEST_ASSERT(isfinite(turbo_host[i])); + const float err = fabsf(ref_host[i] - turbo_host[i]); + if (err > max_abs) max_abs = err; + } + TEST_ASSERT(max_abs < 2.0e-4f); + } + + free(sinks_raw); + free(turbo_host); + free(ref_host); + free(comp_host); + free(raw_host); + free(q_host); + ds4_gpu_tensor_free(heads_turbo); + ds4_gpu_tensor_free(heads_selected); + ds4_gpu_tensor_free(heads_selected_f16); + ds4_gpu_tensor_free(heads_ref); + ds4_gpu_tensor_free(identity_topk); + ds4_gpu_tensor_free(topk); + ds4_gpu_tensor_free(comp_selected_deq_f16); + ds4_gpu_tensor_free(comp_selected_deq); + ds4_gpu_tensor_free(comp_turbo); + ds4_gpu_tensor_free(comp_deq); + ds4_gpu_tensor_free(comp_src); + ds4_gpu_tensor_free(raw); + ds4_gpu_tensor_free(q); +} + +static void test_metal_turbo_indexed_attention(void) { + test_metal_turbo_indexed_attention_one(false); + test_metal_turbo_indexed_attention_one(true); +} + +static void test_metal_kernels(void) { + test_metal_f16_matvec_fast_nr0_4(); + test_metal_turbo_quant_roundtrip(); + test_metal_turbo_indexed_attention(); +} + static char *test_read_file(const char *path) { FILE *fp = fopen(path, "rb"); if (!fp) return NULL; @@ -650,7 +958,7 @@ static const ds4_test_entry test_entries[] = { {"--long-context", "long-context", "long-context story fact-recall regression", test_long_story_fact_recall}, {"--tool-call-quality", "tool-call-quality", "model emits valid DSML tool calls", test_tool_call_quality}, {"--logprob-vectors", "logprob-vectors", "official API top-logprob vector comparison", test_official_logprob_vectors}, - {"--metal-kernels", "metal-kernels", "isolated Metal kernel numeric regressions", test_metal_f16_matvec_fast_nr0_4}, + {"--metal-kernels", "metal-kernels", "isolated Metal kernel numeric regressions", test_metal_kernels}, #endif {"--server", "server", "server parser/rendering/cache unit tests", test_server_unit_group}, }; diff --git a/tests/verify_turbo_server.sh b/tests/verify_turbo_server.sh new file mode 100755 index 00000000..ce799f14 --- /dev/null +++ b/tests/verify_turbo_server.sh @@ -0,0 +1,639 @@ +#!/usr/bin/env bash +set -euo pipefail + +usage() { + cat <<'USAGE' +Usage: tests/verify_turbo_server.sh [options] + +Runs live ds4-server verification for TurboQuant KV modes: + - /v1/models smoke check + - non-streaming and streaming /v1/chat/completions + - memory-token cache reuse + - memory-text cache reuse across a tokenizer-boundary prompt extension + - cold/continued/evict KV disk checkpoint writes + - restart and disk-text KV checkpoint load + - optional deterministic output comparison across FP8, turbo3, turbo4 + - optional long-prompt no-cache performance benchmark across FP8, turbo3, turbo4 + +Options: + --server FILE Server binary. Default: ./ds4-server + --model FILE GGUF model. Default: ds4flash.gguf + --ctx N Server context. Default: 8192 + --port N Preferred port. Default: 1233 + --host HOST Bind host. Default: 127.0.0.1 + --kv-dir DIR Base KV cache directory. Default: /tmp/ds4-kv-turbo4 + --out-dir DIR Log/trace output directory. Default: /tmp/ds4-turbo-verify.$$ + --primary MODE Cache validation mode: 4, 3, fp8, or 0. Default: 4 + --no-cache Skip primary cache validation phase. + --compare MODES Comma-separated comparison modes. Default: fp8,3,4 + --no-compare Skip comparison phase. + --bench MODES Comma-separated benchmark modes. Default: fp8,3,4 + --no-bench Skip performance benchmark phase. + --bench-repeats N Benchmark prompt paragraph repeats. Default: 240 + --bench-tokens N Benchmark generated tokens. Default: 64 + --long-repeats N Long prompt paragraph repeats. Default: 80 + --max-tokens N Generated tokens for decode checks. Default: 16 + --start-timeout N Seconds to wait for server startup. Default: 600 + --keep-server Leave the last server running on failure. + -h, --help Show this help. + +Environment overrides use the DS4_VERIFY_* prefix, for example: + DS4_VERIFY_CTX=384000 DS4_VERIFY_KV_DIR=/tmp/ds4-kv-turbo4 tests/verify_turbo_server.sh +USAGE +} + +SERVER="${DS4_VERIFY_SERVER:-./ds4-server}" +MODEL="${DS4_VERIFY_MODEL:-ds4flash.gguf}" +HOST="${DS4_VERIFY_HOST:-127.0.0.1}" +PORT="${DS4_VERIFY_PORT:-1233}" +CTX="${DS4_VERIFY_CTX:-8192}" +KV_BASE="${DS4_VERIFY_KV_DIR:-/tmp/ds4-kv-turbo4}" +OUT_DIR="${DS4_VERIFY_OUT_DIR:-/tmp/ds4-turbo-verify.$$}" +PRIMARY_MODE="${DS4_VERIFY_PRIMARY_MODE:-4}" +CACHE="${DS4_VERIFY_CACHE:-1}" +COMPARE_MODES="${DS4_VERIFY_COMPARE_MODES:-fp8,3,4}" +COMPARE="${DS4_VERIFY_COMPARE:-1}" +BENCH_MODES="${DS4_VERIFY_BENCH_MODES:-fp8,3,4}" +BENCH="${DS4_VERIFY_BENCH:-1}" +BENCH_REPEATS="${DS4_VERIFY_BENCH_REPEATS:-240}" +BENCH_TOKENS="${DS4_VERIFY_BENCH_TOKENS:-64}" +LONG_REPEATS="${DS4_VERIFY_LONG_REPEATS:-80}" +MAX_TOKENS="${DS4_VERIFY_MAX_TOKENS:-16}" +START_TIMEOUT="${DS4_VERIFY_START_TIMEOUT:-600}" +KV_SPACE_MB="${DS4_VERIFY_KV_SPACE_MB:-4096}" +KV_MIN_TOKENS="${DS4_VERIFY_KV_MIN_TOKENS:-64}" +KV_COLD_MAX_TOKENS="${DS4_VERIFY_KV_COLD_MAX_TOKENS:-4096}" +KV_CONTINUED_TOKENS="${DS4_VERIFY_KV_CONTINUED_TOKENS:-512}" +KV_TRIM_TOKENS="${DS4_VERIFY_KV_TRIM_TOKENS:-8}" +KV_ALIGN_TOKENS="${DS4_VERIFY_KV_ALIGN_TOKENS:-64}" +KEEP_SERVER=0 + +while [ "$#" -gt 0 ]; do + case "$1" in + --server) SERVER="$2"; shift 2 ;; + --model) MODEL="$2"; shift 2 ;; + --host) HOST="$2"; shift 2 ;; + --port) PORT="$2"; shift 2 ;; + --ctx) CTX="$2"; shift 2 ;; + --kv-dir) KV_BASE="$2"; shift 2 ;; + --out-dir) OUT_DIR="$2"; shift 2 ;; + --primary) PRIMARY_MODE="$2"; shift 2 ;; + --no-cache) CACHE=0; shift ;; + --compare) COMPARE_MODES="$2"; COMPARE=1; shift 2 ;; + --no-compare) COMPARE=0; shift ;; + --bench) BENCH_MODES="$2"; BENCH=1; shift 2 ;; + --no-bench) BENCH=0; shift ;; + --bench-repeats) BENCH_REPEATS="$2"; shift 2 ;; + --bench-tokens) BENCH_TOKENS="$2"; shift 2 ;; + --long-repeats) LONG_REPEATS="$2"; shift 2 ;; + --max-tokens) MAX_TOKENS="$2"; shift 2 ;; + --start-timeout) START_TIMEOUT="$2"; shift 2 ;; + --keep-server) KEEP_SERVER=1; shift ;; + -h|--help) usage; exit 0 ;; + *) echo "verify_turbo_server: unknown option: $1" >&2; usage >&2; exit 2 ;; + esac +done + +require_file() { + if [ ! -e "$1" ]; then + echo "verify_turbo_server: missing $2: $1" >&2 + exit 2 + fi +} + +require_file "$SERVER" "server binary" +require_file "$MODEL" "model" +command -v curl >/dev/null 2>&1 || { echo "verify_turbo_server: curl is required" >&2; exit 2; } +command -v python3 >/dev/null 2>&1 || { echo "verify_turbo_server: python3 is required" >&2; exit 2; } + +mkdir -p "$OUT_DIR" + +SERVER_PID="" +SERVER_LOG="" +SERVER_TRACE="" +SERVER_PORT="" + +cleanup() { + if [ -n "${SERVER_PID:-}" ] && kill -0 "$SERVER_PID" >/dev/null 2>&1; then + if [ "$KEEP_SERVER" = 1 ]; then + echo "verify_turbo_server: leaving server pid $SERVER_PID running at http://$HOST:$SERVER_PORT" >&2 + return + fi + kill -TERM "$SERVER_PID" >/dev/null 2>&1 || true + wait "$SERVER_PID" >/dev/null 2>&1 || true + fi +} +trap cleanup EXIT + +choose_port() { + python3 - "$HOST" "$PORT" <<'PY' +import socket +import sys + +host = sys.argv[1] +preferred = int(sys.argv[2]) + +def can_bind(port): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind((host, port)) + return True + except OSError: + return False + finally: + s.close() + +if can_bind(preferred): + print(preferred) +else: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind((host, 0)) + print(s.getsockname()[1]) + s.close() +PY +} + +mode_name() { + case "$1" in + ""|0|fp8|FP8) echo "fp8" ;; + 3|turbo3) echo "turbo3" ;; + 4|turbo4) echo "turbo4" ;; + *) echo "$1" ;; + esac +} + +mode_env_prefix() { + case "$1" in + ""|0|fp8|FP8) echo "env -u DS4_KV_TURBO" ;; + 3|turbo3) echo "env DS4_KV_TURBO=3" ;; + 4|turbo4) echo "env DS4_KV_TURBO=4" ;; + *) echo "env DS4_KV_TURBO=$1" ;; + esac +} + +start_server() { + local mode="$1" + local phase="$2" + local name + name="$(mode_name "$mode")" + SERVER_PORT="$(choose_port)" + if [ "$SERVER_PORT" != "$PORT" ]; then + echo "verify_turbo_server: preferred port $PORT is busy; using $SERVER_PORT" >&2 + fi + + local kv_dir="$KV_BASE.$name" + mkdir -p "$kv_dir" + SERVER_LOG="$OUT_DIR/$name.$phase.server.log" + SERVER_TRACE="$OUT_DIR/$name.$phase.trace" + echo "verify_turbo_server: starting $name server ($phase) on http://$HOST:$SERVER_PORT" >&2 + + local env_prefix + env_prefix="$(mode_env_prefix "$mode")" + # shellcheck disable=SC2086 + $env_prefix "$SERVER" \ + --model "$MODEL" \ + --ctx "$CTX" \ + --tokens "$MAX_TOKENS" \ + --host "$HOST" \ + --port "$SERVER_PORT" \ + --trace "$SERVER_TRACE" \ + --kv-disk-dir "$kv_dir" \ + --kv-disk-space-mb "$KV_SPACE_MB" \ + --kv-cache-min-tokens "$KV_MIN_TOKENS" \ + --kv-cache-cold-max-tokens "$KV_COLD_MAX_TOKENS" \ + --kv-cache-continued-interval-tokens "$KV_CONTINUED_TOKENS" \ + --kv-cache-boundary-trim-tokens "$KV_TRIM_TOKENS" \ + --kv-cache-boundary-align-tokens "$KV_ALIGN_TOKENS" \ + >"$SERVER_LOG" 2>&1 & + SERVER_PID="$!" + wait_ready "$SERVER_PID" "http://$HOST:$SERVER_PORT" +} + +start_server_bench() { + local mode="$1" + local phase="$2" + local name + name="$(mode_name "$mode")" + SERVER_PORT="$(choose_port)" + if [ "$SERVER_PORT" != "$PORT" ]; then + echo "verify_turbo_server: preferred port $PORT is busy; using $SERVER_PORT" >&2 + fi + + SERVER_LOG="$OUT_DIR/$name.$phase.server.log" + SERVER_TRACE="$OUT_DIR/$name.$phase.trace" + echo "verify_turbo_server: starting $name server ($phase, no kv disk cache) on http://$HOST:$SERVER_PORT" >&2 + + local env_prefix + env_prefix="$(mode_env_prefix "$mode")" + # shellcheck disable=SC2086 + $env_prefix "$SERVER" \ + --model "$MODEL" \ + --ctx "$CTX" \ + --tokens "$BENCH_TOKENS" \ + --host "$HOST" \ + --port "$SERVER_PORT" \ + --trace "$SERVER_TRACE" \ + >"$SERVER_LOG" 2>&1 & + SERVER_PID="$!" + wait_ready "$SERVER_PID" "http://$HOST:$SERVER_PORT" +} + +stop_server() { + if [ -z "${SERVER_PID:-}" ]; then + return + fi + if kill -0 "$SERVER_PID" >/dev/null 2>&1; then + kill -TERM "$SERVER_PID" >/dev/null 2>&1 || true + wait "$SERVER_PID" >/dev/null 2>&1 || true + fi + SERVER_PID="" +} + +wait_ready() { + local pid="$1" + local base_url="$2" + local deadline=$((SECONDS + START_TIMEOUT)) + while [ "$SECONDS" -lt "$deadline" ]; do + if ! kill -0 "$pid" >/dev/null 2>&1; then + echo "verify_turbo_server: server exited during startup; log: $SERVER_LOG" >&2 + tail -80 "$SERVER_LOG" >&2 || true + exit 1 + fi + if curl -fsS --max-time 2 "$base_url/v1/models" >"$OUT_DIR/.ready.json" 2>/dev/null; then + return + fi + sleep 1 + done + echo "verify_turbo_server: timed out waiting for server; log: $SERVER_LOG" >&2 + tail -80 "$SERVER_LOG" >&2 || true + exit 1 +} + +client_json() { + local base_url="$1" + local action="$2" + local out="$3" + python3 - "$base_url" "$action" "$out" "$LONG_REPEATS" "$MAX_TOKENS" "$BENCH_REPEATS" "$BENCH_TOKENS" <<'PY' +import json +import sys +import urllib.error +import urllib.request + +base_url, action, out_path = sys.argv[1], sys.argv[2], sys.argv[3] +long_repeats, max_tokens = int(sys.argv[4]), int(sys.argv[5]) +bench_repeats, bench_tokens = int(sys.argv[6]), int(sys.argv[7]) + +def request_json(path, payload=None, stream=False): + data = None if payload is None else json.dumps(payload).encode("utf-8") + req = urllib.request.Request(base_url + path, data=data) + if payload is not None: + req.add_header("Content-Type", "application/json") + try: + with urllib.request.urlopen(req, timeout=1800) as resp: + if stream: + chunks = [] + saw_done = False + for raw in resp: + line = raw.decode("utf-8", "replace").rstrip("\n") + chunks.append(line) + if line.strip() == "data: [DONE]": + saw_done = True + text = "\n".join(chunks) + "\n" + if not saw_done: + raise RuntimeError("stream ended without [DONE]") + return text + body = resp.read().decode("utf-8") + return body + except urllib.error.HTTPError as e: + body = e.read().decode("utf-8", "replace") + raise RuntimeError(f"HTTP {e.code}: {body}") from e + +def chat_payload(content, tokens=max_tokens, stream=False): + return { + "model": "deepseek-v4-flash", + "messages": [{"role": "user", "content": content}], + "think": False, + "temperature": 0, + "top_p": 1, + "seed": 20260513, + "max_tokens": tokens, + "stream": stream, + "stream_options": {"include_usage": True}, + } + +def long_prompt(): + para = ( + "Cache verification record {i}: the prefix must remain byte-stable, " + "the model should continue after a restored KV checkpoint, and every " + "line includes deterministic filler for token volume.\n" + ) + return "".join(para.format(i=i) for i in range(long_repeats)) + +def bench_prompt(): + para = ( + "Performance benchmark record {i}: this stable long-context prefix " + "forces the attention path to read compressed KV rows during decode, " + "so FP8 and TurboQuant cache formats can be compared directly.\n" + ) + return "".join(para.format(i=i) for i in range(bench_repeats)) + +if action == "models": + body = request_json("/v1/models") + obj = json.loads(body) + assert obj.get("object") == "list" +elif action == "short": + body = request_json("/v1/chat/completions", chat_payload("Define a KV cache hit in one short sentence.")) + obj = json.loads(body) + choice = obj["choices"][0] + assert choice["message"].get("content", "") is not None +elif action == "stream": + body = request_json("/v1/chat/completions", chat_payload("Reply with exactly three words about local inference.", stream=True), stream=True) + assert "data: " in body +elif action == "memory0": + body = request_json("/v1/chat/completions", chat_payload("Keep this prompt resident for a zero-token cache probe.", tokens=0)) + obj = json.loads(body) + assert obj["choices"][0]["finish_reason"] in ("length", "stop") +elif action == "memory_text": + prompt = "Write one common English word related to caching." + first = request_json("/v1/chat/completions", chat_payload(prompt, tokens=1)) + first_obj = json.loads(first) + piece = first_obj["choices"][0]["message"].get("content", "") + if not piece: + raise RuntimeError("memory-text probe generated no assistant bytes") + payload = { + "model": "deepseek-v4-flash", + "messages": [ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": piece + "s"}, + ], + "think": False, + "temperature": 0, + "top_p": 1, + "seed": 20260513, + "max_tokens": 0, + "stream": False, + } + body = request_json("/v1/chat/completions", payload) + obj = json.loads(body) + assert obj["choices"][0]["finish_reason"] in ("length", "stop") +elif action == "long": + body = request_json("/v1/chat/completions", chat_payload(long_prompt())) + obj = json.loads(body) + msg = obj["choices"][0]["message"] + assert "content" in msg +elif action == "evict": + body = request_json("/v1/chat/completions", chat_payload("Unrelated short prompt used to evict the previous live KV session.", tokens=0)) + obj = json.loads(body) + assert obj["choices"] +elif action == "compare": + body = request_json("/v1/chat/completions", chat_payload("In one sentence, describe why disk KV cache reuse helps long chats.", tokens=24)) + obj = json.loads(body) + content = obj["choices"][0]["message"].get("content", "") + if not content.strip(): + raise RuntimeError("empty comparison content") + words = content.split() + run = 1 + for prev, cur in zip(words, words[1:]): + run = run + 1 if cur == prev else 1 + if run >= 8: + raise RuntimeError("comparison content appears repetitively collapsed") +elif action == "bench": + body = request_json("/v1/chat/completions", chat_payload(bench_prompt(), tokens=bench_tokens)) + obj = json.loads(body) + content = obj["choices"][0]["message"].get("content", "") + if not content.strip(): + raise RuntimeError("empty benchmark content") +else: + raise RuntimeError(f"unknown action: {action}") + +with open(out_path, "w", encoding="utf-8") as f: + f.write(body) +PY +} + +require_trace_pattern() { + local pattern="$1" + local file="$2" + local label="$3" + if ! grep -q "$pattern" "$file"; then + echo "verify_turbo_server: missing $label in $file" >&2 + tail -120 "$file" >&2 || true + exit 1 + fi +} + +require_log_pattern() { + local pattern="$1" + local file="$2" + local label="$3" + if ! grep -q "$pattern" "$file"; then + echo "verify_turbo_server: missing $label in $file" >&2 + tail -120 "$file" >&2 || true + exit 1 + fi +} + +count_pattern() { + local pattern="$1" + local file="$2" + if [ -f "$file" ]; then + grep -c "$pattern" "$file" || true + else + echo 0 + fi +} + +summarize_phase() { + local mode="$1" + local phase="$2" + local log="$OUT_DIR/$(mode_name "$mode").$phase.server.log" + local trace="$OUT_DIR/$(mode_name "$mode").$phase.trace" + { + echo "mode=$(mode_name "$mode") phase=$phase" + echo "log=$log" + echo "trace=$trace" + echo "memory-token=$(count_pattern 'cache_source: memory-token' "$trace")" + echo "memory-text=$(count_pattern 'cache_source: memory-text' "$trace")" + echo "disk-text=$(count_pattern 'cache_source: disk-text' "$trace")" + echo "cold-store=$(count_pattern 'reason=cold' "$log")" + echo "continued-store=$(count_pattern 'reason=continued' "$log")" + echo "evict-store=$(count_pattern 'reason=evict' "$log")" + grep -E 'context buffers|prefill chunk|prompt done|decoding chunk|kv cache (stored|hit|load failed|evicted)' "$log" || true + } >"$OUT_DIR/$(mode_name "$mode").$phase.summary" +} + +run_primary_cache_validation() { + local mode="$1" + local name + name="$(mode_name "$mode")" + local base_url + rm -rf "$KV_BASE.$name" + mkdir -p "$KV_BASE.$name" + + start_server "$mode" "cache1" + base_url="http://$HOST:$SERVER_PORT" + client_json "$base_url" models "$OUT_DIR/$name.models.json" + client_json "$base_url" short "$OUT_DIR/$name.short.json" + client_json "$base_url" stream "$OUT_DIR/$name.stream.sse" + client_json "$base_url" memory0 "$OUT_DIR/$name.memory0.a.json" + client_json "$base_url" memory0 "$OUT_DIR/$name.memory0.b.json" + client_json "$base_url" memory_text "$OUT_DIR/$name.memory-text.json" + client_json "$base_url" long "$OUT_DIR/$name.long.first.json" + client_json "$base_url" evict "$OUT_DIR/$name.evict.json" + stop_server + + summarize_phase "$mode" "cache1" + require_trace_pattern 'cache_source: memory-token' "$OUT_DIR/$name.cache1.trace" "memory-token cache source" + require_trace_pattern 'cache_source: memory-text' "$OUT_DIR/$name.cache1.trace" "memory-text cache source" + require_log_pattern 'reason=cold' "$OUT_DIR/$name.cache1.server.log" "cold KV store" + require_log_pattern 'reason=continued' "$OUT_DIR/$name.cache1.server.log" "continued KV store" + require_log_pattern 'reason=evict' "$OUT_DIR/$name.cache1.server.log" "evict KV store" + + local kv_count + kv_count="$(find "$KV_BASE.$name" -name '*.kv' -type f | wc -l | tr -d ' ')" + if [ "$kv_count" = "0" ]; then + echo "verify_turbo_server: no KV checkpoint files were created in $KV_BASE.$name" >&2 + exit 1 + fi + + start_server "$mode" "cache2" + base_url="http://$HOST:$SERVER_PORT" + client_json "$base_url" long "$OUT_DIR/$name.long.second.json" + stop_server + + summarize_phase "$mode" "cache2" + require_trace_pattern 'cache_source: disk-text' "$OUT_DIR/$name.cache2.trace" "disk-text cache source after restart" + require_log_pattern 'kv cache hit text' "$OUT_DIR/$name.cache2.server.log" "KV disk cache hit" +} + +run_compare_mode() { + local mode="$1" + local name + name="$(mode_name "$mode")" + start_server "$mode" "compare" + client_json "http://$HOST:$SERVER_PORT" compare "$OUT_DIR/$name.compare.json" + stop_server + summarize_phase "$mode" "compare" +} + +run_comparison() { + local modes_csv="$1" + local modes + modes="${modes_csv//,/ }" + for mode in $modes; do + run_compare_mode "$mode" + done + python3 - "$OUT_DIR" $modes <<'PY' +import json +import os +import sys + +out_dir = sys.argv[1] +modes = sys.argv[2:] +for mode in modes: + name = {"0": "fp8", "fp8": "fp8", "FP8": "fp8", "3": "turbo3", "turbo3": "turbo3", "4": "turbo4", "turbo4": "turbo4"}.get(mode, mode) + path = os.path.join(out_dir, f"{name}.compare.json") + with open(path, encoding="utf-8") as f: + obj = json.load(f) + content = obj["choices"][0]["message"].get("content", "").strip().replace("\n", "\\n") + print(f"{name}: {content}") +PY +} + +run_bench_mode() { + local mode="$1" + local name + name="$(mode_name "$mode")" + start_server_bench "$mode" "bench" + client_json "http://$HOST:$SERVER_PORT" bench "$OUT_DIR/$name.bench.json" + stop_server + summarize_phase "$mode" "bench" +} + +run_benchmark() { + local modes_csv="$1" + local modes + modes="${modes_csv//,/ }" + for mode in $modes; do + run_bench_mode "$mode" + done + python3 - "$OUT_DIR" $modes <<'PY' +import os +import re +import sys + +out_dir = sys.argv[1] +modes = sys.argv[2:] + +def mode_name(mode): + return {"0": "fp8", "fp8": "fp8", "FP8": "fp8", "3": "turbo3", "turbo3": "turbo3", "4": "turbo4", "turbo4": "turbo4"}.get(mode, mode) + +print("mode\tcontext_mib\tprompt_tokens\tprefill_sec\tprefill_tps\tdecode_tokens\tdecode_tps") +for mode in modes: + name = mode_name(mode) + path = os.path.join(out_dir, f"{name}.bench.server.log") + text = open(path, encoding="utf-8", errors="replace").read() + + mem = "" + m = re.search(r"context buffers ([0-9.]+) MiB", text) + if m: + mem = m.group(1) + + prompt_tokens = "" + prefill_sec = "" + prefill_tps = "" + prompt_matches = list(re.finditer(r"ctx=([0-9]+)\.\.([0-9]+):([0-9]+) prompt done ([0-9.]+)s", text)) + if prompt_matches: + m = prompt_matches[-1] + prompt_tokens = m.group(3) + prefill_sec = m.group(4) + try: + prefill_tps = f"{float(prompt_tokens) / float(prefill_sec):.2f}" + except ZeroDivisionError: + prefill_tps = "inf" + + decode_tokens = "" + decode_tps = "" + decode_matches = list(re.finditer(r"gen=([0-9]+) decoding chunk=[0-9.]+ t/s avg=([0-9.]+) t/s", text)) + if decode_matches: + m = decode_matches[-1] + decode_tokens = m.group(1) + decode_tps = m.group(2) + + print(f"{name}\t{mem}\t{prompt_tokens}\t{prefill_sec}\t{prefill_tps}\t{decode_tokens}\t{decode_tps}") +PY +} + +echo "verify_turbo_server: logs: $OUT_DIR" >&2 + +if [ "$CACHE" = 1 ]; then + run_primary_cache_validation "$PRIMARY_MODE" +fi + +if [ "$COMPARE" = 1 ]; then + run_comparison "$COMPARE_MODES" >"$OUT_DIR/compare.outputs" +fi + +if [ "$BENCH" = 1 ]; then + run_benchmark "$BENCH_MODES" >"$OUT_DIR/bench.tsv" +fi + +{ + echo "verify_turbo_server: OK" + echo "output_dir=$OUT_DIR" + echo "kv_base=$KV_BASE" + echo "primary=$(mode_name "$PRIMARY_MODE")" + if [ "$CACHE" = 1 ]; then + echo "cache_summaries:" + ls "$OUT_DIR"/*.summary 2>/dev/null || true + fi + if [ -f "$OUT_DIR/compare.outputs" ]; then + echo "compare_outputs:" + cat "$OUT_DIR/compare.outputs" + fi + if [ -f "$OUT_DIR/bench.tsv" ]; then + echo "benchmark:" + cat "$OUT_DIR/bench.tsv" + fi +} | tee "$OUT_DIR/summary.txt"