Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 252 additions & 0 deletions src/engine/tq_generate.c
Original file line number Diff line number Diff line change
Expand Up @@ -802,3 +802,255 @@ int tq_generate_continue(tq_model_t* model,
free(new_tokens);
return generated;
}

/* ============================================================================
* tq_generate_chat_text — text-prefix matching for chat reuse
*
* Solves the BPE re-tokenization issue: when the model generates response
* tokens via sample_topp, those token IDs may not match what tq_encode()
* produces from the same response text in the next turn's prompt. The
* token-level LCP in tq_generate_continue truncates at that boundary.
*
* This function tracks the *text* of the last prompt (which includes the
* model's response from previous turns, accumulated by the caller). On the
* next call, if the new prompt starts with cached_text byte-for-byte, the
* entire cached state is valid — we tokenize only the new SUFFIX text and
* prefill those tokens at positions [n_cached..]. No LCP, no truncation.
*
* After generation, *cached_text_io is updated to:
* prompt + (generated tokens decoded back to text)
* so the next call can fast-path again.
*
* Caller owns *cached_text_io (must free with free()).
* Pass cached_text_io == NULL to disable text-prefix tracking and behave
* exactly like tq_generate_continue.
* ============================================================================ */

typedef struct {
char* buf;
size_t len;
size_t cap;
void (*user_cb)(const char*, void*);
void* user_data;
} chat_accum_t;

static void chat_accum_callback(const char* tok, void* u) {
chat_accum_t* ctx = (chat_accum_t*)u;
if (!tok) return;
size_t tlen = strlen(tok);
if (ctx->len + tlen + 1 > ctx->cap) {
size_t new_cap = (ctx->cap + tlen + 64) * 2;
char* nb = (char*)realloc(ctx->buf, new_cap);
if (!nb) return;
ctx->buf = nb;
ctx->cap = new_cap;
}
memcpy(ctx->buf + ctx->len, tok, tlen);
ctx->len += tlen;
ctx->buf[ctx->len] = '\0';
if (ctx->user_cb) ctx->user_cb(tok, ctx->user_data);
}

int tq_generate_chat_text(tq_model_t* model,
tq_tokenizer_t* tokenizer,
tq_state_t* state,
const char* prompt,
tq_gen_config_t* config,
char** cached_text_io,
int** cached_tokens_io,
int* n_cached_io,
int* cached_capacity_io,
char* output, int output_size) {
if (!model || !state || !config || !cached_tokens_io || !n_cached_io || !cached_capacity_io || !prompt) {
return -1;
}

/* --- 1. Check for text-level prefix match --- */
int matched_text_len = 0;
int prefix_pos = 0; /* tokens already in KV cache that we trust */

if (cached_text_io && *cached_text_io && *n_cached_io > 0) {
size_t cached_len = strlen(*cached_text_io);
if (cached_len > 0 && strncmp(*cached_text_io, prompt, cached_len) == 0) {
matched_text_len = (int)cached_len;
prefix_pos = *n_cached_io;
} else if (getenv("TQ_CHAT_DEBUG")) {
/* Find where they diverge to help diagnose */
size_t diverge = 0;
size_t plen = strlen(prompt);
size_t lim = cached_len < plen ? cached_len : plen;
while (diverge < lim && (*cached_text_io)[diverge] == prompt[diverge]) diverge++;
fprintf(stderr,
"[chat-text] no match: cached_len=%zu prompt_len=%zu diverge_at=%zu\n"
" cached[%zu..]: %.40s\n"
" prompt[%zu..]: %.40s\n",
cached_len, plen, diverge,
diverge, *cached_text_io + diverge,
diverge, prompt + diverge);
}
}

/* Wrap user callback to capture generated text into a buffer for the
* next call's cached_text update. */
chat_accum_t accum = { .buf = NULL, .len = 0, .cap = 0,
.user_cb = config->on_token,
.user_data = config->user_data };
void (*orig_cb)(const char*, void*) = config->on_token;
void* orig_ud = config->user_data;
config->on_token = chat_accum_callback;
config->user_data = &accum;

int generated = 0;

if (matched_text_len > 0) {
/* --- Fast path: text prefix matches --- */
const char* suffix = prompt + matched_text_len;
int max_prompt = model->config.max_seq_len > 0
? model->config.max_seq_len : 4096;
int* suffix_toks = (int*)malloc((size_t)max_prompt * sizeof(int));
if (!suffix_toks) {
config->on_token = orig_cb; config->user_data = orig_ud;
return -1;
}
int n_suffix = 0;
if (*suffix != '\0') {
n_suffix = tq_encode(tokenizer, suffix, suffix_toks, max_prompt, 0);
if (n_suffix < 0) n_suffix = 0;
}

/* Sliding window if needed (drop from start of cached) */
int reserve = config->max_tokens > 0 ? config->max_tokens : 256;
if (prefix_pos + n_suffix + reserve + 32 > max_prompt) {
/* Force a full reprefill — simpler than partial cache shift */
free(suffix_toks);
config->on_token = orig_cb; config->user_data = orig_ud;
*n_cached_io = 0;
if (cached_text_io && *cached_text_io) {
free(*cached_text_io); *cached_text_io = NULL;
}
int n2 = tq_generate_continue(model, tokenizer, state, prompt, config,
cached_tokens_io, n_cached_io, cached_capacity_io,
output, output_size);
/* fall-through path captures cached_text below */
generated = n2;
goto update_cache;
}

/* Grow cache buffer */
int needed = prefix_pos + n_suffix + reserve + 16;
if (*cached_capacity_io < needed) {
int new_cap = needed < 4096 ? 4096 : needed;
int* nb = (int*)realloc(*cached_tokens_io, (size_t)new_cap * sizeof(int));
if (!nb) { free(suffix_toks); config->on_token = orig_cb; config->user_data = orig_ud; return -1; }
*cached_tokens_io = nb;
*cached_capacity_io = new_cap;
}

/* Append suffix tokens to cache + prefill at correct positions */
int* cached = *cached_tokens_io;
for (int i = 0; i < n_suffix; i++) {
cached[prefix_pos + i] = suffix_toks[i];
tq_forward(model, state, suffix_toks[i], prefix_pos + i);
}
*n_cached_io = prefix_pos + n_suffix;
free(suffix_toks);

if (getenv("TQ_CHAT_DEBUG")) {
fprintf(stderr, "[chat-text] FAST text_match=%d new_suffix_tokens=%d\n",
matched_text_len, n_suffix);
}

/* --- Run generation loop directly --- */
int vocab_size = model->config.vocab_size;
int n_cached = *n_cached_io;
int pos = n_cached;
int prev_token = n_cached > 0 ? cached[n_cached - 1] : 1;

unsigned long long rng_state = config->rng_seed
? (unsigned long long)config->rng_seed : (unsigned long long)time(NULL);
int next_token = tq_sample_topp(state->logits, vocab_size,
config->temperature, config->top_p,
&rng_state);

int output_pos = 0;
int eos_tokens[] = { 1, 2, 106, 128001, 128006, 128007, 128008, 128009, 248044, 248046 };
int n_eos = sizeof(eos_tokens) / sizeof(eos_tokens[0]);

while (generated < config->max_tokens) {
int is_eos = 0;
for (int e = 0; e < n_eos; e++) {
if (next_token == eos_tokens[e]) { is_eos = 1; break; }
}
if (is_eos) break;
if (pos >= model->config.max_seq_len) break;

const char* piece = tokenizer ? tq_decode(tokenizer, prev_token, next_token) : "";
int should_stop = 0;
if (piece) {
if (strstr(piece, "<|im_end|>") || strstr(piece, "<|eot_id|>") ||
strstr(piece, "<|start_header_id|>")) {
should_stop = 1; piece = "";
}
}
if (should_stop) break;

int piece_len = (int)strlen(piece ? piece : "");
if (config->on_token && piece) config->on_token(piece, config->user_data);
if (output && piece && output_pos + piece_len < output_size - 1) {
memcpy(output + output_pos, piece, piece_len);
output_pos += piece_len;
}

if (n_cached < *cached_capacity_io) {
cached[n_cached++] = next_token;
*n_cached_io = n_cached;
}

prev_token = next_token;
tq_forward(model, state, next_token, pos);
pos++;
generated++;

next_token = tq_sample_topp(state->logits, vocab_size,
config->temperature, config->top_p,
&rng_state);
}

if (output && output_size > 0) {
output[output_pos < output_size ? output_pos : output_size - 1] = '\0';
}
} else {
/* --- Slow path: no text-prefix match, use token LCP fallback --- */
if (getenv("TQ_CHAT_DEBUG")) {
fprintf(stderr, "[chat-text] SLOW no text-prefix match, full tokenize\n");
}
generated = tq_generate_continue(
model, tokenizer, state, prompt, config,
cached_tokens_io, n_cached_io, cached_capacity_io,
output, output_size);
}

update_cache:
/* Restore the original callback before returning to caller */
config->on_token = orig_cb;
config->user_data = orig_ud;

/* Update cached_text = prompt + generated text. The next call can
* fast-path against this if its prompt starts with this string. */
if (cached_text_io) {
size_t plen = strlen(prompt);
size_t glen = accum.len;
size_t new_len = plen + glen;
char* nt = (char*)malloc(new_len + 1);
if (nt) {
memcpy(nt, prompt, plen);
if (glen > 0 && accum.buf) memcpy(nt + plen, accum.buf, glen);
nt[new_len] = '\0';
if (*cached_text_io) free(*cached_text_io);
*cached_text_io = nt;
}
}
if (accum.buf) free(accum.buf);

return generated;
}
62 changes: 43 additions & 19 deletions src/server/tq_server.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
#include <stdarg.h>
#include <stdbool.h>

/* Forward decl: defined in src/engine/tq_generate.c.
* Not yet exposed in turboquant.h since it's a chat-mode helper. */
/* Forward decls: defined in src/engine/tq_generate.c.
* Not yet exposed in turboquant.h since they're chat-mode helpers. */
extern int tq_generate_continue(tq_model_t* model,
tq_tokenizer_t* tokenizer,
tq_state_t* state,
Expand All @@ -30,6 +30,18 @@ extern int tq_generate_continue(tq_model_t* model,
int* n_cached_io,
int* cached_capacity_io,
char* output, int output_size);

/* Text-prefix matching variant — solves BPE re-tokenization mismatch. */
extern int tq_generate_chat_text(tq_model_t* model,
tq_tokenizer_t* tokenizer,
tq_state_t* state,
const char* prompt,
tq_gen_config_t* config,
char** cached_text_io,
int** cached_tokens_io,
int* n_cached_io,
int* cached_capacity_io,
char* output, int output_size);
#if defined(_MSC_VER)
#include <intrin.h>
typedef volatile long atomic_int;
Expand Down Expand Up @@ -95,6 +107,7 @@ typedef struct {
int* cached_tokens;
int n_cached;
int cached_capacity;
char* cached_text; /* prompt + generated, for text-prefix matching */
long last_used; /* monotonic counter for LRU */
} kv_session_t;

Expand Down Expand Up @@ -144,6 +157,7 @@ static kv_session_t* get_or_create_session(tq_server_t* server,
/* Free old session contents (if any) */
if (s->kv_state) tq_free_state(s->kv_state);
if (s->cached_tokens) free(s->cached_tokens);
if (s->cached_text) free(s->cached_text);

memset(s, 0, sizeof(*s));
strncpy(s->id, sid, SESSION_ID_MAX - 1);
Expand Down Expand Up @@ -237,15 +251,22 @@ static const char* json_extract_string(const char* p, char* buf, int buf_size) {
/* Find a key in JSON and return pointer to value (past the colon).
* Simple scan — works for flat or lightly nested objects. */
static const char* json_find_key(const char* json, const char* key) {
/* Find a "key": pattern. Naive scan: locate every "key" occurrence
* and verify the next non-whitespace char is ':'. This skips false
* matches where "key" appears as a *value* (e.g., {"role":"user"}
* collides with json_find_key("user") if we don't check the colon). */
char pattern[256];
snprintf(pattern, sizeof(pattern), "\"%s\"", key);
const char* p = strstr(json, pattern);
if (!p) return NULL;
p += strlen(pattern);
p = json_skip_ws(p);
if (*p != ':') return NULL;
p++;
return json_skip_ws(p);
size_t plen = strlen(pattern);
const char* p = json;
while ((p = strstr(p, pattern)) != NULL) {
const char* after = json_skip_ws(p + plen);
if (*after == ':') {
return json_skip_ws(after + 1);
}
p += plen; /* skip past this false match and keep searching */
}
return NULL;
}

/* Extract a number (int or float) from current position */
Expand Down Expand Up @@ -758,11 +779,12 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod
kv_session_t* sess = get_or_create_session(server, req.session_id,
gen_cfg.kv_type,
gen_cfg.value_quant_bits);
tq_generate_continue(server->config.model, server->config.tokenizer,
sess->kv_state, req.prompt, &gen_cfg,
&sess->cached_tokens, &sess->n_cached,
&sess->cached_capacity,
output, sizeof(output));
tq_generate_chat_text(server->config.model, server->config.tokenizer,
sess->kv_state, req.prompt, &gen_cfg,
&sess->cached_text,
&sess->cached_tokens, &sess->n_cached,
&sess->cached_capacity,
output, sizeof(output));

/* Send final chunk with finish_reason */
char final_chunk[SSE_CHUNK_SIZE];
Expand Down Expand Up @@ -795,11 +817,12 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod
kv_session_t* sess = get_or_create_session(server, req.session_id,
gen_cfg.kv_type,
gen_cfg.value_quant_bits);
tq_generate_continue(server->config.model, server->config.tokenizer,
sess->kv_state, req.prompt, &gen_cfg,
&sess->cached_tokens, &sess->n_cached,
&sess->cached_capacity,
output, sizeof(output));
tq_generate_chat_text(server->config.model, server->config.tokenizer,
sess->kv_state, req.prompt, &gen_cfg,
&sess->cached_text,
&sess->cached_tokens, &sess->n_cached,
&sess->cached_capacity,
output, sizeof(output));

const char* content = collect.buf ? collect.buf : "";

Expand Down Expand Up @@ -1260,6 +1283,7 @@ void tq_server_free(tq_server_t* server) {
for (int i = 0; i < MAX_SESSIONS; i++) {
if (server->sessions[i].kv_state) tq_free_state(server->sessions[i].kv_state);
if (server->sessions[i].cached_tokens) free(server->sessions[i].cached_tokens);
if (server->sessions[i].cached_text) free(server->sessions[i].cached_text);
}
if (g_server == server) g_server = NULL;
free(server);
Expand Down
Loading