Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 61 additions & 14 deletions src/llama-memory-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ void llama_memory_recurrent::clear(bool data) {
ggml_backend_buffer_clear(buf.get(), 0);
}
}

std::fill(rs_idx.begin(), rs_idx.end(), 0);
}

bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
Expand All @@ -156,6 +158,15 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
p1 = std::numeric_limits<llama_pos>::max();
}

const bool rm_all = p0 == 0 && p1 == std::numeric_limits<llama_pos>::max();
if (rm_all) {
if (seq_id >= 0) {
set_rs_idx(seq_id, 0);
} else {
std::fill(rs_idx.begin(), rs_idx.end(), 0);
}
}

// models like Mamba or RWKV can't have a state partially erased at the end
// of the sequence because their state isn't preserved for previous tokens
if (seq_id >= (int64_t) size) {
Expand Down Expand Up @@ -719,16 +730,8 @@ size_t llama_memory_recurrent::size_s_bytes() const {
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);

// [TAG_RS_STATE_ROLLBACK_SUPPORT]
if (n_rs_seq != 0) {
for (uint32_t i = 0; i < rs_idx.size(); ++i) {
if (rs_idx[i] != 0) {
GGML_ABORT("recurrent state read/write is not supported with partial rollback");
}
}
}

std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges_data; // logical source row ranges
uint32_t cell_count = 0;

// Count the number of cells with the specified seq_id
Expand All @@ -738,6 +741,35 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
const auto & cell = cells[i];
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
++cell_count;
uint32_t rs_idx_cur = 0;

if (n_rs_seq != 0) {
if (seq_id != -1) {
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < rs_idx.size());
rs_idx_cur = rs_idx[seq_id];
} else {
bool has_rs_idx = false;
for (const llama_seq_id cell_seq_id : cell.seq_id) {
GGML_ASSERT(cell_seq_id >= 0 && (size_t) cell_seq_id < rs_idx.size());

const uint32_t seq_rs_idx = rs_idx[cell_seq_id];
if (!has_rs_idx) {
rs_idx_cur = seq_rs_idx;
has_rs_idx = true;
} else if (rs_idx_cur != seq_rs_idx) {
GGML_ABORT("cannot write shared recurrent state with different rollback indices");
}
}
}
}

const uint32_t cell_id = rs_idx_cur * size + (cell.src >= 0 ? cell.src : (int32_t) i);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cell.src >= 0 ? cell.src

I guess this was the part that I was missing in my approach. I will take detailed look tomorrow and run some tests to verify it's working as expected.

if (cell_ranges_data.empty() || cell_ranges_data.back().second != cell_id) {
cell_ranges_data.emplace_back(cell_id, cell_id + 1);
} else {
cell_ranges_data.back().second++;
}

if (cell_range_begin == size) {
cell_range_begin = i;
}
Expand All @@ -763,10 +795,16 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
}
GGML_ASSERT(cell_count == cell_count_check);

cell_count_check = 0;
for (const auto & range : cell_ranges_data) {
cell_count_check += range.second - range.first;
}
GGML_ASSERT(cell_count == cell_count_check);

io.write(&cell_count, sizeof(cell_count));

state_write_meta(io, cell_ranges, seq_id);
state_write_data(io, cell_ranges);
state_write_data(io, cell_ranges_data);
}

void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
Expand All @@ -788,6 +826,14 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i
}
throw std::runtime_error("failed to restore kv cache");
}

if (n_rs_seq != 0) {
if (seq_id == -1) {
std::fill(rs_idx.begin(), rs_idx.end(), 0);
} else {
set_rs_idx(seq_id, 0);
}
}
}

void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
Expand Down Expand Up @@ -830,7 +876,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
io.write(&r_size_row, sizeof(r_size_row));

// Write each range of cells of r_size_row length
// Write each logical cell row range. With pending recurrent rollback,
// the logical current state may live in a rollback snapshot plane.
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t buf_size = range_size * r_size_row;
Expand All @@ -851,7 +898,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
io.write(&s_size_row, sizeof(s_size_row));

// Write each range of S tensor rows
// Write each logical cell row range. With pending recurrent rollback,
// the logical current state may live in a rollback snapshot plane.
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t buf_size = range_size * s_size_row;
Expand All @@ -878,9 +926,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
// Write GQA embedding size
io.write(&n_embd_s, sizeof(n_embd_s));

// For each row, we get the element values of each cell
// For each row, we get the element values of each logical cell
for (uint32_t j = 0; j < n_embd_s; ++j) {
// Write each range of cells of s_size_el length
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t src_offset = (range.first + j * mem_size) * s_size_el;
Expand Down
3 changes: 3 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ llama_build_and_test(test-backend-sampler.cpp LABEL "model")
llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -m "${MODEL_DEST}")
set_tests_properties(test-state-restore-fragmented PROPERTIES FIXTURES_REQUIRED test-download-model)

llama_build_and_test(test-recurrent-state-rollback.cpp LABEL "model" ARGS -m "${MODEL_DEST}")
set_tests_properties(test-recurrent-state-rollback PROPERTIES FIXTURES_REQUIRED test-download-model)

if (NOT GGML_BACKEND_DL)
# these tests use the backends directly and cannot be built with dynamic loading
llama_build_and_test(test-barrier.cpp)
Expand Down
185 changes: 185 additions & 0 deletions tests/test-recurrent-state-rollback.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
#include "arg.h"
#include "common.h"
#include "llama.h"

#include <algorithm>
#include <clocale>
#include <cmath>
#include <cstdio>
#include <vector>

static llama_context * make_ctx(const common_params & params, llama_model * model) {
auto cparams = common_context_params_to_llama(params);
cparams.n_seq_max = 1;
cparams.n_rs_seq = 8;
cparams.n_batch = std::max(cparams.n_batch, (uint32_t) (cparams.n_rs_seq + 1));
cparams.n_ubatch = std::max(cparams.n_ubatch, (uint32_t) (cparams.n_rs_seq + 1));
return llama_init_from_model(model, cparams);
}

static bool decode_tokens(llama_context * ctx, const std::vector<llama_token> & tokens, uint32_t count) {
llama_batch batch = llama_batch_init(count, 0, 1);
for (uint32_t pos = 0; pos < count; ++pos) {
common_batch_add(batch, tokens[pos], pos, { 0 }, false);
}
const bool ok = llama_decode(ctx, batch) == 0;
llama_batch_free(batch);
return ok;
}

static bool decode_one(llama_context * ctx, llama_token tok, llama_pos pos) {
llama_batch batch = llama_batch_init(1, 0, 1);
common_batch_add(batch, tok, pos, { 0 }, true);
const bool ok = llama_decode(ctx, batch) == 0;
llama_batch_free(batch);
return ok;
}

int main(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");

common_params params;
params.sampling.seed = 1234;
params.n_predict = 1;

common_init();

if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
return 1;
}

ggml_backend_load_all();

common_init_result_ptr llama_init = common_init_from_params(params);
llama_model * model = llama_init->model();
if (model == nullptr) {
fprintf(stderr, "%s : failed to init model\n", __func__);
return 1;
}

if (!llama_model_is_recurrent(model) && !llama_model_is_hybrid(model)) {
fprintf(stderr, "%s : skipping for non-recurrent model\n", __func__);
return 0;
}

const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_vocab_n_tokens(vocab);

llama_context * ctx_src = make_ctx(params, model);
llama_context * ctx_dst = make_ctx(params, model);
if (ctx_src == nullptr || ctx_dst == nullptr) {
fprintf(stderr, "%s : failed to init contexts\n", __func__);
return 1;
}

if (llama_n_rs_seq(ctx_src) == 0) {
fprintf(stderr, "%s : skipping because n_rs_seq is disabled\n", __func__);
llama_free(ctx_src);
llama_free(ctx_dst);
return 0;
}

std::vector<llama_token> tokens = common_tokenize(ctx_src, "The quick brown fox jumps", true);
const uint32_t n_rs_seq = llama_n_rs_seq(ctx_src);
if (tokens.size() > n_rs_seq + 1) {
tokens.resize(n_rs_seq + 1);
}
if (tokens.size() < 2) {
fprintf(stderr, "%s : not enough prompt tokens\n", __func__);
return 1;
}
const uint32_t n_tokens = tokens.size();
const llama_token last_tok = tokens.back();
const llama_pos last_pos = (llama_pos) n_tokens - 2;

// Decode the full prompt on the source, then roll back the last position.
// Rollback leaves the recurrent memory in a snapshot state (rs_idx != 0).
if (!decode_tokens(ctx_src, tokens, n_tokens)) {
fprintf(stderr, "%s : failed to decode prompt\n", __func__);
return 1;
}
if (!llama_memory_seq_rm(llama_get_memory(ctx_src), 0, last_pos, -1)) {
fprintf(stderr, "%s : rollback failed\n", __func__);
return 1;
}

// Save the rolled-back state and restore it into a fresh context.
common_prompt_checkpoint ckpt;
ckpt.update_tgt(ctx_src, 0, 0);
ckpt.load_tgt(ctx_dst, 0, 0);

// Replay the rolled-back token on both contexts and compare logits.
if (!decode_one(ctx_src, last_tok, last_pos) ||
!decode_one(ctx_dst, last_tok, last_pos)) {
fprintf(stderr, "%s : replay failed\n", __func__);
return 1;
}

const float * logits_src = llama_get_logits_ith(ctx_src, 0);
const float * logits_dst = llama_get_logits_ith(ctx_dst, 0);
if (logits_src == nullptr || logits_dst == nullptr) {
fprintf(stderr, "%s : missing logits\n", __func__);
return 1;
}

constexpr float eps = 1e-5f;
for (int i = 0; i < n_vocab; ++i) {
if (std::fabs(logits_src[i] - logits_dst[i]) > eps) {
fprintf(stderr, "%s : logits mismatch at token %d (%g != %g)\n",
__func__, i, (double) logits_src[i], (double) logits_dst[i]);
return 1;
}
}

// Repeat the load into a context that already has its own rollback state:
// groups 1..n_rs_seq hold a *different* prompt's history, and rs_idx[0] is
// non-zero at load time. The restore must wipe that state and still match.
llama_context * ctx_dirty = make_ctx(params, model);
if (ctx_dirty == nullptr) {
fprintf(stderr, "%s : failed to init dirty ctx\n", __func__);
return 1;
}

std::vector<llama_token> noise = tokens;
for (auto & t : noise) {
t = (t + 1) % n_vocab;
if (t < 0) {
t = 0;
}
}
if (!decode_tokens(ctx_dirty, noise, n_tokens)) {
fprintf(stderr, "%s : dirty prompt decode failed\n", __func__);
return 1;
}
if (!llama_memory_seq_rm(llama_get_memory(ctx_dirty), 0, last_pos, -1)) {
fprintf(stderr, "%s : dirty rollback failed\n", __func__);
return 1;
}

ckpt.load_tgt(ctx_dirty, 0, 0);

if (!decode_one(ctx_dirty, last_tok, last_pos)) {
fprintf(stderr, "%s : dirty replay failed\n", __func__);
return 1;
}

const float * logits_dirty = llama_get_logits_ith(ctx_dirty, 0);
if (logits_dirty == nullptr) {
fprintf(stderr, "%s : missing dirty logits\n", __func__);
return 1;
}

for (int i = 0; i < n_vocab; ++i) {
if (std::fabs(logits_src[i] - logits_dirty[i]) > eps) {
fprintf(stderr, "%s : dirty-ctx logits mismatch at token %d (%g != %g)\n",
__func__, i, (double) logits_src[i], (double) logits_dirty[i]);
return 1;
}
}

fprintf(stderr, "%s : recurrent rollback checkpoint restored successfully\n", __func__);
llama_free(ctx_src);
llama_free(ctx_dst);
llama_free(ctx_dirty);
return 0;
}
Loading