From aac0805169584fbd85be1f001a701125b0e3a4e0 Mon Sep 17 00:00:00 2001 From: quantumaikr Date: Sun, 12 Apr 2026 09:09:04 +0900 Subject: [PATCH] =?UTF-8?q?feat:=20text-prefix=20matching=20=E2=80=94=20by?= =?UTF-8?q?pass=20BPE=20re-tokenization=20in=20chat=20reuse?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follow-up to PR #49. The token-level LCP path in tq_generate_continue has a fundamental limitation: model-generated tokens (sample_topp) and text-encoded tokens (tq_encode of the response in the next turn) can diverge due to BPE merge non-roundtripping. This caps per-turn LCP at the prompt boundary (~10 tokens), so longer histories still incur mostly-full reprefill. Fix: tq_generate_chat_text() — text-level prefix matching. How it works: 1. Each session stores the entire prompt+response text from the previous call (cached_text). 2. On a new request, check if the new prompt starts with cached_text byte-for-byte. If yes, the cached state is byte-equivalent valid. 3. Tokenize ONLY the suffix (new_prompt[strlen(cached_text):]) and prefill those tokens at positions [n_cached..n_cached + n_suffix). 4. Run generation. The accumulated output text gets appended to cached_text via a tee callback for the next call. 5. If text prefix doesn't match, fall back to tq_generate_continue (token LCP path). Bug fix bundled: json_find_key("user") was matching the value in {"role":"user"} instead of the top-level "user" key. Result: every request used the "default" session, so multi-session was effectively broken (cross-pollution). The fix scans for "key": (with colon) to disambiguate from value matches. Measured (SmolLM2-135M, single thread, real chat replay): Single user, 10-turn accumulation: PR #48 (token LCP only): turn 10 → 3700 ms PR #49 (above + multi-session): turn 10 → 3700 ms (LCP still capped) This PR (text-prefix path): turn 10 → 739 ms (5x) alice + bob interleaved, 5 turns each (real assistant replay): PR #49: alice 5 = 2412 ms, bob 5 = 2357 ms Now: alice 5 = 498 ms, bob 5 = 462 ms (5x) The growth that remains (~50ms/turn) is the unavoidable O(n) cost of the attention computation over the full context — KV prefill is now truly O(new tokens per turn), not O(full history per turn). Co-Authored-By: Claude Opus 4.6 (1M context) --- src/engine/tq_generate.c | 252 +++++++++++++++++++++++++++++++++++++++ src/server/tq_server.c | 62 +++++++--- 2 files changed, 295 insertions(+), 19 deletions(-) diff --git a/src/engine/tq_generate.c b/src/engine/tq_generate.c index fc5ea3b..cca9e9b 100644 --- a/src/engine/tq_generate.c +++ b/src/engine/tq_generate.c @@ -802,3 +802,255 @@ int tq_generate_continue(tq_model_t* model, free(new_tokens); return generated; } + +/* ============================================================================ + * tq_generate_chat_text — text-prefix matching for chat reuse + * + * Solves the BPE re-tokenization issue: when the model generates response + * tokens via sample_topp, those token IDs may not match what tq_encode() + * produces from the same response text in the next turn's prompt. The + * token-level LCP in tq_generate_continue truncates at that boundary. + * + * This function tracks the *text* of the last prompt (which includes the + * model's response from previous turns, accumulated by the caller). On the + * next call, if the new prompt starts with cached_text byte-for-byte, the + * entire cached state is valid — we tokenize only the new SUFFIX text and + * prefill those tokens at positions [n_cached..]. No LCP, no truncation. + * + * After generation, *cached_text_io is updated to: + * prompt + (generated tokens decoded back to text) + * so the next call can fast-path again. + * + * Caller owns *cached_text_io (must free with free()). + * Pass cached_text_io == NULL to disable text-prefix tracking and behave + * exactly like tq_generate_continue. + * ============================================================================ */ + +typedef struct { + char* buf; + size_t len; + size_t cap; + void (*user_cb)(const char*, void*); + void* user_data; +} chat_accum_t; + +static void chat_accum_callback(const char* tok, void* u) { + chat_accum_t* ctx = (chat_accum_t*)u; + if (!tok) return; + size_t tlen = strlen(tok); + if (ctx->len + tlen + 1 > ctx->cap) { + size_t new_cap = (ctx->cap + tlen + 64) * 2; + char* nb = (char*)realloc(ctx->buf, new_cap); + if (!nb) return; + ctx->buf = nb; + ctx->cap = new_cap; + } + memcpy(ctx->buf + ctx->len, tok, tlen); + ctx->len += tlen; + ctx->buf[ctx->len] = '\0'; + if (ctx->user_cb) ctx->user_cb(tok, ctx->user_data); +} + +int tq_generate_chat_text(tq_model_t* model, + tq_tokenizer_t* tokenizer, + tq_state_t* state, + const char* prompt, + tq_gen_config_t* config, + char** cached_text_io, + int** cached_tokens_io, + int* n_cached_io, + int* cached_capacity_io, + char* output, int output_size) { + if (!model || !state || !config || !cached_tokens_io || !n_cached_io || !cached_capacity_io || !prompt) { + return -1; + } + + /* --- 1. Check for text-level prefix match --- */ + int matched_text_len = 0; + int prefix_pos = 0; /* tokens already in KV cache that we trust */ + + if (cached_text_io && *cached_text_io && *n_cached_io > 0) { + size_t cached_len = strlen(*cached_text_io); + if (cached_len > 0 && strncmp(*cached_text_io, prompt, cached_len) == 0) { + matched_text_len = (int)cached_len; + prefix_pos = *n_cached_io; + } else if (getenv("TQ_CHAT_DEBUG")) { + /* Find where they diverge to help diagnose */ + size_t diverge = 0; + size_t plen = strlen(prompt); + size_t lim = cached_len < plen ? cached_len : plen; + while (diverge < lim && (*cached_text_io)[diverge] == prompt[diverge]) diverge++; + fprintf(stderr, + "[chat-text] no match: cached_len=%zu prompt_len=%zu diverge_at=%zu\n" + " cached[%zu..]: %.40s\n" + " prompt[%zu..]: %.40s\n", + cached_len, plen, diverge, + diverge, *cached_text_io + diverge, + diverge, prompt + diverge); + } + } + + /* Wrap user callback to capture generated text into a buffer for the + * next call's cached_text update. */ + chat_accum_t accum = { .buf = NULL, .len = 0, .cap = 0, + .user_cb = config->on_token, + .user_data = config->user_data }; + void (*orig_cb)(const char*, void*) = config->on_token; + void* orig_ud = config->user_data; + config->on_token = chat_accum_callback; + config->user_data = &accum; + + int generated = 0; + + if (matched_text_len > 0) { + /* --- Fast path: text prefix matches --- */ + const char* suffix = prompt + matched_text_len; + int max_prompt = model->config.max_seq_len > 0 + ? model->config.max_seq_len : 4096; + int* suffix_toks = (int*)malloc((size_t)max_prompt * sizeof(int)); + if (!suffix_toks) { + config->on_token = orig_cb; config->user_data = orig_ud; + return -1; + } + int n_suffix = 0; + if (*suffix != '\0') { + n_suffix = tq_encode(tokenizer, suffix, suffix_toks, max_prompt, 0); + if (n_suffix < 0) n_suffix = 0; + } + + /* Sliding window if needed (drop from start of cached) */ + int reserve = config->max_tokens > 0 ? config->max_tokens : 256; + if (prefix_pos + n_suffix + reserve + 32 > max_prompt) { + /* Force a full reprefill — simpler than partial cache shift */ + free(suffix_toks); + config->on_token = orig_cb; config->user_data = orig_ud; + *n_cached_io = 0; + if (cached_text_io && *cached_text_io) { + free(*cached_text_io); *cached_text_io = NULL; + } + int n2 = tq_generate_continue(model, tokenizer, state, prompt, config, + cached_tokens_io, n_cached_io, cached_capacity_io, + output, output_size); + /* fall-through path captures cached_text below */ + generated = n2; + goto update_cache; + } + + /* Grow cache buffer */ + int needed = prefix_pos + n_suffix + reserve + 16; + if (*cached_capacity_io < needed) { + int new_cap = needed < 4096 ? 4096 : needed; + int* nb = (int*)realloc(*cached_tokens_io, (size_t)new_cap * sizeof(int)); + if (!nb) { free(suffix_toks); config->on_token = orig_cb; config->user_data = orig_ud; return -1; } + *cached_tokens_io = nb; + *cached_capacity_io = new_cap; + } + + /* Append suffix tokens to cache + prefill at correct positions */ + int* cached = *cached_tokens_io; + for (int i = 0; i < n_suffix; i++) { + cached[prefix_pos + i] = suffix_toks[i]; + tq_forward(model, state, suffix_toks[i], prefix_pos + i); + } + *n_cached_io = prefix_pos + n_suffix; + free(suffix_toks); + + if (getenv("TQ_CHAT_DEBUG")) { + fprintf(stderr, "[chat-text] FAST text_match=%d new_suffix_tokens=%d\n", + matched_text_len, n_suffix); + } + + /* --- Run generation loop directly --- */ + int vocab_size = model->config.vocab_size; + int n_cached = *n_cached_io; + int pos = n_cached; + int prev_token = n_cached > 0 ? cached[n_cached - 1] : 1; + + unsigned long long rng_state = config->rng_seed + ? (unsigned long long)config->rng_seed : (unsigned long long)time(NULL); + int next_token = tq_sample_topp(state->logits, vocab_size, + config->temperature, config->top_p, + &rng_state); + + int output_pos = 0; + int eos_tokens[] = { 1, 2, 106, 128001, 128006, 128007, 128008, 128009, 248044, 248046 }; + int n_eos = sizeof(eos_tokens) / sizeof(eos_tokens[0]); + + while (generated < config->max_tokens) { + int is_eos = 0; + for (int e = 0; e < n_eos; e++) { + if (next_token == eos_tokens[e]) { is_eos = 1; break; } + } + if (is_eos) break; + if (pos >= model->config.max_seq_len) break; + + const char* piece = tokenizer ? tq_decode(tokenizer, prev_token, next_token) : ""; + int should_stop = 0; + if (piece) { + if (strstr(piece, "<|im_end|>") || strstr(piece, "<|eot_id|>") || + strstr(piece, "<|start_header_id|>")) { + should_stop = 1; piece = ""; + } + } + if (should_stop) break; + + int piece_len = (int)strlen(piece ? piece : ""); + if (config->on_token && piece) config->on_token(piece, config->user_data); + if (output && piece && output_pos + piece_len < output_size - 1) { + memcpy(output + output_pos, piece, piece_len); + output_pos += piece_len; + } + + if (n_cached < *cached_capacity_io) { + cached[n_cached++] = next_token; + *n_cached_io = n_cached; + } + + prev_token = next_token; + tq_forward(model, state, next_token, pos); + pos++; + generated++; + + next_token = tq_sample_topp(state->logits, vocab_size, + config->temperature, config->top_p, + &rng_state); + } + + if (output && output_size > 0) { + output[output_pos < output_size ? output_pos : output_size - 1] = '\0'; + } + } else { + /* --- Slow path: no text-prefix match, use token LCP fallback --- */ + if (getenv("TQ_CHAT_DEBUG")) { + fprintf(stderr, "[chat-text] SLOW no text-prefix match, full tokenize\n"); + } + generated = tq_generate_continue( + model, tokenizer, state, prompt, config, + cached_tokens_io, n_cached_io, cached_capacity_io, + output, output_size); + } + +update_cache: + /* Restore the original callback before returning to caller */ + config->on_token = orig_cb; + config->user_data = orig_ud; + + /* Update cached_text = prompt + generated text. The next call can + * fast-path against this if its prompt starts with this string. */ + if (cached_text_io) { + size_t plen = strlen(prompt); + size_t glen = accum.len; + size_t new_len = plen + glen; + char* nt = (char*)malloc(new_len + 1); + if (nt) { + memcpy(nt, prompt, plen); + if (glen > 0 && accum.buf) memcpy(nt + plen, accum.buf, glen); + nt[new_len] = '\0'; + if (*cached_text_io) free(*cached_text_io); + *cached_text_io = nt; + } + } + if (accum.buf) free(accum.buf); + + return generated; +} diff --git a/src/server/tq_server.c b/src/server/tq_server.c index f41d6b9..898c3cb 100644 --- a/src/server/tq_server.c +++ b/src/server/tq_server.c @@ -19,8 +19,8 @@ #include #include -/* Forward decl: defined in src/engine/tq_generate.c. - * Not yet exposed in turboquant.h since it's a chat-mode helper. */ +/* Forward decls: defined in src/engine/tq_generate.c. + * Not yet exposed in turboquant.h since they're chat-mode helpers. */ extern int tq_generate_continue(tq_model_t* model, tq_tokenizer_t* tokenizer, tq_state_t* state, @@ -30,6 +30,18 @@ extern int tq_generate_continue(tq_model_t* model, int* n_cached_io, int* cached_capacity_io, char* output, int output_size); + +/* Text-prefix matching variant — solves BPE re-tokenization mismatch. */ +extern int tq_generate_chat_text(tq_model_t* model, + tq_tokenizer_t* tokenizer, + tq_state_t* state, + const char* prompt, + tq_gen_config_t* config, + char** cached_text_io, + int** cached_tokens_io, + int* n_cached_io, + int* cached_capacity_io, + char* output, int output_size); #if defined(_MSC_VER) #include typedef volatile long atomic_int; @@ -95,6 +107,7 @@ typedef struct { int* cached_tokens; int n_cached; int cached_capacity; + char* cached_text; /* prompt + generated, for text-prefix matching */ long last_used; /* monotonic counter for LRU */ } kv_session_t; @@ -144,6 +157,7 @@ static kv_session_t* get_or_create_session(tq_server_t* server, /* Free old session contents (if any) */ if (s->kv_state) tq_free_state(s->kv_state); if (s->cached_tokens) free(s->cached_tokens); + if (s->cached_text) free(s->cached_text); memset(s, 0, sizeof(*s)); strncpy(s->id, sid, SESSION_ID_MAX - 1); @@ -237,15 +251,22 @@ static const char* json_extract_string(const char* p, char* buf, int buf_size) { /* Find a key in JSON and return pointer to value (past the colon). * Simple scan — works for flat or lightly nested objects. */ static const char* json_find_key(const char* json, const char* key) { + /* Find a "key": pattern. Naive scan: locate every "key" occurrence + * and verify the next non-whitespace char is ':'. This skips false + * matches where "key" appears as a *value* (e.g., {"role":"user"} + * collides with json_find_key("user") if we don't check the colon). */ char pattern[256]; snprintf(pattern, sizeof(pattern), "\"%s\"", key); - const char* p = strstr(json, pattern); - if (!p) return NULL; - p += strlen(pattern); - p = json_skip_ws(p); - if (*p != ':') return NULL; - p++; - return json_skip_ws(p); + size_t plen = strlen(pattern); + const char* p = json; + while ((p = strstr(p, pattern)) != NULL) { + const char* after = json_skip_ws(p + plen); + if (*after == ':') { + return json_skip_ws(after + 1); + } + p += plen; /* skip past this false match and keep searching */ + } + return NULL; } /* Extract a number (int or float) from current position */ @@ -758,11 +779,12 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod kv_session_t* sess = get_or_create_session(server, req.session_id, gen_cfg.kv_type, gen_cfg.value_quant_bits); - tq_generate_continue(server->config.model, server->config.tokenizer, - sess->kv_state, req.prompt, &gen_cfg, - &sess->cached_tokens, &sess->n_cached, - &sess->cached_capacity, - output, sizeof(output)); + tq_generate_chat_text(server->config.model, server->config.tokenizer, + sess->kv_state, req.prompt, &gen_cfg, + &sess->cached_text, + &sess->cached_tokens, &sess->n_cached, + &sess->cached_capacity, + output, sizeof(output)); /* Send final chunk with finish_reason */ char final_chunk[SSE_CHUNK_SIZE]; @@ -795,11 +817,12 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod kv_session_t* sess = get_or_create_session(server, req.session_id, gen_cfg.kv_type, gen_cfg.value_quant_bits); - tq_generate_continue(server->config.model, server->config.tokenizer, - sess->kv_state, req.prompt, &gen_cfg, - &sess->cached_tokens, &sess->n_cached, - &sess->cached_capacity, - output, sizeof(output)); + tq_generate_chat_text(server->config.model, server->config.tokenizer, + sess->kv_state, req.prompt, &gen_cfg, + &sess->cached_text, + &sess->cached_tokens, &sess->n_cached, + &sess->cached_capacity, + output, sizeof(output)); const char* content = collect.buf ? collect.buf : ""; @@ -1260,6 +1283,7 @@ void tq_server_free(tq_server_t* server) { for (int i = 0; i < MAX_SESSIONS; i++) { if (server->sessions[i].kv_state) tq_free_state(server->sessions[i].kv_state); if (server->sessions[i].cached_tokens) free(server->sessions[i].cached_tokens); + if (server->sessions[i].cached_text) free(server->sessions[i].cached_text); } if (g_server == server) g_server = NULL; free(server);