Skip to content

[WIP] fix(cp): wrap linear attention CP in custom autograd.Function#1692

Closed
lilei199908 wants to merge 1 commit intoTHUDM:mainfrom
lilei199908:fix/cp-linear-attn-autograd
Closed

[WIP] fix(cp): wrap linear attention CP in custom autograd.Function#1692
lilei199908 wants to merge 1 commit intoTHUDM:mainfrom
lilei199908:fix/cp-linear-attn-autograd

Conversation

@lilei199908
Copy link
Copy Markdown
Collaborator

Replace dist.nn.all_gather-based CP handling for Qwen3.5 linear attention with a custom CPLinearAttnFunction that:

  1. Frees all-gathered tensors after forward (saves memory proportional to cp_size * local_seq_len * hidden_dim)
  2. Re-gathers and recomputes forward in backward, ensuring correct reduce-scatter of all gradients (including dk, dv)
  3. All-gathers grad_output from all CP ranks so parameter gradients are computed from the full loss

The implementation follows the same pattern as TE's AttnFuncWithCPAndKVAllGather (gather in forward, re-gather + reduce in backward) adapted for linear attention where the entire sequence must be processed as a unit.

Also overrides Attention.forward in qwen3_5.py so the base class's generic CP logic (hf_attention.py) is no longer used for this model.

Tested with torchrun --nproc_per_node={2,4,8}: forward output and backward gradients (hidden + params) match non-CP reference within 5e-5 tolerance.

Replace dist.nn.all_gather-based CP handling for Qwen3.5 linear attention
with a custom CPLinearAttnFunction that:

1. Frees all-gathered tensors after forward (saves memory proportional to
   cp_size * local_seq_len * hidden_dim)
2. Re-gathers and recomputes forward in backward, ensuring correct
   reduce-scatter of all gradients (including dk, dv)
3. All-gathers grad_output from all CP ranks so parameter gradients are
   computed from the full loss

The implementation follows the same pattern as TE's AttnFuncWithCPAndKVAllGather
(gather in forward, re-gather + reduce in backward) adapted for linear attention
where the entire sequence must be processed as a unit.

Also overrides Attention.forward in qwen3_5.py so the base class's generic
CP logic (hf_attention.py) is no longer used for this model.

Tested with torchrun --nproc_per_node={2,4,8}: forward output and backward
gradients (hidden + params) match non-CP reference within 5e-5 tolerance.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copilot AI review requested due to automatic review settings March 9, 2026 06:15
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR replaces the dist.nn.all_gather-based context parallelism (CP) handling for Qwen3.5 linear attention with a custom CPLinearAttnFunction (torch.autograd.Function). The key motivation is that the base class HuggingfaceAttention.forward keeps the full all-gathered tensor in the autograd graph (wasting memory proportional to cp_size * local_seq_len * hidden_dim) and doesn't correctly compute all parameter gradients (dk, dv) during backward. The new approach gathers in forward, frees the gathered tensor, then re-gathers and recomputes in backward — following the same pattern as TE's AttnFuncWithCPAndKVAllGather.

Changes:

  • Added _cp_all_gather_zigzag and _cp_slice_zigzag helper functions and CPLinearAttnFunction custom autograd class for memory-efficient context-parallel linear attention
  • Overrode Attention.forward in qwen3_5.py to use CPLinearAttnFunction instead of the base class's generic CP logic
  • Added distributed tests verifying gather/slice roundtrip and forward/backward numerical precision against a non-CP reference

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
slime_plugins/models/qwen3_5.py Added zigzag CP helpers, CPLinearAttnFunction autograd function, and overrode Attention.forward to bypass the base class's CP logic
tests/test_cp_linear_attn_precision.py New multi-GPU test verifying gather/slice roundtrip exactness and forward/backward precision within 5e-5 tolerance

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants