|
9 | 9 | #include "llama-model.h" |
10 | 10 |
|
11 | 11 | #include <cinttypes> |
| 12 | +#include <cmath> |
12 | 13 | #include <cstring> |
13 | 14 | #include <limits> |
14 | 15 | #include <stdexcept> |
@@ -72,6 +73,43 @@ llama_context::llama_context( |
72 | 73 | cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; |
73 | 74 | } |
74 | 75 |
|
| 76 | + if (cparams.yarn_ext_factor != 0) { |
| 77 | + static auto get_mscale = [](float scale, float mscale) { |
| 78 | + return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); |
| 79 | + }; |
| 80 | + |
| 81 | + const float factor = 1.0f / cparams.rope_freq_scale; |
| 82 | + |
| 83 | + // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348 |
| 84 | + if (hparams.rope_yarn_log_mul != 0.0f) { |
| 85 | + // note: here we assume `mscale == 1.0f` |
| 86 | + // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f |
| 87 | + float mscale = 1.0f; |
| 88 | + const float mscale_all_dims = hparams.rope_yarn_log_mul; |
| 89 | + |
| 90 | + // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] |
| 91 | + // special-case DEEPSEEK v2: |
| 92 | + // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43 |
| 93 | + if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) { |
| 94 | + mscale = mscale_all_dims; |
| 95 | + } |
| 96 | + |
| 97 | + cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); |
| 98 | + |
| 99 | + LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n", |
| 100 | + __func__, cparams.yarn_attn_factor, mscale, mscale_all_dims); |
| 101 | + } else { |
| 102 | + cparams.yarn_attn_factor = get_mscale(factor, 1.0f); |
| 103 | + } |
| 104 | + |
| 105 | + // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor: |
| 106 | + // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544 |
| 107 | + // |
| 108 | + // ref: https://github.com/ggml-org/llama.cpp/discussions/7416 |
| 109 | + // https://github.com/ggml-org/llama.cpp/pull/17945 |
| 110 | + cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor)); |
| 111 | + } |
| 112 | + |
75 | 113 | cparams.yarn_attn_factor *= hparams.rope_attn_factor; |
76 | 114 |
|
77 | 115 | if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { |
|
0 commit comments