Skip to content

[diffusion] feat: implement VAE parallel decode for Wan#16510

Closed
Songrui625 wants to merge 22 commits intosgl-project:mainfrom
Songrui625:vae-parallel-decode
Closed

[diffusion] feat: implement VAE parallel decode for Wan#16510
Songrui625 wants to merge 22 commits intosgl-project:mainfrom
Songrui625:vae-parallel-decode

Conversation

@Songrui625
Copy link
Copy Markdown

@Songrui625 Songrui625 commented Jan 5, 2026

Motivation

Resolves #13191

Generating long or high-resolution videos demands more time and VRAM footprint during VAE decoding.

This Merge Request implements VAE parallel decode for Wan, which could accelerate decoding time and reduce peak VRAM usage during VAE decoding when using multiple GPUs.

Modifications

Basic Idea

The basic idea is when doing convolutions using multiple GPUs, follow the procedure below:

  1. Split the latents from the denoising stage along height dimension.
  2. Perform halo exchange(ghost cell) by P2P communication to exchange the data between edge of RankN and RankN+1.
  3. Perform all-gather operation to get the complete output.

Implemente Detail

  • Implement WanDistCausalConv3d and WanDistConv2d, which perform halo exchange to share data across GPUs before the actual convolution.
  • During each frame of latents decoding, we first split latents along height dimension and then proceed with decoding as usual, and finally all-gather the outputs.

Generated Videos

  • Wan2.2 720p video on single H20. VAE decoding time: 26.2795s.
uv run sglang generate --model-path /data00/models/Wan-AI/Wan2.2-T2V-A14B-Diffusers --height 720 --width 1280 --seed 1024 --attention-backend sage_attn --prompt "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.
sgl_wan22_vae_1gpu.mp4
  • Wan2.2 720p video with VAE parallel decode on 4 * H20. VAE decoding time: 9.1224s.
uv run sglang generate --model-path /data00/models/Wan-AI/Wan2.2-T2V-A14B-Diffusers --height 720 --width 1280 --seed 1024 --num-gpus 4 --ulysses-degree 4 --vae-config.use-parallel-decode --attention-backend sage_attn --prompt "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.
sgl_wan22_vae_4gpus.mp4

Accuracy Tests

Benchmarking and Profiling

The baseline of benchmark is to generate a 1280*720 81 frames video from model Wan2.1-T2V-14B-Diffusers with single inference step.

number of GPUs VRAM max Peak(MB) VAE decode VRAM Peak (GB) VAE Decoding Time(s)
1 * H20 51797.58 23.4 16.8490
2 * H20 39755.49 11.79 9.6863
4 * H20 35320.15 5.9 6.1343
8 * H20 35320.15 3.1 4.7050

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments (/tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci) or contact authorized users to do so.
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions Bot added the diffusion SGLang Diffusion label Jan 5, 2026
@Songrui625 Songrui625 changed the title [diffusion] feat: implement VAE parallel decode [diffusion] feat: implement VAE parallel decode for Wan Jan 5, 2026
) # casting needed for mps since amp isn't supported
return super().forward(x)

def halo_exchange(x: torch.Tensor, height_halo_size: int = 1) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how should we generalize to other vae?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't checked the VAE code of model HunyuanVideo. Will you mind if I make it a common function in the next PR, like support VAE parallel decode for HunyuanVideo?

Comment thread python/sglang/multimodal_gen/runtime/models/vaes/wanvae.py
@zyksir
Copy link
Copy Markdown
Collaborator

zyksir commented Jan 6, 2026

@Songrui625 please add unit test to make sure the output of distributed version is equal to non distributed one

Copy link
Copy Markdown
Collaborator

@mickqian mickqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a unittest, or I assume the parallel decode should be enabled automatically for multi-gpus?

@Songrui625
Copy link
Copy Markdown
Author

@Songrui625 please add unit test to make sure the output of distributed version is equal to non distributed one

Thanks to point out! Working on it.

@Songrui625
Copy link
Copy Markdown
Author

Songrui625 commented Jan 7, 2026

Please add a unittest, or I assume the parallel decode should be enabled automatically for multi-gpus?

Use the new option --vae-config.use-parallel-decode to enable.

@mickqian
Copy link
Copy Markdown
Collaborator

mickqian commented Jan 7, 2026

Please add a unittest, or I assume the parallel decode should be enabled automatically for multi-gpus?

Use the new option --vae-config.use-parallel-decode to enable.

and should we turn it on as default?

@zcnrex
Copy link
Copy Markdown
Contributor

zcnrex commented Jan 8, 2026

Could you help add some sample generated videos?

@triple-mu
Copy link
Copy Markdown
Contributor

Hello, would you be willing to first try using my PR to simplify the VAE logic?
The current VAE computation has quite complex caching logic along the temporal dimension. My PR significantly simplifies this caching logic while remaining computationally equivalent to the original implementation.
If this PR can be merged, I believe your PR would also become much simpler.

#15068

@Songrui625
Copy link
Copy Markdown
Author

Songrui625 commented Jan 12, 2026

Hello, would you be willing to first try using my PR to simplify the VAE logic? The current VAE computation has quite complex caching logic along the temporal dimension. My PR significantly simplifies this caching logic while remaining computationally equivalent to the original implementation. If this PR can be merged, I believe your PR would also become much simpler.

#15068

Thanks, I will check later.

@Songrui625 Songrui625 closed this Jan 12, 2026
@Songrui625
Copy link
Copy Markdown
Author

Songrui625 commented Jan 12, 2026

Please add a unittest, or I assume the parallel decode should be enabled automatically for multi-gpus?

Use the new option --vae-config.use-parallel-decode to enable.

and should we turn it on as default?

Sorry for the late response, I think we could make it disabled by default until it is proved that it's stable.

@Songrui625 Songrui625 reopened this Jan 12, 2026
@Songrui625
Copy link
Copy Markdown
Author

Could you help add some sample generated videos?

OK. The sample video will be provided later along with the unit tests!

@Songrui625
Copy link
Copy Markdown
Author

Still handling the padding problem when the height dimensions are not divisible by the GPU count; it needs to deal with the errors caused by padding.

@mickqian
Copy link
Copy Markdown
Collaborator

@Songrui625 appreciate your continuous work!

@Songrui625 Songrui625 force-pushed the vae-parallel-decode branch 2 times, most recently from 3b7eb4d to 9c2c6cf Compare February 4, 2026 12:09
@Songrui625
Copy link
Copy Markdown
Author

@Songrui625 appreciate your continuous work!

Hi, mick, I have pushed the unit test. Please review again. To be clear, the distributed version of convolution may have some rounding errors compared to the non-distributed one. So I set the atol and rtol to different values based on the data type.

With the rounding errors accumulating, the distributed version of WanDecoder3d may have a larger difference compared to non-distributed one. I think it is acceptable so the atol and rtol to 5e-2 when the data type is torch.bfloat16. CC @zyksir

@Songrui625
Copy link
Copy Markdown
Author

This PR is ready to review again.

@Songrui625
Copy link
Copy Markdown
Author

I had pushed comprehensive unit tests. It's disappointing that this PR was left hanging and eventually discarded by the reviewers. Even more frustrating is the feeling that my contributions might be overlooked because I'm not an official member of the team.

@Songrui625 Songrui625 closed this Feb 10, 2026
@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Feb 11, 2026

I had pushed comprehensive unit tests. It's disappointing that this PR was left hanging and eventually discarded by the reviewers. Even more frustrating is the feeling that my contributions might be overlooked because I'm not an official member of the team.

Hi,

I had pushed comprehensive unit tests. It's disappointing that this PR was left hanging and eventually discarded by the reviewers. Even more frustrating is the feeling that my contributions might be overlooked because I'm not an official member of the team.

Hi, I'm sorry about this. The situation is that we're currently working on a technical report and have incorporated the changes from this PR, which is currently blocking a version release. So, we merged that PR while giving you proper credit.

It wasn't because you're not part of the Diffusion team (you are actually). You can see that we've separately acknowledged you in this blog post: lm-sys/lm-sys.github.io#310.

8d5ff62d-766c-49a4-a077-416885c3c5a0

Regarding your tests, we can resubmit a new PR and merge it into the main branch. Thank you very much for your contribution. Once again, I apologize for any inconvenience caused.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion

Projects

None yet

Development

Successfully merging this pull request may close these issues.

diffusion, parallelism: VAE Decode Parallel

6 participants