Skip to content

Commit 8ea958d

Browse files
ngxsontdakhran
andauthored
model : add ASR support for LFM2-Audio-1.5B (conformer) (#18106)
* ASR with LFM2-Audio-1.5B * Set rope_theta * Fix comment * Remove rope_theta setting * Address PR feedback * rename functions to conformer * remove some redundant ggml_cont * fix missing tensor * add prefix "a." for conv tensors * remove redundant reshape * clean up * add test model --------- Co-authored-by: Tarek Dakhran <[email protected]>
1 parent f9ec885 commit 8ea958d

File tree

17 files changed

+669
-29
lines changed

17 files changed

+669
-29
lines changed

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1196,7 +1196,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
11961196
[](common_params & params, const std::string & value) {
11971197
params.system_prompt = value;
11981198
}
1199-
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DIFFUSION}));
1199+
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DIFFUSION, LLAMA_EXAMPLE_MTMD}));
12001200
add_opt(common_arg(
12011201
{"--perf"},
12021202
{"--no-perf"},

convert_hf_to_gguf.py

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,9 @@ def load_hparams(dir_model: Path, is_mistral_format: bool):
712712
if "thinker_config" in config:
713713
# rename for Qwen2.5-Omni
714714
config["text_config"] = config["thinker_config"]["text_config"]
715+
if "lfm" in config:
716+
# rename for LFM2-Audio
717+
config["text_config"] = config["lfm"]
715718
return config
716719

717720
@classmethod
@@ -9713,19 +9716,25 @@ def set_gguf_parameters(self):
97139716
self._add_feed_forward_length()
97149717

97159718
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
9716-
is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
9717-
if is_vision_tensor:
9718-
# skip vision tensors
9719+
if self._is_vision_tensor(name) or self._is_audio_tensor(name):
9720+
# skip multimodal tensors
97199721
return []
97209722

9721-
name = name.replace("language_model.", "")
9723+
name = name.replace("language_model.", "") # vision
9724+
name = name.replace("lfm.", "model.") # audio
97229725

97239726
# conv op requires 2d tensor
97249727
if 'conv.conv' in name:
97259728
data_torch = data_torch.squeeze(1)
97269729

97279730
return [(self.map_tensor_name(name), data_torch)]
97289731

9732+
def _is_vision_tensor(self, name: str) -> bool:
9733+
return "vision_tower" in name or "multi_modal_projector" in name
9734+
9735+
def _is_audio_tensor(self, name: str):
9736+
return any(p in name for p in ["audio", "codebook", "conformer", "depth_embedding", "depthformer", "depth_linear"])
9737+
97299738

97309739
@ModelBase.register("Lfm2MoeForCausalLM")
97319740
class LFM2MoeModel(TextModel):
@@ -9831,6 +9840,81 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
98319840
return [] # skip other tensors
98329841

98339842

9843+
@ModelBase.register("Lfm2AudioForConditionalGeneration")
9844+
class LFM2AudioModel(MmprojModel):
9845+
has_vision_encoder = False
9846+
has_audio_encoder = True
9847+
model_name = "Lfm2AudioEncoder"
9848+
9849+
_batch_norm_tensors: list[dict[str, Tensor]] | None = None
9850+
9851+
def get_audio_config(self) -> dict[str, Any] | None:
9852+
return self.global_config.get("encoder")
9853+
9854+
def set_gguf_parameters(self):
9855+
assert self.hparams_audio is not None
9856+
self.hparams_audio["hidden_size"] = self.hparams_audio["d_model"]
9857+
self.hparams_audio["intermediate_size"] = self.hparams_audio["d_model"]
9858+
self.hparams_audio["num_attention_heads"] = self.hparams_audio["n_heads"]
9859+
super().set_gguf_parameters()
9860+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LFM2A)
9861+
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["feat_in"])
9862+
self.gguf_writer.add_audio_attention_layernorm_eps(1e-5)
9863+
9864+
def tensor_force_quant(self, name, new_name, bid, n_dims):
9865+
if ".conv" in name and ".weight" in name:
9866+
return gguf.GGMLQuantizationType.F32
9867+
return super().tensor_force_quant(name, new_name, bid, n_dims)
9868+
9869+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
9870+
# skip language model tensors
9871+
if name.startswith("lfm."):
9872+
return []
9873+
9874+
# for training only
9875+
if any(p in name for p in ["audio_loss_weight"]):
9876+
return []
9877+
9878+
# for audio output
9879+
if any(p in name for p in ["codebook_offsets", "depth_embeddings", "depth_linear", "depthformer"]):
9880+
return []
9881+
9882+
# fold running_mean, running_var and eps into weight and bias for batch_norm
9883+
if "batch_norm" in name:
9884+
if self._batch_norm_tensors is None:
9885+
self._batch_norm_tensors = [{} for _ in range(self.block_count)]
9886+
assert bid is not None
9887+
self._batch_norm_tensors[bid][name] = data_torch
9888+
9889+
if len(self._batch_norm_tensors[bid]) < 5:
9890+
return []
9891+
9892+
weight = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.weight"]
9893+
bias = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.bias"]
9894+
running_mean = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.running_mean"]
9895+
running_var = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.running_var"]
9896+
eps = 1e-5 # default value
9897+
9898+
a = weight / torch.sqrt(running_var + eps)
9899+
b = bias - running_mean * a
9900+
return [
9901+
(self.map_tensor_name(f"conformer.layers.{bid}.conv.batch_norm.weight"), a),
9902+
(self.map_tensor_name(f"conformer.layers.{bid}.conv.batch_norm.bias"), b),
9903+
]
9904+
9905+
# reshape conv weights
9906+
if name.startswith("conformer.pre_encode.conv.") and name.endswith(".bias"):
9907+
data_torch = data_torch[:, None, None]
9908+
if "conv.depthwise_conv" in name and name.endswith(".weight"):
9909+
assert data_torch.shape[1] == 1
9910+
data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[2])
9911+
if "conv.pointwise_conv" in name and name.endswith(".weight"):
9912+
assert data_torch.shape[2] == 1
9913+
data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[1])
9914+
9915+
return [(self.map_tensor_name(name), data_torch)]
9916+
9917+
98349918
@ModelBase.register("SmallThinkerForCausalLM")
98359919
class SmallThinkerModel(TextModel):
98369920
model_arch = gguf.MODEL_ARCH.SMALLTHINKER

ggml/src/ggml-cuda/ssm-conv.cu

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -102,31 +102,25 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
102102
const int threads = 128;
103103
GGML_ASSERT(nr % threads == 0);
104104

105-
if (n_t <= 32) {
106-
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
107-
if (nc == 4) {
108-
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
109-
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
110-
} else if (nc == 3) {
111-
ssm_conv_f32<threads, 3><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
112-
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
105+
auto launch_kernel = [&](auto NC) {
106+
constexpr int kNC = decltype(NC)::value;
107+
if (n_t <= 32) {
108+
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
109+
ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
110+
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
113111
} else {
114-
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
115-
}
116-
} else {
117-
if (nc == 4) {
118-
const int64_t split_n_t = 32;
119-
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
120-
ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
121-
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
122-
} else if (nc == 3) {
123112
const int64_t split_n_t = 32;
124113
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
125-
ssm_conv_long_token_f32<threads, 3, split_n_t><<<blocks, threads, 0, stream>>>(
114+
ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>(
126115
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
127-
} else {
128-
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
129116
}
117+
};
118+
119+
switch (nc) {
120+
case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
121+
case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
122+
case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
123+
default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
130124
}
131125
}
132126

gguf-py/gguf/constants.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,8 @@ class MODEL_TENSOR(IntEnum):
690690
V_TOK_EOI = auto() # cogvlm
691691
# audio (mtmd)
692692
A_ENC_EMBD_POS = auto()
693+
A_ENC_EMBD_NORM = auto()
694+
A_ENC_EMBD_TO_LOGITS = auto()
693695
A_ENC_CONV1D = auto()
694696
A_PRE_NORM = auto()
695697
A_POST_NORM = auto()
@@ -700,8 +702,13 @@ class MODEL_TENSOR(IntEnum):
700702
A_ENC_OUTPUT = auto()
701703
A_ENC_OUTPUT_NORM = auto()
702704
A_ENC_FFN_UP = auto()
705+
A_ENC_FFN_NORM = auto()
703706
A_ENC_FFN_GATE = auto()
704707
A_ENC_FFN_DOWN = auto()
708+
A_ENC_FFN_UP_1 = auto()
709+
A_ENC_FFN_NORM_1 = auto()
710+
A_ENC_FFN_GATE_1 = auto()
711+
A_ENC_FFN_DOWN_1 = auto()
705712
A_MMPROJ = auto()
706713
A_MMPROJ_FC = auto()
707714
A_MM_NORM_PRE = auto()
@@ -713,6 +720,16 @@ class MODEL_TENSOR(IntEnum):
713720
NEXTN_HNORM = auto()
714721
NEXTN_SHARED_HEAD_HEAD = auto()
715722
NEXTN_SHARED_HEAD_NORM = auto()
723+
# lfm2 audio
724+
A_ENC_NORM_CONV = auto()
725+
A_ENC_LINEAR_POS = auto()
726+
A_ENC_POS_BIAS_U = auto()
727+
A_ENC_POS_BIAS_V = auto()
728+
A_ENC_OUT = auto()
729+
A_ENC_CONV_DW = auto() # SSM conv
730+
A_ENC_CONV_NORM = auto() # SSM conv
731+
A_ENC_CONV_PW1 = auto()
732+
A_ENC_CONV_PW2 = auto()
716733

717734

718735
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -1064,7 +1081,10 @@ class MODEL_TENSOR(IntEnum):
10641081
MODEL_TENSOR.V_TOK_BOI: "v.boi",
10651082
MODEL_TENSOR.V_TOK_EOI: "v.eoi",
10661083
# audio (mtmd)
1084+
# note: all audio tensor names must use prefix "a." or "mm.a."
10671085
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
1086+
MODEL_TENSOR.A_ENC_EMBD_NORM: "a.position_embd_norm",
1087+
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: "a.embd_to_logits",
10681088
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
10691089
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
10701090
MODEL_TENSOR.A_POST_NORM: "a.post_ln",
@@ -1074,13 +1094,28 @@ class MODEL_TENSOR(IntEnum):
10741094
MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1",
10751095
MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out",
10761096
MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2",
1097+
MODEL_TENSOR.A_ENC_FFN_NORM: "a.blk.{bid}.ffn_norm",
10771098
MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up",
10781099
MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate",
10791100
MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down",
1101+
MODEL_TENSOR.A_ENC_FFN_NORM_1: "a.blk.{bid}.ffn_norm_1",
1102+
MODEL_TENSOR.A_ENC_FFN_UP_1: "a.blk.{bid}.ffn_up_1",
1103+
MODEL_TENSOR.A_ENC_FFN_GATE_1: "a.blk.{bid}.ffn_gate_1",
1104+
MODEL_TENSOR.A_ENC_FFN_DOWN_1: "a.blk.{bid}.ffn_down_1",
10801105
MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}",
10811106
MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc",
10821107
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
10831108
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
1109+
# lfm2 audio
1110+
MODEL_TENSOR.A_ENC_NORM_CONV: "a.blk.{bid}.norm_conv",
1111+
MODEL_TENSOR.A_ENC_LINEAR_POS: "a.blk.{bid}.linear_pos",
1112+
MODEL_TENSOR.A_ENC_POS_BIAS_U: "a.blk.{bid}.pos_bias_u",
1113+
MODEL_TENSOR.A_ENC_POS_BIAS_V: "a.blk.{bid}.pos_bias_v",
1114+
MODEL_TENSOR.A_ENC_OUT: "a.pre_encode.out",
1115+
MODEL_TENSOR.A_ENC_CONV_DW: "a.blk.{bid}.conv_dw",
1116+
MODEL_TENSOR.A_ENC_CONV_NORM: "a.blk.{bid}.conv_norm",
1117+
MODEL_TENSOR.A_ENC_CONV_PW1: "a.blk.{bid}.conv_pw1",
1118+
MODEL_TENSOR.A_ENC_CONV_PW2: "a.blk.{bid}.conv_pw2",
10841119
# NextN/MTP
10851120
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj",
10861121
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens",
@@ -1145,6 +1180,8 @@ class MODEL_TENSOR(IntEnum):
11451180
MODEL_TENSOR.V_TOK_EOI,
11461181
# audio
11471182
MODEL_TENSOR.A_ENC_EMBD_POS,
1183+
MODEL_TENSOR.A_ENC_EMBD_NORM,
1184+
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS,
11481185
MODEL_TENSOR.A_ENC_CONV1D,
11491186
MODEL_TENSOR.A_PRE_NORM,
11501187
MODEL_TENSOR.A_POST_NORM,
@@ -1154,13 +1191,27 @@ class MODEL_TENSOR(IntEnum):
11541191
MODEL_TENSOR.A_ENC_INPUT_NORM,
11551192
MODEL_TENSOR.A_ENC_OUTPUT,
11561193
MODEL_TENSOR.A_ENC_OUTPUT_NORM,
1194+
MODEL_TENSOR.A_ENC_FFN_NORM,
11571195
MODEL_TENSOR.A_ENC_FFN_UP,
11581196
MODEL_TENSOR.A_ENC_FFN_GATE,
11591197
MODEL_TENSOR.A_ENC_FFN_DOWN,
1198+
MODEL_TENSOR.A_ENC_FFN_NORM_1,
1199+
MODEL_TENSOR.A_ENC_FFN_UP_1,
1200+
MODEL_TENSOR.A_ENC_FFN_GATE_1,
1201+
MODEL_TENSOR.A_ENC_FFN_DOWN_1,
11601202
MODEL_TENSOR.A_MMPROJ,
11611203
MODEL_TENSOR.A_MMPROJ_FC,
11621204
MODEL_TENSOR.A_MM_NORM_PRE,
11631205
MODEL_TENSOR.A_MM_NORM_MID,
1206+
MODEL_TENSOR.A_ENC_NORM_CONV,
1207+
MODEL_TENSOR.A_ENC_LINEAR_POS,
1208+
MODEL_TENSOR.A_ENC_POS_BIAS_U,
1209+
MODEL_TENSOR.A_ENC_POS_BIAS_V,
1210+
MODEL_TENSOR.A_ENC_OUT,
1211+
MODEL_TENSOR.A_ENC_CONV_DW,
1212+
MODEL_TENSOR.A_ENC_CONV_NORM,
1213+
MODEL_TENSOR.A_ENC_CONV_PW1,
1214+
MODEL_TENSOR.A_ENC_CONV_PW2,
11641215
],
11651216
MODEL_ARCH.LLAMA: [
11661217
MODEL_TENSOR.TOKEN_EMBD,
@@ -3363,6 +3414,7 @@ class VisionProjectorType:
33633414
LIGHTONOCR = "lightonocr"
33643415
COGVLM = "cogvlm"
33653416
JANUS_PRO = "janus_pro"
3417+
LFM2A = "lfm2a" # audio
33663418
GLM4V = "glm4v"
33673419

33683420

0 commit comments

Comments
 (0)