Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,34 @@ support the CPU backend for reference/debug use and share the same KV session
and snapshot format as Metal and CUDA, but normal inference should use Metal or
CUDA.

### Metal TurboQuant KV cache

The Metal graph can store compressed attention KV rows with the TurboQuant
PolarQuant/WHT formats instead of the default FP8 rows:

```sh
DS4_KV_TURBO=4 ./ds4-server --ctx 384000 ...
DS4_KV_TURBO=3 ./ds4-server --ctx 384000 ...
```

`DS4_KV_TURBO=4` is the preferred quality/speed mode. `DS4_KV_TURBO=3`
uses a smaller 3-bit cache and is mainly useful when memory pressure is the
first constraint. Unset `DS4_KV_TURBO`, or set it to `0`, to keep the FP8
path.

For ratio-4 indexed decode, the default compressed-row top-k remains 512.
Fast non-quality runs can lower it with `DS4_METAL_DECODE_INDEXER_TOP_K`;
values are capped at 512 and rounded down to a power of two:

```sh
DS4_KV_TURBO=4 DS4_METAL_DECODE_INDEXER_TOP_K=128 ./ds4-server --ctx 384000 ...
```

`--quality` keeps the 512-row path. The diagnostic switches
`DS4_METAL_DISABLE_TURBO_DIRECT_ATTN=1` and
`DS4_METAL_DISABLE_TURBO_SELECTED_F16=1` restore the older materialized
attention paths for comparisons.

## Steering

This project supports steering with single-vector activation directions; see the
Expand Down
584 changes: 469 additions & 115 deletions ds4.c

Large diffs are not rendered by default.

143 changes: 143 additions & 0 deletions ds4_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,104 @@ int ds4_gpu_dsv4_fp8_kv_quantize_tensor(
uint32_t head_dim,
uint32_t n_rot);

/* =========================================================================
* TurboQuant KV Cache Compression.
* =========================================================================
*
* TurboQuant (arXiv 2504.19874) compresses KV cache rows with PolarQuant +
* Walsh-Hadamard rotation. 3-bit (turbo3) and 4-bit (turbo4) formats are
* supported. The Metal integration stores compressed rows in per-layer turbo
* caches and dequantizes them back to float32 scratch before passing them to
* the existing attention kernels.
*/

/* TurboQuant types for kv_quant_type configuration. */
#define DS4_KV_QUANT_FP8 0
#define DS4_KV_QUANT_TURBO3 1
#define DS4_KV_QUANT_TURBO4 2

/* Block sizes (bytes) for buffer allocation. */
#define DS4_TURBO3_BLOCK_BYTES 14 /* sizeof(block_turbo3_0): qs[8] + signs[4] + norm(2) */
#define DS4_TURBO4_BLOCK_BYTES 18 /* sizeof(block_turbo4_0): qs[16] + norm(2) */

/* Quantize a float32 KV tensor to TurboQuant blocks. The input tensor is laid
* out as [n_tok, n_head, head_dim]. The output receives packed blocks for the
* non-RoPE prefix, followed by the RoPE tail as raw float32. */
int ds4_gpu_turbo3_kv_quantize_tensor(
ds4_gpu_tensor *out,
const ds4_gpu_tensor *x,
uint32_t n_tok,
uint32_t n_head,
uint32_t head_dim,
uint32_t n_rot);

int ds4_gpu_turbo4_kv_quantize_tensor(
ds4_gpu_tensor *out,
const ds4_gpu_tensor *x,
uint32_t n_tok,
uint32_t n_head,
uint32_t head_dim,
uint32_t n_rot);

/* Dequantize TurboQuant KV blocks to float32 for flash attention. */
int ds4_gpu_turbo3_dequant_f32_tensor(
ds4_gpu_tensor *out,
const ds4_gpu_tensor *x,
uint32_t n_blocks,
uint32_t n_rot,
uint32_t n_rows);

int ds4_gpu_turbo4_dequant_f32_tensor(
ds4_gpu_tensor *out,
const ds4_gpu_tensor *x,
uint32_t n_blocks,
uint32_t n_rot,
uint32_t n_rows);

int ds4_gpu_turbo3_dequant_selected_f32_tensor(
ds4_gpu_tensor *out,
ds4_gpu_tensor *identity_topk,
const ds4_gpu_tensor *x,
const ds4_gpu_tensor *topk,
uint32_t n_blocks,
uint32_t n_rot,
uint32_t n_comp,
uint32_t top_k,
uint32_t n_tokens);

int ds4_gpu_turbo4_dequant_selected_f32_tensor(
ds4_gpu_tensor *out,
ds4_gpu_tensor *identity_topk,
const ds4_gpu_tensor *x,
const ds4_gpu_tensor *topk,
uint32_t n_blocks,
uint32_t n_rot,
uint32_t n_comp,
uint32_t top_k,
uint32_t n_tokens);

int ds4_gpu_turbo3_dequant_selected_f16_tensor(
ds4_gpu_tensor *out,
ds4_gpu_tensor *identity_topk,
const ds4_gpu_tensor *x,
const ds4_gpu_tensor *topk,
uint32_t n_blocks,
uint32_t n_rot,
uint32_t n_comp,
uint32_t top_k,
uint32_t n_tokens);

int ds4_gpu_turbo4_dequant_selected_f16_tensor(
ds4_gpu_tensor *out,
ds4_gpu_tensor *identity_topk,
const ds4_gpu_tensor *x,
const ds4_gpu_tensor *topk,
uint32_t n_blocks,
uint32_t n_rot,
uint32_t n_comp,
uint32_t top_k,
uint32_t n_tokens);

int ds4_gpu_rope_tail_tensor(
ds4_gpu_tensor *x,
uint32_t n_tok,
Expand Down Expand Up @@ -494,6 +592,51 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor(
uint32_t n_head,
uint32_t head_dim);

int ds4_gpu_attention_indexed_mixed_comp_f16_batch_heads_tensor(
ds4_gpu_tensor *heads,
const void *model_map,
uint64_t model_size,
uint64_t sinks_offset,
const ds4_gpu_tensor *q,
const ds4_gpu_tensor *raw_kv,
const ds4_gpu_tensor *comp_kv,
const ds4_gpu_tensor *topk,
uint32_t n_tokens,
uint32_t pos0,
uint32_t n_raw,
uint32_t raw_cap,
uint32_t raw_start,
uint32_t n_comp,
uint32_t top_k,
uint32_t window,
uint32_t ratio,
uint32_t n_head,
uint32_t head_dim);

int ds4_gpu_attention_indexed_mixed_turbo_batch_heads_tensor(
ds4_gpu_tensor *heads,
const void *model_map,
uint64_t model_size,
uint64_t sinks_offset,
const ds4_gpu_tensor *q,
const ds4_gpu_tensor *raw_kv,
const ds4_gpu_tensor *comp_turbo_kv,
const ds4_gpu_tensor *topk,
uint32_t kv_quant_type,
uint32_t n_blocks,
uint32_t n_rot,
uint32_t n_tokens,
uint32_t pos0,
uint32_t n_raw,
uint32_t raw_cap,
uint32_t raw_start,
uint32_t n_comp,
uint32_t top_k,
uint32_t window,
uint32_t ratio,
uint32_t n_head,
uint32_t head_dim);

int ds4_gpu_attention_prefill_static_mixed_heads_tensor(
ds4_gpu_tensor *heads,
const void *model_map,
Expand Down
Loading