Skip to content

Conversation

@TianHao324
Copy link
Contributor

Summary

Testing Done

image
  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@TianHao324
Copy link
Contributor Author

TianHao324 commented Dec 15, 2025

@Tcc0403 @zheliuyu Please review this revision.

@TianHao324
Copy link
Contributor Author

@zheliuyu Also, could you please provide the performance data of the relevant operators on the GPU?

@zheliuyu
Copy link
Contributor

@zheliuyu Also, could you please provide the performance data of the relevant operators on the GPU?

Of course. Could you please let me know which NPU you are using?

k_base = k_ptr + pid * k_row_stride

# Process in chunks to prevent UB overflow
for qh_block in range(0, n_qh, BLOCK_Q):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L45-L100:
Why can't we just reuse the existing logic?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# left half of the head
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
# right half of the head
second_half_q_offsets = first_half_q_offsets + (hd // 2)
second_half_k_offsets = first_half_k_offsets + (hd // 2)
second_q_mask = first_q_mask
second_k_mask = first_k_mask
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)

I think we don't have to worry about head_dim being too large since head_dim in most llms are really small (qwen3/llama: 128, gemma3: 256). Simply looping over pad_n_qh or pad_n_kh with smaller block size in similar a way to prevent UB overflow should be enough.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# left half of the head
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
# right half of the head
second_half_q_offsets = first_half_q_offsets + (hd // 2)
second_half_k_offsets = first_half_k_offsets + (hd // 2)
second_q_mask = first_q_mask
second_k_mask = first_k_mask
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)

我觉得我们不必担心体积太大,因为大多数大型语言模型都很小(qwen3/llama:128,gemma3:256)。 简单地循环覆盖或以类似方式减少块大小以防止UB溢出,应该就足够了。head_dim``head_dim``pad_n_qh``pad_n_kh

edited, Now we only perform block division on q and k.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L45-L100: Why can't we just reuse the existing logic?

The existing logic can cause ub overflow on npu.

@TianHao324 TianHao324 marked this pull request as draft December 16, 2025 01:36
@TianHao324 TianHao324 marked this pull request as ready for review December 23, 2025 10:27
) -> int:
dev_props = torch.npu.get_device_properties(0).name
tbe.common.platform.set_current_compile_soc_info(dev_props)
ub_size_bytes=tbe.common.platform.get_soc_spec("UB_SIZE")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there another way to get UB_SIZE? Ideally we want to avoid introducing a new dependency if it's only for such info.

Comment on lines 107 to 117
# 3 * hd // 2 + (3 * hd + 2) * max_block <= max_elements
max_block = min(
n_heads,
int((max_elements - 3 * hd // 2) // (3 * hd + 2))
)

if max_block != triton.next_power_of_2(max_block):
return triton.next_power_of_2(max_block) // 2
else:
return triton.next_power_of_2(max_block)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is a little bit confusing.

Is total ub cost an assumption? or real data collected from profiling system? I'm not familiar with npu so I might be wrong, but in cuda I believe only q and k blocks are loaded into shared memory while others are just register values.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming the cost is correct, I suggest rewriting conditions for readibility.

Suggested change
# 3 * hd // 2 + (3 * hd + 2) * max_block <= max_elements
max_block = min(
n_heads,
int((max_elements - 3 * hd // 2) // (3 * hd + 2))
)
if max_block != triton.next_power_of_2(max_block):
return triton.next_power_of_2(max_block) // 2
else:
return triton.next_power_of_2(max_block)
# 3 * hd // 2 + (3 * hd + 2) * max_block <= max_elements
max_block = int((max_elements - 3 * hd // 2) // (3 * hd + 2))
if max_block < n_heads:
return triton.next_power_of_2(max_block) // 2
else:
# n_heads is power of 2 in most models, so it's unlikely to exceed max_block
return triton.next_power_of_2(n_heads)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another approach is hardcoding block_size for different npu architectures from autotuning.

Copy link
Contributor Author

@TianHao324 TianHao324 Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is a little bit confusing.

Is total ub cost an assumption? or real data collected from profiling system? I'm not familiar with npu so I might be wrong, but in cuda I believe only q and k blocks are loaded into shared memory while others are just register values.

The cost of these ubs is estimated based on the actual code.

total ub cost:
cos_vals, sin_vals, d_idx: hd//2
qh_idx, qh_mask: BLOCK
block_mask, offsets, q_left, q_right, new_left, new_right: BLOCK × (hd//2)
3 * hd // 2 + (3 * hd + 2) * max_block <= max_elements

The corresponding functional segment is(The logic for q and k is the same.)

d_idx = tl.arange(0, hd // 2)
cos_vals = tl.load(cos + d_idx)
sin_vals = tl.load(sin + d_idx)

q_base = q_ptr + pid * q_row_stride
k_base = k_ptr + pid * k_row_stride

# Process in chunks to prevent UB overflow
for qh_block in range(0, n_qh, BLOCK_Q):
    qh_idx = tl.arange(0, BLOCK_Q) + qh_block

    qh_mask = qh_idx < n_qh
    block_mask = qh_mask[:, None]

    offsets = qh_idx[:, None] * hd + d_idx[None, :]

    q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
    q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)

    if not BACKWARD_PASS:
        new_left = q_left * cos_vals - q_right * sin_vals
        new_right = q_right * cos_vals + q_left * sin_vals
    else:
        new_left = q_left * cos_vals + q_right * sin_vals
        new_right = q_right * cos_vals - q_left * sin_vals

    tl.store(q_base + offsets, new_left, mask=block_mask)
    tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another approach is hardcoding block_size for different npu architectures from autotuning.

Because the ub size varies between different Ascend devices, such as B2 and B4, hardcoding may result in poor performance.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Dec 29, 2025

Thanks for the contribution! Added you to the co-author list in #987

@Tcc0403 Tcc0403 closed this Dec 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants