Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 76 additions & 35 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,14 @@ def set_gguf_parameters(self):
logger.warning(f"Unknown RoPE type: {rope_type}")
logger.info(f"gguf: rope scaling type = {rope_gguf_type.name}")

if "mrope_section" in self.rope_parameters:
mrope_section = self.rope_parameters["mrope_section"]
# Pad to 4 dimensions [time, height, width, extra]
while len(mrope_section) < 4:
mrope_section.append(0)
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
logger.info(f"gguf: mrope sections: {mrope_section[:4]}")

if (rope_theta := rope_params.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta)
logger.info(f"gguf: rope theta = {rope_theta}")
Expand Down Expand Up @@ -3739,9 +3747,6 @@ class Qwen2VLModel(TextModel):

def set_gguf_parameters(self):
super().set_gguf_parameters()
mrope_section = self.hparams["rope_scaling"]["mrope_section"]
mrope_section += [0] * max(0, 4 - len(mrope_section))
self.gguf_writer.add_rope_dimension_sections(mrope_section)

def set_vocab(self):
try:
Expand Down Expand Up @@ -4377,6 +4382,30 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration")
class Glm4VVisionModel(Qwen3VLVisionModel):
def set_gguf_parameters(self):
MmprojModel.set_gguf_parameters(self) # skip Qwen3VLVisionModel parameters
assert self.hparams_vision is not None
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLM4V)

hidden_act = str(self.hparams_vision.get("hidden_act", "")).lower()
if hidden_act == "gelu":
self.gguf_writer.add_vision_use_gelu(True)
elif hidden_act == "silu":
self.gguf_writer.add_vision_use_silu(True)

rms_norm_eps = self.hparams_vision.get("rms_norm_eps", 1e-5)
self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("model.visual."):
name = name.replace("model.visual.", "visual.")
if name.startswith("visual.merger."):
return [(self.map_tensor_name(name), data_torch)]
return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen3VLForConditionalGeneration")
class Qwen3VLTextModel(Qwen3Model):
model_arch = gguf.MODEL_ARCH.QWEN3VL
Expand All @@ -4385,20 +4414,6 @@ def set_gguf_parameters(self):
super().set_gguf_parameters()

# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
text_config = self.hparams.get("text_config", {})
# rope_scaling is deprecated in V5, use rope_parameters instead
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}

if rope_scaling.get("mrope_section"):
# mrope_section contains [time, height, width] dimensions
mrope_section = rope_scaling["mrope_section"]
# Pad to 4 dimensions [time, height, width, extra]
while len(mrope_section) < 4:
mrope_section.append(0)
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])

logger.info(f"MRoPE sections: {mrope_section[:4]}")

vision_config = self.hparams.get("vision_config", {})
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
Expand All @@ -4417,22 +4432,6 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel):

def set_gguf_parameters(self):
super().set_gguf_parameters()

# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
text_config = self.hparams.get("text_config", {})
# rope_scaling is deprecated in V5, use rope_parameters instead
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}

if rope_scaling.get("mrope_section"):
# mrope_section contains [time, height, width] dimensions
mrope_section = rope_scaling["mrope_section"]
# Pad to 4 dimensions [time, height, width, extra]
while len(mrope_section) < 4:
mrope_section.append(0)
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])

logger.info(f"MRoPE sections: {mrope_section[:4]}")

vision_config = self.hparams.get("vision_config", {})
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
Expand Down Expand Up @@ -7795,6 +7794,15 @@ def prepare_tensors(self):
@ModelBase.register("Glm4ForCausalLM", "Glm4vForConditionalGeneration")
class Glm4Model(TextModel):
model_arch = gguf.MODEL_ARCH.GLM4
use_mrope = False
partial_rotary_factor = 0.5

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.partial_rotary_factor = self.rope_parameters.get("partial_rotary_factor", 0.5)
if "mrope_section" in self.rope_parameters:
self.use_mrope = True
logger.info("Q/K weight will need to be permuted for M-RoPE")

def set_vocab(self):
from transformers import AutoTokenizer
Expand All @@ -7816,17 +7824,49 @@ def set_gguf_parameters(self):
super().set_gguf_parameters()
if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.partial_rotary_factor))

@staticmethod
def normal_to_neox(weights: Tensor, n_head: int, n_head_kv: int, head_dim: int, partial_rotary_factor: float) -> Tensor:
orig_shape = weights.shape
if len(orig_shape) == 1:
weights = weights.unsqueeze(1) # [out_dim, 1]
if len(weights.shape) != 2:
raise ValueError("Only 1D and 2D tensors are supported.")
n_effective_heads = weights.shape[0] // head_dim
if n_head_kv is not None and n_effective_heads != n_head:
if n_effective_heads != n_head_kv:
raise AssertionError(f"Mismatch in effective heads: computed {n_effective_heads}, expected {n_head} or {n_head_kv}")
rotary_dim = int(head_dim * partial_rotary_factor)
if rotary_dim % 2 != 0:
raise ValueError("rotary_dim must be even.")
reshaped = weights.reshape(n_effective_heads, head_dim, -1)
rot_part = reshaped[:, :rotary_dim, :]
non_rot_part = reshaped[:, rotary_dim:, :]
permuted_rot = torch.cat((rot_part[:, ::2, :], rot_part[:, 1::2, :]), dim=1)
combined = torch.cat((permuted_rot, non_rot_part), dim=1)
result = combined.reshape(weights.shape)
return result if len(orig_shape) != 1 else result.squeeze(1)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("model.visual."): # ignore visual part of Glm4v
return []
elif name.startswith("model.language_model."):
name = name.replace("language_model.", "") # for Glm4v
if self.use_mrope:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams["num_key_value_heads"]
n_embd = self.hparams["hidden_size"]
head_dim = n_embd // n_head
# because llama.cpp M-RoPE kernel only supports Neox ordering, we have to permute the weights here
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_head, head_dim, self.partial_rotary_factor)
if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_kv_head, head_dim, self.partial_rotary_factor)
return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Glm4MoeForCausalLM")
@ModelBase.register("Glm4MoeForCausalLM", "Glm4vMoeForConditionalGeneration")
class Glm4MoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GLM4_MOE

Expand Down Expand Up @@ -7893,6 +7933,7 @@ def set_gguf_parameters(self):

_experts: list[dict[str, Tensor]] | None = None

# note: unlike GLM4V non-MoE, we don't need to permute Q/K here since GLM4V_MOE uses Neox ordering already
def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None
) -> Iterable[tuple[str, Tensor]]:
Expand Down
7 changes: 7 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ class MODEL_TENSOR(IntEnum):
V_MMPROJ_PEG = auto()
V_ENC_EMBD_CLS = auto()
V_ENC_EMBD_PATCH = auto()
V_ENC_EMBD_NORM = auto()
V_ENC_EMBD_POS = auto()
V_ENC_INPUT_NORM = auto()
V_ENC_ATTN_QKV = auto()
Expand All @@ -660,6 +661,7 @@ class MODEL_TENSOR(IntEnum):
V_LAYER_SCALE_2 = auto()
V_PRE_NORM = auto()
V_POST_NORM = auto()
V_MM_POST_NORM = auto()
V_MM_INP_NORM = auto()
V_MM_INP_PROJ = auto() # gemma3
V_MM_SOFT_EMB_NORM = auto() # gemma3
Expand Down Expand Up @@ -1014,6 +1016,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}",
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd",
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd",
MODEL_TENSOR.V_ENC_EMBD_NORM: "v.norm_embd",
MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd",
MODEL_TENSOR.V_ENC_ATTN_QKV: "v.blk.{bid}.attn_qkv",
MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q",
Expand All @@ -1032,6 +1035,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2",
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
MODEL_TENSOR.V_MM_POST_NORM: "mm.post_norm",
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm",
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm",
Expand Down Expand Up @@ -1092,6 +1096,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_MMPROJ_PEG,
MODEL_TENSOR.V_ENC_EMBD_CLS,
MODEL_TENSOR.V_ENC_EMBD_PATCH,
MODEL_TENSOR.V_ENC_EMBD_NORM,
MODEL_TENSOR.V_ENC_EMBD_POS,
MODEL_TENSOR.V_ENC_INPUT_NORM,
MODEL_TENSOR.V_ENC_ATTN_QKV,
Expand All @@ -1110,6 +1115,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_LAYER_SCALE_2,
MODEL_TENSOR.V_PRE_NORM,
MODEL_TENSOR.V_POST_NORM,
MODEL_TENSOR.V_MM_POST_NORM,
MODEL_TENSOR.V_MM_INP_PROJ,
MODEL_TENSOR.V_MM_INP_NORM,
MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
Expand Down Expand Up @@ -3328,6 +3334,7 @@ class VisionProjectorType:
LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm"
JANUS_PRO = "janus_pro"
GLM4V = "glm4v"


# Items here are (block size, type size)
Expand Down
15 changes: 15 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MMPROJ_FC: (
"model.connector.modality_projection.proj", # SmolVLM
"model.vision.linear_proj.linear_proj", # cogvlm
"visual.merger.proj", # glm4v
),

MODEL_TENSOR.V_MMPROJ_MLP: (
Expand Down Expand Up @@ -1240,6 +1241,10 @@ class TensorNameMap:
"model.vision.patch_embedding.proj", # cogvlm
),

MODEL_TENSOR.V_ENC_EMBD_NORM: (
"visual.post_conv_layernorm", # glm4v
),

MODEL_TENSOR.V_ENC_EMBD_POS: (
"vision_tower.vision_model.embeddings.position_embedding",
"model.vision_tower.embeddings.position_embeddings", # Intern-S1
Expand All @@ -1249,6 +1254,7 @@ class TensorNameMap:
"vision_tower.patch_embed.pos_emb", # kimi-vl
"visual.pos_embed", # qwen3vl
"model.vision.patch_embedding.position_embedding", # cogvlm
"visual.embeddings.position_embedding", # glm4v
),

MODEL_TENSOR.V_ENC_ATTN_QKV: (
Expand Down Expand Up @@ -1404,6 +1410,11 @@ class TensorNameMap:
"vision_model.layernorm_post", # llama4
"visual.merger.ln_q", # qwen2vl
"vision_tower.encoder.final_layernorm", # kimi-vl
"visual.post_layernorm", # glm4v
),

MODEL_TENSOR.V_MM_POST_NORM: (
"visual.merger.post_projection_norm", # glm4v
),

MODEL_TENSOR.V_MM_INP_PROJ: (
Expand Down Expand Up @@ -1473,6 +1484,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MM_PATCH_MERGER: (
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 - hf
"patch_merger.merging_layer", # mistral
"visual.downsample", # glm4v
),

MODEL_TENSOR.V_DS_NORM: (
Expand All @@ -1493,14 +1505,17 @@ class TensorNameMap:

MODEL_TENSOR.V_MM_UP: (
"model.vision.linear_proj.dense_h_to_4h", # cogvlm
"visual.merger.up_proj", # glm4v
),

MODEL_TENSOR.V_MM_DOWN: (
"model.vision.linear_proj.dense_4h_to_h", # cogvlm
"visual.merger.down_proj", # glm4v
),

MODEL_TENSOR.V_MM_GATE: (
"model.vision.linear_proj.gate_proj", # cogvlm
"visual.merger.gate_proj", # glm4v
),

MODEL_TENSOR.V_TOK_BOI: (
Expand Down
4 changes: 4 additions & 0 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,7 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama

return false;
}

bool llama_hparams::use_mrope() const {
return rope_sections[0] > 0 && rope_sections[1] > 0;
}
2 changes: 2 additions & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ struct llama_hparams {
// TODO: think of a better place for this function
// TODO: pack the SWA params in a struct?
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);

bool use_mrope() const;
};

static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
15 changes: 10 additions & 5 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_GLM4:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);
switch (hparams.n_layer) {
case 40: type = LLM_TYPE_9B; break;
case 61: type = LLM_TYPE_32B; break;
Expand All @@ -1697,8 +1698,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_GLM4_MOE:
{
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);

// MoE parameters
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert);
Expand Down Expand Up @@ -7761,7 +7763,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_PLM:
case LLM_ARCH_CHATGLM:
case LLM_ARCH_GLM4:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_GRANITE_HYBRID:
Expand Down Expand Up @@ -7823,7 +7824,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
case LLM_ARCH_SMALLTHINKER:
case LLM_ARCH_GLM4_MOE:
case LLM_ARCH_SEED_OSS:
case LLM_ARCH_GROVEMOE:
case LLM_ARCH_APERTUS:
Expand All @@ -7840,6 +7840,11 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_QWEN3VLMOE:
return LLAMA_ROPE_TYPE_IMROPE;

case LLM_ARCH_GLM4:
return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NORM;
case LLM_ARCH_GLM4_MOE:
return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX;
Comment on lines +7843 to +7846
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the 2 models (vision and non-vision) are mostly the same, except for the rope mode, so I was quite lazy not to duplicate it into a new arch (which adds involves quite a lot of copy-paste code)

I hope that we can somewhat allow de-duplicating some code via #18051

In the meantime, lmk if you're OK with keeping this hack, or a new arch is still preferable @ggerganov @CISC


// all model arches should be listed explicitly here
case LLM_ARCH_UNKNOWN:
GGML_ABORT("unknown architecture");
Expand Down
Loading
Loading