Conversation
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…_attn_pad_bw_seqs
for more information, see https://pre-commit.ci
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…ransformerEngine into flash_attn_pad_bw_seqs
Greptile Summaryenables FlashAttention 3 backend for THD format with Key Changes
Verification Neededcheck that FA3 with Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant DotProductAttention
participant Utils as get_attention_backend
participant FlashAttention as FlashAttention Backend
participant FlashAttn3 as flash_attn_3 Library
User->>DotProductAttention: forward(q, k, v, pad_between_seqs=True, cu_seqlens_q_padded, cu_seqlens_kv_padded)
DotProductAttention->>Utils: get_attention_backend(qkv_format="thd", pad_between_seqs=True)
alt FlashAttention 3 installed
Utils-->>DotProductAttention: use_flash_attention=True
else FlashAttention 2 only
Utils-->>DotProductAttention: use_flash_attention=False
end
DotProductAttention->>FlashAttention: forward(q, k, v, cu_seqlens_q, cu_seqlens_q_padded, pad_between_seqs=True)
FlashAttention->>FlashAttention: Use cu_seqlens_q_padded instead of cu_seqlens_q
FlashAttention->>FlashAttention: Calculate seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
FlashAttention->>FlashAttention: Calculate seqused_k = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
FlashAttention->>FlashAttn3: flash_attn_varlen_func_v3(q, k, v, cu_seqlens_q_padded, cu_seqlens_kv_padded, seqused_q, seqused_k)
FlashAttn3-->>FlashAttention: output (with padding handled internally)
FlashAttention-->>DotProductAttention: output
DotProductAttention-->>User: output
|
| # if `pad_between_seqs` is True, provide flash_attn_3 with `seqused_q` and `seqused_k` | ||
| # in addition to `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` to avoid affecting the | ||
| # padding positions. | ||
| if pad_between_seqs: | ||
| fa_3_optional_forward_kwargs["seqused_q"] = ( | ||
| cu_seqlens_q[1:] - cu_seqlens_q[:-1] | ||
| ) | ||
| fa_3_optional_forward_kwargs["seqused_k"] = ( | ||
| cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] | ||
| ) |
There was a problem hiding this comment.
style: verify that flash_attn_3 with seqused_q/seqused_k truly avoids writing to padding positions - the related issue #2391 mentions "we need to manually set the output of the padded positions to zero" (similar to how FusedAttention zeroes output in C++ for THD format). if flash_attn_3 doesn't zero these internally, output may have garbage values in padded positions. have you verified that flash_attn_3 correctly handles padding internally with these parameters?
Description
With THD and FlashAttn3, TE should support
pad_between_seqs=True. (Right now, this works for non-CP cases. Currently, working on enabling this with CP)Fixes #2399
Type of change
Changes
Please list the changes introduced in this PR:
dpa/dot_product_attention.py::DotProductAttention->forward: plumb throughpad_between_seqs,cu_seqlens_q_paddedandcu_seqlens_k_paddedbackends.py::FlashAttention->forward:pad_between_seqs,cu_seqlens_q_paddedandcu_seqlens_k_paddedseqused_q/seqused_kbefore callingflash_attn_varlen_funcfromflash_attn_3dpa/utils.py::get_attention_packendto switchuse_flash_attentiontoTruewhenflash_attn_3is installedtests/pytorch/attention/test_attention.py::_run_dot_product_attentionto run FlashAttn for THD andpad_between_seqs=TrueChecklist: