Skip to content

Fix FP16 overflow in GQA attention and concat_past_present buffer overflow#4677

Open
aditya-dl wants to merge 3 commits intoROCm:developfrom
aditya-dl:fix-fp16-overflow-gqa-attention
Open

Fix FP16 overflow in GQA attention and concat_past_present buffer overflow#4677
aditya-dl wants to merge 3 commits intoROCm:developfrom
aditya-dl:fix-fp16-overflow-gqa-attention

Conversation

@aditya-dl
Copy link
Copy Markdown
Contributor

@aditya-dl aditya-dl commented Mar 16, 2026

Motivation

Qwen1.5-architecture models produce garbage output when running FP16 inference through MIGraphX. Two root causes were identified:

  1. The dot-to-softmax attention chain overflows in FP16 precision

Technical Details

FP16 overflow fix: Extends find_softmax_base_ops to walk backwards through the attention chain (mul, where, broadcast, convert) to find the feeding dot instruction. The entire dot-to-softmax range is upcast to FP32, preventing overflow in attention score computation. Bool-type inputs (where conditions) are excluded from conversion.

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

@aditya-dl aditya-dl requested a review from causten as a code owner March 16, 2026 19:27
@pfultz2
Copy link
Copy Markdown
Collaborator

pfultz2 commented Mar 17, 2026

Unit tests need to be added and the CI failure fixed.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes two root causes of garbage output in Qwen1.5-architecture models during FP16 inference: (1) FP16 overflow in the dot→softmax attention chain, and (2) a buffer overflow in concat_past_present during prompt processing when the sequence length exceeds the past cache size.

Changes:

  • Extends find_softmax_base_ops in rewrite_reduce.cpp to walk backward from softmax through mul/where/broadcast/convert to find a feeding dot instruction, upcasting the entire range to FP32 (with bool inputs excluded).
  • Fixes concat_past_present buffer sizing across the operator definition, GPU lowering, JIT compiler, and GPU kernel so that the output buffer is properly sized when sequence_length > past_cache_sequence_length.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
src/rewrite_reduce.cpp Adds backward walk from softmax to dot for FP32 upcast range extension; skips bool inputs in conversion
src/include/migraphx/op/concat_past_present.hpp Updates compute_shape to return larger shape when needed; uses std::max for present_buffer_sequence_length
src/targets/gpu/lowering.cpp Allocates properly-sized GPU buffer when output shape exceeds past cache shape
src/targets/gpu/jit/concat_past_present.cpp Adjusts JIT compiler output shape to match larger buffer when needed
src/targets/gpu/kernels/include/migraphx/kernels/concat_past_present.hpp GPU kernel uses max(past_seq, seq_len) for present buffer sequence length

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

@pfultz2
Copy link
Copy Markdown
Collaborator

pfultz2 commented Mar 18, 2026

I think you should put the softmax change in a separate PR as I dont think we will merge the concat_past_present change.

@aditya-dl aditya-dl force-pushed the fix-fp16-overflow-gqa-attention branch from 6f03700 to 9da5d9f Compare March 19, 2026 17:39
@aditya-dl
Copy link
Copy Markdown
Contributor Author

@pfultz2 Updated this PR with only the softmax change.

@aditya-dl
Copy link
Copy Markdown
Contributor Author

The CI failures in test_ck_gemm_softmax_gemm_0 and test_ck_gemm_softmax_gemm_1 were caused by the backward walk unconditionally extending the FP32 upcast to the upstream dot. For non-masked attention patterns (dot→mul→softmax), propagate_precision merges the FP32 converts into a clean chain, and fuse_attention fuses it into an MLIR attention module. rocMLIR's high-level pipeline then fails because it can't handle FP32 tosa.reduce_max inside fused attention.

The fix gates the backward walk on the presence of where in the chain. Without where, the dot stays in FP16 and CK fused attention handles precision internally. With where (all GQA models using the GroupQueryAttention ONNX op — Qwen, Llama, Phi, DeepSeek), the dot is upcast to FP32. The where ops block propagate_precision from merging converts (multi-input op), which prevents fuse_attention from matching, so ops run as separate FP32 kernels — fixing the FP16 overflow.

Added a new test (softmax_dot_no_where_preserves_half) verifying the dot is NOT upcast when where is absent.

Comment on lines +103 to +126
{
auto current = inp;
bool found_where = false;
for(int depth = 0; depth < 10; ++depth)
{
auto name = current->name();
if(name == "dot")
{
if(found_where)
range_start = current;
break;
}
if(name == "where")
{
found_where = true;
current = current->inputs()[2]; // attention data is 3rd arg
}
else if(name == "mul" || name == "broadcast" ||
name == "multibroadcast" || name == "convert")
current = current->inputs()[0];
else
break;
}
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be cleaner to break this out into another function

{
auto current = inp;
bool found_where = false;
for(int depth = 0; depth < 10; ++depth)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this depth arbitrary or is there an expected value?

else
break;
}
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does a lot in one function and its hard to tell what it is doing. This can be written with unfold, std::find_if and std::adjacent_find to make it a lot cleaner:

auto p = unfold(inp, [](instruction_ref x) -> std::optional<instruction_ref> {
    if(x->inputs().size() == 1)
        return x->inputs().front();
    if(contains({"add", "mul"}, x->name()))
    {
        auto it = std::find_if(x->inputs().begin(), x->inputs().end(), [](instruction_ref input) {
            return not input->can_eval();
        });
        if(it == x->inputs().end())
            return nullopt;
        return *it;
    }
    if(x->name() == "where")
        return x->inputs().at(2); // attention data is 3rd arg
    return nullopt;
});

auto where_it = std::find_if(p.begin(), p.end(), [&](instruction_ref x) {
    return x->name() == "where";
});
if(where_it != p.end())
{
    range_start = std::adjacent_find(where_it, p.end(), [&](instruction_ref, instruction_ref x) {
        return x->name() == "dot";
    });
}

EXPECT(all_of(chain_ops, [](auto ins) {
return ins->get_shape().type() == migraphx::shape::float_type;
}));
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests should be written with the expected module:

module m1;
{
    // Create test module
}

run_pass(m1);

module m2;
{
    // Create expected module
}

EXPECT(m1 == m2);

@aditya-dl
Copy link
Copy Markdown
Contributor Author

The found_where gate doesn't fully resolve the CI failure.
test_group_query_attention_decode_small still fails because propagate_precision merges the FP32 converts for the small test shapes (even with where present), allowing fuse_attention to fuse the FP32 chain into an MLIR attention module. rocMLIR then fails on FP32 tosa.reduce_max ("can't trace the reduction output to a kernel result").

The core issue: find_softmax_base_ops intentionally upcasts the softmax chain to FP32, but fuse_attention later fuses it into an MLIR attention module that rocMLIR can't compile with FP32 reduce operations.

What's the recommended way to prevent fuse_attention from fusing when the softmax chain has been upcast to FP32? For example, would a type check in find_attention::apply() be appropriate?

  if(softmax_input->get_shape().type() == shape::float_type)
      return;

I'll also address the code review comments (unfold refactor, test restructure, depth parameter) once the CI issue is resolved.

@pfultz2
Copy link
Copy Markdown
Collaborator

pfultz2 commented Mar 25, 2026

rocMLIR then fails on FP32 tosa.reduce_max ("can't trace the reduction output to a kernel result").

This doesnt look like data type error. It looks like the graph was changed in a way the rocMLIR can no longer trace it back to the gemm.

What's the recommended way to prevent fuse_attention from fusing when the softmax chain has been upcast to FP32?

We dont want to prevent attention fusion when fp32 softmax is used. We need to make sure this works by either fixing the graph we give to rocMLIR or updating the rocMLIR to support this case.

@aditya-dl
Copy link
Copy Markdown
Contributor Author

rocMLIR then fails on FP32 tosa.reduce_max ("can't trace the reduction output to a kernel result").

This doesnt look like data type error. It looks like the graph was changed in a way the rocMLIR can no longer trace it back to the gemm.

What's the recommended way to prevent fuse_attention from fusing when the softmax chain has been upcast to FP32?

We dont want to prevent attention fusion when fp32 softmax is used. We need to make sure this works by either fixing the graph we give to rocMLIR or updating the rocMLIR to support this case.

@pfultz2 can you advice on a way to debug this to identify the exact root cause of this? This doesn't replicate on my local system with RDNA GPUs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants