Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
21d9e11
Adding attention bias support in MHA
urpetkov-amd Mar 17, 2026
37a5668
Adding tests for bias and attention bias
urpetkov-amd Mar 17, 2026
1373214
Clang format
urpetkov-amd Mar 17, 2026
374b6a6
Licencing
urpetkov-amd Mar 17, 2026
a16f79c
Merge branch 'develop' into multihead_attention_bias
urpetkov-amd Mar 17, 2026
f922d41
Merge branch 'develop' into multihead_attention_bias
urpetkov-amd Mar 18, 2026
940735f
Adding more tests
urpetkov-amd Mar 20, 2026
bf9e7a5
Merge branch 'develop' into multihead_attention_bias
urpetkov-amd Mar 20, 2026
b66963c
Adding verify test
urpetkov-amd Mar 20, 2026
94b461a
Adding verify tests
urpetkov-amd Mar 20, 2026
7c59780
Merge branch 'develop' into multihead_attention_bias
urpetkov-amd Mar 20, 2026
1932ae6
Clang format
urpetkov-amd Mar 20, 2026
29a1d32
Merge branch 'develop' into multihead_attention_bias
urpetkov-amd Mar 23, 2026
77e32db
Merge branch 'develop' into multihead_attention_bias
urpetkov-amd Mar 24, 2026
e003bb7
Adding past key value support
urpetkov-amd Mar 24, 2026
2f89cbb
Additional changes
urpetkov-amd Mar 24, 2026
101bbef
Fixing empty inputs
urpetkov-amd Mar 24, 2026
1881473
Adding parse tests
urpetkov-amd Mar 24, 2026
b8deb4b
Deleting matmulbnb4
urpetkov-amd Mar 25, 2026
45d7e98
Merge branch 'develop' into multihead_past_kv_seq
urpetkov-amd Mar 25, 2026
769e724
Resolving one missing conflict
urpetkov-amd Mar 25, 2026
8c75f0b
Deleting past state test
urpetkov-amd Mar 25, 2026
5d48baa
Clang format
urpetkov-amd Mar 25, 2026
c8902a2
Fix nested failure
urpetkov-amd Mar 25, 2026
bea768c
Merge branch 'develop' into multihead_past_kv_seq
urpetkov-amd Mar 25, 2026
abfe5ab
Delete redundant checks
urpetkov-amd Mar 26, 2026
cadbed6
Merge branch 'develop' into multihead_past_kv_seq
urpetkov-amd Mar 26, 2026
42da9c4
Fixing missing check for attention bias
urpetkov-amd Mar 26, 2026
b1cea9c
Adding past kv tests
urpetkov-amd Mar 26, 2026
72ee022
Clang format
urpetkov-amd Mar 26, 2026
3706e40
Adding verify tests
urpetkov-amd Mar 26, 2026
e01fe99
Adding tests
urpetkov-amd Mar 26, 2026
3ca3e4a
Fix clang tidy issue
urpetkov-amd Mar 26, 2026
f7fd2ea
Deleting unnecessary changes
urpetkov-amd Mar 26, 2026
71f2370
Merge branch 'develop' into multihead_past_kv_seq
urpetkov-amd Mar 27, 2026
af9f587
Adding optional instruction refs
urpetkov-amd Mar 27, 2026
1e9bd99
Adding blurb box in verify test
urpetkov-amd Mar 27, 2026
b524d17
Merge branch 'develop' into multihead_past_kv_seq
urpetkov-amd Apr 3, 2026
54c12d7
Merge branch 'develop' into multihead_past_kv_seq
urpetkov-amd Apr 7, 2026
5bb583d
Merge branch 'develop' into multihead_past_kv_seq
urpetkov-amd Apr 8, 2026
8f9d417
Merge branch 'develop' into multihead_past_kv_seq
TedThemistokleous Apr 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 159 additions & 10 deletions src/onnx/parse_multi_head_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>
{
if(args.size() > 4)
{
// Skip validation if the mask is empty (optional input not provided)
Comment thread
urpetkov-amd marked this conversation as resolved.
if(args.at(4)->get_shape().elements() == 0)
return;

const auto key_pad_lens = args.at(4)->get_shape().lens();
const auto key_pad_len_size = key_pad_lens.size();
const auto key_pad_type = args.at(4)->get_shape().type();
Expand Down Expand Up @@ -330,20 +334,109 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>
}
}

void check_past_key(const std::vector<instruction_ref>& args,
const multi_head_attention_parameters& params) const
{
if(args.size() <= 6)
return;

// Skip validation if past_key is empty (optional input not provided)
if(args.at(6)->get_shape().elements() == 0)
return;

const auto past_key_lens = args.at(6)->get_shape().lens();
if(past_key_lens.size() != 4)
MIGRAPHX_THROW("MultiHeadAttention: past_key must be 4D shape");

if(past_key_lens[0] != params.batch_size)
MIGRAPHX_THROW("MultiHeadAttention: past_key first dimension must be batch_size");

if(past_key_lens[1] != params.num_heads)
MIGRAPHX_THROW("MultiHeadAttention: past_key second dimension must be num_heads");

if(past_key_lens[3] != params.head_size)
MIGRAPHX_THROW("MultiHeadAttention: past_key fourth dimension must be head_size");
}

void check_past_value(const std::vector<instruction_ref>& args,
const multi_head_attention_parameters& params) const
{
if(args.size() <= 7)
return;

// Skip validation if past_value is empty (optional input not provided)
if(args.at(7)->get_shape().elements() == 0)
return;

const auto past_value_lens = args.at(7)->get_shape().lens();
if(past_value_lens.size() != 4)
MIGRAPHX_THROW("MultiHeadAttention: past_value must be 4D shape");

if(past_value_lens[0] != params.batch_size)
MIGRAPHX_THROW("MultiHeadAttention: past_value first dimension must be batch_size");

if(past_value_lens[1] != params.num_heads)
MIGRAPHX_THROW("MultiHeadAttention: past_value second dimension must be num_heads");

if(past_value_lens[3] != params.head_size_v)
MIGRAPHX_THROW("MultiHeadAttention: past_value fourth dimension must be head_size_v");
}

void check_past_key_value_match(const std::vector<instruction_ref>& args) const
{
if(args.size() <= 7)
return;

// Skip if either past_key or past_value is empty
if(args.at(6)->get_shape().elements() == 0 or args.at(7)->get_shape().elements() == 0)
return;

const auto past_key_lens = args.at(6)->get_shape().lens();
const auto past_value_lens = args.at(7)->get_shape().lens();
if(past_value_lens[2] != past_key_lens[2])
MIGRAPHX_THROW("MultiHeadAttention: past_key and past_value must have "
"matching past_sequence_length");
}

void check_past_sequence_length(const std::vector<instruction_ref>& args) const
{
if(args.size() <= 8)
return;

// Skip validation if past_sequence_length is empty
if(args.at(8)->get_shape().elements() == 0)
return;

const auto past_seq_len_type = args.at(8)->get_shape().type();
if(past_seq_len_type != shape::int32_type)
MIGRAPHX_THROW("MultiHeadAttention: past_sequence_length must be a int32 tensor");
}

void check_past_inputs(const std::vector<instruction_ref>& args,
const multi_head_attention_parameters& params) const
{
check_past_key(args, params);
check_past_value(args, params);
check_past_key_value_match(args);
check_past_sequence_length(args);
}

void check_inputs(const std::vector<instruction_ref>& args,
multi_head_attention_parameters& params) const
{
if(args.empty() or args.size() > 6)
if(args.empty() or args.size() > 9)
MIGRAPHX_THROW(
"MultiHeadAttention: Wrong number of inputs. Only 'query', 'key', "
"'value', bias, key_padding_mask and attention_bias inputs are supported.");
"'value', bias, key_padding_mask, attention_bias, past_key, past_value, and "
"past_sequence_length inputs are supported.");

// Order matters here. Most parameters defined by input query, key, value parameters
// This must be used first to extract hidden size, batch, etc
check_query_dim(args, params);
check_bias(args, params);
check_key_padding_mask(args, params);
check_attention_bias(args, params);
check_past_inputs(args, params);
}

std::tuple<instruction_ref, instruction_ref, instruction_ref>
Expand Down Expand Up @@ -620,10 +713,10 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>
return params;
}

instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
auto params = handle_attributes(info, parser);
check_inputs(args, params);
Expand Down Expand Up @@ -654,6 +747,49 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>
value = info.add_instruction(make_op("transpose", {{"permutation", perm}}), value);
}

// Handle past_key and past_value concatenation using concat_past_present
std::optional<instruction_ref> present_key;
std::optional<instruction_ref> present_value;
if(args.size() > 7)
{
auto past_key = args[6];
auto past_value = args[7];

// Only use concat_past_present if past states are non-empty
if(past_key->get_shape().elements() > 0 and past_value->get_shape().elements() > 0)
{
// If past_sequence_length is provided (input 8), use it, otherwise use batch-wise
// zeros
instruction_ref seqlens_k;
if(args.size() > 8 and args[8]->get_shape().elements() > 0)
{
seqlens_k = args[8];
}
else
{
std::vector<int32_t> zeros(params.batch_size, 0);
seqlens_k = info.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int32_type,
{static_cast<size_t>(params.batch_size)}},
zeros});
}

std::vector<instruction_ref> concat_k_inputs{key, seqlens_k, past_key};
std::vector<instruction_ref> concat_v_inputs{value, seqlens_k, past_value};

// Use concat_past_present operator for efficient KV cache concatenation
present_key = info.add_instruction(
make_op("concat_past_present", {{"kv_num_heads", params.num_heads}}),
concat_k_inputs);
present_value = info.add_instruction(
make_op("concat_past_present", {{"kv_num_heads", params.num_heads}}),
concat_v_inputs);

key = present_key.value();
value = present_value.value();
}
}

// Set attention mask and bias when detected on input
std::optional<instruction_ref> attn_mask;
if(args.size() > 4)
Expand Down Expand Up @@ -688,17 +824,30 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>
result = info.add_common_op("add", result, attn_mask.value());
}

result = info.add_common_op("mul", result, scale_literal);
result = info.add_instruction(make_op("softmax", {{"axis", -1}}), result);
result = info.add_instruction(make_op("dot"), result, value);
result = info.add_common_op("mul", result, scale_literal);
auto qk_output = info.add_instruction(make_op("softmax", {{"axis", -1}}), result);
result = info.add_instruction(make_op("dot"), qk_output, value);
result = info.add_instruction(make_op("transpose", {{"permutation", perm}}), result);
result = info.add_instruction(
make_op(
"reshape",
{{"dims", {params.batch_size, params.q_sequence_length, params.hidden_size_v}}}),
result);

return result;
// Return outputs based on what's available: present key, present value and qk are optional
std::vector<instruction_ref> outputs = {result};

// Add present_key and present_value if past states were provided and non-empty
if(present_key.has_value() and present_value.has_value())
{
outputs.push_back(present_key.value());
outputs.push_back(present_value.value());
}

// Note: QK output could be added here if needed
// outputs.push_back(qk_output);

return outputs;
}
};

Expand Down
Loading
Loading