Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ flash_block_sizes: {
# "block_kv_dkv" : 2048,
# "block_kv_dkv_compute" : 2048,
# "block_q_dq" : 3024,
# "block_kv_dq" : 2048
# "block_kv_dq" : 2048,
# "use_fused_bwd_kernel": False,
# }
# GroupNorm groups
norm_num_groups: 32
Expand Down
10 changes: 8 additions & 2 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,11 @@ def get_precision(config):
retval = jax.lax.Precision.HIGHEST
return retval

def value_or_none(flash_block_sizes, key):
if key in flash_block_sizes:
return flash_block_sizes[key]
else:
return None

def get_flash_block_sizes(config):
"""Create custom flash attention BlockSizes."""
Expand All @@ -501,8 +506,9 @@ def get_flash_block_sizes(config):
block_q_dkv=config.flash_block_sizes["block_q_dkv"],
block_kv_dkv=config.flash_block_sizes["block_kv_dkv"],
block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"],
block_q_dq=config.flash_block_sizes["block_q_dq"],
block_kv_dq=config.flash_block_sizes["block_kv_dq"],
block_q_dq=value_or_none(config.flash_block_sizes, "block_q_dq"),
block_kv_dq=value_or_none(config.flash_block_sizes, "block_kv_dq"),
use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel")
)
return flash_block_sizes

Expand Down
29 changes: 25 additions & 4 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def _tpu_flash_attention(
flash_block_sizes: BlockSizes,
dtype: jnp.dtype = jnp.float32,
attention_kernel: str = "flash",
residual_checkpoint_name: str | None = None,
) -> jax.Array:
"""TPU Flash Attention"""

Expand Down Expand Up @@ -213,9 +214,22 @@ def _tpu_flash_attention(
)
def wrap_flash_attention(query, key, value):

query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_sizes.block_q)
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv)
value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv)
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv,)
block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv,)
if uses_fused_kernel:
block_q_sizes += (block_sizes.block_q_dkv,)
block_kv_sizes += (block_sizes.block_kv_dkv,)
else:
block_q_sizes += (block_sizes.block_q_dq,)
block_kv_sizes += (block_sizes.block_kv_dq,)

block_q = max(*block_q_sizes)
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)

block_kv = max(*block_kv_sizes)
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
value, _, _ = _pad_data_for_flash(value, heads, block_kv)

mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
Expand All @@ -237,6 +251,7 @@ def wrap_flash_attention(query, key, value):
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
block_sizes=block_sizes,
save_residuals=True if attention_kernel == "ring" else False,
residual_checkpoint_name=residual_checkpoint_name,
)
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))

Expand Down Expand Up @@ -419,6 +434,7 @@ def _apply_attention(
axis_names_kv: AxisNames,
flash_block_sizes: BlockSizes,
dpa_layer: Callable,
residual_checkpoint_name: str | None = None,
):
"""Routes to different attention kernels."""
_check_attention_inputs(query, key, value)
Expand All @@ -439,7 +455,7 @@ def _apply_attention(
)
elif attention_kernel == "flash":
return _tpu_flash_attention(
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, residual_checkpoint_name=residual_checkpoint_name
)
elif attention_kernel == "ring":
return _tpu_flash_attention(
Expand Down Expand Up @@ -574,6 +590,7 @@ def __init__(
flash_block_sizes: BlockSizes = None,
dtype: DType = jnp.float32,
quant: Quant = None,
residual_checkpoint_name: str | None = None,
):
self.dpa_layer = None
if attention_kernel == "cudnn_flash_te":
Expand All @@ -593,6 +610,7 @@ def __init__(
self.flash_block_sizes = flash_block_sizes
self.dtype = dtype
self.quant = quant
self.residual_checkpoint_name = residual_checkpoint_name

def apply_attention(self, query: Array, key: Array, value: Array):
return _apply_attention(
Expand All @@ -613,6 +631,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
axis_names_kv=self.axis_names_kv,
flash_block_sizes=self.flash_block_sizes,
dpa_layer=self.dpa_layer,
residual_checkpoint_name=self.residual_checkpoint_name,
)


Expand Down Expand Up @@ -701,6 +720,7 @@ def __init__(
precision: jax.lax.Precision = None,
qkv_bias: bool = False,
quant: Quant = None,
residual_checkpoint_name: str | None = None,
):
if attention_kernel == "cudnn_flash_te":
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
Expand Down Expand Up @@ -730,6 +750,7 @@ def __init__(
flash_block_sizes=flash_block_sizes,
dtype=dtype,
quant=quant,
residual_checkpoint_name=residual_checkpoint_name,
)
# None axes corresponds to the stacked weights across all blocks
# because of the use of nnx.vmap and nnx.scan.
Expand Down
8 changes: 8 additions & 0 deletions src/maxdiffusion/models/gradient_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class GradientCheckpointType(Enum):
MATMUL_WITHOUT_BATCH = auto()
OFFLOAD_MATMUL_WITHOUT_BATCH = auto()
CUSTOM = auto()
HIDDEN_STATE_WITH_OFFLOAD = auto()

@classmethod
def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType":
Expand Down Expand Up @@ -76,6 +77,13 @@ def to_jax_policy(self, names_which_can_be_saved: list = [], names_which_can_be_
offload_dst="pinned_host",
)
return policy
case GradientCheckpointType.HIDDEN_STATE_WITH_OFFLOAD:
return jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=[],
names_which_can_be_offloaded=["hidden_states","self_attn","cross_attn"],
offload_src="device",
offload_dst="pinned_host",
)
case GradientCheckpointType.MATMUL_WITHOUT_BATCH:
return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims

Expand Down
5 changes: 5 additions & 0 deletions src/maxdiffusion/models/wan/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from jax.sharding import PartitionSpec
from jax.ad_checkpoint import checkpoint_name
from flax import nnx
import flax.linen as nn
import numpy as np
from .... import common_types
from ...modeling_flax_utils import FlaxModelMixin, get_activation
Expand Down Expand Up @@ -282,6 +283,7 @@ def __init__(
precision=precision,
attention_kernel=attention,
dropout=dropout,
residual_checkpoint_name='self_attn',
)

# 1. Cross-attention
Expand All @@ -300,6 +302,7 @@ def __init__(
precision=precision,
attention_kernel=attention,
dropout=dropout,
residual_checkpoint_name='cross_attn',
)
assert cross_attn_norm is True
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
Expand Down Expand Up @@ -335,6 +338,7 @@ def __call__(
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
)
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
hidden_states = checkpoint_name(hidden_states, "hidden_states")
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))

# 1. Self-attention
Expand Down Expand Up @@ -514,6 +518,7 @@ def __call__(
deterministic: bool = True,
rngs: nnx.Rngs = None,
) -> Union[jax.Array, Dict[str, jax.Array]]:
hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None))
batch_size, _, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
Expand Down
Loading