Fix FP16 overflow in GQA attention and concat_past_present buffer overflow#4677
Fix FP16 overflow in GQA attention and concat_past_present buffer overflow#4677aditya-dl wants to merge 3 commits intoROCm:developfrom
Conversation
|
Unit tests need to be added and the CI failure fixed. |
There was a problem hiding this comment.
Pull request overview
This PR fixes two root causes of garbage output in Qwen1.5-architecture models during FP16 inference: (1) FP16 overflow in the dot→softmax attention chain, and (2) a buffer overflow in concat_past_present during prompt processing when the sequence length exceeds the past cache size.
Changes:
- Extends
find_softmax_base_opsinrewrite_reduce.cppto walk backward from softmax throughmul/where/broadcast/convertto find a feedingdotinstruction, upcasting the entire range to FP32 (with bool inputs excluded). - Fixes
concat_past_presentbuffer sizing across the operator definition, GPU lowering, JIT compiler, and GPU kernel so that the output buffer is properly sized whensequence_length > past_cache_sequence_length.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
src/rewrite_reduce.cpp |
Adds backward walk from softmax to dot for FP32 upcast range extension; skips bool inputs in conversion |
src/include/migraphx/op/concat_past_present.hpp |
Updates compute_shape to return larger shape when needed; uses std::max for present_buffer_sequence_length |
src/targets/gpu/lowering.cpp |
Allocates properly-sized GPU buffer when output shape exceeds past cache shape |
src/targets/gpu/jit/concat_past_present.cpp |
Adjusts JIT compiler output shape to match larger buffer when needed |
src/targets/gpu/kernels/include/migraphx/kernels/concat_past_present.hpp |
GPU kernel uses max(past_seq, seq_len) for present buffer sequence length |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
|
I think you should put the softmax change in a separate PR as I dont think we will merge the concat_past_present change. |
6f03700 to
9da5d9f
Compare
|
@pfultz2 Updated this PR with only the softmax change. |
|
The CI failures in The fix gates the backward walk on the presence of where in the chain. Without where, the dot stays in FP16 and CK fused attention handles precision internally. With where (all GQA models using the GroupQueryAttention ONNX op — Qwen, Llama, Phi, DeepSeek), the dot is upcast to FP32. The where ops block propagate_precision from merging converts (multi-input op), which prevents fuse_attention from matching, so ops run as separate FP32 kernels — fixing the FP16 overflow. Added a new test (softmax_dot_no_where_preserves_half) verifying the dot is NOT upcast when where is absent. |
| { | ||
| auto current = inp; | ||
| bool found_where = false; | ||
| for(int depth = 0; depth < 10; ++depth) | ||
| { | ||
| auto name = current->name(); | ||
| if(name == "dot") | ||
| { | ||
| if(found_where) | ||
| range_start = current; | ||
| break; | ||
| } | ||
| if(name == "where") | ||
| { | ||
| found_where = true; | ||
| current = current->inputs()[2]; // attention data is 3rd arg | ||
| } | ||
| else if(name == "mul" || name == "broadcast" || | ||
| name == "multibroadcast" || name == "convert") | ||
| current = current->inputs()[0]; | ||
| else | ||
| break; | ||
| } | ||
| } |
There was a problem hiding this comment.
Would be cleaner to break this out into another function
| { | ||
| auto current = inp; | ||
| bool found_where = false; | ||
| for(int depth = 0; depth < 10; ++depth) |
There was a problem hiding this comment.
Is this depth arbitrary or is there an expected value?
| else | ||
| break; | ||
| } | ||
| } |
There was a problem hiding this comment.
This does a lot in one function and its hard to tell what it is doing. This can be written with unfold, std::find_if and std::adjacent_find to make it a lot cleaner:
auto p = unfold(inp, [](instruction_ref x) -> std::optional<instruction_ref> {
if(x->inputs().size() == 1)
return x->inputs().front();
if(contains({"add", "mul"}, x->name()))
{
auto it = std::find_if(x->inputs().begin(), x->inputs().end(), [](instruction_ref input) {
return not input->can_eval();
});
if(it == x->inputs().end())
return nullopt;
return *it;
}
if(x->name() == "where")
return x->inputs().at(2); // attention data is 3rd arg
return nullopt;
});
auto where_it = std::find_if(p.begin(), p.end(), [&](instruction_ref x) {
return x->name() == "where";
});
if(where_it != p.end())
{
range_start = std::adjacent_find(where_it, p.end(), [&](instruction_ref, instruction_ref x) {
return x->name() == "dot";
});
}| EXPECT(all_of(chain_ops, [](auto ins) { | ||
| return ins->get_shape().type() == migraphx::shape::float_type; | ||
| })); | ||
| } |
There was a problem hiding this comment.
Tests should be written with the expected module:
module m1;
{
// Create test module
}
run_pass(m1);
module m2;
{
// Create expected module
}
EXPECT(m1 == m2);|
The found_where gate doesn't fully resolve the CI failure. The core issue: What's the recommended way to prevent fuse_attention from fusing when the softmax chain has been upcast to FP32? For example, would a type check in find_attention::apply() be appropriate? I'll also address the code review comments (unfold refactor, test restructure, depth parameter) once the CI issue is resolved. |
This doesnt look like data type error. It looks like the graph was changed in a way the rocMLIR can no longer trace it back to the gemm.
We dont want to prevent attention fusion when fp32 softmax is used. We need to make sure this works by either fixing the graph we give to rocMLIR or updating the rocMLIR to support this case. |
@pfultz2 can you advice on a way to debug this to identify the exact root cause of this? This doesn't replicate on my local system with RDNA GPUs. |
Motivation
Qwen1.5-architecture models produce garbage output when running FP16 inference through MIGraphX. Two root causes were identified:
Technical Details
FP16 overflow fix: Extends
find_softmax_base_opsto walk backwards through the attention chain (mul, where, broadcast, convert) to find the feeding dot instruction. The entiredot-to-softmax range is upcast to FP32, preventing overflow in attention score computation. Bool-type inputs (where conditions) are excluded from conversion.Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable