Skip to content

Conversation

@ichbinhandsome
Copy link

@ichbinhandsome ichbinhandsome commented Dec 14, 2025

As discussed in #15902, Eagle3 represents the current SOTA in speculative decoding and is widely adopted across the industry. Integrating Eagle3 into llama.cpp enhances its performance and strengthens its competitiveness among leading inference frameworks. With Eagle3 speculative decoding now integrated into llama.cpp, inference performance has been significantly improved, achieving a 2–3× speedup.
This enhancement is the result of close collaboration between the NVIDIA and GGML teams, showcasing a strong technical partnership.

The following provides a brief overview of this PR:

EAGLE3 is an encoder-decoder based speculative decoding method:

  • Extracts features from target model at specific layers
  • Uses feature fusion layer to compress target features
  • Generates draft tokens with single-layer decoder
  • Maps draft vocabulary to target vocabulary via d2t tensor

Key changes:

  • Add LLM_ARCH_EAGLE3 architecture
  • Add EAGLE3 encoder/decoder graph (src/models/eagle3.cpp)
  • Add feature extraction from target model layers
  • Add g_embeddings handling for decoder input
  • Add GGML_TENSOR_FLAG_SYNC for GPU synchronization
  • Add --eagle3 flag for speculative-simple example
  • Add EAGLE3 model conversion in convert_hf_to_gguf.py

EAGLE3 Architecture Overview :

┌─────────────────────────────────────────────────────────────────┐
│                    EAGLE3 Overview                              │
└─────────────────────────────────────────────────────────────────┘

  Target Model          EAGLE3 Encoder         EAGLE3 Decoder
  (LLaMA 8B)              (FC Layer)           (1-layer Transformer)
       │                      │                       │
       │                      │                       │
       ▼                      ▼                       ▼
┌─────────────┐        ┌─────────────┐        ┌─────────────────┐
│  Generate   │        │  Compress   │        │  Generate Draft │
│  Features   │───────►│  Features   │───────►│  Tokens Fast    │
│  [12288]    │        │  [4096]     │        │  [k tokens]     │
└─────────────┘        └─────────────┘        └────────┬────────┘
                                                       │
                                                       ▼
                                              ┌─────────────────┐
                                              │  Verify Drafts  │
                                              │  with Target    │
                                              └─────────────────┘

How to run EAGLE3 in llama.cpp

Requirements

This PR currently only support two EAGLE3 models:

Step 1: Convert Models to GGUF Format

  • Convert Target Model
TARGET_MODEL_HF="${MODELS_DIR}/Meta-Llama-3.1-8B-Instruct"
TARGET_MODEL_GGUF="${MODELS_DIR}/Meta-Llama-3.1-8B-Instruct_bf16.gguf"

python convert_hf_to_gguf.py \
    "${TARGET_MODEL_HF}" \
    --outtype bf16 \
    --outfile "${TARGET_MODEL_GGUF}"
  • Convert EAGLE3 Draft Model
TARGET_MODEL_HF="${MODELS_DIR}/Meta-Llama-3.1-8B-Instruct"
EAGLE3_MODEL_HF="${MODELS_DIR}/EAGLE3-LLaMA3.1-Instruct-8B"
EAGLE3_MODEL_GGUF="${MODELS_DIR}/EAGLE3-LLaMA3.1-Instruct-8B_fp16.gguf"

python convert_hf_to_gguf.py \
    "${EAGLE3_MODEL_HF}" \
    --outtype f16 \
    --target-model-dir "${TARGET_MODEL_HF}" \
    --outfile "${EAGLE3_MODEL_GGUF}"

Step 2: Compile llama.cpp

cmake -B build -DGGML_CUDA=ON
cmake --build build --config Release

Step 3: Run EAGLE3 Speculative Decoding

for prompt in \
    "Write a quicksort algorithm in Python. Write code only." \
    "Explain the Pythagorean theorem" \
    "Plan a 1 day trip to DC"; do
  echo "=== Prompt: $prompt ==="
    ./build/bin/llama-speculative-simple \
      -m "${TARGET_MODEL_GGUF}" \
      -md "${EAGLE3_MODEL_GGUF}" \
      --eagle3 -p "$prompt" -n 256 --draft 8 \
      --temp 0 --top-k 1 --seed 42 -ngl 99 -ngld 99 
done

Performance Evaluation (RTX A6000 48GB)

Note: Using the chat_template for each model version can improve acceptance rates. Always apply the model’s corresponding chat_template when constructing prompts.

  • LLaMA3.1-Instruct-8B with BF16, its Eagle3 with FP16
Prompt Baseline (llama-cli) EAGLE3 (draft_size=8) Accept Rate Speedup
Write a quicksort algorithm in Python. Write code only. 44.5 t/s 146.2 t/s 80.6% 3.28x
Explain the Pythagorean theorem 44.5 t/s 126.8 t/s 77.7% 2.85x
Plan a 1 day trip to DC 44.5 t/s 111.8 t/s 78.4% 2.51x
  • LLaMA3.1-Instruct-8B with Q4_K_M, its Eagle3 with Q4_K_M
Prompt Baseline (llama-cli) EAGLE3 (draft_size=8) Accept Rate Speedup
Write a quicksort algorithm in Python. Write code only. 121.5 t/s 260.5 t/s 83.6% 2.14x
Explain the Pythagorean theorem 121.4 t/s 232.4 t/s 78.6% 1.91x
Plan a 1 day trip to DC 121.4 t/s 186.8 t/s 71.5% 1.54x
  • LLaMA3.3-Instruct-70B with Q4_K_M, its Eagle3 with Q4_K_M
Prompt Baseline (llama-cli) EAGLE3 (draft_size=8) Accept Rate Speedup
Write a quicksort algorithm in Python. Write code only. 15.6 t/s 31.7 t/s 67.7% 2.03x
Explain the Pythagorean theorem 15.6 t/s 37.1 t/s 80.8% 2.38x
Plan a 1 day trip to DC 15.6 t/s 29.9 t/s 73.5% 1.92x

Details of GGML backend modifications

In the Eagle3 decoder, two parallel inputs are processed:

input_embeds ──→ RMS_NORM ──┐
                            ├──→ CONCAT ──→ Transformer Decoder
g_embeddings ──→ RMS_NORM ──┘

When both RMS_NORM operations run in the same GPU split, a lack of synchronization causes buffer contention and race conditions (CPU execution is fine as it auto‑syncs between subgraphs).

Solution:
Use ggml_set_sync() to add a synchronization point after the first RMS_NORM, forcing the scheduler to create a split boundary and synchronize before continuing.

input_embeds ──→ RMS_NORM ──→ [SYNC] ──┐
                                       ├──→ CONCAT ──→ Transformer Decoder
g_embeddings ─────────────→ RMS_NORM ──┘
         (split 1)            |         (split 2)
                           barrier

This ensures correct execution and can be applied to any parallel path that needs synchronization, not just Eagle3.

Examples results

  • Prompt: "Write a quicksort algorithm in Python. Write code only."
image
  • Prompt: "Explain the Pythagorean theorem"
image
  • Prompt: "Plan a 1 day trip to DC"
image

Future Steps

  • Support more Eagle3 models
  • Currently, Eagle3 is integrated only in llama-speculative-simple, support may need to be extended to other APIs if possible
  • Support context-dependent tree sampling (tree attention) as described in the Eagle3 paper to improve accept rate
  • Support batch processing (batch size > 1) with Eagle3 speculative decoding

EAGLE3 is an encoder-decoder based speculative decoding method:
- Extracts features from target model at specific layers
- Uses feature fusion layer to compress target features
- Generates draft tokens with single-layer decoder
- Maps draft vocabulary to target vocabulary via d2t tensor

Key changes:
- Add LLM_ARCH_EAGLE3 architecture
- Add EAGLE3 encoder/decoder graph (src/models/eagle3.cpp)
- Add feature extraction from target model layers
- Add g_embeddings handling for decoder input
- Add GGML_TENSOR_FLAG_SYNC for GPU synchronization
- Add --eagle3 flag for speculative-simple example
- Add EAGLE3 model conversion in convert_hf_to_gguf.py
Comment on lines +65 to +68

// Force a sync point between the two parallel RMS_NORM paths
// This prevents buffer reuse issues on GPU (EAGLE3 GPU fix)
ggml_set_sync(input_embeds_normed);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very strange that you need to do it explicitly.

The ggml_concat operator (like every other ggml op) tracks the input tensors on which it depends. So it should not be possible to get a buffer reuse when the data in the buffer is still pending a computation.

I think this sync should not be necessary and if removing it causes some data corruption, the cause is something else which we should investigate in detail.

Can you confirm that removing this call still causes problems?

Copy link
Author

@ichbinhandsome ichbinhandsome Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just revalidated this, and without calling ggml_set_sync, the buffer data gets overwritten, causing the acceptance rate to nearly 3-4%. This issue only occurs on the GPU side — when running draft model on the CPU, the acceptance rate remains stable, and ggml_set_sync is not required.

The results buffers from two RMS_NORM operations appear to conflict, with one being overwritten by invalid (garbage) values. ggml_set_sync is used to enforce synchronization between two RMS_NORM operations on GPU side.

Copy link
Author

@ichbinhandsome ichbinhandsome Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also tried using ggml_set_output for the two RMS_NORM results to avoid buffer overwriting. However, once I set it, the buffer for the concatenated results got overwritten. I then tried setting that as well, but the subsequent Q, K, and V attention result buffers were still being overwritten. It seems there’s an issue with buffer allocation in the scheduler when handling parallel inputs on GPU. So I came up with this method to resolve the issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I am able to reproduce the issue. Looking into this.

@ngxson
Copy link
Collaborator

ngxson commented Dec 15, 2025

Judging by the description of this PR, I believe many models with multiple-token prediction also have the same strategy of reusing hidden features from the main model.

It can be quite interesting to generalize this features to support other models. I would expect some kind of sub-llama_context that allow both the main and draft models to share the same cgraph, avoiding the need of explicitly passing the intermediate embedding through the host memory.

@ggerganov
Copy link
Member

It can be quite interesting to generalize this features to support other models.

I will definitely be looking at refactoring the implementation to become more generic before merging it. The initial results in terms of performance are really great, but we'll need to work on cleaning up the code and reduce the special-casing in several places. I'll try to provide insights how to do that in the next days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples ggml changes relating to the ggml tensor library for machine learning model Model specific python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants