[WIP] fix(cp): wrap linear attention CP in custom autograd.Function#1692
[WIP] fix(cp): wrap linear attention CP in custom autograd.Function#1692lilei199908 wants to merge 1 commit intoTHUDM:mainfrom
Conversation
Replace dist.nn.all_gather-based CP handling for Qwen3.5 linear attention
with a custom CPLinearAttnFunction that:
1. Frees all-gathered tensors after forward (saves memory proportional to
cp_size * local_seq_len * hidden_dim)
2. Re-gathers and recomputes forward in backward, ensuring correct
reduce-scatter of all gradients (including dk, dv)
3. All-gathers grad_output from all CP ranks so parameter gradients are
computed from the full loss
The implementation follows the same pattern as TE's AttnFuncWithCPAndKVAllGather
(gather in forward, re-gather + reduce in backward) adapted for linear attention
where the entire sequence must be processed as a unit.
Also overrides Attention.forward in qwen3_5.py so the base class's generic
CP logic (hf_attention.py) is no longer used for this model.
Tested with torchrun --nproc_per_node={2,4,8}: forward output and backward
gradients (hidden + params) match non-CP reference within 5e-5 tolerance.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR replaces the dist.nn.all_gather-based context parallelism (CP) handling for Qwen3.5 linear attention with a custom CPLinearAttnFunction (torch.autograd.Function). The key motivation is that the base class HuggingfaceAttention.forward keeps the full all-gathered tensor in the autograd graph (wasting memory proportional to cp_size * local_seq_len * hidden_dim) and doesn't correctly compute all parameter gradients (dk, dv) during backward. The new approach gathers in forward, frees the gathered tensor, then re-gathers and recomputes in backward — following the same pattern as TE's AttnFuncWithCPAndKVAllGather.
Changes:
- Added
_cp_all_gather_zigzagand_cp_slice_zigzaghelper functions andCPLinearAttnFunctioncustom autograd class for memory-efficient context-parallel linear attention - Overrode
Attention.forwardinqwen3_5.pyto useCPLinearAttnFunctioninstead of the base class's generic CP logic - Added distributed tests verifying gather/slice roundtrip and forward/backward numerical precision against a non-CP reference
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| slime_plugins/models/qwen3_5.py | Added zigzag CP helpers, CPLinearAttnFunction autograd function, and overrode Attention.forward to bypass the base class's CP logic |
| tests/test_cp_linear_attn_precision.py | New multi-GPU test verifying gather/slice roundtrip exactness and forward/backward precision within 5e-5 tolerance |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Replace dist.nn.all_gather-based CP handling for Qwen3.5 linear attention with a custom CPLinearAttnFunction that:
The implementation follows the same pattern as TE's AttnFuncWithCPAndKVAllGather (gather in forward, re-gather + reduce in backward) adapted for linear attention where the entire sequence must be processed as a unit.
Also overrides Attention.forward in qwen3_5.py so the base class's generic CP logic (hf_attention.py) is no longer used for this model.
Tested with torchrun --nproc_per_node={2,4,8}: forward output and backward gradients (hidden + params) match non-CP reference within 5e-5 tolerance.