diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 1fb0108068..af118f3a85 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -394,6 +394,7 @@ def _get_max_segments_per_sequence(self): def _check_configs(self): # TODO(rewang): probably adds this in is_fused_attn_available + # TDOD(KshitijLakhani): probably add/move this to is_fused_attn_available if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): pytest.skip("THD format requires padding masks.") @@ -417,11 +418,59 @@ def _check_configs(self): pytest.skip( "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" ) + compute_capability = get_device_compute_capability(0) + cudnn_version = get_cudnn_version() + # D=256 bprop on SM10x (cuDNN FE 1.24 / BE 9.23+) uses the deterministic algorithm path only, + # which rejects dBias, dropout, and ALiBi. It supports vanilla type of softmax only and allows SWA + # together with a causal mask only. + is_sm10x = 100 <= compute_capability < 110 + if self.is_training and is_sm10x and (self.head_dim_qk == 256 or self.head_dim_v == 256): + if self.head_dim_qk != 256 or self.head_dim_v != 256: + pytest.skip( + "D=256 BWD on Blackwell only supports d_qk == d_v == 256;" + f" got d_qk={self.head_dim_qk}, d_v={self.head_dim_v}." + ) + if cudnn_version < 92300: + pytest.skip( + "D=256 BWD on Blackwell requires cuDNN 9.23 or newer;" + f" got cuDNN {cudnn_version}." + ) + # TODO(KshitijLakhani): cuDNN FE can model bias input separately from dBias, + # but TE does not yet plumb whether dBias is requested into the common backend selector. + # Until that distinction is available, the D=256 SM10x gate requires no bias. + unsupported = None + if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: + unsupported = "pre-scale bias" + elif self.attn_bias_type != AttnBiasType.NO_BIAS: + unsupported = ( + "post-scale bias in TE's D=256 backend gate; bias-input-only" + " support needs TE to distinguish between bias input and dBias" + ) + elif self.dropout_prob != 0.0: + unsupported = "dropout" + elif self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX: + unsupported = "non-vanilla softmax" + if unsupported is not None: + pytest.skip( + "D=256 BWD on Blackwell uses the deterministic SM100 D=256 SDPA BWD" + f" kernel which does not support {unsupported}." + ) + if self.window_size is not None and self.window_size != (-1, -1): + if not self.attn_mask_type.is_causal(): + pytest.skip( + "D=256 BWD on Blackwell uses the SM100 D=256 SDPA BWD kernel" + " which requires window_size=(-1, -1) for non-causal masks." + ) + if self.window_size[1] not in (-1, 0): + pytest.skip( + "D=256 BWD on Blackwell only supports right window -1 or 0" + " for causal masks." + ) - if get_device_compute_capability(0) >= 100 and self.is_training: + if compute_capability >= 100 and self.is_training: if FusedAttnHelper.is_non_deterministic_allowed() and ( (self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS) - or get_cudnn_version() < 90700 + or cudnn_version < 90700 ): pytest.skip( "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with" @@ -430,7 +479,7 @@ def _check_configs(self): if not FusedAttnHelper.is_non_deterministic_allowed() and ( self.dropout_prob != 0.0 or self.attn_bias_type != AttnBiasType.NO_BIAS - or get_cudnn_version() < 91801 + or cudnn_version < 91801 ): pytest.skip( "For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or" @@ -1474,6 +1523,36 @@ def test_backward( QKVLayout.THD_THD_THD, id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE", ), + # D=256 deterministic backward on the SM100 dedicated SDPA bprop kernel + # (cuDNN FE 1.24 / BE 9.23+). + pytest.param( + 4, + 128, + 128, + 16, + 16, + 256, + 256, + jnp.float16, + QKVLayout.BSHD_BS2HD, + id="4-128-128-16-16-256-256-FP16-SELF-KV_PACKED", + ), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 256, + 256, + jnp.float16, + QKVLayout.THD_T2HD, + id="4-128-128-16-16-256-256-FP16-SELF-RAGGED_KV_PACKED", + marks=pytest.mark.xfail( + reason="cuDNN 9.23 D=256 BWD currently does not build a THD execution plan.", + strict=True, + ), + ), ], ) @pytest.mark.parametrize( diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 401bd6f01d..59134974d4 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -377,6 +377,37 @@ def test_dpa_fa4_hdim256(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) +# cuDNN FusedAttention D=256 bprop is supported on sm10x from cuDNN 9.23 (FE 1.24), +# via the dedicated deterministic SDPA bprop kernel, which supports d_qk == d_v == 256 only, +# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only. +model_configs_fused_hdim256 = { + # test: ModelConfig(b, sq, hq, dqk) -> head_dim_v defaults to head_dim_qk (256) + "fused_hd256_no_mask": ModelConfig(2, 512, 16, 256), + "fused_hd256_padding": ModelConfig(2, 512, 16, 256, attn_mask_type="padding"), + # SWA is allowed only together with a causal mask on the D=256 bprop kernel. + "fused_hd256_causal_swa": ModelConfig( + 2, 1024, 16, 256, attn_mask_type="causal", window_size=(128, 0) + ), + # GQA variant (num_gqa_groups < num_heads). + "fused_hd256_padding_causal_gqa": ModelConfig( + 2, 1024, 16, 256, num_gqa_groups=4, attn_mask_type="padding_causal" + ), +} + + +@pytest.mark.skipif(get_cudnn_version() < (9, 23, 0), reason="cuDNN 9.23+ is required.") +@pytest.mark.skipif( + device_compute_capability not in ((10, 0), (10, 3)), + reason="cuDNN FusedAttention head_dim=256 backward is Blackwell server (SM100/SM103) only.", +) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("model_configs", [model_configs_fused_hdim256]) +@pytest.mark.parametrize("model", model_configs_fused_hdim256.keys()) +def test_dpa_fused_attn_hdim256(dtype, model_configs, model): + """Test DotProductAttention with cuDNN FusedAttention: head_dim=256 backward on Blackwell""" + test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) + + model_configs_fa4_mla = { # test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv) "fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64), diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index d2eb1a831c..e23e0dd17c 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -327,7 +327,22 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && - cudnn_runtime_version >= 91100)) && + cudnn_runtime_version >= 91100) || + // 9.23: d_qk = d_v = 256 + SM10x (cuDNN FE 1.24 / BE 9.23+) + bprop + non-paged + (head_dim_qk == 256 && head_dim_v == 256 && is_training && sm_arch_ >= 100 && + sm_arch_ < 110 && cudnn_runtime_version >= 92300 && + layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && + // The FE forces this path onto the deterministic bprop algorithm, which on + // Blackwell rejects dBias, dropout, and ALiBi (and supports vanilla softmax only). + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0 && + softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX && + // Non-causal D=256 supports only full-window attention; SWA is allowed only for causal masks. + ((window_size_left == -1 && window_size_right == -1) || + ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && + (window_size_right == -1 || window_size_right == 0))))) && // 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA // Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed (!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 &&