Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions quant.h
Original file line number Diff line number Diff line change
Expand Up @@ -15674,18 +15674,36 @@ int tq_generate_continue(tq_model_t* model,
return -1;
}

/* Encode new prompt */
int new_tokens[4096];
/* Heap-allocated prompt token buffer (was a 4096-stack array, which
* silently truncated after ~10 turns of accumulating chat history).
* Cap at the model's max_seq_len so we never exceed KV bounds. */
int max_prompt = model->config.max_seq_len > 0
? model->config.max_seq_len : 4096;
int* new_tokens = (int*)malloc((size_t)max_prompt * sizeof(int));
if (!new_tokens) return -1;
int n_new = 0;
if (tokenizer && prompt) {
int add_bos = (model->config.model_type == 1) ? 1 : 0;
n_new = tq_encode(tokenizer, prompt, new_tokens, 4096, add_bos);
n_new = tq_encode(tokenizer, prompt, new_tokens, max_prompt, add_bos);
}
if (n_new <= 0) {
new_tokens[0] = (model->config.model_type == 1) ? 2 : 1;
n_new = 1;
}

/* Sliding window: drop oldest prompt tokens if the new prompt would
* leave no room for max_tokens of generation. Keeps the most recent
* tokens. Forces full reprefill since the prefix shifted. */
int reserve = config->max_tokens > 0 ? config->max_tokens : 256;
int budget = max_prompt - reserve - 32;
if (budget < 64) budget = 64;
if (n_new > budget) {
int drop = n_new - budget;
memmove(new_tokens, new_tokens + drop, (size_t)budget * sizeof(int));
n_new = budget;
*n_cached_io = 0;
}

/* Find longest common prefix with the cached tokens.
* If the new prompt is just an extension of the cached one, we skip
* everything up to the LCP and only prefill the suffix. */
Expand All @@ -15694,27 +15712,21 @@ int tq_generate_continue(tq_model_t* model,

int lcp = tq_lcp_int(cached_tokens, n_cached, new_tokens, n_new);

/* If the cached tokens go beyond the LCP (i.e., the new prompt diverges
* from history mid-way, e.g., user edited a previous message), we have
* to invalidate the divergent suffix. The simplest correct option is to
* roll the state position back to lcp. The KV cache itself doesn't need
* to be cleared — positions >= lcp will just be overwritten when we
* prefill the new suffix. */
int pos_start = lcp;

/* Prefill the new suffix */
/* Prefill the new suffix [lcp, n_new) */
for (int i = lcp; i < n_new; i++) {
tq_forward(model, state, new_tokens[i], i);
}
int pos = n_new;
int prefill_tokens = n_new - lcp;
int prefix_hit = lcp;

/* Save the n_new prompt into the cache buffer (will append generated
* tokens below). Grow the buffer if needed. */
int needed_cap = n_new + config->max_tokens + 16;
if (*cached_capacity_io < needed_cap) {
int new_cap = needed_cap < 4096 ? 4096 : needed_cap;
int* nb = (int*)realloc(*cached_tokens_io, (size_t)new_cap * sizeof(int));
if (!nb) return -1;
if (!nb) { free(new_tokens); return -1; }
*cached_tokens_io = nb;
*cached_capacity_io = new_cap;
cached_tokens = nb;
Expand Down Expand Up @@ -15825,6 +15837,14 @@ int tq_generate_continue(tq_model_t* model,
if (output && output_size > 0) {
output[output_pos < output_size ? output_pos : output_size - 1] = '\0';
}

if (getenv("TQ_CHAT_DEBUG")) {
fprintf(stderr,
"[chat] prefix_hit=%d prefill=%d generated=%d cached=%d\n",
prefix_hit, prefill_tokens, generated, *n_cached_io);
}

free(new_tokens);
return generated;
}

Expand Down
44 changes: 40 additions & 4 deletions src/engine/tq_generate.c
Original file line number Diff line number Diff line change
Expand Up @@ -630,18 +630,40 @@ int tq_generate_continue(tq_model_t* model,
return -1;
}

/* Encode new prompt */
int new_tokens[4096];
/* Encode new prompt — use a heap buffer that grows on demand instead
* of a fixed stack array. The previous int new_tokens[4096] silently
* truncated long contexts (10+ turns of accumulated chat history).
* Cap at the model's max_seq_len so we never exceed KV cache bounds. */
int max_prompt = model->config.max_seq_len > 0
? model->config.max_seq_len : 4096;
int* new_tokens = (int*)malloc((size_t)max_prompt * sizeof(int));
if (!new_tokens) return -1;
int n_new = 0;
if (tokenizer && prompt) {
int add_bos = (model->config.model_type == 1) ? 1 : 0;
n_new = tq_encode(tokenizer, prompt, new_tokens, 4096, add_bos);
n_new = tq_encode(tokenizer, prompt, new_tokens, max_prompt, add_bos);
}
if (n_new <= 0) {
new_tokens[0] = (model->config.model_type == 1) ? 2 : 1;
n_new = 1;
}

/* Sliding window: if the new prompt + reserved generation room would
* exceed max_seq_len, drop the oldest tokens from the front of the
* prompt. We keep the most recent (max_seq_len - max_tokens - 32) tokens.
* Note: this discards conversation history; ideally callers send
* pre-trimmed prompts, but this prevents catastrophic failure. */
int reserve = config->max_tokens > 0 ? config->max_tokens : 256;
int budget = max_prompt - reserve - 32;
if (budget < 64) budget = 64;
if (n_new > budget) {
int drop = n_new - budget;
memmove(new_tokens, new_tokens + drop, (size_t)budget * sizeof(int));
n_new = budget;
/* Force full reprefill since the prefix shifted */
*n_cached_io = 0;
}

int n_cached = *n_cached_io;
int* cached_tokens = *cached_tokens_io;
int lcp = tq_lcp_int(cached_tokens, n_cached, new_tokens, n_new);
Expand All @@ -652,12 +674,16 @@ int tq_generate_continue(tq_model_t* model,
}
int pos = n_new;

/* Track prefill metrics for observability */
int prefill_tokens = n_new - lcp;
int prefix_hit = lcp;

/* Grow cache buffer if needed */
int needed_cap = n_new + config->max_tokens + 16;
if (*cached_capacity_io < needed_cap) {
int new_cap = needed_cap < 4096 ? 4096 : needed_cap;
int* nb = (int*)realloc(*cached_tokens_io, (size_t)new_cap * sizeof(int));
if (!nb) return -1;
if (!nb) { free(new_tokens); return -1; }
*cached_tokens_io = nb;
*cached_capacity_io = new_cap;
cached_tokens = nb;
Expand Down Expand Up @@ -764,5 +790,15 @@ int tq_generate_continue(tq_model_t* model,
if (output && output_size > 0) {
output[output_pos < output_size ? output_pos : output_size - 1] = '\0';
}

/* Log cache metrics: prefix_hit / prefill_tokens / generated.
* Useful for tuning chat clients that want to maximize KV reuse. */
if (getenv("TQ_CHAT_DEBUG")) {
fprintf(stderr,
"[chat] prefix_hit=%d prefill=%d generated=%d cached=%d\n",
prefix_hit, prefill_tokens, generated, *n_cached_io);
}

free(new_tokens);
return generated;
}
131 changes: 105 additions & 26 deletions src/server/tq_server.c
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,86 @@ typedef volatile long atomic_int;
* Server state
* ============================================================ */

/* ============================================================
* Per-session KV cache for multi-client chat reuse
*
* Each client identifies itself with X-Session-Id header (or the
* "user" field in the request body, OpenAI-compatible). Sessions are
* stored in a small LRU table; the least recently used is evicted
* when MAX_SESSIONS is reached.
*
* Without this, two concurrent chat clients would corrupt each
* other's KV cache. The inference_mutex still serializes per-token
* forward passes (single model weights), but the cache state is
* now per-session.
* ============================================================ */
#define MAX_SESSIONS 16
#define SESSION_ID_MAX 64

typedef struct {
char id[SESSION_ID_MAX]; /* "" = unused slot */
tq_state_t* kv_state;
int* cached_tokens;
int n_cached;
int cached_capacity;
long last_used; /* monotonic counter for LRU */
} kv_session_t;

struct tq_server {
tq_server_config_t config;
int listen_fd;
atomic_int running;
atomic_int active_connections; /* track concurrent threads */
pthread_mutex_t inference_mutex; /* serialize inference (single model state) */

/* Persistent inference state — shared across requests for chat-mode
* KV cache reuse. The inference_mutex above serializes access. */
tq_state_t* kv_state;
int* cached_tokens;
int n_cached;
int cached_capacity;
kv_session_t sessions[MAX_SESSIONS];
long session_clock;
};

/* Find or allocate a session by id. Caller holds inference_mutex.
* Returns a pointer into server->sessions. Never NULL (LRU evicts). */
static kv_session_t* get_or_create_session(tq_server_t* server,
const char* sid,
tq_type kv_type,
int value_quant_bits) {
if (!sid || !sid[0]) sid = "default";
server->session_clock++;

int empty_slot = -1;
int lru_slot = 0;
long lru_time = server->sessions[0].last_used;

for (int i = 0; i < MAX_SESSIONS; i++) {
if (server->sessions[i].id[0] == '\0') {
if (empty_slot < 0) empty_slot = i;
continue;
}
if (strncmp(server->sessions[i].id, sid, SESSION_ID_MAX) == 0) {
server->sessions[i].last_used = server->session_clock;
return &server->sessions[i];
}
if (server->sessions[i].last_used < lru_time) {
lru_time = server->sessions[i].last_used;
lru_slot = i;
}
}

/* Not found — pick empty slot or evict LRU */
int slot = empty_slot >= 0 ? empty_slot : lru_slot;
kv_session_t* s = &server->sessions[slot];

/* Free old session contents (if any) */
if (s->kv_state) tq_free_state(s->kv_state);
if (s->cached_tokens) free(s->cached_tokens);

memset(s, 0, sizeof(*s));
strncpy(s->id, sid, SESSION_ID_MAX - 1);
s->kv_state = tq_create_state_ex(
&server->config.model->config, kv_type, value_quant_bits);
s->last_used = server->session_clock;
return s;
}

/* Global server pointer for signal handler */
static tq_server_t* g_server = NULL;

Expand Down Expand Up @@ -226,6 +291,10 @@ typedef struct {

/* Built prompt */
char* prompt; /* heap-allocated */

/* Session id for KV cache reuse (OpenAI 'user' field).
* Empty = "default" session. */
char session_id[64];
} chat_request_t;

static void free_chat_request(chat_request_t* req) {
Expand Down Expand Up @@ -374,6 +443,13 @@ static int parse_chat_request(const char* body, chat_request_t* req) {
v = json_find_key(body, "delta_kv");
if (v) req->delta_kv = json_extract_bool(v);

/* OpenAI-compatible 'user' field doubles as our session id for KV
* cache reuse. Clients that pass the same user across turns get
* O(delta) prefill cost; clients that don't share the "default"
* slot (still works for single-user demos). */
v = json_find_key(body, "user");
if (v) json_extract_string(v, req->session_id, sizeof(req->session_id));

/* Parse messages */
v = json_find_key(body, "messages");
if (!v) {
Expand Down Expand Up @@ -673,18 +749,19 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod
gen_cfg.user_data = &sse_ctx;

char output[1]; /* writes via callback, output buffer unused */
/* Use tq_generate_continue with persistent KV state for chat reuse:
* matches the longest common prefix of req.prompt against
* server->cached_tokens, prefills only the suffix. Turns chat
* latency from O(history^2) into O(new_tokens). */
if (!server->kv_state) {
server->kv_state = tq_create_state_ex(
&server->config.model->config, gen_cfg.kv_type, gen_cfg.value_quant_bits);
}
/* Per-session KV cache reuse:
* - Sessions are keyed by req.session_id (OpenAI 'user' field).
* - Each session has its own kv_state + cached_tokens.
* - LRU evicts the least recently used when the table is full.
* - The longest common prefix between cached tokens and the new
* prompt is reused; only the suffix is prefilled. */
kv_session_t* sess = get_or_create_session(server, req.session_id,
gen_cfg.kv_type,
gen_cfg.value_quant_bits);
tq_generate_continue(server->config.model, server->config.tokenizer,
server->kv_state, req.prompt, &gen_cfg,
&server->cached_tokens, &server->n_cached,
&server->cached_capacity,
sess->kv_state, req.prompt, &gen_cfg,
&sess->cached_tokens, &sess->n_cached,
&sess->cached_capacity,
output, sizeof(output));

/* Send final chunk with finish_reason */
Expand Down Expand Up @@ -715,14 +792,13 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod
gen_cfg.user_data = &collect;

char output[1];
if (!server->kv_state) {
server->kv_state = tq_create_state_ex(
&server->config.model->config, gen_cfg.kv_type, gen_cfg.value_quant_bits);
}
kv_session_t* sess = get_or_create_session(server, req.session_id,
gen_cfg.kv_type,
gen_cfg.value_quant_bits);
tq_generate_continue(server->config.model, server->config.tokenizer,
server->kv_state, req.prompt, &gen_cfg,
&server->cached_tokens, &server->n_cached,
&server->cached_capacity,
sess->kv_state, req.prompt, &gen_cfg,
&sess->cached_tokens, &sess->n_cached,
&sess->cached_capacity,
output, sizeof(output));

const char* content = collect.buf ? collect.buf : "";
Expand Down Expand Up @@ -1180,8 +1256,11 @@ void tq_server_stop(tq_server_t* server) {
void tq_server_free(tq_server_t* server) {
if (!server) return;
pthread_mutex_destroy(&server->inference_mutex);
if (server->kv_state) tq_free_state(server->kv_state);
if (server->cached_tokens) free(server->cached_tokens);
/* Free all session KV caches */
for (int i = 0; i < MAX_SESSIONS; i++) {
if (server->sessions[i].kv_state) tq_free_state(server->sessions[i].kv_state);
if (server->sessions[i].cached_tokens) free(server->sessions[i].cached_tokens);
}
if (g_server == server) g_server = NULL;
free(server);
}
Expand Down
Loading