Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "log.h"
#include "llama.h"
#include "sampling.h"
#include "speculative.h"
#include "unicode.h"

#include <algorithm>
Expand Down Expand Up @@ -1247,6 +1248,29 @@ common_init_result::common_init_result(common_params & params) :
cparams.n_samplers = pimpl->samplers_seq_config.size();
}

// [TAG_RS_STATE_ROLLBACK_SUPPORT]
// TODO: ngram speculative methods require checkpointing in addition to partial RS rollback
// currently this is not supported. so we disable the partial rollback
if (cparams.n_rs_seq > 0 && (llama_model_is_recurrent(model) || llama_model_is_hybrid(model))) {
auto & types = params.speculative.types;

for (int i = 0; i < (int) types.size(); i++) {
if (types[i] == COMMON_SPECULATIVE_TYPE_NONE) {
continue;
}
if (types[i] == COMMON_SPECULATIVE_TYPE_DRAFT_MTP) {
continue;
}

cparams.n_rs_seq = 0;

LOG_WRN("%s: recurrent state rollback is not compatible with '%s' - disabling rollback support\n", __func__,
common_speculative_type_to_str(types[i]).c_str());

break;
}
}

llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
Expand Down
2 changes: 1 addition & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ struct common_params_speculative_draft {
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding

float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
float p_min = 0.75f; // minimum speculative decoding probability (greedy) // TODO: change default to 0.0f

common_params_model mparams;

Expand Down
7 changes: 6 additions & 1 deletion common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
for (auto & s : smpls) {
common_params_sampling sparams;
sparams.no_perf = false;
sparams.top_k = 1;
sparams.top_k = 1; // TODO: re-enable top_k == 10 and utilize `p_min` spec param
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
}
Expand Down Expand Up @@ -1494,6 +1494,11 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u

GGML_ASSERT(impl);

// TODO: currently only the implementation that generated the draft is used to accept it
// however, some implementations (such as MTP) need to also "see" the accepted tokens
// extend `common_speculative_impl::accept()` with an extra argument `bool is_other` to
// inform the implementation if the accepted tokens are from another implementation and
// pass the accepted tokens to all remaining implementations using `is_other == true`
{
common_time_meas tm(impl->t_accept_us, !impl->gen_perf);
if (n_accepted > 0) {
Expand Down
6 changes: 3 additions & 3 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2531,7 +2531,7 @@ kernel void kernel_rwkv_wkv7_f32(

constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];

#if 1
template<short NSG>
Expand All @@ -2549,6 +2549,7 @@ kernel void kernel_gated_delta_net_impl(
uint3 ntg[[threads_per_threadgroup]]) {
#define S_v FC_gated_delta_net_ne20
#define G FC_gated_delta_net_ne30
#define K FC_gated_delta_net_K

const uint tx = tpitg.x;
const uint ty = tpitg.y;
Expand All @@ -2562,8 +2563,6 @@ kernel void kernel_gated_delta_net_impl(

const float scale = 1.0f / sqrt((float)S_v);

const uint K = FC_gated_delta_net_K;

// input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0.
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v;
Expand Down Expand Up @@ -2666,6 +2665,7 @@ kernel void kernel_gated_delta_net_impl(

#undef S_v
#undef G
#undef K
}

typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
Expand Down
11 changes: 10 additions & 1 deletion src/llama-memory-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,15 @@ size_t llama_memory_recurrent::size_s_bytes() const {
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);

// [TAG_RS_STATE_ROLLBACK_SUPPORT]
if (n_rs_seq != 0) {
for (uint32_t i = 0; i < rs_idx.size(); ++i) {
if (rs_idx[i] != 0) {
GGML_ABORT("recurrent state read/write is not supported with partial rollback");
}
}
}

std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0;

Expand All @@ -743,7 +752,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
cell_ranges.emplace_back(cell_range_begin, size);
}

if (flags % LLAMA_STATE_SEQ_FLAGS_ON_DEVICE && cell_ranges.size() > 1) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) && cell_ranges.size() > 1) {
GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n");
}

Expand Down
Loading