Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 82 additions & 3 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def _get_max_segments_per_sequence(self):

def _check_configs(self):
# TODO(rewang): probably adds this in is_fused_attn_available
# TDOD(KshitijLakhani): probably add/move this to is_fused_attn_available
if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
pytest.skip("THD format requires padding masks.")

Expand All @@ -417,11 +418,59 @@ def _check_configs(self):
pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
compute_capability = get_device_compute_capability(0)
cudnn_version = get_cudnn_version()
# D=256 bprop on SM10x (cuDNN FE 1.24 / BE 9.23+) uses the deterministic algorithm path only,
# which rejects dBias, dropout, and ALiBi. It supports vanilla type of softmax only and allows SWA
# together with a causal mask only.
is_sm10x = 100 <= compute_capability < 110
if self.is_training and is_sm10x and (self.head_dim_qk == 256 or self.head_dim_v == 256):
if self.head_dim_qk != 256 or self.head_dim_v != 256:
pytest.skip(
"D=256 BWD on Blackwell only supports d_qk == d_v == 256;"
f" got d_qk={self.head_dim_qk}, d_v={self.head_dim_v}."
)
if cudnn_version < 92300:
pytest.skip(
"D=256 BWD on Blackwell requires cuDNN 9.23 or newer;"
f" got cuDNN {cudnn_version}."
)
# TODO(KshitijLakhani): cuDNN FE can model bias input separately from dBias,
# but TE does not yet plumb whether dBias is requested into the common backend selector.
# Until that distinction is available, the D=256 SM10x gate requires no bias.
unsupported = None
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
unsupported = "pre-scale bias"
elif self.attn_bias_type != AttnBiasType.NO_BIAS:
unsupported = (
"post-scale bias in TE's D=256 backend gate; bias-input-only"
" support needs TE to distinguish between bias input and dBias"
)
elif self.dropout_prob != 0.0:
unsupported = "dropout"
elif self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
unsupported = "non-vanilla softmax"
if unsupported is not None:
pytest.skip(
"D=256 BWD on Blackwell uses the deterministic SM100 D=256 SDPA BWD"
f" kernel which does not support {unsupported}."
)
if self.window_size is not None and self.window_size != (-1, -1):
if not self.attn_mask_type.is_causal():
pytest.skip(
"D=256 BWD on Blackwell uses the SM100 D=256 SDPA BWD kernel"
" which requires window_size=(-1, -1) for non-causal masks."
)
if self.window_size[1] not in (-1, 0):
pytest.skip(
"D=256 BWD on Blackwell only supports right window -1 or 0"
" for causal masks."
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these checks duplicate to the checks we added on the C++ side? Would the call FusedAttnHelper().get_fused_attn_backend() give you the same gating effect?

Copy link
Copy Markdown
Collaborator Author

@KshitijLakhani KshitijLakhani Jun 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if we are just interested in the gating effect, you are right. The get_fused_attn_backend() will return NVTE_No_Backend and then there's a catch-all at the end which basically skip the tests as there is no fused attn backend avalable.

However, the reason for this to be here is to give a meaningful reason as to why a test is being skipped as compared to a generic "Unsupported inputs combination or device compute capability." message which does not qualify the reason for the skip. Unfortunately, on the JAX attn side we do not log the reason for disabling fused attn in the feature code like we have on the Pytorch side in d_p_a/utils.py. So there is no way for the user to know why the test was skipped. Hence, we need to rely on test code to log this on the JAX side.

I'd suggest we leave this in here for now. And when your PR for generating log messages in the C++ level when selecting the attn backend is ready, I can plumb it through onto the JAX side and then as part of that clean up, get rid of all the skip messages in check_configs()

if get_device_compute_capability(0) >= 100 and self.is_training:
if compute_capability >= 100 and self.is_training:
if FusedAttnHelper.is_non_deterministic_allowed() and (
(self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS)
or get_cudnn_version() < 90700
or cudnn_version < 90700
):
pytest.skip(
"For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with"
Expand All @@ -430,7 +479,7 @@ def _check_configs(self):
if not FusedAttnHelper.is_non_deterministic_allowed() and (
self.dropout_prob != 0.0
or self.attn_bias_type != AttnBiasType.NO_BIAS
or get_cudnn_version() < 91801
or cudnn_version < 91801
):
pytest.skip(
"For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or"
Expand Down Expand Up @@ -1474,6 +1523,36 @@ def test_backward(
QKVLayout.THD_THD_THD,
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE",
),
# D=256 deterministic backward on the SM100 dedicated SDPA bprop kernel
# (cuDNN FE 1.24 / BE 9.23+).
pytest.param(
4,
128,
128,
16,
16,
256,
256,
jnp.float16,
QKVLayout.BSHD_BS2HD,
id="4-128-128-16-16-256-256-FP16-SELF-KV_PACKED",
),
pytest.param(
4,
128,
128,
16,
16,
256,
256,
jnp.float16,
QKVLayout.THD_T2HD,
id="4-128-128-16-16-256-256-FP16-SELF-RAGGED_KV_PACKED",
marks=pytest.mark.xfail(
reason="cuDNN 9.23 D=256 BWD currently does not build a THD execution plan.",
strict=True,
),
),
],
)
@pytest.mark.parametrize(
Expand Down
31 changes: 31 additions & 0 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,37 @@ def test_dpa_fa4_hdim256(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


# cuDNN FusedAttention D=256 bprop is supported on sm10x from cuDNN 9.23 (FE 1.24),
# via the dedicated deterministic SDPA bprop kernel, which supports d_qk == d_v == 256 only,
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.
model_configs_fused_hdim256 = {
# test: ModelConfig(b, sq, hq, dqk) -> head_dim_v defaults to head_dim_qk (256)
"fused_hd256_no_mask": ModelConfig(2, 512, 16, 256),
"fused_hd256_padding": ModelConfig(2, 512, 16, 256, attn_mask_type="padding"),
# SWA is allowed only together with a causal mask on the D=256 bprop kernel.
"fused_hd256_causal_swa": ModelConfig(
2, 1024, 16, 256, attn_mask_type="causal", window_size=(128, 0)
),
# GQA variant (num_gqa_groups < num_heads).
"fused_hd256_padding_causal_gqa": ModelConfig(
2, 1024, 16, 256, num_gqa_groups=4, attn_mask_type="padding_causal"
),
}


@pytest.mark.skipif(get_cudnn_version() < (9, 23, 0), reason="cuDNN 9.23+ is required.")
@pytest.mark.skipif(
device_compute_capability not in ((10, 0), (10, 3)),
reason="cuDNN FusedAttention head_dim=256 backward is Blackwell server (SM100/SM103) only.",
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_fused_hdim256])
@pytest.mark.parametrize("model", model_configs_fused_hdim256.keys())
def test_dpa_fused_attn_hdim256(dtype, model_configs, model):
"""Test DotProductAttention with cuDNN FusedAttention: head_dim=256 backward on Blackwell"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


model_configs_fa4_mla = {
# test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv)
"fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64),
Expand Down
17 changes: 16 additions & 1 deletion transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,22 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) ||
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
cudnn_runtime_version >= 91100)) &&
cudnn_runtime_version >= 91100) ||
// 9.23: d_qk = d_v = 256 + SM10x (cuDNN FE 1.24 / BE 9.23+) + bprop + non-paged
(head_dim_qk == 256 && head_dim_v == 256 && is_training && sm_arch_ >= 100 &&
sm_arch_ < 110 && cudnn_runtime_version >= 92300 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD &&
// The FE forces this path onto the deterministic bprop algorithm, which on
// Blackwell rejects dBias, dropout, and ALiBi (and supports vanilla softmax only).
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0 &&
softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX &&
// Non-causal D=256 supports only full-window attention; SWA is allowed only for causal masks.
((window_size_left == -1 && window_size_right == -1) ||
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
(window_size_right == -1 || window_size_right == 0))))) &&
// 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
// Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed
(!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 &&
Expand Down
Loading