Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions bindings/python/quantcpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,79 @@ def _run():
if error_box[0] is not None:
raise error_box[0]

def chat(self, prompt: str) -> Iterator[str]:
"""Multi-turn chat with KV cache reuse.

Like ``generate()``, but the KV cache persists across calls. When you
re-send the conversation history each turn, only the new tokens are
prefilled — turn N's latency is O(new_tokens), not O(history^2).

Pass ``prompt=None`` to reset the chat session.

Falls back to ``generate()`` on older library builds without
``quant_chat`` symbol.
"""
self._ensure_open()
lib = get_lib()

if not hasattr(lib, "quant_chat"):
# Older library — silently fall back to non-reusing generate
yield from self.generate(prompt or "")
return

if prompt is None:
with self._lock:
lib.quant_chat(self._ctx, None, ON_TOKEN_CB(0), None)
return

if self._chat:
prompt = self._apply_chat_template(prompt)

tokens = []
done = threading.Event()
error_box = [None]

def _on_token(text_ptr, _user_data):
if text_ptr:
tokens.append(text_ptr.decode("utf-8", errors="replace"))

cb = ON_TOKEN_CB(_on_token)

def _run():
try:
with self._lock:
lib.quant_chat(self._ctx, prompt.encode("utf-8"), cb, None)
except Exception as e:
error_box[0] = e
finally:
done.set()

thread = threading.Thread(target=_run, daemon=True)
thread.start()

yielded = 0
while not done.is_set() or yielded < len(tokens):
if yielded < len(tokens):
yield tokens[yielded]
yielded += 1
else:
done.wait(timeout=0.01)

while yielded < len(tokens):
yield tokens[yielded]
yielded += 1

if error_box[0] is not None:
raise error_box[0]

def reset_chat(self) -> None:
"""Reset the chat KV cache. Next chat() call starts fresh."""
self._ensure_open()
lib = get_lib()
if hasattr(lib, "quant_chat"):
with self._lock:
lib.quant_chat(self._ctx, None, ON_TOKEN_CB(0), None)

def save_context(self, path: str) -> None:
"""Save the current KV cache to disk.

Expand Down
14 changes: 14 additions & 0 deletions bindings/python/quantcpp/_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,20 @@ def _setup_signatures(lib: ctypes.CDLL) -> None:
]
lib.quant_generate.restype = ctypes.c_int

# int quant_chat(quant_ctx* ctx, const char* prompt,
# void (*on_token)(const char*, void*), void* user_data)
# Multi-turn chat with KV cache reuse — avoids the O(n^2) prefill cost
# of quant_generate when the user re-sends conversation history.
# Optional: only present in single-header builds (>= v0.13).
if hasattr(lib, "quant_chat"):
lib.quant_chat.argtypes = [
ctypes.c_void_p,
ctypes.c_char_p,
ON_TOKEN_CB,
ctypes.c_void_p,
]
lib.quant_chat.restype = ctypes.c_int

# char* quant_ask(quant_ctx* ctx, const char* prompt)
lib.quant_ask.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
lib.quant_ask.restype = ctypes.c_void_p # We use c_void_p so we can free()
Expand Down
10 changes: 9 additions & 1 deletion bindings/python/quantcpp/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,23 @@ def cmd_run(args):
print()
else:
print("quantcpp \u2014 type your message, Ctrl+C to exit", file=sys.stderr)
# Multi-turn chat: accumulate history as ChatML so the model sees
# prior turns. m.chat() reuses the KV cache via prefix-match, so
# repeating the history is cheap (O(new tokens), not O(n^2)).
history = ""
try:
while True:
question = input("\nYou: ")
if not question.strip():
continue
history += f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
print("AI: ", end="", flush=True)
for tok in m.generate(question):
reply_buf = []
for tok in m.chat(history):
print(tok, end="", flush=True)
reply_buf.append(tok)
print()
history += "".join(reply_buf) + "<|im_end|>\n"
except (KeyboardInterrupt, EOFError):
print("\nBye!", file=sys.stderr)

Expand Down
Loading
Loading