Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3124,6 +3124,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.p_min = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_P_MIN"));
add_opt(common_arg(
{"--eagle3"},
"use EAGLE3 speculative decoding with the draft model",
[](common_params & params) {
params.speculative.eagle3 = true;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-cd", "--ctx-size-draft"}, "N",
string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx),
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ struct common_params_speculative {
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)

bool eagle3 = false; // use EAGLE3 speculative decoding
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;

Expand Down
199 changes: 199 additions & 0 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ struct common_speculative {
llama_tokens prompt_dft;
bool vocab_dft_compatible = true; // whether retokenization is needed
std::map<std::string, std::string> tgt_dft_replacements = {};

// EAGLE3 specific
struct llama_context * eagle3_encoder = nullptr;
struct llama_context * eagle3_decoder = nullptr;
int32_t eagle3_n_past = 0; // number of verified positions in decoder KV cache
};

struct common_speculative * common_speculative_init(
Expand Down Expand Up @@ -74,13 +79,50 @@ struct common_speculative * common_speculative_init(
return result;
}

struct common_speculative * common_speculative_init_eagle3(
struct llama_context * ctx_tgt,
struct llama_context * ctx_encoder,
struct llama_context * ctx_decoder) {

auto * result = new common_speculative {
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ nullptr, // Not used for EAGLE3
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_decoder), 0, 1),
/* .prompt_dft = */ {},
/* .vocab_dft_compatible = */ true, // EAGLE3 uses same vocab
/* .tgt_dft_replacements = */ {},
/* .eagle3_encoder = */ ctx_encoder,
/* .eagle3_decoder = */ ctx_decoder,
};

// Initialize sampler for EAGLE3 decoder
{
common_params_sampling params;
params.no_perf = false;
params.top_k = 10; // set 1 for greedy sampling (argmax) to match vLLM's default behavior but >1 always gets higher acceptance rate for eagle3
params.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
result->smpl = common_sampler_init(llama_get_model(ctx_decoder), params);
}

return result;
}

void common_speculative_free(struct common_speculative * spec) {
if (spec == nullptr) {
return;
}

common_sampler_free(spec->smpl);

// EAGLE3 cleanup
if (spec->eagle3_encoder) {
llama_free(spec->eagle3_encoder);
}
if (spec->eagle3_decoder) {
llama_free(spec->eagle3_decoder);
}

llama_batch_free(spec->batch);

delete spec;
Expand Down Expand Up @@ -181,12 +223,169 @@ static std::string replace_to_tgt(
return result;
}

// EAGLE3 Draft Generation with KV Cache Reuse
//
// ============================================================================
// EXAMPLE: Two rounds of speculative decoding
// ============================================================================
//
// ROUND 1 (Initial):
// Prompt: [t0, t1, t2, t3, t4], target generates t5
// prompt_tgt = [t0, t1, t2, t3, t4], id_last = t5 (GENERATED)
// n = 5, n_past = 0, n_new = 5
//
// Step 1: Encoder
// features: [f0, f1, f2, f3, f4] → g_embeddings: [g0, g1, g2, g3, g4]
//
// Step 2: Decoder batch (positions 0-4)
// tokens: [t1, t2, t3, t4, t5] ← prompt[1:] + id_last
// g_embd: [g0, g1, g2, g3, g4]
// positions: [0, 1, 2, 3, 4 ]
// → KV cache: [0, 1, 2, 3, 4]
// → sample d1 from logits[4]
//
// Step 3: Autoregressive (positions 5, 6, ...)
// pos 5: token=d1, g_embd=prenorm[4] → KV cache: [0,1,2,3,4,5] → d2
// pos 6: token=d2, g_embd=prenorm → KV cache: [0,1,2,3,4,5,6] → d3
//
// Output: [d1, d2, d3]
// Update: n_past = 5 (verified positions from batch decode)
//
// ROUND 2 (assuming d1 accepted, d2/d3 rejected):
// prompt_tgt = [t0, t1, t2, t3, t4, t5, d1], id_last = t6 (new target output)
// n = 7, n_past = 5, n_new = 2
//
// Step 1: Clear KV cache [5, inf) - remove draft positions
// KV cache: [0, 1, 2, 3, 4] (reuse from round 1!)
//
// Step 2: Encoder (only new tokens)
// features: [f5, f6] → g_embeddings: [g5, g6]
//
// Step 3: Decoder batch (only new positions 5-6)
// tokens: [d1, t6] (prompt_tgt[6], id_last)
// g_embd: [g5, g6]
// positions: [5, 6 ]
// → KV cache: [0,1,2,3,4] + [5,6] = [0,1,2,3,4,5,6]
// → sample d1' from logits[1] (last position in batch)
//
// Step 4: Autoregressive...
//
// ============================================================================
//
// Key insight: Decoder KV cache stores K/V computed from (tok_embd + g_embd).
// For verified positions, both tok_embd and g_embd are fixed (encoder output),
// so KV cache can be reused. Draft positions use prenorm as g_embd, which
// differs from encoder output, so they must be cleared and recomputed.
//
static llama_tokens gen_eagle3_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt_tgt,
llama_token id_last) {

auto * ctx_tgt = spec->ctx_tgt;
auto * ctx_encoder = spec->eagle3_encoder;
auto * ctx_decoder = spec->eagle3_decoder;
auto * smpl = spec->smpl;
auto & batch = spec->batch;

const int n_embd = llama_model_n_embd(llama_get_model(ctx_encoder));
const int n = (int)prompt_tgt.size();
const int n_new = n - spec->eagle3_n_past;

GGML_ASSERT(n >= 1 && "prompt_tgt is empty");
GGML_ASSERT(n_new >= 1 && "must have at least 1 new token");

// Clear draft positions from decoder KV cache [n_past, inf)
llama_memory_seq_rm(llama_get_memory(ctx_decoder), 0, spec->eagle3_n_past, -1);

// Encoder: features → g_embeddings
const float * features = llama_get_eagle3_target_features(ctx_tgt);
GGML_ASSERT(features && "no target features");

llama_batch enc_batch = {
/*.n_tokens =*/ n_new,
/*.token =*/ nullptr,
/*.embd =*/ const_cast<float*>(features),
/*.pos =*/ nullptr,
/*.n_seq_id =*/ nullptr,
/*.seq_id =*/ nullptr,
/*.logits =*/ nullptr,
};
GGML_ASSERT(llama_encode(ctx_encoder, enc_batch) == 0);

const float * g_embd = llama_get_embeddings(ctx_encoder);
GGML_ASSERT(g_embd && "encoder output failed");

// Decoder batch: process new tokens with KV cache reuse
llama_set_eagle3_g_embeddings(ctx_decoder, g_embd, n_embd, n_new);

common_batch_clear(batch);
for (int i = 0; i < n_new; i++) {
const int pos = spec->eagle3_n_past + i;
const llama_token tok = (pos < n - 1) ? prompt_tgt[pos + 1] : id_last;
common_batch_add(batch, tok, pos, {0}, true);
}

GGML_ASSERT(llama_decode(ctx_decoder, batch) == 0);

spec->eagle3_n_past = n; // update verified positions

// Sample draft tokens
llama_tokens result;
common_sampler_reset(smpl);

// Sample and check probability (consistent with standard speculative decoding)
auto sample_and_check = [&](int idx) -> bool {
common_sampler_sample(smpl, ctx_decoder, idx);

const auto * cur_p = common_sampler_get_candidates(smpl, true);
const llama_token id = cur_p->data[0].id;

common_sampler_accept(smpl, id, true);
result.push_back(id);

return cur_p->data[0].p >= params.p_min;
};

// First draft token from batch decode
if (!sample_and_check(n_new - 1)) {
return result;
}

// Autoregressive: use prenorm as g_embd (-1 = last output)
const float * prenorm = llama_get_embeddings_ith(ctx_decoder, -1);

for (int i = 1; i < params.n_draft; i++) {
GGML_ASSERT(prenorm && "prenorm failed");
llama_set_eagle3_g_embeddings(ctx_decoder, prenorm, n_embd, 1);

common_batch_clear(batch);
common_batch_add(batch, result.back(), n - 1 + i, {0}, true);
GGML_ASSERT(llama_decode(ctx_decoder, batch) == 0);

prenorm = llama_get_embeddings_ith(ctx_decoder, -1);

if (!sample_and_check(0)) {
break;
}
}

return result;
}

llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt_tgt_main_model, // specified in target model vocab
llama_token id_last) {

// EAGLE3 path
if (spec->eagle3_encoder && spec->eagle3_decoder) {
return gen_eagle3_draft(spec, params, prompt_tgt_main_model, id_last);
}

// Standard draft model path
auto & batch = spec->batch;
auto & ctx_tgt = spec->ctx_tgt;
auto & ctx_dft = spec->ctx_dft;
Expand Down
7 changes: 7 additions & 0 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ struct common_speculative * common_speculative_init(
struct llama_context * ctx_dft
);

// EAGLE3: Initialize speculative decoding with EAGLE3 encoder and decoder contexts
struct common_speculative * common_speculative_init_eagle3(
struct llama_context * ctx_tgt,
struct llama_context * ctx_encoder,
struct llama_context * ctx_decoder
);

void common_speculative_free(struct common_speculative * spec);

bool common_speculative_are_compatible(
Expand Down
Loading
Loading