diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index bd0ac41974..735d051a70 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1817,11 +1817,19 @@ def get_model(dtype, config): @pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) @pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +@pytest.mark.parametrize("deterministic", [True, False]) def test_mha_fp8_vs_f16( - dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode + dtype, + model, + qkv_format, + input_layernorm, + fp8_dpa_bwd, + RoPE, + is_training, + scaling_mode, + deterministic, ): """Test MultiHeadAttention module in FP8""" - os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] @@ -1850,7 +1858,7 @@ def test_mha_fp8_vs_f16( fp8=True, fp8_meta=fp8_meta, is_training=is_training, - deterministic=_deterministic, + deterministic=deterministic, ) flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends if flash_attn_supported + fused_attn_supported_fp8 < 1: @@ -1862,7 +1870,7 @@ def test_mha_fp8_vs_f16( qkv_dtype=dtype, qkv_layout=qkv_format.replace("hd", "h3d"), is_training=is_training, - deterministic=_deterministic, + deterministic=deterministic, ) _, fused_attn_supported_f16, _ = available_backends if not fused_attn_supported_f16: @@ -2063,7 +2071,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) @pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) -def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode): +@pytest.mark.parametrize("deterministic", [True, False]) +def test_dpa_fp8_vs_f16( + dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode, deterministic +): """Test DotProductAttention module in FP8""" config = model_configs_fp8_vs_f16[model] @@ -2078,7 +2089,6 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal # config.dropout_p = 0.1 os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" - os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1" # Test backend availability @@ -2104,7 +2114,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fp8=True, fp8_meta=fp8_meta, is_training=is_training, - deterministic=_deterministic, + deterministic=deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if flash_attn_supported + fused_attn_supported < 1: @@ -2115,7 +2125,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal qkv_dtype=dtype, qkv_layout=qkv_layout, is_training=is_training, - deterministic=_deterministic, + deterministic=deterministic, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 4f8367aac7..b5679280c6 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -770,10 +770,10 @@ void nvte_fused_attn_bwd_qkvpacked( Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, - input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, - input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, - handle); + bias_type, attn_mask_type, deterministic, &Q_view, &K_view, &V_view, input_O, + input_dO, input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, + &dV_view, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, + stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -1087,10 +1087,10 @@ void nvte_fused_attn_bwd_kvpacked( Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, - &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, - stream, handle); + qkv_layout, bias_type, attn_mask_type, deterministic, input_Q, &K_view, + &V_view, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, + output_dQ, &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -1323,9 +1323,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, - output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, + qkv_layout, bias_type, attn_mask_type, deterministic, input_Q, input_K, + input_V, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, + output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f886ec77f4..d885d23a85 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1978,13 +1978,13 @@ void fused_attn_fp8_fwd_impl_v1( void fused_attn_fp8_bwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, - void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, - void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, - void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, - void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, - void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, - void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, + NVTE_Mask_Type mask_type, bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, + void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, + void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, + void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, + void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, + void* devPtrAmaxdV, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size, @@ -1999,6 +1999,7 @@ void fused_attn_fp8_bwd_impl_v1( bool is_dropout = (dropout_probability != 0.0f); auto bias_b = b; auto bias_h = h; + const auto cudnn_runtime_version = cudnnGetVersion(); NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || @@ -2037,7 +2038,7 @@ void fused_attn_fp8_bwd_impl_v1( 0, 0, true, - false, + deterministic, qkv_tensor_type, o_tensor_type, do_tensor_type, @@ -2209,6 +2210,10 @@ void fused_attn_fp8_bwd_impl_v1( // } // } + if (cudnn_runtime_version >= 91900) { + sdpa_backward_options.set_deterministic_algorithm(deterministic); + } + if (is_padding) { seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("seq_q") @@ -2512,11 +2517,11 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor* input_Q, - const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, - const Tensor* input_dO, const Tensor* input_M, const Tensor* input_ZInv, - const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, - const Tensor* output_dK, const Tensor* output_dV, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic, + const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, + const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, + const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, + const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { @@ -2574,11 +2579,11 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, + p_dropout, qkv_layout, bias_type, mask_type, deterministic, devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, + devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, + devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, + devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index a1a932fdf5..225e700eff 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -28,11 +28,11 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_M, const Tensor *input_ZInv, - const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, - const Tensor *output_dK, const Tensor *output_dV, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, + const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, + const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 56e6f093d1..135a1b354c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1064,8 +1064,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_fused_attention = False fused_attention_backend = None - if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: - logger.debug("Disabling FusedAttention for determinism reasons with FP8") + if ( + fused_attention_backend == FusedAttnBackend["FP8"] + and is_training + and device_compute_capability < (10, 0) + ): + logger.debug( + "Disabling FusedAttention for determinism reasons with FP8 on arch < sm100" + ) use_fused_attention = False fused_attention_backend = None if (