diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 4a973045..533235c6 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -81,7 +81,8 @@ flash_block_sizes: { # "block_kv_dkv" : 2048, # "block_kv_dkv_compute" : 2048, # "block_q_dq" : 3024, -# "block_kv_dq" : 2048 +# "block_kv_dq" : 2048, +# "use_fused_bwd_kernel": False, # } # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 6638e0f8..e5347797 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -489,6 +489,11 @@ def get_precision(config): retval = jax.lax.Precision.HIGHEST return retval +def value_or_none(flash_block_sizes, key): + if key in flash_block_sizes: + return flash_block_sizes[key] + else: + return None def get_flash_block_sizes(config): """Create custom flash attention BlockSizes.""" @@ -501,8 +506,9 @@ def get_flash_block_sizes(config): block_q_dkv=config.flash_block_sizes["block_q_dkv"], block_kv_dkv=config.flash_block_sizes["block_kv_dkv"], block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"], - block_q_dq=config.flash_block_sizes["block_q_dq"], - block_kv_dq=config.flash_block_sizes["block_kv_dq"], + block_q_dq=value_or_none(config.flash_block_sizes, "block_q_dq"), + block_kv_dq=value_or_none(config.flash_block_sizes, "block_kv_dq"), + use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel") ) return flash_block_sizes diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 5df5f334..b9892c5f 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -174,6 +174,7 @@ def _tpu_flash_attention( flash_block_sizes: BlockSizes, dtype: jnp.dtype = jnp.float32, attention_kernel: str = "flash", + residual_checkpoint_name: str | None = None, ) -> jax.Array: """TPU Flash Attention""" @@ -213,9 +214,22 @@ def _tpu_flash_attention( ) def wrap_flash_attention(query, key, value): - query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_sizes.block_q) - key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv) - value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv) + uses_fused_kernel = block_sizes.use_fused_bwd_kernel + block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv,) + block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv,) + if uses_fused_kernel: + block_q_sizes += (block_sizes.block_q_dkv,) + block_kv_sizes += (block_sizes.block_kv_dkv,) + else: + block_q_sizes += (block_sizes.block_q_dq,) + block_kv_sizes += (block_sizes.block_kv_dq,) + + block_q = max(*block_q_sizes) + query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q) + + block_kv = max(*block_kv_sizes) + key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv) + value, _, _ = _pad_data_for_flash(value, heads, block_kv) mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) @@ -237,6 +251,7 @@ def wrap_flash_attention(query, key, value): q_seq_shards=1, # the sizes of the axis is sharding over seq_len block_sizes=block_sizes, save_residuals=True if attention_kernel == "ring" else False, + residual_checkpoint_name=residual_checkpoint_name, ) vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) @@ -419,6 +434,7 @@ def _apply_attention( axis_names_kv: AxisNames, flash_block_sizes: BlockSizes, dpa_layer: Callable, + residual_checkpoint_name: str | None = None, ): """Routes to different attention kernels.""" _check_attention_inputs(query, key, value) @@ -439,7 +455,7 @@ def _apply_attention( ) elif attention_kernel == "flash": return _tpu_flash_attention( - query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype + query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, residual_checkpoint_name=residual_checkpoint_name ) elif attention_kernel == "ring": return _tpu_flash_attention( @@ -574,6 +590,7 @@ def __init__( flash_block_sizes: BlockSizes = None, dtype: DType = jnp.float32, quant: Quant = None, + residual_checkpoint_name: str | None = None, ): self.dpa_layer = None if attention_kernel == "cudnn_flash_te": @@ -593,6 +610,7 @@ def __init__( self.flash_block_sizes = flash_block_sizes self.dtype = dtype self.quant = quant + self.residual_checkpoint_name = residual_checkpoint_name def apply_attention(self, query: Array, key: Array, value: Array): return _apply_attention( @@ -613,6 +631,7 @@ def apply_attention(self, query: Array, key: Array, value: Array): axis_names_kv=self.axis_names_kv, flash_block_sizes=self.flash_block_sizes, dpa_layer=self.dpa_layer, + residual_checkpoint_name=self.residual_checkpoint_name, ) @@ -701,6 +720,7 @@ def __init__( precision: jax.lax.Precision = None, qkv_bias: bool = False, quant: Quant = None, + residual_checkpoint_name: str | None = None, ): if attention_kernel == "cudnn_flash_te": raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") @@ -730,6 +750,7 @@ def __init__( flash_block_sizes=flash_block_sizes, dtype=dtype, quant=quant, + residual_checkpoint_name=residual_checkpoint_name, ) # None axes corresponds to the stacked weights across all blocks # because of the use of nnx.vmap and nnx.scan. diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py index 086223f8..99c514df 100644 --- a/src/maxdiffusion/models/gradient_checkpoint.py +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -41,6 +41,7 @@ class GradientCheckpointType(Enum): MATMUL_WITHOUT_BATCH = auto() OFFLOAD_MATMUL_WITHOUT_BATCH = auto() CUSTOM = auto() + HIDDEN_STATE_WITH_OFFLOAD = auto() @classmethod def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": @@ -76,6 +77,13 @@ def to_jax_policy(self, names_which_can_be_saved: list = [], names_which_can_be_ offload_dst="pinned_host", ) return policy + case GradientCheckpointType.HIDDEN_STATE_WITH_OFFLOAD: + return jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=[], + names_which_can_be_offloaded=["hidden_states","self_attn","cross_attn"], + offload_src="device", + offload_dst="pinned_host", + ) case GradientCheckpointType.MATMUL_WITHOUT_BATCH: return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 02c7a565..46dd7ca5 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -21,6 +21,7 @@ from jax.sharding import PartitionSpec from jax.ad_checkpoint import checkpoint_name from flax import nnx +import flax.linen as nn import numpy as np from .... import common_types from ...modeling_flax_utils import FlaxModelMixin, get_activation @@ -282,6 +283,7 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, + residual_checkpoint_name='self_attn', ) # 1. Cross-attention @@ -300,6 +302,7 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, + residual_checkpoint_name='cross_attn', ) assert cross_attn_norm is True self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) @@ -335,6 +338,7 @@ def __call__( (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 ) hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) + hidden_states = checkpoint_name(hidden_states, "hidden_states") encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None)) # 1. Self-attention @@ -514,6 +518,7 @@ def __call__( deterministic: bool = True, rngs: nnx.Rngs = None, ) -> Union[jax.Array, Dict[str, jax.Array]]: + hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None)) batch_size, _, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size post_patch_num_frames = num_frames // p_t