[diffusion] feat: implement VAE parallel decode for Wan#16510
[diffusion] feat: implement VAE parallel decode for Wan#16510Songrui625 wants to merge 22 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
| ) # casting needed for mps since amp isn't supported | ||
| return super().forward(x) | ||
|
|
||
| def halo_exchange(x: torch.Tensor, height_halo_size: int = 1) -> torch.Tensor: |
There was a problem hiding this comment.
how should we generalize to other vae?
There was a problem hiding this comment.
I haven't checked the VAE code of model HunyuanVideo. Will you mind if I make it a common function in the next PR, like support VAE parallel decode for HunyuanVideo?
|
@Songrui625 please add unit test to make sure the output of distributed version is equal to non distributed one |
mickqian
left a comment
There was a problem hiding this comment.
Please add a unittest, or I assume the parallel decode should be enabled automatically for multi-gpus?
Thanks to point out! Working on it. |
Use the new option |
and should we turn it on as default? |
|
Could you help add some sample generated videos? |
|
Hello, would you be willing to first try using my PR to simplify the VAE logic? |
Thanks, I will check later. |
Sorry for the late response, I think we could make it disabled by default until it is proved that it's stable. |
OK. The sample video will be provided later along with the unit tests! |
5fbbac8 to
2a880d8
Compare
2a880d8 to
ceb0195
Compare
|
Still handling the padding problem when the height dimensions are not divisible by the GPU count; it needs to deal with the errors caused by padding. |
|
@Songrui625 appreciate your continuous work! |
3b7eb4d to
9c2c6cf
Compare
Hi, mick, I have pushed the unit test. Please review again. To be clear, the distributed version of convolution may have some rounding errors compared to the non-distributed one. So I set the With the rounding errors accumulating, the distributed version of WanDecoder3d may have a larger difference compared to non-distributed one. I think it is acceptable so the |
Signed-off-by: Songrui625 <songrui625@gmail.com>
9c2c6cf to
49de9c0
Compare
|
This PR is ready to review again. |
|
I had pushed comprehensive unit tests. It's disappointing that this PR was left hanging and eventually discarded by the reviewers. Even more frustrating is the feeling that my contributions might be overlooked because I'm not an official member of the team. |
Hi,
Hi, I'm sorry about this. The situation is that we're currently working on a technical report and have incorporated the changes from this PR, which is currently blocking a version release. So, we merged that PR while giving you proper credit. It wasn't because you're not part of the Diffusion team (you are actually). You can see that we've separately acknowledged you in this blog post: lm-sys/lm-sys.github.io#310.
Regarding your tests, we can resubmit a new PR and merge it into the main branch. Thank you very much for your contribution. Once again, I apologize for any inconvenience caused. |

Motivation
Resolves #13191
Generating long or high-resolution videos demands more time and VRAM footprint during VAE decoding.
This Merge Request implements VAE parallel decode for Wan, which could accelerate decoding time and reduce peak VRAM usage during VAE decoding when using multiple GPUs.
Modifications
Basic Idea
The basic idea is when doing convolutions using multiple GPUs, follow the procedure below:
RankNandRankN+1.Implemente Detail
WanDistCausalConv3dandWanDistConv2d, which perform halo exchange to share data across GPUs before the actual convolution.Generated Videos
uv run sglang generate --model-path /data00/models/Wan-AI/Wan2.2-T2V-A14B-Diffusers --height 720 --width 1280 --seed 1024 --attention-backend sage_attn --prompt "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.sgl_wan22_vae_1gpu.mp4
uv run sglang generate --model-path /data00/models/Wan-AI/Wan2.2-T2V-A14B-Diffusers --height 720 --width 1280 --seed 1024 --num-gpus 4 --ulysses-degree 4 --vae-config.use-parallel-decode --attention-backend sage_attn --prompt "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.sgl_wan22_vae_4gpus.mp4
Accuracy Tests
Benchmarking and Profiling
The baseline of benchmark is to generate a 1280*720 81 frames video from model
Wan2.1-T2V-14B-Diffuserswith single inference step.Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci) or contact authorized users to do so.