Skip to content

Commit 01f720a

Browse files
fix
1 parent 1f0dcc3 commit 01f720a

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2388,7 +2388,11 @@ def forward(
23882388
self.rope_deltas = rope_deltas
23892389

23902390
else:
2391-
batch_size, seq_length = input_ids.shape
2391+
if inputs_embeds is not None:
2392+
batch_size, seq_length, _ = inputs_embeds.shape
2393+
else:
2394+
batch_size, seq_length = input_ids.shape
2395+
23922396
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
23932397
position_ids = torch.arange(seq_length, device=input_ids.device)
23942398
position_ids = position_ids.view(1, -1).expand(batch_size, -1)

src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2588,7 +2588,11 @@ def forward(
25882588
self.rope_deltas = rope_deltas
25892589

25902590
else:
2591-
batch_size, seq_length = input_ids.shape
2591+
if inputs_embeds is not None:
2592+
batch_size, seq_length, _ = inputs_embeds.shape
2593+
else:
2594+
batch_size, seq_length = input_ids.shape
2595+
25922596
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
25932597
position_ids = torch.arange(seq_length, device=input_ids.device)
25942598
position_ids = position_ids.view(1, -1).expand(batch_size, -1)

0 commit comments

Comments
 (0)