diff --git a/quant.h b/quant.h index 15ad224..b0e6358 100644 --- a/quant.h +++ b/quant.h @@ -15674,18 +15674,36 @@ int tq_generate_continue(tq_model_t* model, return -1; } - /* Encode new prompt */ - int new_tokens[4096]; + /* Heap-allocated prompt token buffer (was a 4096-stack array, which + * silently truncated after ~10 turns of accumulating chat history). + * Cap at the model's max_seq_len so we never exceed KV bounds. */ + int max_prompt = model->config.max_seq_len > 0 + ? model->config.max_seq_len : 4096; + int* new_tokens = (int*)malloc((size_t)max_prompt * sizeof(int)); + if (!new_tokens) return -1; int n_new = 0; if (tokenizer && prompt) { int add_bos = (model->config.model_type == 1) ? 1 : 0; - n_new = tq_encode(tokenizer, prompt, new_tokens, 4096, add_bos); + n_new = tq_encode(tokenizer, prompt, new_tokens, max_prompt, add_bos); } if (n_new <= 0) { new_tokens[0] = (model->config.model_type == 1) ? 2 : 1; n_new = 1; } + /* Sliding window: drop oldest prompt tokens if the new prompt would + * leave no room for max_tokens of generation. Keeps the most recent + * tokens. Forces full reprefill since the prefix shifted. */ + int reserve = config->max_tokens > 0 ? config->max_tokens : 256; + int budget = max_prompt - reserve - 32; + if (budget < 64) budget = 64; + if (n_new > budget) { + int drop = n_new - budget; + memmove(new_tokens, new_tokens + drop, (size_t)budget * sizeof(int)); + n_new = budget; + *n_cached_io = 0; + } + /* Find longest common prefix with the cached tokens. * If the new prompt is just an extension of the cached one, we skip * everything up to the LCP and only prefill the suffix. */ @@ -15694,19 +15712,13 @@ int tq_generate_continue(tq_model_t* model, int lcp = tq_lcp_int(cached_tokens, n_cached, new_tokens, n_new); - /* If the cached tokens go beyond the LCP (i.e., the new prompt diverges - * from history mid-way, e.g., user edited a previous message), we have - * to invalidate the divergent suffix. The simplest correct option is to - * roll the state position back to lcp. The KV cache itself doesn't need - * to be cleared — positions >= lcp will just be overwritten when we - * prefill the new suffix. */ - int pos_start = lcp; - - /* Prefill the new suffix */ + /* Prefill the new suffix [lcp, n_new) */ for (int i = lcp; i < n_new; i++) { tq_forward(model, state, new_tokens[i], i); } int pos = n_new; + int prefill_tokens = n_new - lcp; + int prefix_hit = lcp; /* Save the n_new prompt into the cache buffer (will append generated * tokens below). Grow the buffer if needed. */ @@ -15714,7 +15726,7 @@ int tq_generate_continue(tq_model_t* model, if (*cached_capacity_io < needed_cap) { int new_cap = needed_cap < 4096 ? 4096 : needed_cap; int* nb = (int*)realloc(*cached_tokens_io, (size_t)new_cap * sizeof(int)); - if (!nb) return -1; + if (!nb) { free(new_tokens); return -1; } *cached_tokens_io = nb; *cached_capacity_io = new_cap; cached_tokens = nb; @@ -15825,6 +15837,14 @@ int tq_generate_continue(tq_model_t* model, if (output && output_size > 0) { output[output_pos < output_size ? output_pos : output_size - 1] = '\0'; } + + if (getenv("TQ_CHAT_DEBUG")) { + fprintf(stderr, + "[chat] prefix_hit=%d prefill=%d generated=%d cached=%d\n", + prefix_hit, prefill_tokens, generated, *n_cached_io); + } + + free(new_tokens); return generated; } diff --git a/src/engine/tq_generate.c b/src/engine/tq_generate.c index 64f8515..fc5ea3b 100644 --- a/src/engine/tq_generate.c +++ b/src/engine/tq_generate.c @@ -630,18 +630,40 @@ int tq_generate_continue(tq_model_t* model, return -1; } - /* Encode new prompt */ - int new_tokens[4096]; + /* Encode new prompt — use a heap buffer that grows on demand instead + * of a fixed stack array. The previous int new_tokens[4096] silently + * truncated long contexts (10+ turns of accumulated chat history). + * Cap at the model's max_seq_len so we never exceed KV cache bounds. */ + int max_prompt = model->config.max_seq_len > 0 + ? model->config.max_seq_len : 4096; + int* new_tokens = (int*)malloc((size_t)max_prompt * sizeof(int)); + if (!new_tokens) return -1; int n_new = 0; if (tokenizer && prompt) { int add_bos = (model->config.model_type == 1) ? 1 : 0; - n_new = tq_encode(tokenizer, prompt, new_tokens, 4096, add_bos); + n_new = tq_encode(tokenizer, prompt, new_tokens, max_prompt, add_bos); } if (n_new <= 0) { new_tokens[0] = (model->config.model_type == 1) ? 2 : 1; n_new = 1; } + /* Sliding window: if the new prompt + reserved generation room would + * exceed max_seq_len, drop the oldest tokens from the front of the + * prompt. We keep the most recent (max_seq_len - max_tokens - 32) tokens. + * Note: this discards conversation history; ideally callers send + * pre-trimmed prompts, but this prevents catastrophic failure. */ + int reserve = config->max_tokens > 0 ? config->max_tokens : 256; + int budget = max_prompt - reserve - 32; + if (budget < 64) budget = 64; + if (n_new > budget) { + int drop = n_new - budget; + memmove(new_tokens, new_tokens + drop, (size_t)budget * sizeof(int)); + n_new = budget; + /* Force full reprefill since the prefix shifted */ + *n_cached_io = 0; + } + int n_cached = *n_cached_io; int* cached_tokens = *cached_tokens_io; int lcp = tq_lcp_int(cached_tokens, n_cached, new_tokens, n_new); @@ -652,12 +674,16 @@ int tq_generate_continue(tq_model_t* model, } int pos = n_new; + /* Track prefill metrics for observability */ + int prefill_tokens = n_new - lcp; + int prefix_hit = lcp; + /* Grow cache buffer if needed */ int needed_cap = n_new + config->max_tokens + 16; if (*cached_capacity_io < needed_cap) { int new_cap = needed_cap < 4096 ? 4096 : needed_cap; int* nb = (int*)realloc(*cached_tokens_io, (size_t)new_cap * sizeof(int)); - if (!nb) return -1; + if (!nb) { free(new_tokens); return -1; } *cached_tokens_io = nb; *cached_capacity_io = new_cap; cached_tokens = nb; @@ -764,5 +790,15 @@ int tq_generate_continue(tq_model_t* model, if (output && output_size > 0) { output[output_pos < output_size ? output_pos : output_size - 1] = '\0'; } + + /* Log cache metrics: prefix_hit / prefill_tokens / generated. + * Useful for tuning chat clients that want to maximize KV reuse. */ + if (getenv("TQ_CHAT_DEBUG")) { + fprintf(stderr, + "[chat] prefix_hit=%d prefill=%d generated=%d cached=%d\n", + prefix_hit, prefill_tokens, generated, *n_cached_io); + } + + free(new_tokens); return generated; } diff --git a/src/server/tq_server.c b/src/server/tq_server.c index 73b811c..f41d6b9 100644 --- a/src/server/tq_server.c +++ b/src/server/tq_server.c @@ -73,6 +73,31 @@ typedef volatile long atomic_int; * Server state * ============================================================ */ +/* ============================================================ + * Per-session KV cache for multi-client chat reuse + * + * Each client identifies itself with X-Session-Id header (or the + * "user" field in the request body, OpenAI-compatible). Sessions are + * stored in a small LRU table; the least recently used is evicted + * when MAX_SESSIONS is reached. + * + * Without this, two concurrent chat clients would corrupt each + * other's KV cache. The inference_mutex still serializes per-token + * forward passes (single model weights), but the cache state is + * now per-session. + * ============================================================ */ +#define MAX_SESSIONS 16 +#define SESSION_ID_MAX 64 + +typedef struct { + char id[SESSION_ID_MAX]; /* "" = unused slot */ + tq_state_t* kv_state; + int* cached_tokens; + int n_cached; + int cached_capacity; + long last_used; /* monotonic counter for LRU */ +} kv_session_t; + struct tq_server { tq_server_config_t config; int listen_fd; @@ -80,14 +105,54 @@ struct tq_server { atomic_int active_connections; /* track concurrent threads */ pthread_mutex_t inference_mutex; /* serialize inference (single model state) */ - /* Persistent inference state — shared across requests for chat-mode - * KV cache reuse. The inference_mutex above serializes access. */ - tq_state_t* kv_state; - int* cached_tokens; - int n_cached; - int cached_capacity; + kv_session_t sessions[MAX_SESSIONS]; + long session_clock; }; +/* Find or allocate a session by id. Caller holds inference_mutex. + * Returns a pointer into server->sessions. Never NULL (LRU evicts). */ +static kv_session_t* get_or_create_session(tq_server_t* server, + const char* sid, + tq_type kv_type, + int value_quant_bits) { + if (!sid || !sid[0]) sid = "default"; + server->session_clock++; + + int empty_slot = -1; + int lru_slot = 0; + long lru_time = server->sessions[0].last_used; + + for (int i = 0; i < MAX_SESSIONS; i++) { + if (server->sessions[i].id[0] == '\0') { + if (empty_slot < 0) empty_slot = i; + continue; + } + if (strncmp(server->sessions[i].id, sid, SESSION_ID_MAX) == 0) { + server->sessions[i].last_used = server->session_clock; + return &server->sessions[i]; + } + if (server->sessions[i].last_used < lru_time) { + lru_time = server->sessions[i].last_used; + lru_slot = i; + } + } + + /* Not found — pick empty slot or evict LRU */ + int slot = empty_slot >= 0 ? empty_slot : lru_slot; + kv_session_t* s = &server->sessions[slot]; + + /* Free old session contents (if any) */ + if (s->kv_state) tq_free_state(s->kv_state); + if (s->cached_tokens) free(s->cached_tokens); + + memset(s, 0, sizeof(*s)); + strncpy(s->id, sid, SESSION_ID_MAX - 1); + s->kv_state = tq_create_state_ex( + &server->config.model->config, kv_type, value_quant_bits); + s->last_used = server->session_clock; + return s; +} + /* Global server pointer for signal handler */ static tq_server_t* g_server = NULL; @@ -226,6 +291,10 @@ typedef struct { /* Built prompt */ char* prompt; /* heap-allocated */ + + /* Session id for KV cache reuse (OpenAI 'user' field). + * Empty = "default" session. */ + char session_id[64]; } chat_request_t; static void free_chat_request(chat_request_t* req) { @@ -374,6 +443,13 @@ static int parse_chat_request(const char* body, chat_request_t* req) { v = json_find_key(body, "delta_kv"); if (v) req->delta_kv = json_extract_bool(v); + /* OpenAI-compatible 'user' field doubles as our session id for KV + * cache reuse. Clients that pass the same user across turns get + * O(delta) prefill cost; clients that don't share the "default" + * slot (still works for single-user demos). */ + v = json_find_key(body, "user"); + if (v) json_extract_string(v, req->session_id, sizeof(req->session_id)); + /* Parse messages */ v = json_find_key(body, "messages"); if (!v) { @@ -673,18 +749,19 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod gen_cfg.user_data = &sse_ctx; char output[1]; /* writes via callback, output buffer unused */ - /* Use tq_generate_continue with persistent KV state for chat reuse: - * matches the longest common prefix of req.prompt against - * server->cached_tokens, prefills only the suffix. Turns chat - * latency from O(history^2) into O(new_tokens). */ - if (!server->kv_state) { - server->kv_state = tq_create_state_ex( - &server->config.model->config, gen_cfg.kv_type, gen_cfg.value_quant_bits); - } + /* Per-session KV cache reuse: + * - Sessions are keyed by req.session_id (OpenAI 'user' field). + * - Each session has its own kv_state + cached_tokens. + * - LRU evicts the least recently used when the table is full. + * - The longest common prefix between cached tokens and the new + * prompt is reused; only the suffix is prefilled. */ + kv_session_t* sess = get_or_create_session(server, req.session_id, + gen_cfg.kv_type, + gen_cfg.value_quant_bits); tq_generate_continue(server->config.model, server->config.tokenizer, - server->kv_state, req.prompt, &gen_cfg, - &server->cached_tokens, &server->n_cached, - &server->cached_capacity, + sess->kv_state, req.prompt, &gen_cfg, + &sess->cached_tokens, &sess->n_cached, + &sess->cached_capacity, output, sizeof(output)); /* Send final chunk with finish_reason */ @@ -715,14 +792,13 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod gen_cfg.user_data = &collect; char output[1]; - if (!server->kv_state) { - server->kv_state = tq_create_state_ex( - &server->config.model->config, gen_cfg.kv_type, gen_cfg.value_quant_bits); - } + kv_session_t* sess = get_or_create_session(server, req.session_id, + gen_cfg.kv_type, + gen_cfg.value_quant_bits); tq_generate_continue(server->config.model, server->config.tokenizer, - server->kv_state, req.prompt, &gen_cfg, - &server->cached_tokens, &server->n_cached, - &server->cached_capacity, + sess->kv_state, req.prompt, &gen_cfg, + &sess->cached_tokens, &sess->n_cached, + &sess->cached_capacity, output, sizeof(output)); const char* content = collect.buf ? collect.buf : ""; @@ -1180,8 +1256,11 @@ void tq_server_stop(tq_server_t* server) { void tq_server_free(tq_server_t* server) { if (!server) return; pthread_mutex_destroy(&server->inference_mutex); - if (server->kv_state) tq_free_state(server->kv_state); - if (server->cached_tokens) free(server->cached_tokens); + /* Free all session KV caches */ + for (int i = 0; i < MAX_SESSIONS; i++) { + if (server->sessions[i].kv_state) tq_free_state(server->sessions[i].kv_state); + if (server->sessions[i].cached_tokens) free(server->sessions[i].cached_tokens); + } if (g_server == server) g_server = NULL; free(server); }