Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 172 additions & 38 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .wan_pipeline import WanPipeline, transformer_forward_pass
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache
from ...models.wan.transformers.transformer_wan import WanModel
from typing import List, Union, Optional
from ...pyconfig import HyperParameters
Expand All @@ -21,6 +21,7 @@
from flax.linen import partitioning as nn_partitioning
import jax
import jax.numpy as jnp
import numpy as np
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler


Expand All @@ -32,7 +33,7 @@ def __init__(
config: HyperParameters,
low_noise_transformer: Optional[WanModel],
high_noise_transformer: Optional[WanModel],
**kwargs
**kwargs,
):
super().__init__(config=config, **kwargs)
self.low_noise_transformer = low_noise_transformer
Expand Down Expand Up @@ -109,7 +110,15 @@ def __call__(
prompt_embeds: jax.Array = None,
negative_prompt_embeds: jax.Array = None,
vae_only: bool = False,
use_cfg_cache: bool = False,
):
if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
raise ValueError(
f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
f"(got {guidance_scale_low}, {guidance_scale_high}). "
"CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
)

latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
prompt,
negative_prompt,
Expand Down Expand Up @@ -138,6 +147,8 @@ def __call__(
num_inference_steps=num_inference_steps,
scheduler=self.scheduler,
scheduler_state=scheduler_state,
use_cfg_cache=use_cfg_cache,
height=height,
)

with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
Expand Down Expand Up @@ -172,51 +183,174 @@ def run_inference_2_2(
num_inference_steps: int,
scheduler: FlaxUniPCMultistepScheduler,
scheduler_state,
use_cfg_cache: bool = False,
height: int = 480,
):
"""Denoising loop for WAN 2.2 T2V with optional FasterCache CFG-Cache.

Dual-transformer CFG-Cache strategy (enabled via use_cfg_cache=True):
- High-noise phase (t >= boundary): always full CFG — short phase, critical
for establishing video structure.
- Low-noise phase (t < boundary): FasterCache alternation — full CFG every N
steps, FFT frequency-domain compensation on cache steps (batch×1).
- Boundary transition: mandatory full CFG step to populate cache for the
low-noise transformer.
- FFT compensation identical to WAN 2.1 (Lv et al., ICLR 2025).
"""
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
if do_classifier_free_guidance:
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)

def low_noise_branch(operands):
latents, timestep, prompt_embeds = operands
return transformer_forward_pass(
low_noise_graphdef,
low_noise_state,
low_noise_rest,
latents,
timestep,
prompt_embeds,
do_classifier_free_guidance,
guidance_scale_low,
)
bsz = latents.shape[0]

def high_noise_branch(operands):
latents, timestep, prompt_embeds = operands
return transformer_forward_pass(
high_noise_graphdef,
high_noise_state,
high_noise_rest,
latents,
timestep,
prompt_embeds,
do_classifier_free_guidance,
guidance_scale_high,
# ── CFG cache path ──
if use_cfg_cache and do_classifier_free_guidance:
# Get timesteps as numpy for Python-level scheduling decisions
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]

# Resolution-dependent CFG cache config — adapted for Wan 2.2.
if height >= 720:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we also need to check width >= 720?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now we don't need because I'm thinking of configuration just to have video size standard definition, and so will have associated width. For example a 720p video will have the height of 720 and width of 1280, a 480p video has height of 480 and width of 854, and so on.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wan supports both landscape and portrait. What is the implication if someone sets it to portrait mode?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing will really change at this point the only difference is that cfg_cache_end_step = int(num_inference_steps * 0.9), but not much difference visually.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I'll note on this and have other TODO PRs to make the scheduling more confined.

cfg_cache_interval = 5
cfg_cache_start_step = int(num_inference_steps / 3)
cfg_cache_end_step = int(num_inference_steps * 0.9)
cfg_cache_alpha = 0.2
else:
cfg_cache_interval = 5
cfg_cache_start_step = int(num_inference_steps / 3)
cfg_cache_end_step = num_inference_steps - 1
cfg_cache_alpha = 0.2

# Pre-split embeds once
prompt_cond_embeds = prompt_embeds
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)

# Determine the first low-noise step (boundary transition).
# In Wan 2.2 the boundary IS the structural→detail transition, so
# all low-noise cache steps should emphasise high-frequency correction.
first_low_step = next(
(s for s in range(num_inference_steps) if not step_uses_high[s]),
num_inference_steps,
)
t0_step = first_low_step # all cache steps get high-freq boost

# Pre-compute cache schedule and phase-dependent weights.
first_full_in_low_seen = False
step_is_cache = []
step_w1w2 = []
for s in range(num_inference_steps):
if step_uses_high[s]:
# Never cache high-noise transformer steps
step_is_cache.append(False)
else:
is_cache = (
first_full_in_low_seen
and s >= cfg_cache_start_step
and s < cfg_cache_end_step
and (s - cfg_cache_start_step) % cfg_cache_interval != 0
)
step_is_cache.append(is_cache)
if not is_cache:
first_full_in_low_seen = True

# Phase-dependent weights: w = 1 + α·I(condition)
if s < t0_step:
step_w1w2.append((1.0 + cfg_cache_alpha, 1.0)) # high-noise: boost low-freq
else:
step_w1w2.append((1.0, 1.0 + cfg_cache_alpha)) # low-noise: boost high-freq

# Cache tensors (on-device JAX arrays, initialised to None).
cached_noise_cond = None
cached_noise_uncond = None

for step in range(num_inference_steps):
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
is_cache_step = step_is_cache[step]

# Select transformer and guidance scale based on precomputed schedule
if step_uses_high[step]:
graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest
guidance_scale = guidance_scale_high
else:
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
guidance_scale = guidance_scale_low

if is_cache_step:
# ── Cache step: cond-only forward + FFT frequency compensation ──
w1, w2 = step_w1w2[step]
timestep = jnp.broadcast_to(t, bsz)
noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache(
graphdef,
state,
rest,
latents,
timestep,
prompt_cond_embeds,
cached_noise_cond,
cached_noise_uncond,
guidance_scale=guidance_scale,
w1=jnp.float32(w1),
w2=jnp.float32(w2),
)
else:
# ── Full CFG step: doubled batch, store raw cond/uncond for cache ──
latents_doubled = jnp.concatenate([latents] * 2)
timestep = jnp.broadcast_to(t, bsz * 2)
noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg(
graphdef,
state,
rest,
latents_doubled,
timestep,
prompt_embeds_combined,
guidance_scale=guidance_scale,
)

latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents

# ── Original non-cache path ──
# Uses same Python-level if/else transformer selection as the cache path
# so both paths compile to identical XLA graphs (critical for bfloat16
# reproducibility in the PSNR comparison).
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]

prompt_embeds_combined = (
jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds
)

for step in range(num_inference_steps):
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
if do_classifier_free_guidance:
latents = jnp.concatenate([latents] * 2)
timestep = jnp.broadcast_to(t, latents.shape[0])

use_high_noise = jnp.greater_equal(t, boundary)
if step_uses_high[step]:
graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest
guidance_scale = guidance_scale_high
else:
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
guidance_scale = guidance_scale_low

# Selects the model based on the current timestep:
# - high_noise_model: Used for early diffusion steps where t >= config.boundary_timestep (high noise).
# - low_noise_model: Used for later diffusion steps where t < config.boundary_timestep (low noise).
noise_pred, latents = jax.lax.cond(
use_high_noise, high_noise_branch, low_noise_branch, (latents, timestep, prompt_embeds)
)
if do_classifier_free_guidance:
latents_doubled = jnp.concatenate([latents] * 2)
timestep = jnp.broadcast_to(t, bsz * 2)
noise_pred, _, _ = transformer_forward_pass_full_cfg(
graphdef,
state,
rest,
latents_doubled,
timestep,
prompt_embeds_combined,
guidance_scale=guidance_scale,
)
else:
timestep = jnp.broadcast_to(t, bsz)
noise_pred, latents = transformer_forward_pass(
graphdef,
state,
rest,
latents,
timestep,
prompt_embeds,
do_classifier_free_guidance,
guidance_scale,
)

latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents
Loading
Loading