diff --git a/README.md b/README.md index ea02bf39..5f7a9232 100644 --- a/README.md +++ b/README.md @@ -694,6 +694,32 @@ make cuda CUDA_ARCH=sm_120 make cuda CUDA_ARCH=native ``` +### CUDA direct-model partial weight cache + +For GPUs that cannot hold the full GGUF weight image in VRAM, use direct-model +mode with the partial weight cache: + +```sh +DS4_CUDA_DIRECT_MODEL=1 \ +DS4_CUDA_PARTIAL_WEIGHT_CACHE=1 \ +DS4_CUDA_WEIGHT_CACHE_LIMIT_GB=10 \ +./ds4 --cuda -p "Hello" +``` + +The partial cache selects high-benefit DS4 weights at startup and keeps +uncached weights on the direct-model path. This is intended for cards such as +24 GB RTX 4090-class GPUs where full model copy/cache would exceed VRAM. + +Useful controls: + +* `DS4_CUDA_WEIGHT_CACHE_LIMIT_GB` or `DS4_CUDA_WEIGHT_CACHE_LIMIT_MB` +* `DS4_CUDA_WEIGHT_CACHE_RESERVE_MB` +* `DS4_CUDA_WEIGHT_CACHE_VERBOSE=1` +* `DS4_CUDA_STRICT_WEIGHT_CACHE=1` + +Without `DS4_CUDA_PARTIAL_WEIGHT_CACHE=1`, existing CUDA full-cache/direct-model +behavior is unchanged. + There is also a CPU reference/debug path: ```sh diff --git a/ds4.c b/ds4.c index 8825c257..762ff3ab 100644 --- a/ds4.c +++ b/ds4.c @@ -927,6 +927,66 @@ typedef struct { ds4_tensor *tensors; } ds4_model; +typedef struct { + ds4_tensor *hc_attn_fn; + ds4_tensor *hc_attn_scale; + ds4_tensor *hc_attn_base; + ds4_tensor *attn_norm; + ds4_tensor *attn_q_a; + ds4_tensor *attn_q_a_norm; + ds4_tensor *attn_q_b; + ds4_tensor *attn_kv; + ds4_tensor *attn_kv_a_norm; + ds4_tensor *attn_sinks; + ds4_tensor *attn_output_a; + ds4_tensor *attn_output_b; + ds4_tensor *attn_compressor_ape; + ds4_tensor *attn_compressor_kv; + ds4_tensor *attn_compressor_gate; + ds4_tensor *attn_compressor_norm; + ds4_tensor *indexer_attn_q_b; + ds4_tensor *indexer_proj; + ds4_tensor *indexer_compressor_ape; + ds4_tensor *indexer_compressor_kv; + ds4_tensor *indexer_compressor_gate; + ds4_tensor *indexer_compressor_norm; + ds4_tensor *hc_ffn_fn; + ds4_tensor *hc_ffn_scale; + ds4_tensor *hc_ffn_base; + ds4_tensor *ffn_norm; + ds4_tensor *ffn_gate_tid2eid; + ds4_tensor *ffn_gate_inp; + ds4_tensor *ffn_exp_probs_b; + ds4_tensor *ffn_gate_exps; + ds4_tensor *ffn_up_exps; + ds4_tensor *ffn_down_exps; + ds4_tensor *ffn_gate_shexp; + ds4_tensor *ffn_up_shexp; + ds4_tensor *ffn_down_shexp; +} ds4_layer_weights; + +typedef struct { + ds4_tensor *token_embd; + ds4_tensor *output_hc_base; + ds4_tensor *output_hc_fn; + ds4_tensor *output_hc_scale; + ds4_tensor *output_norm; + ds4_tensor *output; + ds4_layer_weights layer[DS4_N_LAYER]; +} ds4_weights; + +typedef struct { + ds4_tensor *e_proj; + ds4_tensor *h_proj; + ds4_tensor *enorm; + ds4_tensor *hnorm; + ds4_tensor *norm; + ds4_tensor *hc_head_base; + ds4_tensor *hc_head_fn; + ds4_tensor *hc_head_scale; + ds4_layer_weights block; +} ds4_mtp_weights; + static uint64_t scalar_value_size(uint32_t type) { switch (type) { case GGUF_VALUE_UINT8: @@ -1372,6 +1432,289 @@ static uint64_t accelerator_cuda_preload_span_bytes(void) { return mb * 1048576ull; } +typedef struct { + uint64_t off; + uint64_t end; + uint64_t bytes; + uint32_t priority; + uint32_t layer; + uint32_t group; +} accelerator_weight_cache_candidate; + +enum { + ACCELERATOR_WEIGHT_CACHE_GLOBAL_STATE = 0, + ACCELERATOR_WEIGHT_CACHE_LAYER_STATE = 5, + ACCELERATOR_WEIGHT_CACHE_ATTENTION = 10, + ACCELERATOR_WEIGHT_CACHE_COMPRESSOR = 15, + ACCELERATOR_WEIGHT_CACHE_FFN_SHARED = 20, + ACCELERATOR_WEIGHT_CACHE_OUTPUT = 25, + ACCELERATOR_WEIGHT_CACHE_TOKEN_EMBD = 30, + ACCELERATOR_WEIGHT_CACHE_ROUTED_EXPERTS = 40, +}; + +/* Lower priorities are cached first: cover small global/layer state and dense + * per-token paths before spending the partial VRAM budget on large embeddings + * and routed expert matrices. */ + +static bool accelerator_cuda_env_enabled(const char *name) { + const char *env = getenv(name); + return env && env[0] && !(env[0] == '0' && env[1] == '\0'); +} + +static bool accelerator_cuda_partial_weight_cache_enabled(void) { + return accelerator_cuda_env_enabled("DS4_CUDA_PARTIAL_WEIGHT_CACHE"); +} + +static int accelerator_weight_cache_candidate_cmp(const void *a, const void *b) { + const accelerator_weight_cache_candidate *ca = a; + const accelerator_weight_cache_candidate *cb = b; + if (ca->priority < cb->priority) return -1; + if (ca->priority > cb->priority) return 1; + if (ca->group < cb->group) return -1; + if (ca->group > cb->group) return 1; + if (ca->layer < cb->layer) return -1; + if (ca->layer > cb->layer) return 1; + if (ca->off < cb->off) return -1; + if (ca->off > cb->off) return 1; + if (ca->bytes > cb->bytes) return -1; + if (ca->bytes < cb->bytes) return 1; + return 0; +} + +static int accelerator_weight_cache_candidate_off_cmp(const void *a, const void *b) { + const accelerator_weight_cache_candidate *ca = a; + const accelerator_weight_cache_candidate *cb = b; + if (ca->off < cb->off) return -1; + if (ca->off > cb->off) return 1; + if (ca->end < cb->end) return -1; + if (ca->end > cb->end) return 1; + if (ca->bytes > cb->bytes) return -1; + if (ca->bytes < cb->bytes) return 1; + return 0; +} + +static bool accelerator_partial_cache_add( + accelerator_weight_cache_candidate *cands, + uint32_t *count, + uint32_t cap, + const ds4_model *m, + const ds4_tensor *t, + uint32_t priority, + uint32_t layer, + uint32_t group) { + if (!t || t->bytes == 0) return true; + if (t->abs_offset > m->size || t->bytes > m->size - t->abs_offset) { + fprintf(stderr, "ds4: invalid CUDA cache candidate range for %.*s\n", + (int)t->name.len, t->name.ptr); + return false; + } + if (*count >= cap) { + fprintf(stderr, "ds4: too many CUDA partial weight cache candidates\n"); + return false; + } + accelerator_weight_cache_candidate *c = &cands[*count]; + c->off = t->abs_offset; + c->bytes = t->bytes; + c->end = t->abs_offset + t->bytes; + c->priority = priority; + c->layer = layer; + c->group = group; + (*count)++; + return true; +} + +static bool accelerator_partial_cache_collect( + accelerator_weight_cache_candidate *cands, + uint32_t *count, + uint32_t cap, + const ds4_model *m, + const ds4_weights *w) { +#define ADD_GLOBAL(t_, p_) \ + do { \ + if (!accelerator_partial_cache_add(cands, count, cap, m, (t_), (p_), UINT32_MAX, 0)) return false; \ + } while (0) +#define ADD_LAYER(t_, p_) \ + do { \ + if (!accelerator_partial_cache_add(cands, count, cap, m, (t_), (p_), il, 0)) return false; \ + } while (0) +#define ADD_LAYER_GROUP(t_, p_, group_) \ + do { \ + if (!accelerator_partial_cache_add(cands, count, cap, m, (t_), (p_), il, (group_))) return false; \ + } while (0) + + ADD_GLOBAL(w->output_hc_base, ACCELERATOR_WEIGHT_CACHE_GLOBAL_STATE); + ADD_GLOBAL(w->output_hc_scale, ACCELERATOR_WEIGHT_CACHE_GLOBAL_STATE); + ADD_GLOBAL(w->output_norm, ACCELERATOR_WEIGHT_CACHE_GLOBAL_STATE); + ADD_GLOBAL(w->output_hc_fn, ACCELERATOR_WEIGHT_CACHE_GLOBAL_STATE); + ADD_GLOBAL(w->output, ACCELERATOR_WEIGHT_CACHE_OUTPUT); + ADD_GLOBAL(w->token_embd, ACCELERATOR_WEIGHT_CACHE_TOKEN_EMBD); + + for (uint32_t il = 0; il < DS4_N_LAYER; il++) { + const ds4_layer_weights *l = &w->layer[il]; + ADD_LAYER(l->hc_attn_scale, ACCELERATOR_WEIGHT_CACHE_LAYER_STATE); + ADD_LAYER(l->hc_attn_base, ACCELERATOR_WEIGHT_CACHE_LAYER_STATE); + ADD_LAYER(l->attn_norm, ACCELERATOR_WEIGHT_CACHE_LAYER_STATE); + ADD_LAYER(l->attn_q_a_norm, ACCELERATOR_WEIGHT_CACHE_LAYER_STATE); + ADD_LAYER(l->attn_kv_a_norm, ACCELERATOR_WEIGHT_CACHE_LAYER_STATE); + ADD_LAYER(l->attn_sinks, ACCELERATOR_WEIGHT_CACHE_LAYER_STATE); + ADD_LAYER(l->attn_compressor_norm, ACCELERATOR_WEIGHT_CACHE_LAYER_STATE); + ADD_LAYER(l->indexer_compressor_norm, ACCELERATOR_WEIGHT_CACHE_LAYER_STATE); + ADD_LAYER(l->hc_ffn_scale, ACCELERATOR_WEIGHT_CACHE_LAYER_STATE); + ADD_LAYER(l->hc_ffn_base, ACCELERATOR_WEIGHT_CACHE_LAYER_STATE); + ADD_LAYER(l->ffn_norm, ACCELERATOR_WEIGHT_CACHE_LAYER_STATE); + ADD_LAYER(l->ffn_exp_probs_b, ACCELERATOR_WEIGHT_CACHE_LAYER_STATE); + ADD_LAYER(l->ffn_gate_tid2eid, ACCELERATOR_WEIGHT_CACHE_LAYER_STATE); + + ADD_LAYER(l->hc_attn_fn, ACCELERATOR_WEIGHT_CACHE_ATTENTION); + ADD_LAYER(l->attn_q_a, ACCELERATOR_WEIGHT_CACHE_ATTENTION); + ADD_LAYER(l->attn_q_b, ACCELERATOR_WEIGHT_CACHE_ATTENTION); + ADD_LAYER(l->attn_kv, ACCELERATOR_WEIGHT_CACHE_ATTENTION); + ADD_LAYER(l->attn_output_a, ACCELERATOR_WEIGHT_CACHE_ATTENTION); + ADD_LAYER(l->attn_output_b, ACCELERATOR_WEIGHT_CACHE_ATTENTION); + + ADD_LAYER(l->attn_compressor_ape, ACCELERATOR_WEIGHT_CACHE_COMPRESSOR); + ADD_LAYER(l->attn_compressor_kv, ACCELERATOR_WEIGHT_CACHE_COMPRESSOR); + ADD_LAYER(l->attn_compressor_gate, ACCELERATOR_WEIGHT_CACHE_COMPRESSOR); + ADD_LAYER(l->indexer_attn_q_b, ACCELERATOR_WEIGHT_CACHE_COMPRESSOR); + ADD_LAYER(l->indexer_proj, ACCELERATOR_WEIGHT_CACHE_COMPRESSOR); + ADD_LAYER(l->indexer_compressor_ape, ACCELERATOR_WEIGHT_CACHE_COMPRESSOR); + ADD_LAYER(l->indexer_compressor_kv, ACCELERATOR_WEIGHT_CACHE_COMPRESSOR); + ADD_LAYER(l->indexer_compressor_gate, ACCELERATOR_WEIGHT_CACHE_COMPRESSOR); + + ADD_LAYER(l->hc_ffn_fn, ACCELERATOR_WEIGHT_CACHE_FFN_SHARED); + ADD_LAYER(l->ffn_gate_inp, ACCELERATOR_WEIGHT_CACHE_FFN_SHARED); + ADD_LAYER(l->ffn_gate_shexp, ACCELERATOR_WEIGHT_CACHE_FFN_SHARED); + ADD_LAYER(l->ffn_up_shexp, ACCELERATOR_WEIGHT_CACHE_FFN_SHARED); + ADD_LAYER(l->ffn_down_shexp, ACCELERATOR_WEIGHT_CACHE_FFN_SHARED); + + const uint32_t routed_group = 1000u + il; + ADD_LAYER_GROUP(l->ffn_gate_exps, ACCELERATOR_WEIGHT_CACHE_ROUTED_EXPERTS, routed_group); + ADD_LAYER_GROUP(l->ffn_up_exps, ACCELERATOR_WEIGHT_CACHE_ROUTED_EXPERTS, routed_group); + ADD_LAYER_GROUP(l->ffn_down_exps, ACCELERATOR_WEIGHT_CACHE_ROUTED_EXPERTS, routed_group); + } + +#undef ADD_GLOBAL +#undef ADD_LAYER +#undef ADD_LAYER_GROUP + return true; +} + +static uint32_t accelerator_partial_cache_dedup( + accelerator_weight_cache_candidate *cands, + uint32_t count) { + qsort(cands, count, sizeof(cands[0]), accelerator_weight_cache_candidate_cmp); + uint32_t out = 0; + for (uint32_t i = 0; i < count; i++) { + bool seen = false; + for (uint32_t j = 0; j < out; j++) { + if (cands[j].off == cands[i].off && cands[j].end == cands[i].end) { + seen = true; + break; + } + } + if (!seen) cands[out++] = cands[i]; + } + return out; +} + +static bool accelerator_cache_model_partial( + const ds4_model *m, + const ds4_weights *w, + uint64_t *cached_out, + uint32_t *ranges_out) { + enum { CAND_CAP = 2048 }; + accelerator_weight_cache_candidate *cands = xmalloc((size_t)CAND_CAP * sizeof(cands[0])); + uint32_t count = 0; + uint64_t cached = 0; + uint32_t ranges = 0; + + if (!accelerator_partial_cache_collect(cands, &count, CAND_CAP, m, w)) { + free(cands); + return false; + } + count = accelerator_partial_cache_dedup(cands, count); + if (count == 0) { + free(cands); + if (cached_out) *cached_out = 0; + if (ranges_out) *ranges_out = 0; + return true; + } + + const double t0 = now_sec(); + const uint64_t max_span = accelerator_cuda_preload_span_bytes(); + bool ok = true; + bool stopped = false; + bool fallback_printed = false; + uint32_t span_id = 0; + + for (uint32_t i = 0; i < count && ok && !stopped;) { + uint32_t j = i + 1; + while (j < count && cands[j].priority == cands[i].priority) j++; + qsort(cands + i, (size_t)(j - i), sizeof(cands[0]), accelerator_weight_cache_candidate_off_cmp); + + for (uint32_t k = i; k < j && ok && !stopped;) { + uint64_t off = cands[k].off; + uint64_t end = cands[k].end; + k++; + while (k < j && + cands[k].off <= end + 65536u && + cands[k].end >= off && + cands[k].end - off <= max_span) { + if (cands[k].end > end) end = cands[k].end; + k++; + } + while (off < end && ok && !stopped) { + uint64_t chunk_end = end; + if (chunk_end - off > max_span) chunk_end = off + max_span; + char label[96]; + snprintf(label, sizeof(label), "partial:p%u:span%u", cands[i].priority, span_id); + const uint64_t bytes = chunk_end - off; + if (ds4_gpu_cache_model_range(m->map, m->size, off, bytes, label) == 0) { + if (accelerator_cuda_env_enabled("DS4_CUDA_STRICT_WEIGHT_CACHE")) { + fprintf(stderr, + "ds4: CUDA partial weight cache failed for %s at offset %" PRIu64 + " bytes %" PRIu64 "\n", + label, off, bytes); + ok = false; + } else { + if (!fallback_printed) { + fprintf(stderr, + "ds4: CUDA partial weight cache stopped at %s " + "(offset %" PRIu64 ", %.2f MiB); remaining weights use direct fallback\n", + label, off, (double)bytes / 1048576.0); + fallback_printed = true; + } + stopped = true; + } + } else { + cached += bytes; + ranges++; + } + span_id++; + off = chunk_end; + } + } + i = j; + } + + if (ok && cached != 0) { + const double t1 = now_sec(); + if (ds4_log_is_tty(stderr)) fputc('\n', stderr); + fprintf(stderr, + "ds4: CUDA partial weight cache prepared %.2f GiB in %u ranges " + "from %u candidates in %.3fs\n", + (double)cached / 1073741824.0, + ranges, + count, + t1 - t0); + } + + free(cands); + if (cached_out) *cached_out = cached; + if (ranges_out) *ranges_out = ranges; + return ok; +} + static bool accelerator_cache_model_tensor_spans(const ds4_model *m, uint64_t *cached_out) { accelerator_tensor_span *spans = xmalloc((size_t)m->n_tensors * sizeof(spans[0])); uint64_t nspan = 0; @@ -1423,10 +1766,15 @@ static bool accelerator_cache_model_tensor_spans(const ds4_model *m, uint64_t *c return true; } -static bool accelerator_cache_model_tensors(ds4_backend backend, const ds4_model *m) { +static bool accelerator_cache_model_tensors(ds4_backend backend, const ds4_model *m, const ds4_weights *w) { if (backend != DS4_BACKEND_CUDA) return true; - if (!m || !m->map || m->size == 0) return false; - if (getenv("DS4_CUDA_DIRECT_MODEL") != NULL) { + if (!m || !m->map || m->size == 0 || !w) return false; + if (accelerator_cuda_partial_weight_cache_enabled()) { + uint64_t cached = 0; + uint32_t ranges = 0; + return accelerator_cache_model_partial(m, w, &cached, &ranges); + } + if (accelerator_cuda_env_enabled("DS4_CUDA_DIRECT_MODEL")) { return true; } @@ -1460,9 +1808,10 @@ static bool accelerator_cache_model_tensors(ds4_backend backend, const ds4_model return true; } #else -static bool accelerator_cache_model_tensors(ds4_backend backend, const ds4_model *m) { +static bool accelerator_cache_model_tensors(ds4_backend backend, const ds4_model *m, const ds4_weights *w) { (void)backend; (void)m; + (void)w; return true; } #endif @@ -1984,66 +2333,6 @@ static void ds4_vec_dot_iq2_xxs_pair_q8_K( #endif } -typedef struct { - ds4_tensor *hc_attn_fn; - ds4_tensor *hc_attn_scale; - ds4_tensor *hc_attn_base; - ds4_tensor *attn_norm; - ds4_tensor *attn_q_a; - ds4_tensor *attn_q_a_norm; - ds4_tensor *attn_q_b; - ds4_tensor *attn_kv; - ds4_tensor *attn_kv_a_norm; - ds4_tensor *attn_sinks; - ds4_tensor *attn_output_a; - ds4_tensor *attn_output_b; - ds4_tensor *attn_compressor_ape; - ds4_tensor *attn_compressor_kv; - ds4_tensor *attn_compressor_gate; - ds4_tensor *attn_compressor_norm; - ds4_tensor *indexer_attn_q_b; - ds4_tensor *indexer_proj; - ds4_tensor *indexer_compressor_ape; - ds4_tensor *indexer_compressor_kv; - ds4_tensor *indexer_compressor_gate; - ds4_tensor *indexer_compressor_norm; - ds4_tensor *hc_ffn_fn; - ds4_tensor *hc_ffn_scale; - ds4_tensor *hc_ffn_base; - ds4_tensor *ffn_norm; - ds4_tensor *ffn_gate_tid2eid; - ds4_tensor *ffn_gate_inp; - ds4_tensor *ffn_exp_probs_b; - ds4_tensor *ffn_gate_exps; - ds4_tensor *ffn_up_exps; - ds4_tensor *ffn_down_exps; - ds4_tensor *ffn_gate_shexp; - ds4_tensor *ffn_up_shexp; - ds4_tensor *ffn_down_shexp; -} ds4_layer_weights; - -typedef struct { - ds4_tensor *token_embd; - ds4_tensor *output_hc_base; - ds4_tensor *output_hc_fn; - ds4_tensor *output_hc_scale; - ds4_tensor *output_norm; - ds4_tensor *output; - ds4_layer_weights layer[DS4_N_LAYER]; -} ds4_weights; - -typedef struct { - ds4_tensor *e_proj; - ds4_tensor *h_proj; - ds4_tensor *enorm; - ds4_tensor *hnorm; - ds4_tensor *norm; - ds4_tensor *hc_head_base; - ds4_tensor *hc_head_fn; - ds4_tensor *hc_head_scale; - ds4_layer_weights block; -} ds4_mtp_weights; - /* ========================================================================= * Fixed Weight Binding and Model Validation. * ========================================================================= @@ -17118,7 +17407,19 @@ int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt) { *out = NULL; return 1; } - if (!e->mtp_ready && !accelerator_cache_model_tensors(e->backend, &e->model)) { + const char *partial_cache_env = getenv("DS4_CUDA_PARTIAL_WEIGHT_CACHE"); + const bool partial_cuda_cache = + e->backend == DS4_BACKEND_CUDA && + partial_cache_env && + partial_cache_env[0] && + !(partial_cache_env[0] == '0' && partial_cache_env[1] == '\0'); + if (e->mtp_ready && partial_cuda_cache) { + fprintf(stderr, + "ds4: CUDA partial weight cache applies to the base model; " + "MTP weights remain on the direct path\n"); + } + if ((!e->mtp_ready || partial_cuda_cache) && + !accelerator_cache_model_tensors(e->backend, &e->model, &e->weights)) { fprintf(stderr, "ds4: %s failed to prepare startup model cache\n", ds4_backend_name(e->backend)); ds4_engine_close(e); diff --git a/ds4_cuda.cu b/ds4_cuda.cu index 1eee21de..9357c199 100644 --- a/ds4_cuda.cu +++ b/ds4_cuda.cu @@ -131,6 +131,8 @@ static std::unordered_map g_q8_f16_by_offset; static std::vector g_q8_f32_ranges; static std::unordered_map g_q8_f32_by_offset; static uint64_t g_model_range_bytes; +static uint64_t g_model_auto_cache_limit_bytes; +static int g_model_auto_cache_limit_ready; static uint64_t g_q8_f16_bytes; static uint64_t g_q8_f32_bytes; static int g_q8_f16_disabled_after_oom; @@ -151,7 +153,8 @@ static const char *cuda_model_range_ptr_from_fd( const void *model_map, uint64_t offset, uint64_t bytes, - const char *what); + const char *what, + int allow_direct_fallback); __global__ static void dequant_q8_0_to_f16_kernel( __half *out, const unsigned char *w, @@ -195,27 +198,67 @@ static const char *cuda_model_ptr(const void *model_map, uint64_t offset) { return (const char *)model_map + offset; } -static const char *cuda_model_range_ptr(const void *model_map, uint64_t offset, uint64_t bytes, const char *what) { - if (bytes == 0) return cuda_model_ptr(model_map, offset); - if (g_model_device_owned || g_model_registered) return cuda_model_ptr(model_map, offset); - if (g_model_hmm_direct && - getenv("DS4_CUDA_WEIGHT_CACHE") == NULL && - getenv("DS4_CUDA_WEIGHT_PRELOAD") == NULL) { +static int cuda_env_enabled(const char *name) { + const char *env = getenv(name); + return env && env[0] && !(env[0] == '0' && env[1] == '\0'); +} + +static int cuda_partial_weight_cache_enabled(void) { + return cuda_env_enabled("DS4_CUDA_PARTIAL_WEIGHT_CACHE"); +} + +static int cuda_direct_model_enabled(void) { + return cuda_env_enabled("DS4_CUDA_DIRECT_MODEL"); +} + +static const char *cuda_model_range_lookup_device_cached( + const void *model_map, + uint64_t offset, + uint64_t bytes) { + if (!model_map || bytes == 0) return NULL; + if (model_map == g_model_host_base && g_model_device_owned) { return cuda_model_ptr(model_map, offset); } - const char *direct_env = getenv("DS4_CUDA_DIRECT_MODEL"); - if (direct_env && direct_env[0]) return cuda_model_ptr(model_map, offset); const uint64_t end = offset + bytes; + if (end < offset) return NULL; + auto exact = g_model_range_by_offset.find(offset); if (exact != g_model_range_by_offset.end()) { const cuda_model_range &r = g_model_ranges[exact->second]; - if (r.host_base == model_map && end >= offset && bytes <= r.bytes) return r.device_ptr; + const uint64_t rend = r.offset + r.bytes; + if (r.host_base == model_map && !r.host_registered && + rend >= r.offset && end <= rend) return r.device_ptr; } + for (const cuda_model_range &r : g_model_ranges) { - if (r.host_base == model_map && offset >= r.offset && end >= offset && end <= r.offset + r.bytes) { + const uint64_t rend = r.offset + r.bytes; + if (r.host_base == model_map && !r.host_registered && rend >= r.offset && + offset >= r.offset && end <= rend) { return r.device_ptr + (offset - r.offset); } + } + return NULL; +} + +static const char *cuda_model_range_lookup_cached( + const void *model_map, + uint64_t offset, + uint64_t bytes) { + const char *cached = cuda_model_range_lookup_device_cached(model_map, offset, bytes); + if (cached) return cached; + if (!model_map || bytes == 0) return NULL; + + if (model_map == g_model_host_base && g_model_registered) { + return cuda_model_ptr(model_map, offset); + } + + const uint64_t end = offset + bytes; + if (end < offset) return NULL; + + auto exact = g_model_range_by_offset.find(offset); + if (exact != g_model_range_by_offset.end()) { + const cuda_model_range &r = g_model_ranges[exact->second]; if (r.host_base == model_map && r.host_registered && r.registered_base && r.registered_device_base) { const uintptr_t h0 = (uintptr_t)((const char *)model_map + offset); const uintptr_t h1 = h0 + bytes; @@ -225,8 +268,52 @@ static const char *cuda_model_range_ptr(const void *model_map, uint64_t offset, } } + for (const cuda_model_range &r : g_model_ranges) { + if (r.host_base == model_map && r.host_registered && r.registered_base && r.registered_device_base) { + const uintptr_t h0 = (uintptr_t)((const char *)model_map + offset); + const uintptr_t h1 = h0 + bytes; + const uintptr_t r0 = (uintptr_t)r.registered_base; + const uintptr_t r1 = r0 + r.registered_bytes; + if (h1 >= h0 && h0 >= r0 && h1 <= r1) return r.registered_device_base + (h0 - r0); + } + } + return NULL; +} + +static int cuda_model_range_is_device_cached(const void *model_map, uint64_t offset, uint64_t bytes) { + if (bytes == 0) return 1; + return cuda_model_range_lookup_device_cached(model_map, offset, bytes) != NULL; +} + +static int cuda_partial_weight_cache_miss( + const void *model_map, + uint64_t offset, + uint64_t bytes) { + return cuda_partial_weight_cache_enabled() && + !cuda_model_range_is_device_cached(model_map, offset, bytes); +} + +static const char *cuda_model_range_ptr(const void *model_map, uint64_t offset, uint64_t bytes, const char *what) { + if (bytes == 0) return cuda_model_ptr(model_map, offset); + + const char *cached = cuda_model_range_lookup_cached(model_map, offset, bytes); + if (cached) return cached; + + if (cuda_partial_weight_cache_enabled() && + (cuda_direct_model_enabled() || g_model_hmm_direct)) { + return cuda_model_ptr(model_map, offset); + } + + if (g_model_hmm_direct && + getenv("DS4_CUDA_WEIGHT_CACHE") == NULL && + getenv("DS4_CUDA_WEIGHT_PRELOAD") == NULL) { + return cuda_model_ptr(model_map, offset); + } + + if (cuda_direct_model_enabled()) return cuda_model_ptr(model_map, offset); + if (getenv("DS4_CUDA_NO_FD_CACHE") == NULL) { - const char *fd_ptr = cuda_model_range_ptr_from_fd(model_map, offset, bytes, what); + const char *fd_ptr = cuda_model_range_ptr_from_fd(model_map, offset, bytes, what, 1); if (fd_ptr) return fd_ptr; } @@ -302,32 +389,6 @@ static const char *cuda_model_range_ptr(const void *model_map, uint64_t offset, return (const char *)dev; } -static int cuda_model_range_is_cached(const void *model_map, uint64_t offset, uint64_t bytes) { - if (bytes == 0) return 1; - if (g_model_device_owned || g_model_registered) return 1; - - const uint64_t end = offset + bytes; - if (end < offset) return 0; - for (const cuda_model_range &r : g_model_ranges) { - if (r.host_base == model_map && - offset >= r.offset && - end <= r.offset + r.bytes) { - return 1; - } - if (r.host_base == model_map && - r.host_registered && - r.registered_base && - r.registered_device_base) { - const uintptr_t h0 = (uintptr_t)((const char *)model_map + offset); - const uintptr_t h1 = h0 + bytes; - const uintptr_t r0 = (uintptr_t)r.registered_base; - const uintptr_t r1 = r0 + r.registered_bytes; - if (h1 >= h0 && h0 >= r0 && h1 <= r1) return 1; - } - } - return 0; -} - static void cuda_q8_f16_cache_release_all(void) { for (const cuda_q8_f16_range &r : g_q8_f16_ranges) { (void)cudaFree(r.device_ptr); @@ -539,6 +600,9 @@ static const __half *cuda_q8_f16_ptr( } } if (!cuda_q8_f16_cache_allowed(label, in_dim, out_dim)) return NULL; + if (cuda_partial_weight_cache_miss(model_map, offset, weight_bytes)) { + return NULL; + } const char *q8 = cuda_model_range_ptr(model_map, offset, weight_bytes, "q8_0"); if (!q8) return NULL; @@ -594,6 +658,9 @@ static float *cuda_q8_f32_ptr( } } if (!cuda_q8_f32_cache_allowed(label, in_dim, out_dim)) return NULL; + if (cuda_partial_weight_cache_miss(model_map, offset, weight_bytes)) { + return NULL; + } const char *q8 = cuda_model_range_ptr(model_map, offset, weight_bytes, label ? label : "q8_0"); if (!q8) return NULL; @@ -925,15 +992,71 @@ static int cuda_model_stage_read(void *stage, uint64_t stage_bytes, } static uint64_t cuda_model_cache_limit_bytes(void) { - uint64_t gb = 0; + int mb_present = 0; + const uint64_t mb_limit = cuda_parse_mib_env("DS4_CUDA_WEIGHT_CACHE_LIMIT_MB", &mb_present); + if (mb_present) return mb_limit; + const char *env = getenv("DS4_CUDA_WEIGHT_CACHE_LIMIT_GB"); if (env && env[0]) { char *end = NULL; unsigned long long v = strtoull(env, &end, 10); - if (end != env) gb = (uint64_t)v; + if (end != env) { + if (v == 0) return UINT64_MAX; + if (v > UINT64_MAX / 1073741824ull) return UINT64_MAX; + return (uint64_t)v * 1073741824ull; + } + } + + if (cuda_partial_weight_cache_enabled()) { + if (!g_model_auto_cache_limit_ready) { + size_t free_b = 0; + size_t total_b = 0; + cudaError_t err = cudaMemGetInfo(&free_b, &total_b); + if (err != cudaSuccess) { + fprintf(stderr, "ds4: CUDA model cache memory query failed: %s\n", + cudaGetErrorString(err)); + (void)cudaGetLastError(); + g_model_auto_cache_limit_bytes = 0; + } else { + const uint64_t total = (uint64_t)total_b; + int reserve_present = 0; + uint64_t reserve = cuda_parse_mib_env("DS4_CUDA_WEIGHT_CACHE_RESERVE_MB", &reserve_present); + if (!reserve_present) { + if (total <= 32ull * 1073741824ull) { + reserve = 6144ull * 1048576ull; + } else { + const uint64_t pct = total / 10u; + reserve = pct > 4096ull * 1048576ull ? pct : 4096ull * 1048576ull; + } + } + const uint64_t free_bytes = (uint64_t)free_b; + g_model_auto_cache_limit_bytes = free_bytes > reserve ? free_bytes - reserve : 0; + if (getenv("DS4_CUDA_WEIGHT_CACHE_VERBOSE")) { + fprintf(stderr, + "ds4: CUDA partial model cache auto limit %.2f GiB " + "(free %.2f GiB reserve %.2f GiB total %.2f GiB)\n", + (double)g_model_auto_cache_limit_bytes / 1073741824.0, + (double)free_bytes / 1073741824.0, + (double)reserve / 1073741824.0, + (double)total / 1073741824.0); + } + } + g_model_auto_cache_limit_ready = 1; + } + return g_model_auto_cache_limit_bytes; } - if (gb == 0) return UINT64_MAX; - return gb * 1073741824ull; + + return UINT64_MAX; +} + +static uint64_t cuda_model_cache_reserve_bytes(uint64_t total_bytes) { + int reserve_present = 0; + const uint64_t reserve = cuda_parse_mib_env("DS4_CUDA_WEIGHT_CACHE_RESERVE_MB", &reserve_present); + if (reserve_present) return reserve; + if (total_bytes <= 32ull * 1073741824ull) return 6144ull * 1048576ull; + const uint64_t pct = total_bytes / 10u; + const uint64_t min_reserve = 4096ull * 1048576ull; + return pct > min_reserve ? pct : min_reserve; } static uint64_t cuda_model_arena_chunk_bytes(uint64_t need) { @@ -958,8 +1081,14 @@ static char *cuda_model_arena_alloc(uint64_t bytes, const char *what) { if (bytes == 0) return NULL; if (g_model_cache_full) return NULL; const uint64_t align = 256u; + if (bytes > UINT64_MAX - (align - 1u)) return NULL; const uint64_t aligned = (bytes + align - 1u) & ~(align - 1u); + const uint64_t limit = cuda_model_cache_limit_bytes(); + if (g_model_range_bytes > limit) return NULL; + uint64_t remaining = limit - g_model_range_bytes; + if (aligned > remaining) return NULL; + for (cuda_model_arena &a : g_model_arenas) { const uint64_t used = (a.used + align - 1u) & ~(align - 1u); if (used <= a.bytes && aligned <= a.bytes - used) { @@ -969,10 +1098,36 @@ static char *cuda_model_arena_alloc(uint64_t bytes, const char *what) { } } - const uint64_t limit = cuda_model_cache_limit_bytes(); - if (g_model_range_bytes > limit || aligned > limit - g_model_range_bytes) return NULL; - - const uint64_t chunk = cuda_model_arena_chunk_bytes(aligned); + uint64_t chunk = cuda_model_arena_chunk_bytes(aligned); + if (limit != UINT64_MAX && chunk > remaining) chunk = remaining; + if (chunk < aligned) return NULL; + if (cuda_partial_weight_cache_enabled()) { + size_t free_b = 0; + size_t total_b = 0; + cudaError_t mem_err = cudaMemGetInfo(&free_b, &total_b); + if (mem_err != cudaSuccess) { + fprintf(stderr, "ds4: CUDA model cache memory query failed: %s\n", + cudaGetErrorString(mem_err)); + (void)cudaGetLastError(); + g_model_cache_full = 1; + return NULL; + } + const uint64_t free_bytes = (uint64_t)free_b; + const uint64_t reserve = cuda_model_cache_reserve_bytes((uint64_t)total_b); + if (free_bytes <= reserve || chunk > free_bytes - reserve) { + g_model_cache_full = 1; + if (getenv("DS4_CUDA_WEIGHT_CACHE_VERBOSE")) { + fprintf(stderr, + "ds4: CUDA partial model cache stopped before %s " + "(request %.2f MiB free %.2f GiB reserve %.2f GiB)\n", + what ? what : "weights", + (double)chunk / 1048576.0, + (double)free_bytes / 1073741824.0, + (double)reserve / 1073741824.0); + } + return NULL; + } + } void *dev = NULL; cudaError_t err = cudaMalloc(&dev, (size_t)chunk); if (err != cudaSuccess) { @@ -999,7 +1154,8 @@ static const char *cuda_model_range_ptr_from_fd( const void *model_map, uint64_t offset, uint64_t bytes, - const char *what) { + const char *what, + int allow_direct_fallback) { if (g_model_fd < 0 || bytes == 0) return NULL; if (g_model_fd_host_base != NULL && model_map != g_model_fd_host_base) return NULL; const uint64_t limit = cuda_model_cache_limit_bytes(); @@ -1010,12 +1166,13 @@ static const char *cuda_model_range_ptr_from_fd( (double)bytes / 1048576.0, (double)limit / 1073741824.0); } + if (!allow_direct_fallback || getenv("DS4_CUDA_STRICT_WEIGHT_CACHE") != NULL) return NULL; return cuda_model_ptr(model_map, offset); } char *dev = cuda_model_arena_alloc(bytes, what); if (!dev) { - if (getenv("DS4_CUDA_STRICT_WEIGHT_CACHE") != NULL) return NULL; + if (!allow_direct_fallback || getenv("DS4_CUDA_STRICT_WEIGHT_CACHE") != NULL) return NULL; return cuda_model_ptr(model_map, offset); } cudaError_t err = cudaSuccess; @@ -1091,6 +1248,63 @@ static const char *cuda_model_range_ptr_from_fd( return (const char *)dev; } +static const char *cuda_model_range_cache_device( + const void *model_map, + uint64_t model_size, + uint64_t offset, + uint64_t bytes, + const char *what) { + if (!model_map || bytes == 0) return NULL; + if (offset > model_size || bytes > model_size - offset) return NULL; + if (model_map == g_model_host_base && g_model_device_owned) { + return cuda_model_ptr(model_map, offset); + } + + const char *cached = cuda_model_range_lookup_device_cached(model_map, offset, bytes); + if (cached) return cached; + + if (getenv("DS4_CUDA_NO_FD_CACHE") == NULL) { + const char *fd_ptr = cuda_model_range_ptr_from_fd(model_map, offset, bytes, what, 0); + if (fd_ptr) { + cached = cuda_model_range_lookup_device_cached(model_map, offset, bytes); + if (cached) return cached; + } + } + + char *dev = cuda_model_arena_alloc(bytes, what); + if (!dev) return NULL; + + const char *src = (const char *)model_map + offset; + const uint64_t chunk = cuda_model_copy_chunk_bytes(); + for (uint64_t done = 0; done < bytes; done += chunk) { + const uint64_t n = bytes - done < chunk ? bytes - done : chunk; + cudaError_t err = cudaMemcpy(dev + done, src + done, (size_t)n, cudaMemcpyHostToDevice); + if (err != cudaSuccess) { + fprintf(stderr, "ds4: CUDA model forced cache copy failed for %s at %.2f/%.2f MiB: %s\n", + what ? what : "weights", + (double)done / 1048576.0, + (double)bytes / 1048576.0, + cudaGetErrorString(err)); + (void)cudaGetLastError(); + g_model_cache_full = 1; + return NULL; + } + } + + g_model_ranges.push_back({model_map, offset, bytes, dev, NULL, NULL, 0, 0, 1}); + g_model_range_by_offset[offset] = g_model_ranges.size() - 1u; + g_model_range_bytes += bytes; + cuda_model_load_progress_note(g_model_range_bytes); + cuda_model_discard_source_pages(model_map, model_size, offset, bytes); + if (getenv("DS4_CUDA_WEIGHT_CACHE_VERBOSE")) { + fprintf(stderr, "ds4: CUDA cached %s %.2f MiB (total %.2f GiB)\n", + what ? what : "weights", + (double)bytes / 1048576.0, + (double)g_model_range_bytes / 1073741824.0); + } + return dev; +} + static int cuda_model_copy_chunked(const void *model_map, uint64_t model_size, uint64_t map_offset, uint64_t map_size) { if (!model_map || model_size == 0 || map_offset > model_size || map_size > model_size - map_offset) return 0; if (getenv("DS4_CUDA_NO_MODEL_COPY") != NULL || @@ -1099,7 +1313,7 @@ static int cuda_model_copy_chunked(const void *model_map, uint64_t model_size, u getenv("DS4_CUDA_WEIGHT_PRELOAD") != NULL) { return 0; } - if (g_model_device_owned || g_model_registered) return 1; + if (model_map == g_model_host_base && (g_model_device_owned || g_model_registered)) return 1; void *dev = NULL; const double t0 = cuda_wall_sec(); @@ -1193,6 +1407,8 @@ static void cuda_model_range_release_all(void) { g_model_ranges.clear(); g_model_range_by_offset.clear(); g_model_range_bytes = 0; + g_model_auto_cache_limit_bytes = 0; + g_model_auto_cache_limit_ready = 0; cuda_model_load_progress_reset(); } @@ -1281,6 +1497,8 @@ extern "C" void ds4_gpu_cleanup(void) { g_model_direct_align = 1; g_model_file_size = 0; g_model_cache_full = 0; + g_model_auto_cache_limit_bytes = 0; + g_model_auto_cache_limit_ready = 0; if (g_model_prefetch_stream) { (void)cudaStreamDestroy(g_model_prefetch_stream); g_model_prefetch_stream = NULL; @@ -1539,8 +1757,17 @@ extern "C" int ds4_gpu_set_model_fd(int fd) { extern "C" int ds4_gpu_cache_model_range(const void *model_map, uint64_t model_size, uint64_t offset, uint64_t bytes, const char *label) { if (!model_map || bytes == 0) return 1; if (offset > model_size || bytes > model_size - offset) return 0; - if (!cuda_model_range_ptr(model_map, offset, bytes, label ? label : "model_tensor")) return 0; - return cuda_model_range_is_cached(model_map, offset, bytes); + const char *cache_label = label ? label : "model_tensor"; + const char *ptr = NULL; + const int force_device_cache = cuda_partial_weight_cache_enabled() || cuda_direct_model_enabled(); + if (force_device_cache) { + ptr = cuda_model_range_cache_device(model_map, model_size, offset, bytes, cache_label); + } else { + ptr = cuda_model_range_ptr(model_map, offset, bytes, cache_label); + } + if (!ptr) return 0; + if (force_device_cache) return cuda_model_range_is_device_cached(model_map, offset, bytes); + return cuda_model_range_lookup_cached(model_map, offset, bytes) != NULL; } extern "C" int ds4_gpu_cache_q8_f16_range(const void *model_map, uint64_t model_size, uint64_t offset, uint64_t bytes, uint64_t in_dim, uint64_t out_dim, const char *label) { @@ -1564,8 +1791,18 @@ extern "C" int ds4_gpu_cache_q8_f16_range(const void *model_map, uint64_t model_ extern "C" void ds4_gpu_print_memory_report(const char *label) { size_t free_b = 0, total_b = 0; (void)cudaMemGetInfo(&free_b, &total_b); - fprintf(stderr, "ds4: CUDA memory report %s: free %.2f MiB total %.2f MiB\n", - label ? label : "", (double)free_b / 1048576.0, (double)total_b / 1048576.0); + uint64_t arena_bytes = 0; + for (const cuda_model_arena &a : g_model_arenas) arena_bytes += a.bytes; + fprintf(stderr, + "ds4: CUDA memory report %s: free %.2f MiB total %.2f MiB " + "model-cache %.2f GiB arenas %.2f GiB q8-f16 %.2f GiB q8-f32 %.2f GiB\n", + label ? label : "", + (double)free_b / 1048576.0, + (double)total_b / 1048576.0, + (double)g_model_range_bytes / 1073741824.0, + (double)arena_bytes / 1073741824.0, + (double)g_q8_f16_bytes / 1073741824.0, + (double)g_q8_f32_bytes / 1073741824.0); } extern "C" void ds4_gpu_set_quality(bool quality) { diff --git a/tests/cuda_long_context_smoke.c b/tests/cuda_long_context_smoke.c index c9a8049d..58de4a53 100644 --- a/tests/cuda_long_context_smoke.c +++ b/tests/cuda_long_context_smoke.c @@ -1,10 +1,13 @@ #include "ds4_gpu.h" +#include #include #include #include #include +#define TEST_DS4_RMS_EPS 1.0e-6f + static double monotonic_seconds(void) { struct timespec ts; clock_gettime(CLOCK_MONOTONIC, &ts); @@ -148,10 +151,78 @@ static int check_decode_attention_overflow_path(void) { return rc; } +static int check_partial_direct_model_cache_precedence(void) { + if (setenv("DS4_CUDA_DIRECT_MODEL", "1", 1) != 0 || + setenv("DS4_CUDA_PARTIAL_WEIGHT_CACHE", "1", 1) != 0 || + setenv("DS4_CUDA_WEIGHT_CACHE_LIMIT_MB", "16", 1) != 0) { + return 1; + } + + const float rms_weight[4] = {2.0f, 3.0f, 4.0f, 5.0f}; + static float host_model[4]; + const uint64_t model_size = sizeof(host_model); + float out_host[4] = {0}; + for (uint32_t i = 0; i < 4; i++) host_model[i] = rms_weight[i]; + + ds4_gpu_tensor *x = ds4_gpu_tensor_alloc(sizeof(rms_weight)); + ds4_gpu_tensor *out = ds4_gpu_tensor_alloc(sizeof(rms_weight)); + int rc = 1; + const float x_host[4] = {1.0f, 1.0f, 1.0f, 1.0f}; + + if (x && out && + ds4_gpu_set_model_map(host_model, model_size) && + ds4_gpu_tensor_write(x, 0, x_host, sizeof(x_host)) && + ds4_gpu_cache_model_range(host_model, model_size, 0, model_size, "test_partial_rms")) { + for (uint32_t i = 0; i < 4; i++) host_model[i] = -100.0f - (float)i; + if (ds4_gpu_rms_norm_weight_tensor(out, x, host_model, model_size, 0, 4, TEST_DS4_RMS_EPS) && + ds4_gpu_synchronize() && + ds4_gpu_tensor_read(out, 0, out_host, sizeof(out_host))) { + rc = 0; + const float scale = 1.0f / sqrtf(1.0f + TEST_DS4_RMS_EPS); + for (uint32_t i = 0; i < 4; i++) { + const float want = rms_weight[i] * scale; + if (fabsf(out_host[i] - want) > 1.0e-3f) { + fprintf(stderr, + "partial direct cache rms mismatch index=%u got=%f expected=%f\n", + i, + (double)out_host[i], + (double)want); + rc = 1; + } + } + } + } + + ds4_gpu_tensor_free(out); + ds4_gpu_tensor_free(x); + return rc; +} + +static int check_partial_direct_model_cache_budget_miss(void) { + if (setenv("DS4_CUDA_DIRECT_MODEL", "1", 1) != 0 || + setenv("DS4_CUDA_PARTIAL_WEIGHT_CACHE", "1", 1) != 0 || + setenv("DS4_CUDA_WEIGHT_CACHE_LIMIT_MB", "1", 1) != 0) { + return 1; + } + + const uint64_t bytes = 2ull * 1024ull * 1024ull; + unsigned char *host_model = (unsigned char *)malloc((size_t)bytes); + if (!host_model) return 1; + const int cached = ds4_gpu_cache_model_range(host_model, bytes, 0, bytes, "test_partial_budget"); + free(host_model); + if (cached != 0) { + fprintf(stderr, "partial direct cache budget test unexpectedly cached 2 MiB under 1 MiB limit\n"); + return 1; + } + return 0; +} + int main(void) { if (!ds4_gpu_init()) return 1; int rc = check_large_topk(); if (check_decode_attention_overflow_path() != 0) rc = 1; + if (check_partial_direct_model_cache_precedence() != 0) rc = 1; + if (check_partial_direct_model_cache_budget_miss() != 0) rc = 1; ds4_gpu_cleanup(); if (rc == 0) puts("cuda long-context regression: OK"); return rc;