Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 229 additions & 1 deletion quant.h
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,10 @@ struct quant_ctx {
int* cached_tokens;
int n_cached;
int cached_capacity;
/* Text-prefix cache: stores the entire prompt + generated response
* text from the last call, allowing the next call to bypass BPE
* re-tokenization issues by matching at the byte level. */
char* cached_text;
};

// ============================================================================
Expand Down Expand Up @@ -15848,6 +15852,225 @@ int tq_generate_continue(tq_model_t* model,
return generated;
}

/* ============================================================================
* tq_generate_chat_text — text-prefix matching for chat reuse
*
* Solves the BPE re-tokenization issue: when the model generates response
* tokens via sample_topp, those token IDs may not match what tq_encode()
* produces from the same response text in the next turn's prompt. The
* token-level LCP in tq_generate_continue truncates at that boundary.
*
* This function tracks the *text* of the last prompt+response. On the next
* call, if the new prompt starts with cached_text byte-for-byte, the entire
* cached state is valid — tokenize ONLY the new SUFFIX text and prefill
* those tokens at positions [n_cached..]. No LCP, no truncation.
*
* Pass cached_text_io == NULL to disable text-prefix tracking.
* ============================================================================ */

typedef struct {
char* buf;
size_t len;
size_t cap;
void (*user_cb)(const char*, void*);
void* user_data;
} chat_accum_t;

static void chat_accum_callback(const char* tok, void* u) {
chat_accum_t* ctx = (chat_accum_t*)u;
if (!tok) return;
size_t tlen = strlen(tok);
if (ctx->len + tlen + 1 > ctx->cap) {
size_t new_cap = (ctx->cap + tlen + 64) * 2;
char* nb = (char*)realloc(ctx->buf, new_cap);
if (!nb) return;
ctx->buf = nb;
ctx->cap = new_cap;
}
memcpy(ctx->buf + ctx->len, tok, tlen);
ctx->len += tlen;
ctx->buf[ctx->len] = '\0';
if (ctx->user_cb) ctx->user_cb(tok, ctx->user_data);
}

int tq_generate_chat_text(tq_model_t* model,
tq_tokenizer_t* tokenizer,
tq_state_t* state,
const char* prompt,
tq_gen_config_t* config,
char** cached_text_io,
int** cached_tokens_io,
int* n_cached_io,
int* cached_capacity_io,
char* output, int output_size) {
if (!model || !state || !config || !cached_tokens_io || !n_cached_io || !cached_capacity_io || !prompt) {
return -1;
}

int matched_text_len = 0;
int prefix_pos = 0;

if (cached_text_io && *cached_text_io && *n_cached_io > 0) {
size_t cached_len = strlen(*cached_text_io);
if (cached_len > 0 && strncmp(*cached_text_io, prompt, cached_len) == 0) {
matched_text_len = (int)cached_len;
prefix_pos = *n_cached_io;
}
}

chat_accum_t accum = { .buf = NULL, .len = 0, .cap = 0,
.user_cb = config->on_token,
.user_data = config->user_data };
void (*orig_cb)(const char*, void*) = config->on_token;
void* orig_ud = config->user_data;
config->on_token = chat_accum_callback;
config->user_data = &accum;

int generated = 0;

if (matched_text_len > 0) {
const char* suffix = prompt + matched_text_len;
int max_prompt = model->config.max_seq_len > 0
? model->config.max_seq_len : 4096;
int* suffix_toks = (int*)malloc((size_t)max_prompt * sizeof(int));
if (!suffix_toks) {
config->on_token = orig_cb; config->user_data = orig_ud;
return -1;
}
int n_suffix = 0;
if (*suffix != '\0') {
n_suffix = tq_encode(tokenizer, suffix, suffix_toks, max_prompt, 0);
if (n_suffix < 0) n_suffix = 0;
}

int reserve = config->max_tokens > 0 ? config->max_tokens : 256;
if (prefix_pos + n_suffix + reserve + 32 > max_prompt) {
free(suffix_toks);
config->on_token = orig_cb; config->user_data = orig_ud;
*n_cached_io = 0;
if (cached_text_io && *cached_text_io) {
free(*cached_text_io); *cached_text_io = NULL;
}
int n2 = tq_generate_continue(model, tokenizer, state, prompt, config,
cached_tokens_io, n_cached_io, cached_capacity_io,
output, output_size);
generated = n2;
goto update_cache;
}

int needed = prefix_pos + n_suffix + reserve + 16;
if (*cached_capacity_io < needed) {
int new_cap = needed < 4096 ? 4096 : needed;
int* nb = (int*)realloc(*cached_tokens_io, (size_t)new_cap * sizeof(int));
if (!nb) { free(suffix_toks); config->on_token = orig_cb; config->user_data = orig_ud; return -1; }
*cached_tokens_io = nb;
*cached_capacity_io = new_cap;
}

int* cached = *cached_tokens_io;
for (int i = 0; i < n_suffix; i++) {
cached[prefix_pos + i] = suffix_toks[i];
tq_forward(model, state, suffix_toks[i], prefix_pos + i);
}
*n_cached_io = prefix_pos + n_suffix;
free(suffix_toks);

if (getenv("TQ_CHAT_DEBUG")) {
fprintf(stderr, "[chat-text] FAST text_match=%d new_suffix_tokens=%d\n",
matched_text_len, n_suffix);
}

/* Generation loop */
int vocab_size = model->config.vocab_size;
int n_cached = *n_cached_io;
int pos = n_cached;
int prev_token = n_cached > 0 ? cached[n_cached - 1] : 1;

uint64_t rng_state = config->rng_seed
? (uint64_t)config->rng_seed : (uint64_t)time(NULL);
int next_token = tq_sample_topp(state->logits, vocab_size,
config->temperature, config->top_p,
&rng_state);

int output_pos = 0;
int eos_tokens[] = { 1, 2, 106, 128001, 128006, 128007, 128008, 128009, 248044, 248046 };
int n_eos = sizeof(eos_tokens) / sizeof(eos_tokens[0]);

while (generated < config->max_tokens) {
int is_eos = 0;
for (int e = 0; e < n_eos; e++) {
if (next_token == eos_tokens[e]) { is_eos = 1; break; }
}
if (is_eos) break;
if (pos >= model->config.max_seq_len) break;

const char* piece = tokenizer ? tq_decode(tokenizer, prev_token, next_token) : "";
int should_stop = 0;
if (piece) {
if (strstr(piece, "<|im_end|>") || strstr(piece, "<|eot_id|>") ||
strstr(piece, "<|start_header_id|>")) {
should_stop = 1; piece = "";
}
}
if (should_stop) break;

int piece_len = (int)strlen(piece ? piece : "");
if (config->on_token && piece) config->on_token(piece, config->user_data);
if (output && piece && output_pos + piece_len < output_size - 1) {
memcpy(output + output_pos, piece, piece_len);
output_pos += piece_len;
}

if (n_cached < *cached_capacity_io) {
cached[n_cached++] = next_token;
*n_cached_io = n_cached;
}

prev_token = next_token;
tq_forward(model, state, next_token, pos);
pos++;
generated++;

next_token = tq_sample_topp(state->logits, vocab_size,
config->temperature, config->top_p,
&rng_state);
}

if (output && output_size > 0) {
output[output_pos < output_size ? output_pos : output_size - 1] = '\0';
}
} else {
if (getenv("TQ_CHAT_DEBUG")) {
fprintf(stderr, "[chat-text] SLOW no text-prefix match, full tokenize\n");
}
generated = tq_generate_continue(
model, tokenizer, state, prompt, config,
cached_tokens_io, n_cached_io, cached_capacity_io,
output, output_size);
}

update_cache:
config->on_token = orig_cb;
config->user_data = orig_ud;

if (cached_text_io) {
size_t plen = strlen(prompt);
size_t glen = accum.len;
size_t new_len = plen + glen;
char* nt = (char*)malloc(new_len + 1);
if (nt) {
memcpy(nt, prompt, plen);
if (glen > 0 && accum.buf) memcpy(nt + plen, accum.buf, glen);
nt[new_len] = '\0';
if (*cached_text_io) free(*cached_text_io);
*cached_text_io = nt;
}
}
if (accum.buf) free(accum.buf);

return generated;
}

// ============================================================================

// ============================================================================
Expand Down Expand Up @@ -16182,6 +16405,7 @@ void quant_free_ctx(quant_ctx* ctx) {
tq_free_state(ctx->state);
tq_free_tokenizer(ctx->tokenizer);
if (ctx->cached_tokens) free(ctx->cached_tokens);
if (ctx->cached_text) free(ctx->cached_text);
free(ctx);
}

Expand Down Expand Up @@ -16217,6 +16441,7 @@ int quant_chat(quant_ctx* ctx, const char* prompt,
ctx->n_cached = 0;
ctx->cached_capacity = 0;
ctx->n_ctx_tokens = 0;
if (ctx->cached_text) { free(ctx->cached_text); ctx->cached_text = NULL; }
return 0;
}

Expand All @@ -16231,8 +16456,11 @@ int quant_chat(quant_ctx* ctx, const char* prompt,
ctx->config.user_data = user_data;

char output[65536];
int n = tq_generate_continue(
/* Use the text-prefix path so chat replays bypass BPE re-tokenization
* issues. Falls back to token-LCP path if text prefix doesn't match. */
int n = tq_generate_chat_text(
ctx->model, ctx->tokenizer, ctx->state, prompt, &ctx->config,
&ctx->cached_text,
&ctx->cached_tokens, &ctx->n_cached, &ctx->cached_capacity,
output, sizeof(output));

Expand Down
2 changes: 1 addition & 1 deletion wasm/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ emcc "$SCRIPT_DIR/quant_wasm.c" \
-s ALLOW_MEMORY_GROWTH=1 \
-s MAXIMUM_MEMORY=4GB \
-s INITIAL_MEMORY=256MB \
-s EXPORTED_FUNCTIONS='["_main","_wasm_load_model","_wasm_generate","_wasm_generate_async","_wasm_model_info","_wasm_is_ready","_malloc","_free"]' \
-s EXPORTED_FUNCTIONS='["_main","_wasm_load_model","_wasm_generate","_wasm_generate_async","_wasm_reset_chat","_wasm_model_info","_wasm_is_ready","_malloc","_free"]' \
-s EXPORTED_RUNTIME_METHODS='["UTF8ToString","allocateUTF8","FS","ccall","cwrap"]' \
-s FORCE_FILESYSTEM=1 \
-s MODULARIZE=0 \
Expand Down
35 changes: 29 additions & 6 deletions wasm/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,18 @@ <h2>Run an <span>LLM</span> in your browser</h2>
return `<|im_start|>user\n${text}<|im_end|>\n<|im_start|>assistant\n`;
}

/* Multi-turn chat history. Sent on every turn so the model has context.
* The C side's quant_chat() does text-prefix matching: turn N's prefill
* is O(new tokens since last call), not O(full history). */
let chatHistory = '';

function resetChatSession() {
chatHistory = '';
if (typeof Module !== 'undefined' && Module._wasm_reset_chat) {
Module._wasm_reset_chat();
}
}

function stopGeneration() { stopRequested = true; }

async function generate() {
Expand All @@ -405,11 +417,14 @@ <h2>Run an <span>LLM</span> in your browser</h2>

addMessage('user', text);
const aDiv = addMessage('assistant', '');
aDiv.innerHTML = '<span class="thinking"><span class="spinner"></span> Processing prompt (may take a few seconds)...</span>';
const isFirstTurn = chatHistory.length === 0;
aDiv.innerHTML = isFirstTurn
? '<span class="thinking"><span class="spinner"></span> Processing prompt (first turn — may take a few seconds)...</span>'
: '<span class="thinking"><span class="spinner"></span> Generating...</span>';
let output = '', count = 0;
const t0 = performance.now();
document.getElementById('statTokens').textContent = '';
document.getElementById('statSpeed').textContent = 'processing prompt...';
document.getElementById('statSpeed').textContent = isFirstTurn ? 'processing prompt...' : 'generating...';

Module.onToken = (tok) => {
output += tok; count++;
Expand All @@ -433,11 +448,13 @@ <h2>Run an <span>LLM</span> in your browser</h2>

await new Promise(r => requestAnimationFrame(() => requestAnimationFrame(r)));

const prompt = getChatPrompt(text);
/* Build the full ChatML prompt by appending this turn to the history.
* The C side's quant_chat() does text-prefix matching, so the previous
* turns are reused from the KV cache — only the new user message gets
* prefilled. Turn N's latency: O(new user message), not O(full history). */
chatHistory += `<|im_start|>user\n${text}<|im_end|>\n<|im_start|>assistant\n`;
const prompt = chatHistory;

// Use ccall with async:true — this is the correct way to call
// ASYNCIFY-enabled C functions. Module._fn() direct calls do NOT
// return Promises; only ccall({async:true}) does.
try {
await Module.ccall(
'wasm_generate_async',
Expand All @@ -450,6 +467,12 @@ <h2>Run an <span>LLM</span> in your browser</h2>
console.error('generate error:', e);
}

/* Append the model's response to history so the next turn matches
* the cached_text prefix exactly (byte-for-byte). */
if (output) {
chatHistory += `${output}<|im_end|>\n`;
}

if (!output && !count) {
aDiv.innerHTML = '<em style="color:#555">No output. Try a different prompt.</em>';
}
Expand Down
2 changes: 1 addition & 1 deletion wasm/quant.js

Large diffs are not rendered by default.

Binary file modified wasm/quant.wasm
Binary file not shown.
Loading
Loading