diff --git a/bindings/python/quantcpp/__init__.py b/bindings/python/quantcpp/__init__.py index ff092e3..88bdb6e 100644 --- a/bindings/python/quantcpp/__init__.py +++ b/bindings/python/quantcpp/__init__.py @@ -383,6 +383,79 @@ def _run(): if error_box[0] is not None: raise error_box[0] + def chat(self, prompt: str) -> Iterator[str]: + """Multi-turn chat with KV cache reuse. + + Like ``generate()``, but the KV cache persists across calls. When you + re-send the conversation history each turn, only the new tokens are + prefilled — turn N's latency is O(new_tokens), not O(history^2). + + Pass ``prompt=None`` to reset the chat session. + + Falls back to ``generate()`` on older library builds without + ``quant_chat`` symbol. + """ + self._ensure_open() + lib = get_lib() + + if not hasattr(lib, "quant_chat"): + # Older library — silently fall back to non-reusing generate + yield from self.generate(prompt or "") + return + + if prompt is None: + with self._lock: + lib.quant_chat(self._ctx, None, ON_TOKEN_CB(0), None) + return + + if self._chat: + prompt = self._apply_chat_template(prompt) + + tokens = [] + done = threading.Event() + error_box = [None] + + def _on_token(text_ptr, _user_data): + if text_ptr: + tokens.append(text_ptr.decode("utf-8", errors="replace")) + + cb = ON_TOKEN_CB(_on_token) + + def _run(): + try: + with self._lock: + lib.quant_chat(self._ctx, prompt.encode("utf-8"), cb, None) + except Exception as e: + error_box[0] = e + finally: + done.set() + + thread = threading.Thread(target=_run, daemon=True) + thread.start() + + yielded = 0 + while not done.is_set() or yielded < len(tokens): + if yielded < len(tokens): + yield tokens[yielded] + yielded += 1 + else: + done.wait(timeout=0.01) + + while yielded < len(tokens): + yield tokens[yielded] + yielded += 1 + + if error_box[0] is not None: + raise error_box[0] + + def reset_chat(self) -> None: + """Reset the chat KV cache. Next chat() call starts fresh.""" + self._ensure_open() + lib = get_lib() + if hasattr(lib, "quant_chat"): + with self._lock: + lib.quant_chat(self._ctx, None, ON_TOKEN_CB(0), None) + def save_context(self, path: str) -> None: """Save the current KV cache to disk. diff --git a/bindings/python/quantcpp/_binding.py b/bindings/python/quantcpp/_binding.py index 7eb0957..6a750ce 100644 --- a/bindings/python/quantcpp/_binding.py +++ b/bindings/python/quantcpp/_binding.py @@ -132,6 +132,20 @@ def _setup_signatures(lib: ctypes.CDLL) -> None: ] lib.quant_generate.restype = ctypes.c_int + # int quant_chat(quant_ctx* ctx, const char* prompt, + # void (*on_token)(const char*, void*), void* user_data) + # Multi-turn chat with KV cache reuse — avoids the O(n^2) prefill cost + # of quant_generate when the user re-sends conversation history. + # Optional: only present in single-header builds (>= v0.13). + if hasattr(lib, "quant_chat"): + lib.quant_chat.argtypes = [ + ctypes.c_void_p, + ctypes.c_char_p, + ON_TOKEN_CB, + ctypes.c_void_p, + ] + lib.quant_chat.restype = ctypes.c_int + # char* quant_ask(quant_ctx* ctx, const char* prompt) lib.quant_ask.argtypes = [ctypes.c_void_p, ctypes.c_char_p] lib.quant_ask.restype = ctypes.c_void_p # We use c_void_p so we can free() diff --git a/bindings/python/quantcpp/cli.py b/bindings/python/quantcpp/cli.py index a9506d1..954d7fc 100644 --- a/bindings/python/quantcpp/cli.py +++ b/bindings/python/quantcpp/cli.py @@ -152,15 +152,23 @@ def cmd_run(args): print() else: print("quantcpp \u2014 type your message, Ctrl+C to exit", file=sys.stderr) + # Multi-turn chat: accumulate history as ChatML so the model sees + # prior turns. m.chat() reuses the KV cache via prefix-match, so + # repeating the history is cheap (O(new tokens), not O(n^2)). + history = "" try: while True: question = input("\nYou: ") if not question.strip(): continue + history += f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n" print("AI: ", end="", flush=True) - for tok in m.generate(question): + reply_buf = [] + for tok in m.chat(history): print(tok, end="", flush=True) + reply_buf.append(tok) print() + history += "".join(reply_buf) + "<|im_end|>\n" except (KeyboardInterrupt, EOFError): print("\nBye!", file=sys.stderr) diff --git a/quant.h b/quant.h index b951702..15ad224 100644 --- a/quant.h +++ b/quant.h @@ -62,6 +62,13 @@ int quant_generate(quant_ctx* ctx, const char* prompt, void (*on_token)(const char* text, void* user_data), void* user_data); +// Multi-turn chat with KV cache reuse (O(delta) per turn instead of O(n^2)). +// Subsequent calls only re-prefill the suffix that diverges from history. +// Pass prompt = NULL to reset the chat session. Returns tokens generated. +int quant_chat(quant_ctx* ctx, const char* prompt, + void (*on_token)(const char* text, void* user_data), + void* user_data); + // Generate and return full response as string. Caller must free(). char* quant_ask(quant_ctx* ctx, const char* prompt); @@ -1729,7 +1736,15 @@ struct quant_ctx { tq_state_t* state; tq_tokenizer_t* tokenizer; tq_gen_config_t config; - int n_ctx_tokens; /* number of tokens currently in KV cache */ + int n_ctx_tokens; /* number of tokens currently in KV cache */ + /* Prefix-match cache for chat history reuse: + * stores the actual token IDs that are committed to the KV cache, + * so the next quant_generate() can skip the matching prefix and + * only prefill the diverging suffix. Critical for chat mode where + * each turn re-sends the entire conversation history. */ + int* cached_tokens; + int n_cached; + int cached_capacity; }; // ============================================================================ @@ -15624,6 +15639,195 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer, return generated; } +/* ============================================================================ + * tq_generate_continue — reuse an existing tq_state_t across calls. + * + * Unlike tq_generate (which allocates and frees its own state on every call), + * this function takes a caller-managed state plus a record of which tokens + * are currently committed to the KV cache. It computes the longest common + * prefix between the cached tokens and the new prompt, then prefills only + * the diverging suffix. After generation, *cached_tokens_out and + * *n_cached_out are updated to reflect the new cache contents. + * + * This turns chat mode from O(n^2) (full re-prefill every turn) into + * O(delta) (only the new tokens of each turn). + * + * Returns the number of tokens generated, or -1 on error. + * ============================================================================ */ +static int tq_lcp_int(const int* a, int na, const int* b, int nb) { + int lim = na < nb ? na : nb; + int i = 0; + while (i < lim && a[i] == b[i]) i++; + return i; +} + +int tq_generate_continue(tq_model_t* model, + tq_tokenizer_t* tokenizer, + tq_state_t* state, + const char* prompt, + tq_gen_config_t* config, + int** cached_tokens_io, /* in/out: cached prefix tokens */ + int* n_cached_io, /* in/out: cached count */ + int* cached_capacity_io, /* in/out: allocated capacity */ + char* output, int output_size) { + if (!model || !state || !config || !cached_tokens_io || !n_cached_io || !cached_capacity_io) { + return -1; + } + + /* Encode new prompt */ + int new_tokens[4096]; + int n_new = 0; + if (tokenizer && prompt) { + int add_bos = (model->config.model_type == 1) ? 1 : 0; + n_new = tq_encode(tokenizer, prompt, new_tokens, 4096, add_bos); + } + if (n_new <= 0) { + new_tokens[0] = (model->config.model_type == 1) ? 2 : 1; + n_new = 1; + } + + /* Find longest common prefix with the cached tokens. + * If the new prompt is just an extension of the cached one, we skip + * everything up to the LCP and only prefill the suffix. */ + int n_cached = *n_cached_io; + int* cached_tokens = *cached_tokens_io; + + int lcp = tq_lcp_int(cached_tokens, n_cached, new_tokens, n_new); + + /* If the cached tokens go beyond the LCP (i.e., the new prompt diverges + * from history mid-way, e.g., user edited a previous message), we have + * to invalidate the divergent suffix. The simplest correct option is to + * roll the state position back to lcp. The KV cache itself doesn't need + * to be cleared — positions >= lcp will just be overwritten when we + * prefill the new suffix. */ + int pos_start = lcp; + + /* Prefill the new suffix */ + for (int i = lcp; i < n_new; i++) { + tq_forward(model, state, new_tokens[i], i); + } + int pos = n_new; + + /* Save the n_new prompt into the cache buffer (will append generated + * tokens below). Grow the buffer if needed. */ + int needed_cap = n_new + config->max_tokens + 16; + if (*cached_capacity_io < needed_cap) { + int new_cap = needed_cap < 4096 ? 4096 : needed_cap; + int* nb = (int*)realloc(*cached_tokens_io, (size_t)new_cap * sizeof(int)); + if (!nb) return -1; + *cached_tokens_io = nb; + *cached_capacity_io = new_cap; + cached_tokens = nb; + } + memcpy(cached_tokens, new_tokens, (size_t)n_new * sizeof(int)); + *n_cached_io = n_new; + n_cached = n_new; + + /* --- generation loop (mirrors tq_generate's loop) --- */ + int vocab_size = model->config.vocab_size; + float rep_penalty = config->rep_penalty; + int rep_window = config->rep_window; + if (rep_window > 64) rep_window = 64; + int recent_tokens[64]; + int recent_count = 0; + for (int i = (n_new > rep_window ? n_new - rep_window : 0); i < n_new; i++) { + recent_tokens[recent_count % 64] = new_tokens[i]; + recent_count++; + } + + if (rep_penalty > 1.0f) { + int window = recent_count < rep_window ? recent_count : rep_window; + for (int r = 0; r < window; r++) { + int idx = (recent_count - 1 - r) % 64; + if (idx < 0) idx += 64; + int tok = recent_tokens[idx]; + if (tok >= 0 && tok < vocab_size && state->logits) { + if (state->logits[tok] > 0) state->logits[tok] /= rep_penalty; + else state->logits[tok] *= rep_penalty; + } + } + } + + uint64_t rng_state = config->rng_seed ? (uint64_t)config->rng_seed + : (uint64_t)time(NULL); + int next_token = tq_sample_topp(state->logits, vocab_size, + config->temperature, config->top_p, + &rng_state); + + int generated = 0; + int output_pos = 0; + int prev_token = new_tokens[n_new - 1]; + + int eos_tokens[] = { + 1, 2, 106, 128001, 128006, 128007, 128008, 128009, 248044, 248046, + }; + int n_eos = sizeof(eos_tokens) / sizeof(eos_tokens[0]); + + while (generated < config->max_tokens) { + int is_eos = 0; + for (int e = 0; e < n_eos; e++) { + if (next_token == eos_tokens[e]) { is_eos = 1; break; } + } + if (is_eos) break; + + if (pos >= model->config.max_seq_len) break; /* simple stop, no shift */ + + /* Decode + stream */ + if (tokenizer) { + const char* piece = tq_decode(tokenizer, prev_token, next_token); + int should_stop = 0; + if (piece) { + if (strstr(piece, "<|im_end|>") || strstr(piece, "<|eot_id|>") || + strstr(piece, "<|start_header_id|>")) { + should_stop = 1; piece = ""; + } + } + if (should_stop) break; + int piece_len = (int)strlen(piece ? piece : ""); + if (config->on_token && piece) config->on_token(piece, config->user_data); + if (output && piece && output_pos + piece_len < output_size - 1) { + memcpy(output + output_pos, piece, piece_len); + output_pos += piece_len; + } + } + + /* Append generated token to cache record */ + if (n_cached < *cached_capacity_io) { + cached_tokens[n_cached++] = next_token; + *n_cached_io = n_cached; + } + + prev_token = next_token; + tq_forward(model, state, next_token, pos); + pos++; + generated++; + + if (rep_penalty > 1.0f) { + int window = recent_count < rep_window ? recent_count : rep_window; + for (int r = 0; r < window; r++) { + int idx = (recent_count - 1 - r) % 64; + if (idx < 0) idx += 64; + int tok = recent_tokens[idx]; + if (tok >= 0 && tok < vocab_size) { + if (state->logits[tok] > 0) state->logits[tok] /= rep_penalty; + else state->logits[tok] *= rep_penalty; + } + } + } + + next_token = tq_sample_topp(state->logits, vocab_size, + config->temperature, config->top_p, + &rng_state); + recent_tokens[recent_count % 64] = next_token; + recent_count++; + } + + if (output && output_size > 0) { + output[output_pos < output_size ? output_pos : output_size - 1] = '\0'; + } + return generated; +} + // ============================================================================ // ============================================================================ @@ -15957,9 +16161,65 @@ void quant_free_ctx(quant_ctx* ctx) { if (!ctx) return; tq_free_state(ctx->state); tq_free_tokenizer(ctx->tokenizer); + if (ctx->cached_tokens) free(ctx->cached_tokens); free(ctx); } +/* ---------------------------------------------------------------------- + * quant_chat — chat-mode generate that reuses the KV cache across calls. + * + * Unlike quant_generate (which resets the state on every call and so makes + * each turn O(history_length)), quant_chat keeps the state alive between + * calls. The first call to quant_chat() prefills and generates as normal. + * Subsequent calls compute the longest common prefix between the new prompt + * and the previously processed tokens, skip the matched prefix, and only + * prefill the diverging suffix. + * + * Result: turn N's prefill cost is O(new tokens this turn), not + * O(total history). Chat experience matches what users expect from ollama. + * + * Reset behavior: pass NULL prompt to wipe the cache (start a new chat). + * Returns the number of tokens generated, or -1 on error. + * ---------------------------------------------------------------------- */ +int quant_chat(quant_ctx* ctx, const char* prompt, + void (*on_token)(const char* text, void* user_data), + void* user_data) { + if (!ctx || !ctx->model) return -1; + + /* NULL prompt = reset the chat (clear cache + state) */ + if (!prompt) { + tq_free_state(ctx->state); + ctx->state = tq_create_state_ex(&ctx->model->config, + ctx->config.kv_type, + ctx->config.value_quant_bits); + if (ctx->cached_tokens) free(ctx->cached_tokens); + ctx->cached_tokens = NULL; + ctx->n_cached = 0; + ctx->cached_capacity = 0; + ctx->n_ctx_tokens = 0; + return 0; + } + + if (!ctx->state) { + ctx->state = tq_create_state_ex(&ctx->model->config, + ctx->config.kv_type, + ctx->config.value_quant_bits); + if (!ctx->state) return -1; + } + + ctx->config.on_token = on_token; + ctx->config.user_data = user_data; + + char output[65536]; + int n = tq_generate_continue( + ctx->model, ctx->tokenizer, ctx->state, prompt, &ctx->config, + &ctx->cached_tokens, &ctx->n_cached, &ctx->cached_capacity, + output, sizeof(output)); + + if (n > 0) ctx->n_ctx_tokens = ctx->n_cached; + return n; +} + void quant_free_model(quant_model* model) { tq_free_model((tq_model_t*)model); } diff --git a/src/engine/tq_generate.c b/src/engine/tq_generate.c index 208b4df..64f8515 100644 --- a/src/engine/tq_generate.c +++ b/src/engine/tq_generate.c @@ -601,3 +601,168 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer, tq_free_state(state); return generated; } + +/* ============================================================================ + * tq_generate_continue — chat-mode generation with KV cache reuse. + * + * Caller-managed state: state and cached_tokens persist across calls. + * Each call computes the longest common prefix between cached_tokens and + * the new prompt, prefills only the diverging suffix, and updates the + * cache record. Turns chat from O(history^2) into O(new_tokens_per_turn). + * ============================================================================ */ +static int tq_lcp_int(const int* a, int na, const int* b, int nb) { + int lim = na < nb ? na : nb; + int i = 0; + while (i < lim && a[i] == b[i]) i++; + return i; +} + +int tq_generate_continue(tq_model_t* model, + tq_tokenizer_t* tokenizer, + tq_state_t* state, + const char* prompt, + tq_gen_config_t* config, + int** cached_tokens_io, + int* n_cached_io, + int* cached_capacity_io, + char* output, int output_size) { + if (!model || !state || !config || !cached_tokens_io || !n_cached_io || !cached_capacity_io) { + return -1; + } + + /* Encode new prompt */ + int new_tokens[4096]; + int n_new = 0; + if (tokenizer && prompt) { + int add_bos = (model->config.model_type == 1) ? 1 : 0; + n_new = tq_encode(tokenizer, prompt, new_tokens, 4096, add_bos); + } + if (n_new <= 0) { + new_tokens[0] = (model->config.model_type == 1) ? 2 : 1; + n_new = 1; + } + + int n_cached = *n_cached_io; + int* cached_tokens = *cached_tokens_io; + int lcp = tq_lcp_int(cached_tokens, n_cached, new_tokens, n_new); + + /* Prefill only the new suffix [lcp, n_new) */ + for (int i = lcp; i < n_new; i++) { + tq_forward(model, state, new_tokens[i], i); + } + int pos = n_new; + + /* Grow cache buffer if needed */ + int needed_cap = n_new + config->max_tokens + 16; + if (*cached_capacity_io < needed_cap) { + int new_cap = needed_cap < 4096 ? 4096 : needed_cap; + int* nb = (int*)realloc(*cached_tokens_io, (size_t)new_cap * sizeof(int)); + if (!nb) return -1; + *cached_tokens_io = nb; + *cached_capacity_io = new_cap; + cached_tokens = nb; + } + memcpy(cached_tokens, new_tokens, (size_t)n_new * sizeof(int)); + *n_cached_io = n_new; + n_cached = n_new; + + int vocab_size = model->config.vocab_size; + float rep_penalty = config->rep_penalty; + int rep_window = config->rep_window; + if (rep_window > 64) rep_window = 64; + int recent_tokens[64]; + int recent_count = 0; + for (int i = (n_new > rep_window ? n_new - rep_window : 0); i < n_new; i++) { + recent_tokens[recent_count % 64] = new_tokens[i]; + recent_count++; + } + + if (rep_penalty > 1.0f) { + int window = recent_count < rep_window ? recent_count : rep_window; + for (int r = 0; r < window; r++) { + int idx = (recent_count - 1 - r) % 64; + if (idx < 0) idx += 64; + int tok = recent_tokens[idx]; + if (tok >= 0 && tok < vocab_size && state->logits) { + if (state->logits[tok] > 0) state->logits[tok] /= rep_penalty; + else state->logits[tok] *= rep_penalty; + } + } + } + + unsigned long long rng_state = config->rng_seed ? (unsigned long long)config->rng_seed + : (unsigned long long)time(NULL); + int next_token = tq_sample_topp(state->logits, vocab_size, + config->temperature, config->top_p, + &rng_state); + + int generated = 0; + int output_pos = 0; + int prev_token = new_tokens[n_new - 1]; + + int eos_tokens[] = { + 1, 2, 106, 128001, 128006, 128007, 128008, 128009, 248044, 248046, + }; + int n_eos = sizeof(eos_tokens) / sizeof(eos_tokens[0]); + + while (generated < config->max_tokens) { + int is_eos = 0; + for (int e = 0; e < n_eos; e++) { + if (next_token == eos_tokens[e]) { is_eos = 1; break; } + } + if (is_eos) break; + if (pos >= model->config.max_seq_len) break; + + if (tokenizer) { + const char* piece = tq_decode(tokenizer, prev_token, next_token); + int should_stop = 0; + if (piece) { + if (strstr(piece, "<|im_end|>") || strstr(piece, "<|eot_id|>") || + strstr(piece, "<|start_header_id|>")) { + should_stop = 1; piece = ""; + } + } + if (should_stop) break; + int piece_len = (int)strlen(piece ? piece : ""); + if (config->on_token && piece) config->on_token(piece, config->user_data); + if (output && piece && output_pos + piece_len < output_size - 1) { + memcpy(output + output_pos, piece, piece_len); + output_pos += piece_len; + } + } + + if (n_cached < *cached_capacity_io) { + cached_tokens[n_cached++] = next_token; + *n_cached_io = n_cached; + } + + prev_token = next_token; + tq_forward(model, state, next_token, pos); + pos++; + generated++; + + if (rep_penalty > 1.0f) { + int window = recent_count < rep_window ? recent_count : rep_window; + for (int r = 0; r < window; r++) { + int idx = (recent_count - 1 - r) % 64; + if (idx < 0) idx += 64; + int tok = recent_tokens[idx]; + if (tok >= 0 && tok < vocab_size) { + if (state->logits[tok] > 0) state->logits[tok] /= rep_penalty; + else state->logits[tok] *= rep_penalty; + } + } + } + + next_token = tq_sample_topp(state->logits, vocab_size, + config->temperature, config->top_p, + &rng_state); + recent_tokens[recent_count % 64] = next_token; + recent_count++; + } + + if (output && output_size > 0) { + output[output_pos < output_size ? output_pos : output_size - 1] = '\0'; + } + return generated; +} diff --git a/src/server/tq_server.c b/src/server/tq_server.c index 5f27edb..73b811c 100644 --- a/src/server/tq_server.c +++ b/src/server/tq_server.c @@ -18,6 +18,18 @@ #include #include #include + +/* Forward decl: defined in src/engine/tq_generate.c. + * Not yet exposed in turboquant.h since it's a chat-mode helper. */ +extern int tq_generate_continue(tq_model_t* model, + tq_tokenizer_t* tokenizer, + tq_state_t* state, + const char* prompt, + tq_gen_config_t* config, + int** cached_tokens_io, + int* n_cached_io, + int* cached_capacity_io, + char* output, int output_size); #if defined(_MSC_VER) #include typedef volatile long atomic_int; @@ -67,6 +79,13 @@ struct tq_server { atomic_int running; atomic_int active_connections; /* track concurrent threads */ pthread_mutex_t inference_mutex; /* serialize inference (single model state) */ + + /* Persistent inference state — shared across requests for chat-mode + * KV cache reuse. The inference_mutex above serializes access. */ + tq_state_t* kv_state; + int* cached_tokens; + int n_cached; + int cached_capacity; }; /* Global server pointer for signal handler */ @@ -653,9 +672,20 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod gen_cfg.on_token = sse_token_callback; gen_cfg.user_data = &sse_ctx; - char output[1]; /* tq_generate writes to output, but we use callback */ - tq_generate(server->config.model, server->config.tokenizer, - req.prompt, &gen_cfg, output, sizeof(output)); + char output[1]; /* writes via callback, output buffer unused */ + /* Use tq_generate_continue with persistent KV state for chat reuse: + * matches the longest common prefix of req.prompt against + * server->cached_tokens, prefills only the suffix. Turns chat + * latency from O(history^2) into O(new_tokens). */ + if (!server->kv_state) { + server->kv_state = tq_create_state_ex( + &server->config.model->config, gen_cfg.kv_type, gen_cfg.value_quant_bits); + } + tq_generate_continue(server->config.model, server->config.tokenizer, + server->kv_state, req.prompt, &gen_cfg, + &server->cached_tokens, &server->n_cached, + &server->cached_capacity, + output, sizeof(output)); /* Send final chunk with finish_reason */ char final_chunk[SSE_CHUNK_SIZE]; @@ -685,8 +715,15 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod gen_cfg.user_data = &collect; char output[1]; - tq_generate(server->config.model, server->config.tokenizer, - req.prompt, &gen_cfg, output, sizeof(output)); + if (!server->kv_state) { + server->kv_state = tq_create_state_ex( + &server->config.model->config, gen_cfg.kv_type, gen_cfg.value_quant_bits); + } + tq_generate_continue(server->config.model, server->config.tokenizer, + server->kv_state, req.prompt, &gen_cfg, + &server->cached_tokens, &server->n_cached, + &server->cached_capacity, + output, sizeof(output)); const char* content = collect.buf ? collect.buf : ""; @@ -1143,6 +1180,8 @@ void tq_server_stop(tq_server_t* server) { void tq_server_free(tq_server_t* server) { if (!server) return; pthread_mutex_destroy(&server->inference_mutex); + if (server->kv_state) tq_free_state(server->kv_state); + if (server->cached_tokens) free(server->cached_tokens); if (g_server == server) g_server = NULL; free(server); }