Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 96 additions & 8 deletions ds4.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
#include <sys/file.h>
#include <sys/mman.h>
#include <sys/stat.h>
#ifdef __APPLE__
#include <sys/sysctl.h>
#endif
#include <stdarg.h>
#include <time.h>
#include <unistd.h>
Expand Down Expand Up @@ -6098,7 +6101,35 @@ static uint32_t ds4_default_raw_cap(uint32_t ctx_size) {
return raw_cap;
}

static uint32_t ds4_default_prefill_cap_for_prompt(int prompt_len) {
static bool ds4_host_is_apple_m5_max(void) {
#if defined(__APPLE__)
static int initialized;
static int is_m5_max;
if (!initialized) {
char brand[128];
size_t len = sizeof(brand);
if (sysctlbyname("machdep.cpu.brand_string", brand, &len, NULL, 0) == 0) {
brand[sizeof(brand) - 1] = '\0';
is_m5_max = strstr(brand, "Apple M5 Max") != NULL;
}
initialized = 1;
}
return is_m5_max != 0;
#else
return false;
#endif
}

uint32_t ds4_backend_default_prefill_chunk(ds4_backend backend) {
if (backend == DS4_BACKEND_METAL && ds4_host_is_apple_m5_max()) return 4096u;
return 2048u;
}

uint32_t ds4_backend_default_kv_boundary_align_tokens(ds4_backend backend) {
return ds4_backend_default_prefill_chunk(backend);
}

static uint32_t ds4_default_prefill_cap_for_prompt(int prompt_len, ds4_backend backend) {
if (prompt_len <= 0) return 1;
uint32_t cap = (uint32_t)prompt_len;

Expand All @@ -6109,9 +6140,23 @@ static uint32_t ds4_default_prefill_cap_for_prompt(int prompt_len) {
if (endp != env) {
if (v <= 0) return cap;
cap = (uint32_t)v;
if (ds4_backend_default_prefill_chunk(backend) >= 4096u &&
cap > 4096u && getenv("DS4_METAL_ALLOW_UNSAFE_PREFILL_CHUNK") == NULL)
{
static int warned_large_prefill_chunk;
if (!warned_large_prefill_chunk) {
fprintf(stderr,
"ds4: DS4_METAL_PREFILL_CHUNK=%u exceeds the correctness-gated 4096-token limit; "
"clamping to 4096 (set DS4_METAL_ALLOW_UNSAFE_PREFILL_CHUNK=1 to experiment)\n",
cap);
warned_large_prefill_chunk = 1;
}
cap = 4096u;
}
}
} else if (prompt_len > 2048) {
cap = 2048u;
} else {
const uint32_t default_chunk = ds4_backend_default_prefill_chunk(backend);
if (prompt_len > (int)default_chunk) cap = default_chunk;
}

if (cap == 0) cap = 1;
Expand Down Expand Up @@ -11041,6 +11086,11 @@ static bool metal_graph_q_stage_profile_boundary(
return ds4_gpu_begin_commands() != 0;
}

static bool metal_graph_use_m5_large_prefill_schedule(const ds4_gpu_graph *g) {
return g && g->prefill_cap >= 4096u &&
ds4_backend_default_prefill_chunk(DS4_BACKEND_METAL) >= 4096u;
}

static bool metal_graph_encode_layer_attention_batch(
ds4_gpu_graph *g,
const ds4_model *model,
Expand Down Expand Up @@ -12057,7 +12107,16 @@ static bool metal_graph_encode_layer_attention_batch(
}

const bool topk_prefill_needed = ratio == 4 && n_comp > DS4_N_INDEXER_TOP_K;
if (ok && zero_prefix && topk_prefill_needed && n_comp != 0) {
/* The all-at-once zero-prefix indexed attention path selects one top-k
* set per token from the complete compressed prefix. For very large
* first chunks, high-scoring future compressed rows can crowd out older
* visible rows before the attention kernel applies its causal `visible`
* cutoff, which breaks long-memory prompts. Keep the fast batched path
* for the correctness-gated 2048-token chunk size; larger experimental
* chunks fall through to the per-token indexed path below. */
if (ok && zero_prefix && topk_prefill_needed && n_comp != 0 &&
(!metal_graph_use_m5_large_prefill_schedule(g) || n_tokens <= 2048u))
{
const float index_scale = 1.0f / sqrtf((float)(DS4_N_INDEXER_HEAD_DIM * DS4_N_INDEXER_HEAD));
double index_stage_t0 = 0.0;
if (index_stage_profile) {
Expand Down Expand Up @@ -13345,6 +13404,19 @@ static bool metal_graph_prefill_layer_major(
return ok;
}

static bool metal_graph_prefill_chunked_range(
ds4_gpu_graph *g,
const ds4_model *model,
const ds4_weights *weights,
const token_vec *prompt,
uint32_t start,
uint32_t n_tokens,
float *logits,
bool show_progress,
ds4_session_progress_fn progress,
void *progress_ud,
ds4_imatrix_collector *imatrix);

static bool metal_graph_prefill_raw_swa(
ds4_gpu_graph *g,
const ds4_model *model,
Expand All @@ -13355,6 +13427,19 @@ static bool metal_graph_prefill_raw_swa(
bool show_progress) {
if (n_tokens <= 0 || n_tokens > prompt->len) return false;
if ((uint32_t)n_tokens > g->prefill_cap) return false;
if (metal_graph_use_m5_large_prefill_schedule(g) && n_tokens > 2048) {
return metal_graph_prefill_chunked_range(g,
model,
weights,
prompt,
0,
(uint32_t)n_tokens,
logits,
show_progress,
NULL,
NULL,
NULL);
}
return metal_graph_prefill_layer_major(g, model, weights, prompt, n_tokens, logits, show_progress, NULL);
}

Expand Down Expand Up @@ -13435,6 +13520,9 @@ static bool metal_graph_prefill_chunked_range(
for (uint32_t pos0 = start; pos0 < end; ) {
const uint32_t remaining = end - pos0;
uint32_t local_cap = chunk_cap;
if (metal_graph_use_m5_large_prefill_schedule(g) && pos0 == 0 && local_cap > 2048u) {
local_cap = 2048u;
}
if (start != 0 && g->prefill_cap != 0) {
const uint32_t mod = pos0 % g->prefill_cap;
if (mod != 0) {
Expand Down Expand Up @@ -13843,9 +13931,9 @@ static uint32_t metal_graph_raw_cap_for_context(int ctx_size, uint32_t prefill_c
}

/* Choose the prefill ubatch size. Whole-batch is fastest for normal prompts;
* long prompts default to 2048-token chunks. */
* long prompts default to a backend-tuned chunk size. */
static uint32_t metal_graph_prefill_cap_for_prompt(int prompt_len) {
return ds4_default_prefill_cap_for_prompt(prompt_len);
return ds4_default_prefill_cap_for_prompt(prompt_len, DS4_BACKEND_METAL);
}

/* When a server request shares a large prefix with the live checkpoint, extend
Expand All @@ -13869,7 +13957,7 @@ ds4_context_memory ds4_context_memory_estimate(ds4_backend backend, int ctx_size
uint32_t ctx = ctx_size > 0 ? (uint32_t)ctx_size : 1u;

if (ds4_backend_uses_graph(backend)) {
m.prefill_cap = metal_graph_prefill_cap_for_prompt((int)ctx);
m.prefill_cap = ds4_default_prefill_cap_for_prompt((int)ctx, backend);
m.raw_cap = metal_graph_raw_cap_for_context((int)ctx, m.prefill_cap);

uint32_t min_ratio = UINT32_MAX;
Expand Down Expand Up @@ -17168,7 +17256,7 @@ int ds4_session_create(ds4_session **out, ds4_engine *e, int ctx_size) {
ds4_session *s = xcalloc(1, sizeof(*s));
s->engine = e;
s->ctx_size = ctx_size;
s->prefill_cap = ds4_default_prefill_cap_for_prompt(ctx_size);
s->prefill_cap = ds4_default_prefill_cap_for_prompt(ctx_size, e->backend);
kv_cache_init(&s->cpu_cache, (uint32_t)ctx_size, 0);
cpu_decode_scratch_init(&s->cpu_scratch, (uint32_t)ctx_size);
s->logits = xmalloc((size_t)DS4_N_VOCAB * sizeof(s->logits[0]));
Expand Down
2 changes: 2 additions & 0 deletions ds4.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ const char *ds4_think_mode_name(ds4_think_mode mode);
const char *ds4_think_max_prefix(void);
uint32_t ds4_think_max_min_context(void);
ds4_think_mode ds4_think_mode_for_context(ds4_think_mode mode, int ctx_size);
uint32_t ds4_backend_default_prefill_chunk(ds4_backend backend);
uint32_t ds4_backend_default_kv_boundary_align_tokens(ds4_backend backend);
ds4_context_memory ds4_context_memory_estimate(ds4_backend backend, int ctx_size);
bool ds4_log_is_tty(FILE *fp);
void ds4_log(FILE *fp, ds4_log_type type, const char *fmt, ...);
Expand Down
74 changes: 69 additions & 5 deletions ds4_metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,33 @@ static void ds4_gpu_print_device_summary(void) {
}
}

static int ds4_gpu_is_m5_device(void) {
static int initialized;
static int is_m5;
if (!initialized) {
const char *device_name = g_device.name ? [g_device.name UTF8String] : "";
is_m5 = strstr(device_name, "M5") != NULL;
initialized = 1;
}
return is_m5;
}

static int ds4_gpu_use_m5_private_scratch(void) {
static int initialized;
static int enabled;
if (!initialized) {
enabled = getenv("DS4_METAL_DISABLE_M5_PRIVATE_SCRATCH") == NULL && ds4_gpu_is_m5_device();
initialized = 1;
}
return enabled;
}

static int ds4_gpu_scratch_needs_cpu_access(const char *label) {
if (!label) return 0;
return strstr(label, "mask") != NULL ||
strcmp(label, "ds4_attention_output_group_ids") == 0;
}

#define DS4_METAL_MAX_MODEL_VIEWS 16
#define DS4_METAL_MODEL_MAX_TENSOR_BYTES 704643072ull

Expand Down Expand Up @@ -297,7 +324,20 @@ static int ds4_gpu_ensure_scratch_buffer(
if (bytes == 0) bytes = 1;
if (bytes > NSUIntegerMax) return 0;

*buffer = [g_device newBufferWithLength:bytes options:MTLResourceStorageModeShared];
MTLResourceOptions options = MTLResourceStorageModeShared;
if (ds4_gpu_use_m5_private_scratch() && !ds4_gpu_scratch_needs_cpu_access(label)) {
/*
* Keep Metal's default hazard tracking. These scratch buffers are
* reused by dependent kernels across many compute encoders, and the
* graph does not insert explicit fences for untracked resources.
*/
options = MTLResourceStorageModePrivate;
}

*buffer = [g_device newBufferWithLength:bytes options:options];
if (!*buffer && options != MTLResourceStorageModeShared) {
*buffer = [g_device newBufferWithLength:bytes options:MTLResourceStorageModeShared];
}
if (!*buffer) {
fprintf(stderr, "ds4: failed to allocate Metal scratch buffer %s (%llu bytes)\n",
label, (unsigned long long)bytes);
Expand Down Expand Up @@ -551,18 +591,25 @@ static int ds4_gpu_map_model_views(
return buffer;
}

static int ds4_gpu_use_m5_simdgroup_matrix(void);

static id<MTLComputePipelineState> ds4_gpu_get_mul_mm_pipeline(
const char *function_name,
bool bc_inp,
bool bc_out) {
NSString *key = [NSString stringWithFormat:@"%s_bci=%d_bco=%d",
function_name, bc_inp ? 1 : 0, bc_out ? 1 : 0];
bool m5_sgmatrix = ds4_gpu_use_m5_simdgroup_matrix() != 0;
NSString *key = [NSString stringWithFormat:@"%s_bci=%d_bco=%d_m5sg=%d",
function_name,
bc_inp ? 1 : 0,
bc_out ? 1 : 0,
m5_sgmatrix ? 1 : 0];
id<MTLComputePipelineState> cached = [g_pipeline_cache objectForKey:key];
if (cached) return cached;

MTLFunctionConstantValues *constants = [[MTLFunctionConstantValues alloc] init];
[constants setConstantValue:&bc_inp type:MTLDataTypeBool atIndex:700];
[constants setConstantValue:&bc_out type:MTLDataTypeBool atIndex:701];
[constants setConstantValue:&m5_sgmatrix type:MTLDataTypeBool atIndex:702];

NSError *error = nil;
NSString *name = [NSString stringWithUTF8String:function_name];
Expand Down Expand Up @@ -590,13 +637,17 @@ static int ds4_gpu_map_model_views(
static id<MTLComputePipelineState> ds4_gpu_get_mul_mm_id_pipeline(
const char *function_name,
bool bc_inp) {
NSString *key = [NSString stringWithFormat:@"%s_bci=%d",
function_name, bc_inp ? 1 : 0];
bool m5_sgmatrix = ds4_gpu_use_m5_simdgroup_matrix() != 0;
NSString *key = [NSString stringWithFormat:@"%s_bci=%d_m5sg=%d",
function_name,
bc_inp ? 1 : 0,
m5_sgmatrix ? 1 : 0];
id<MTLComputePipelineState> cached = [g_pipeline_cache objectForKey:key];
if (cached) return cached;

MTLFunctionConstantValues *constants = [[MTLFunctionConstantValues alloc] init];
[constants setConstantValue:&bc_inp type:MTLDataTypeBool atIndex:700];
[constants setConstantValue:&m5_sgmatrix type:MTLDataTypeBool atIndex:702];

NSError *error = nil;
NSString *name = [NSString stringWithUTF8String:function_name];
Expand Down Expand Up @@ -673,6 +724,18 @@ static int ds4_gpu_use_compressor_pair_nr4(void) {
return enabled;
}

static int ds4_gpu_use_m5_simdgroup_matrix(void) {
static int initialized;
static int enabled;
if (!initialized) {
const char *disable = getenv("DS4_METAL_DISABLE_M5_SIMDGROUP_MATRIX");
const char *force = getenv("DS4_METAL_FORCE_M5_SIMDGROUP_MATRIX");
enabled = disable ? 0 : (force ? 1 : ds4_gpu_is_m5_device());
initialized = 1;
}
return enabled;
}

static int ds4_gpu_warm_model_views(void) {
if (g_model_view_count == 0) return 1;

Expand Down Expand Up @@ -1165,6 +1228,7 @@ void ds4_gpu_set_quality(bool quality) {
"#define N_SG_Q8_0 4\n"
"#define FC_MUL_MV 600\n"
"#define FC_MUL_MM 700\n"
"#define FC_MUL_MM_M5_SGMATRIX 702\n"
"#define FC_BIN 1300\n"
"#define FOR_UNROLL(x) _Pragma(\"clang loop unroll(full)\") for (x)\n"
"#define M_PI_F 3.14159265358979323846f\n"
Expand Down
16 changes: 10 additions & 6 deletions ds4_server.c
Original file line number Diff line number Diff line change
Expand Up @@ -8079,11 +8079,10 @@ static void apply_anthropic_stream_tool_ids(tool_calls *calls,
/* Tokenizers may merge text across the prompt boundary. Trimming a small tail
* still improves the cheap token-prefix path, while text-prefix lookup handles
* the cases where canonical prompt tokenization spells the same bytes
* differently. The 2048 alignment also matches the Metal prefill chunk
* schedule, which keeps compressor row finalization identical to a cold full
* prompt. */
* differently. The alignment should match the backend prefill chunk schedule,
* which keeps compressor row finalization identical to a cold full prompt. */
#define KV_CACHE_DEFAULT_BOUNDARY_TRIM_TOKENS 32
#define KV_CACHE_DEFAULT_BOUNDARY_ALIGN_TOKENS 2048
#define KV_CACHE_FALLBACK_BOUNDARY_ALIGN_TOKENS 2048
#define KV_CACHE_DEFAULT_CONTINUED_INTERVAL_TOKENS 10000
#define KV_CACHE_DEFAULT_MB 4096
#define KV_EXT_TOOL_MAP (1u << 0)
Expand Down Expand Up @@ -8118,7 +8117,7 @@ static kv_cache_options kv_cache_default_options(void) {
.cold_max_tokens = KV_CACHE_DEFAULT_COLD_MAX_TOKENS,
.continued_interval_tokens = KV_CACHE_DEFAULT_CONTINUED_INTERVAL_TOKENS,
.boundary_trim_tokens = KV_CACHE_DEFAULT_BOUNDARY_TRIM_TOKENS,
.boundary_align_tokens = KV_CACHE_DEFAULT_BOUNDARY_ALIGN_TOKENS,
.boundary_align_tokens = KV_CACHE_FALLBACK_BOUNDARY_ALIGN_TOKENS,
};
}

Expand Down Expand Up @@ -11431,7 +11430,7 @@ static void usage(FILE *fp) {
" --kv-cache-boundary-trim-tokens N\n"
" Trim this many tail tokens before cold boundary saves to avoid tokenizer boundary merges. Default: 32\n"
" --kv-cache-boundary-align-tokens N\n"
" Align cold boundary saves down to this token multiple. 0 disables alignment. Default: 2048\n"
" Align cold boundary saves down to this token multiple. 0 disables alignment. Default: backend prefill chunk (4096 on M5 Max Metal, otherwise 2048)\n"
" --kv-cache-reject-different-quant\n"
" Refuse checkpoints written by the same model with a different routed-expert quantization.\n"
" --disable-exact-dsml-tool-replay\n"
Expand Down Expand Up @@ -11493,6 +11492,7 @@ static server_config parse_options(int argc, char **argv) {
c.kv_cache = kv_cache_default_options();

bool directional_steering_scale_set = false;
bool kv_boundary_align_set = false;
for (int i = 1; i < argc; i++) {
const char *arg = argv[i];
if (!strcmp(arg, "-h") || !strcmp(arg, "--help")) {
Expand Down Expand Up @@ -11532,6 +11532,7 @@ static server_config parse_options(int argc, char **argv) {
c.kv_cache.boundary_trim_tokens = parse_nonneg_int_arg(need_arg(&i, argc, argv, arg), arg);
} else if (!strcmp(arg, "--kv-cache-boundary-align-tokens")) {
c.kv_cache.boundary_align_tokens = parse_nonneg_int_arg(need_arg(&i, argc, argv, arg), arg);
kv_boundary_align_set = true;
} else if (!strcmp(arg, "--kv-cache-reject-different-quant")) {
c.kv_cache_reject_different_quant = true;
} else if (!strcmp(arg, "--disable-exact-dsml-tool-replay")) {
Expand Down Expand Up @@ -11564,6 +11565,9 @@ static server_config parse_options(int argc, char **argv) {
exit(2);
}
}
if (!kv_boundary_align_set) {
c.kv_cache.boundary_align_tokens = ds4_backend_default_kv_boundary_align_tokens(c.engine.backend);
}
if (c.kv_cache.cold_max_tokens > 0 &&
c.kv_cache.cold_max_tokens < c.kv_cache.min_tokens)
{
Expand Down
Loading