From 2bd2f1d7af5fdda46e27373d4bcb25ccbe02ebc2 Mon Sep 17 00:00:00 2001 From: hsuan-lun-chiang Date: Fri, 23 Jan 2026 08:16:31 +0000 Subject: [PATCH 1/5] Migrate Decoder to NNX --- src/MaxText/configs/base.yml | 1 + src/MaxText/configs/types.py | 1 + src/MaxText/layers/models.py | 22 +- src/MaxText/layers/multi_token_prediction.py | 22 +- src/MaxText/layers/nnx_decoders.py | 813 +++++++++++++++++++ 5 files changed, 843 insertions(+), 16 deletions(-) create mode 100644 src/MaxText/layers/nnx_decoders.py diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 69bec57e89..2f11b3de74 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -1008,6 +1008,7 @@ subslice_shape: "" # NNX enable_nnx: false +pure_nnx_decoder: false ################################## Qwen3-Next Specific Configs ################################## # Kernel size for the 1D convolution in the Gated Delta Net diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 9458fb2fb2..8d2c76c806 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -725,6 +725,7 @@ class HardwareAndMesh(BaseModel): enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.") optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.") shardy: bool = Field(True, description="Whether to use shardy XLA backend.") + pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.") class LayoutAndSharding(BaseModel): diff --git a/src/MaxText/layers/models.py b/src/MaxText/layers/models.py index 2d84eda09a..f9cf4438fd 100644 --- a/src/MaxText/layers/models.py +++ b/src/MaxText/layers/models.py @@ -31,6 +31,7 @@ from MaxText import max_utils from MaxText.layers import nnx_wrappers from MaxText.layers.decoders import Decoder +from MaxText.layers.nnx_decoders import NNXDecoder, decoder_as_linen from MaxText.layers.embeddings import Embed, embed_as_linen from MaxText.layers.encoders import VisionEncoder, vision_encoder_as_linen, AudioEncoder, audio_encoder_as_linen from MaxText.layers.quantizations import AqtQuantization as Quant @@ -86,7 +87,13 @@ def setup(self): ) self.vision_encoder = vision_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_multimodal else None self.audio_encoder = audio_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_audio else None - self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) + if cfg.pure_nnx_decoder: + self.decoder = decoder_as_linen( + config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=nnx.Rngs(0) + ) + else: + self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) + # If MTP is enabled via config, set up the MTP block. if self.config.mtp_num_layers > 0: # Get the list of layer blueprints for the current model. @@ -334,9 +341,11 @@ def __init__( ) self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_multimodal else None self.audio_encoder = AudioEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_audio else None + if cfg.pure_nnx_decoder: + self.decoder = NNXDecoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs) + else: + self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) - decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) - self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs) self.hidden_states = None batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode) @@ -362,13 +371,6 @@ def __init__( else: dummy_attention_metadata = None - self.decoder.lazy_init( - shared_embedding=self.token_embedder, - decoder_input_tokens=dummy_decoder_input_tokens, - decoder_positions=dummy_decoder_positions, - attention_metadata=dummy_attention_metadata, - ) - # If MTP is enabled via config, set up the MTP block. if self.config.mtp_num_layers > 0: # Get the list of layer blueprints for the current model. diff --git a/src/MaxText/layers/multi_token_prediction.py b/src/MaxText/layers/multi_token_prediction.py index a3201de36e..96c94e94db 100644 --- a/src/MaxText/layers/multi_token_prediction.py +++ b/src/MaxText/layers/multi_token_prediction.py @@ -110,12 +110,22 @@ def __init__( rngs=rngs, ) # Use MODEL_MODE_TRAIN for initialization; runtime model_mode is passed dynamically. - mtp_transformer_layer = transformer_layer_module( - config=cfg, - mesh=mesh, - model_mode=MODEL_MODE_TRAIN, - name=f"mtp_{k}_transformer_layer", - ) + if cfg.pure_nnx_decoder: + mtp_transformer_layer = transformer_layer_module( + config=cfg, + mesh=mesh, + model_mode=MODEL_MODE_TRAIN, + name=f"mtp_{k}_transformer_layer", + rngs=rngs, + ) + else: + mtp_transformer_layer = transformer_layer_module( + config=cfg, + mesh=mesh, + model_mode=MODEL_MODE_TRAIN, + name=f"mtp_{k}_transformer_layer", + ) + self.transformer_layer = nnx_wrappers.ToNNX(mtp_transformer_layer, rngs=rngs) # ToNNX requires explicit initialization with sample inputs for proper parameter setup. diff --git a/src/MaxText/layers/nnx_decoders.py b/src/MaxText/layers/nnx_decoders.py new file mode 100644 index 0000000000..2a472bbdc4 --- /dev/null +++ b/src/MaxText/layers/nnx_decoders.py @@ -0,0 +1,813 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for decoder layers""" +# pylint: disable=arguments-differ +# pylint: disable=no-name-in-module + +from typing import Any +import functools +import inspect + +import jax +import jax.numpy as jnp +from jax.ad_checkpoint import checkpoint_name +from jax.sharding import Mesh + +from flax import linen as nn +from flax import nnx +from flax.nnx import wrappers as nnx_wrappers + +from MaxText.configs.types import PositionalEmbedding +from MaxText.common_types import DecoderBlockType, ShardMode, Config, EP_AS_CONTEXT +from MaxText.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE +from MaxText import max_logging +from MaxText.sharding import create_sharding +from MaxText.inference import page_manager +from MaxText.layers import linears +from MaxText.layers import initializers +from MaxText.layers import quantizations +from MaxText import maxtext_utils +from MaxText import multimodal_utils +from MaxText import sharding +from MaxText.layers.attentions import Attention +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.embeddings import Embed, attend_on_embedding +from MaxText.layers.quantizations import AqtQuantization as Quant + +# Import specific layer definitions (assuming these files exist) +from MaxText.layers import ( + deepseek, + deepseek_batchsplit, + gemma, + gemma2, + gemma3, + gpt3, + gpt_oss, + llama2, + llama4, + mistral, + mixtral, + qwen3, + simple_layer, +) + + +class NNXDecoderLayer(nnx.Module): + """ + Transformer decoder layer converted to NNX. + """ + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: None | Quant = None, + name: str = "decoder_layer", + *, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + + cfg = self.config + + self.pre_self_attention_norm = RMSNorm( + num_features=cfg.emb_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=("norm",), + rngs=rngs, + ) + + self.self_attention = Attention( + config=self.config, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + inputs_q_shape=(1, 1, cfg.emb_dim), + inputs_kv_shape=(1, 1, cfg.emb_dim), + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + float32_qk_product=cfg.float32_qk_product, + float32_logits=cfg.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(cfg), + prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), + reshape_q=cfg.reshape_q, + model_mode=model_mode, + ) + + self.mlp = linears.MLPBlock( + in_features=cfg.emb_dim, + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + model_mode=model_mode, + config=cfg, + quant=self.quant, + mesh=self.mesh, + rngs=rngs, + ) + + self.dropout = linears.Dropout(rate=cfg.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + kv_cache: jax.Array | None = None, + attention_metadata: dict[str, Any] | None = None, + ): + cfg = self.config + mesh = self.mesh + _maybe_shard_with_logical = functools.partial( + sharding.maybe_shard_with_logical, + mesh=mesh, + shard_mode=cfg.shard_mode, + ) + + if self.model_mode == MODEL_MODE_PREFILL: + logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") + elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: + logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed") + else: + logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed") + + inputs = _maybe_shard_with_logical(inputs, logical_axis_names) + inputs = checkpoint_name(inputs, "decoder_layer_input") + + lnx = self.pre_self_attention_norm(inputs) + lnx = _maybe_shard_with_logical(lnx, logical_axis_names) + + attention_lnx, kv_cache = self.self_attention( + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) + + mlp_lnx = self.mlp(lnx, deterministic=deterministic) + mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names) + + next_layer_addition = mlp_lnx + attention_lnx + next_layer_addition_dropped_out = self.dropout(next_layer_addition, deterministic=deterministic) + + layer_output = next_layer_addition_dropped_out + inputs + layer_output = _maybe_shard_with_logical(layer_output, logical_axis_names) + + if cfg.record_internal_nn_metrics: + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow( + "intermediates", + "activation_fraction_zero", + jnp.sum(layer_output == 0) / jnp.size(layer_output), + ) + + if cfg.scan_layers: + return layer_output, None + else: + return layer_output, kv_cache + + +class NNXDecoder(nnx.Module): + """A stack of decoder layers as a part of an encoder-decoder architecture, using NNX.""" + + def __init__( + self, + config: Config, + mesh: Mesh, + quant: None | Quant = None, + model_mode: str = MODEL_MODE_TRAIN, + *, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.quant = quant + self.model_mode = model_mode + self.rngs = rngs + + decoder_block_classes = self.get_decoder_layers() + + self.decoder_norm = self.get_norm_layer(num_features=config.emb_dim, rngs=rngs)( + dtype=config.dtype, + weight_dtype=config.weight_dtype, + epsilon=config.normalization_layer_epsilon, + kernel_axes=("norm",), + parameter_memory_host_offload=config.parameter_memory_host_offload, + ) + + if config.trainable_position_size > 0: + self.position_embedder = Embed( + num_embeddings=config.trainable_position_size, + num_features=config.emb_dim, + dtype=config.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + config=config, + mesh=self.mesh, + rngs=rngs, + ) + + self.dropout = linears.Dropout(rate=config.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) + + self.positional_embedding = PositionalEmbedding(embedding_dims=config.base_emb_dim) + + if not config.logits_via_embedding: + self.logits_dense = linears.DenseGeneral( + in_features_shape=config.emb_dim, + out_features_shape=config.vocab_size, + weight_dtype=config.weight_dtype, + dtype=jnp.float32 if config.logits_dot_in_fp32 else config.dtype, + kernel_axes=("embed", "vocab"), + shard_mode=config.shard_mode, + matmul_precision=self.config.matmul_precision, + parameter_memory_host_offload=config.parameter_memory_host_offload, + rngs=rngs, + ) + + self.scanned_layers = None + self.is_deepseek = self.config.decoder_block == DecoderBlockType.DEEPSEEK + self.is_gemma3 = self.config.decoder_block == DecoderBlockType.GEMMA3 + + if self.config.scan_layers: + if self.is_deepseek: + assert len(decoder_block_classes) == 2 + dense_cls, moe_cls = decoder_block_classes + + num_dense = config.first_num_dense_layers + self.dense_stack = self._create_scanned_layers(dense_cls, length=num_dense, rngs=rngs) + + num_moe = config.num_decoder_layers - config.first_num_dense_layers + self.moe_stack = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs) + elif self.is_gemma3: + attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) + scan_length = config.num_decoder_layers // attention_pattern_length + num_remaining_layers = config.num_decoder_layers % attention_pattern_length + rem_layer_kwargs = {"num_of_layers": num_remaining_layers} + + RemattedGemma3Block = gemma3.Gemma3ScannableBlock + + if scan_length > 0: + self.layers = self._create_scanned_layers(RemattedGemma3Block, length=scan_length, rngs=rngs) + self.layers_remainder = RemattedGemma3Block( + config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs + ) # pytype: disable=wrong-keyword-args + else: + layer_cls = decoder_block_classes[0] + num_layers = config.num_decoder_layers + self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs) + else: + self.layers = nnx.List([]) + if self.is_deepseek: + for i in range(config.first_num_dense_layers): + self._create_and_register_layer(dense_cls, rngs, "dense_layer", i) + for i in range(config.num_decoder_layers - config.first_num_dense_layers): + self._create_and_register_layer(moe_cls, rngs, "moe_layer", i) + else: + layer_cls = decoder_block_classes[0] + for i in range(config.num_decoder_layers): + self._create_and_register_layer(layer_cls, rngs, "layers", i) + + def _create_and_register_layer(self, layer_cls, rngs, base_name, i): + attr_name = f"{base_name}_{i}" + layer = self._create_single_layer(layer_cls, rngs) + setattr(self, attr_name, layer) + self.layers.append(layer) + + def _create_single_layer(self, decoder_layer_class, rngs, **kwargs): + """Helper to create a single layer (Linen or NNX).""" + if issubclass(decoder_layer_class, nnx.Module): + return decoder_layer_class( + config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs, **kwargs + ) + else: + layer_linen = decoder_layer_class( + config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, **kwargs + ) + return nnx_wrappers.ToNNX(layer_linen, rngs=rngs) + + def _create_scanned_layers(self, decoder_layer_class, length: int, rngs: nnx.Rngs, **layer_kwargs): + """Creates a VMapped stack of layers, forcing parameter init for Compact modules.""" + + def create_layer_fn(rng): + layer = decoder_layer_class( + config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rng, **layer_kwargs + ) + + return layer + + # Workaround for Deepseek MTP test failure. + # TODO: Handle this properly. + try: + nnx.split_rngs(rngs, splits=length) + except: # pylint: disable=bare-except + pass + + layers_vmapped = nnx.vmap( + create_layer_fn, + in_axes=0, + out_axes=0, + axis_name="layers", + transform_metadata={nnx.PARTITION_NAME: "layers"}, + )(rngs) + + return layers_vmapped + + def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs): + """Runs the layer stack using nnx.scan.""" + policy = self.get_remat_policy() + prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config) + graphdef, params, state = nnx.split( + layers, nnx.Param, ... + ) # state: the mutable state we carry (KV cache, RNGs, etc.) + + layer_cls = layers.__class__ # Access the underlying class + sig = inspect.signature(layer_cls.__call__) + + # Filter kwargs to only include keys that exist in the layer's signature + valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} + + def layer_fn(carry, scanned_vars): + # Unpack the sliced variables for THIS layer + current_params, current_state = scanned_vars + + # Merge using the SLICED state + layer = nnx.merge(graphdef, current_params, current_state) + + # Run the layer (Filter kwargs if using the solution from previous turn) + layer_out = layer(carry, *args, **valid_kwargs) + + new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out + + # Extract the updated state to return it + # _, new_current_state = nnx.split(layer, nnx.Param, ...) + new_current_state = nnx.state(layer) + return new_carry, new_current_state + + layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) + + final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) + nnx.update(layers, scanned_state) + + return final_carry, None + + def get_decoder_layers(self): + """Retrieves decoder layer classes based on config using a dictionary lookup.""" + cfg = self.config + + def get_scannable(normal_cls, scannable_cls): + return [scannable_cls] if cfg.scan_layers else [normal_cls] + + def get_deepseek(): + if cfg.use_batch_split_schedule: + return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer] + return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer] + + layer_map = { + DecoderBlockType.DEFAULT: [NNXDecoderLayer], + DecoderBlockType.LLAMA2: [llama2.LlamaDecoderLayer], + DecoderBlockType.MISTRAL: [mistral.MistralDecoderLayer], + DecoderBlockType.MIXTRAL: [mixtral.MixtralDecoderLayer], + DecoderBlockType.GEMMA: [gemma.GemmaDecoderLayer], + DecoderBlockType.GEMMA2: [gemma2.Gemma2DecoderLayer], + DecoderBlockType.GEMMA3: [gemma3.Gemma3DecoderLayer], + DecoderBlockType.GPT3: [gpt3.Gpt3DecoderLayer], + DecoderBlockType.QWEN3: [qwen3.Qwen3DecoderLayer], + DecoderBlockType.QWEN3_MOE: [qwen3.Qwen3MoeDecoderLayer], + DecoderBlockType.SIMPLE: [simple_layer.SimpleDecoderLayer], + DecoderBlockType.SIMPLE_MLP: [simple_layer.SimpleMlpDecoderLayer], + DecoderBlockType.DEEPSEEK: get_deepseek(), + DecoderBlockType.GPT_OSS: get_scannable(gpt_oss.GptOssDecoderLayer, gpt_oss.GptOssScannableBlock), + DecoderBlockType.QWEN3_NEXT: get_scannable(qwen3.Qwen3NextDecoderLayer, qwen3.Qwen3NextScannableBlock), + DecoderBlockType.LLAMA4: get_scannable(llama4.Llama4DecoderLayer, llama4.Llama4ScannableBlock), + } + + if cfg.decoder_block not in layer_map: + raise ValueError(f"Incorrect decoder_block name {cfg.decoder_block.value=}") + + return layer_map[cfg.decoder_block] + + def minimal_policy(self, with_context=False): + """Helper for creating minimal checkpoint policies.""" + names = [ + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwi_0", + "mlpwi_1", + "mlpwi", + "mlpwo", + ] + if with_context: + names.append("context") + return jax.checkpoint_policies.save_only_these_names(*names) + + def get_remat_policy(self): + """Get remat policy for jax.checkpoint.""" + policy = None + cfg = self.config + if cfg.remat_policy != "none": + if cfg.remat_policy in ("minimal_with_context", "minimal_flash"): + if cfg.remat_policy == "minimal_flash": + max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.") + policy = self.minimal_policy(with_context=True) + elif cfg.remat_policy == "minimal": + policy = self.minimal_policy() + elif cfg.remat_policy == "save_dot_with_context_except_mlp": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "context", + "out_proj", + ) + elif cfg.remat_policy == "save_dot_except_mlpwi": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwo", + ) + elif cfg.remat_policy == "save_dot_except_mlp": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + ) + elif cfg.remat_policy == "save_qkv_proj": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + ) + elif cfg.remat_policy == "qkv_proj_offloaded": + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=[], + names_which_can_be_offloaded=["query_proj", "value_proj", "key_proj"], + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "minimal_offloaded": + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=[], + names_which_can_be_offloaded=[ + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwi_0", + "mlpwi_1", + "mlpwi", + "mlpwo", + ], + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "custom": + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=cfg.tensors_on_device, + names_which_can_be_offloaded=cfg.tensors_to_offload, + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "save_out_proj": + policy = jax.checkpoint_policies.save_only_these_names("out_proj") + else: + assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" + policy = None + return policy + + def get_norm_layer(self, num_features: int, rngs: nnx.Rngs): + """get normalization layer (return type inherits from nn.Module)""" + if self.config.decoder_block in ( + DecoderBlockType.DEFAULT, + DecoderBlockType.LLAMA2, + DecoderBlockType.MISTRAL, + DecoderBlockType.MIXTRAL, + DecoderBlockType.DEEPSEEK, + DecoderBlockType.GEMMA, + DecoderBlockType.GEMMA2, + DecoderBlockType.GEMMA3, + DecoderBlockType.QWEN3, + DecoderBlockType.QWEN3_MOE, + DecoderBlockType.QWEN3_NEXT, + DecoderBlockType.GPT_OSS, + DecoderBlockType.SIMPLE, + DecoderBlockType.SIMPLE_MLP, + DecoderBlockType.LLAMA4, + ): + return functools.partial(RMSNorm, num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs) + elif self.config.decoder_block == DecoderBlockType.GPT3: + return functools.partial( + gpt3.Gpt3LayerNorm, num_features=num_features, reductions_in_fp32=False, use_bias=True, rngs=rngs + ) + else: + raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") + + def _apply_embedding( + self, + shared_embedding: nnx.Module, + decoder_input_tokens, + decoder_positions, + deterministic, + model_mode, + image_embeddings=None, + bidirectional_mask=None, + image_masks=None, + audio_embeddings=None, + audio_masks=None, + ): + """Applies token and positional embeddings to the input tokens.""" + cfg = self.config + + y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode) + + if image_embeddings is not None and cfg.use_multimodal: + if cfg.model_name in [ + "gemma3-4b", + "gemma3-12b", + "gemma3-27b", + "llama4-17b-16e", + "llama4-17b-128e", + "qwen3-omni-30b-a3b", + ]: + y = multimodal_utils.merge_mm_embeddings( + text_embeddings=y, + multimodal_embeddings=image_embeddings, + mask=bidirectional_mask, + token_masks=image_masks, + ) + # TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed + else: + raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") + + if audio_embeddings is not None and cfg.use_audio: + if cfg.model_name in ["qwen3-omni-30b-a3b"]: + y = multimodal_utils.merge_mm_embeddings( + text_embeddings=y, + multimodal_embeddings=audio_embeddings, + mask=audio_masks, + token_masks=None, + ) + else: + raise ValueError(f"Unsupported model_name for audio: {cfg.model_name}") + + y = self.dropout(y, deterministic=deterministic) + y = y.astype(cfg.dtype) + + if cfg.use_untrainable_positional_embedding: + y = self.positional_embedding(y, decoder_positions) + + if cfg.trainable_position_size > 0 and self.position_embedder: + y += self.position_embedder(decoder_positions.astype("int32"), model_mode=model_mode) + + return y + + def apply_output_head(self, shared_embedding, y, deterministic, model_mode): + """Applies final normalization and projects hidden states to logits.""" + + cfg = self.config + if cfg.shard_mode == ShardMode.EXPLICIT: + norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) + else: + norm_out_sharding = None + + y = self.decoder_norm(y, norm_out_sharding) + y = self.dropout(y, deterministic=deterministic) # NNX call + + if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) + else: + out_sharding = create_sharding( + self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") + ) + + if cfg.logits_via_embedding: + if isinstance(shared_embedding, nnx.Module): + embedding_table = shared_embedding.embedding.value + else: + embedding_table = shared_embedding.variables["params"]["embedding"] + if isinstance(embedding_table, nn.spmd.LogicallyPartitioned): + embedding_table = embedding_table.unbox() + attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype + logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) + + if self.config.normalize_embedding_logits: + logits = logits / jnp.sqrt(y.shape[-1]) + if cfg.final_logits_soft_cap: + logits = logits / cfg.final_logits_soft_cap + logits = jnp.tanh(logits) * cfg.final_logits_soft_cap + else: + logits = self.logits_dense(y, out_sharding=out_sharding) + + if self.config.cast_logits_to_fp32: + logits = logits.astype(jnp.float32) + + return logits + + def __call__( + self, + shared_embedding: Any, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + deterministic=False, + model_mode=MODEL_MODE_TRAIN, + previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + bidirectional_mask: None | Any = None, + image_embeddings: None | jnp.ndarray = None, + image_masks: None | jnp.ndarray = None, + kv_caches: list[jax.Array] | None = None, + attention_metadata=None, + ): + cfg = self.config + assert decoder_input_tokens.ndim == 2 # [batch, len] + + y = self._apply_embedding( + shared_embedding, + decoder_input_tokens, + decoder_positions, + deterministic, + model_mode, + image_embeddings, + bidirectional_mask, + image_masks, + ) + + layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) + + layer_kwargs = { + "previous_chunk": previous_chunk, + "page_state": page_state, + "slot": slot, + "attention_metadata": attention_metadata, + } + + if cfg.decoder_block == DecoderBlockType.GEMMA3: + layer_kwargs["bidirectional_mask"] = bidirectional_mask + + if cfg.scan_layers: + if self.is_deepseek: + y, _ = self._apply_layers_sequentially( + self.dense_stack, y, *layer_args, length=cfg.first_num_dense_layers, **layer_kwargs + ) + + num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers + y, _ = self._apply_layers_sequentially(self.moe_stack, y, *layer_args, length=num_moe, **layer_kwargs) + elif self.is_gemma3: + y = self._apply_gemma3_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + bidirectional_mask, + previous_chunk, + page_state, + slot, + ) + else: + y, _ = self._apply_layers_sequentially(self.layers, y, *layer_args, length=cfg.num_decoder_layers, **layer_kwargs) + else: + for i, layer in enumerate(self.layers): + kv_cache = kv_caches[i] if kv_caches is not None else None + + out = layer(y, *layer_args, kv_cache=kv_cache, **layer_kwargs) + + if isinstance(out, tuple): + y, kv_cache_out = out + else: + y = out + kv_cache_out = None + + if kv_caches is not None: + kv_caches[i] = kv_cache_out + + assert isinstance(y, jax.Array) + hidden_state = y + + if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + logits = None + else: + logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) + + return logits, hidden_state, kv_caches + + def _apply_gemma3_scanned_blocks( + self, + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + bidirectional_mask, + previous_chunk, + page_state, + slot, + ): + """Applies Gemma3 scanned decoder blocks, handling main scan and remainders.""" + + cfg = self.config + + # Define the repeating pattern length and calculate how many full blocks to scan + attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) + scan_length = cfg.num_decoder_layers // attention_pattern_length + + layer_call_kwargs = {"bidirectional_mask": bidirectional_mask} + + # Apply the main scan over the full blocks + if scan_length > 0: + broadcast_args = ( + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) + y, _ = self.layers(y, *broadcast_args, **layer_call_kwargs) + + # Apply any remaining layers that did not fit into a full scanned block + num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length + if num_remaining_layers > 0: + # We name the remainder block with a 'remainder' suffix to avoid parameter name collisions + y, _ = self.layers_remainder( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + **layer_call_kwargs, + ) + return y + + +def decoder_as_linen( + config: Config, + mesh: Mesh, + rngs: nnx.Rngs, + model_mode: str, + quant: None | Quant = None, +): + """Creates a Decoder module.""" + module = nnx_wrappers.to_linen( + NNXDecoder, + config=config, + mesh=mesh, + model_mode=model_mode, + rngs=rngs, + quant=quant, + name="decoder", + abstract_init=False, + metadata_fn=initializers.variable_to_logically_partitioned, + ) + return module From 1cf06f56593224c1366186f1fc89cfa262ed9c54 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 27 Jan 2026 16:56:10 +0000 Subject: [PATCH 2/5] Add NNX Decoder unit tests and update to new NNX variable API --- src/MaxText/compare_linen_nnx_tree.py | 460 ++++++++++++++++++++++ tests/unit/decoder_tree_structure_test.py | 403 +++++++++++++++++++ 2 files changed, 863 insertions(+) create mode 100644 src/MaxText/compare_linen_nnx_tree.py create mode 100644 tests/unit/decoder_tree_structure_test.py diff --git a/src/MaxText/compare_linen_nnx_tree.py b/src/MaxText/compare_linen_nnx_tree.py new file mode 100644 index 0000000000..be13505258 --- /dev/null +++ b/src/MaxText/compare_linen_nnx_tree.py @@ -0,0 +1,460 @@ +#!/usr/bin/env python3 +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compare Linen and NNX model tree structures for MaxText. + +This script creates abstract models (without actual checkpoints) for both +Linen and NNX implementations and compares their parameter tree structures. + +Usage: + python compare_linen_nnx_tree.py [--model gemma2-2b] +""" + +import sys +import os +import argparse + +# Set up paths +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +MAXTEXT_SRC_DIR = os.path.join(SCRIPT_DIR, "src", "MaxText") + +# Set environment variable before importing MaxText +os.environ["MAXTEXT_PKG_DIR"] = MAXTEXT_SRC_DIR + +# Add MaxText to path +sys.path.insert(0, os.path.join(SCRIPT_DIR, "src")) + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from flax import linen as nn +from flax import nnx + +from MaxText import pyconfig +from MaxText import maxtext_utils +from MaxText.layers import models +from MaxText.layers import quantizations +from MaxText.common_types import MODEL_MODE_TRAIN + +# Use our computed path +MAXTEXT_PKG_DIR = MAXTEXT_SRC_DIR + + +def get_tree_paths(pytree, prefix=""): + """Recursively extract all paths from a pytree.""" + paths = [] + + if isinstance(pytree, dict): + for key, value in pytree.items(): + new_prefix = f"{prefix}/{key}" if prefix else key + paths.extend(get_tree_paths(value, new_prefix)) + elif isinstance(pytree, (list, tuple)): + for i, value in enumerate(pytree): + new_prefix = f"{prefix}[{i}]" + paths.extend(get_tree_paths(value, new_prefix)) + elif hasattr(pytree, "__dict__"): + # Handle nnx.VariableState or similar objects + for key, value in vars(pytree).items(): + if not key.startswith("_"): + new_prefix = f"{prefix}.{key}" if prefix else key + paths.extend(get_tree_paths(value, new_prefix)) + else: + # Leaf node + if hasattr(pytree, "shape"): + paths.append((prefix, pytree.shape, str(pytree.dtype))) + else: + paths.append((prefix, type(pytree).__name__, "")) + + return paths + + +def extract_linen_paths(vars_dict, prefix=""): + """Extract paths from Linen variables dict using JAX tree utilities.""" + paths = [] + + # Use jax.tree_util to properly flatten the pytree + leaves_with_paths = jax.tree_util.tree_leaves_with_path(vars_dict) + + for path_parts, leaf in leaves_with_paths: + # Convert path parts to string path + path_str = "" + for part in path_parts: + if hasattr(part, "key"): + # DictKey or similar + if path_str: + path_str += "/" + str(part.key) + else: + path_str = str(part.key) + elif hasattr(part, "idx"): + # SequenceKey (list/tuple index) + path_str += f"[{part.idx}]" + elif isinstance(part, str): + if path_str: + path_str += "/" + part + else: + path_str = part + else: + if path_str: + path_str += "/" + str(part) + else: + path_str = str(part) + + # Get shape info from leaf + if hasattr(leaf, "shape"): + paths.append((path_str, leaf.shape, str(leaf.dtype))) + else: + paths.append((path_str, type(leaf).__name__, "")) + + return paths + + +def extract_nnx_paths(state, prefix=""): + """Extract paths from NNX state using JAX tree utilities.""" + paths = [] + + # Use jax.tree_util to properly flatten the NNX state + leaves_with_paths = jax.tree_util.tree_leaves_with_path(state) + + for path_parts, leaf in leaves_with_paths: + # Convert path parts to string path + path_str = "" + for part in path_parts: + if hasattr(part, "key"): + # DictKey or similar + if path_str: + path_str += "/" + str(part.key) + else: + path_str = str(part.key) + elif hasattr(part, "idx"): + # SequenceKey (list/tuple index) + path_str += f"[{part.idx}]" + elif isinstance(part, str): + if path_str: + path_str += "/" + part + else: + path_str = part + else: + if path_str: + path_str += "/" + str(part) + else: + path_str = str(part) + + # Get shape info from leaf + if hasattr(leaf, "shape"): + paths.append((path_str, leaf.shape, str(leaf.dtype))) + elif hasattr(leaf, "value") and hasattr(leaf.value, "shape"): + paths.append((path_str, leaf.value.shape, str(leaf.value.dtype))) + else: + paths.append((path_str, type(leaf).__name__, "")) + + return paths + + +def normalize_path(path, is_linen=False): + """Normalize a path for comparison. + + Linen format: params/params/decoder/layers/0/mlp/wi_0/kernel + NNX format: decoder/layers/0/mlp/wi_0/kernel + + This removes the double 'params' prefix from Linen paths and handles + other minor differences. + """ + # Remove leading 'params/params' from Linen paths + if is_linen and path.startswith("params/params/"): + path = path[len("params/params/") :] + elif is_linen and path.startswith("params/"): + path = path[len("params/") :] + + return path + + +def create_linen_model_abstract(cfg, mesh): + """Create a Linen model and get its abstract parameter structure. + + Uses pure_nnx_decoder=False to get the Linen Decoder parameters. + """ + print("\n" + "=" * 60) + print("Creating Linen model...") + print("=" * 60) + + # Force pure_nnx_decoder=False for Linen model to use Linen Decoder + # We rely on the config being set correctly externally + + quant = quantizations.configure_quantization(cfg) + model = models.transformer_as_linen(config=cfg, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + + # Create dummy inputs + batch_size = cfg.global_batch_size_to_train_on + seq_len = cfg.max_target_length + + rng = jax.random.PRNGKey(0) + dummy_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + dummy_positions = jnp.stack([jnp.arange(seq_len, dtype=jnp.int32) for _ in range(batch_size)]) + dummy_segment_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + + # Use eval_shape to get abstract structure without allocating memory + def init_fn(): + return model.init( + {"params": rng, "aqt": rng, "dropout": rng}, + dummy_tokens, + dummy_positions, + dummy_segment_ids, + enable_dropout=False, + ) + + with mesh: + with nn.logical_axis_rules(cfg.logical_axis_rules): + abstract_vars = jax.eval_shape(init_fn) + + return abstract_vars + + +def create_nnx_model_abstract(cfg, mesh): + """Create an NNX model and get its abstract parameter structure. + + Uses pure_nnx_decoder=True to get the NNX Decoder parameters. + The NNX Transformer class with pure_nnx_decoder=True uses NNXDecoder. + """ + print("\n" + "=" * 60) + print("Creating NNX model...") + print("=" * 60) + + quant = quantizations.configure_quantization(cfg) + + def create_model(): + # Create rngs inside the function to avoid trace context issues + rng = jax.random.PRNGKey(0) + params_rng, dropout_rng = jax.random.split(rng) + rngs = nnx.Rngs(params=params_rng, dropout=dropout_rng) + return models.Transformer(cfg, mesh, quant=quant, rngs=rngs, model_mode=MODEL_MODE_TRAIN) + + with mesh: + with nn.logical_axis_rules(cfg.logical_axis_rules): + abstract_model = nnx.eval_shape(create_model) + + # Extract state from abstract model + _, abstract_state = nnx.split(abstract_model) + + return abstract_state + + +def is_rng_path(path): + """Check if a path is RNG-related.""" + return "/rngs/" in path or path.startswith("rngs/") + + +def compare_tree_structures(linen_vars, nnx_state, hide_rngs=True): + """Compare the tree structures of Linen and NNX models.""" + print("\n" + "=" * 60) + print("Comparing tree structures...") + if hide_rngs: + print("(RNG paths are hidden, use --show-rngs to include them)") + print("=" * 60) + + # Extract paths from both + linen_paths = extract_linen_paths(linen_vars) + nnx_paths = extract_nnx_paths(nnx_state) + + # Filter out RNG paths if requested + if hide_rngs: + linen_paths = [(p, s, d) for p, s, d in linen_paths if not is_rng_path(p)] + nnx_paths = [(p, s, d) for p, s, d in nnx_paths if not is_rng_path(p)] + + print(f"\nLinen total paths: {len(linen_paths)}") + print(f"NNX total paths: {len(nnx_paths)}") + + # Normalize paths for comparison + linen_normalized = {} + for path, shape, dtype in linen_paths: + norm_path = normalize_path(path, is_linen=True) + linen_normalized[norm_path] = (path, shape, dtype) + + nnx_normalized = {} + for path, shape, dtype in nnx_paths: + # Don't normalize NNX paths - compare them directly + # (The previous bug was replacing "/value" which removed the value projection layer name) + norm_path = path + nnx_normalized[norm_path] = (path, shape, dtype) + + # Find matches and mismatches + linen_only = set(linen_normalized.keys()) - set(nnx_normalized.keys()) + nnx_only = set(nnx_normalized.keys()) - set(linen_normalized.keys()) + common = set(linen_normalized.keys()) & set(nnx_normalized.keys()) + + print(f"\nPaths in both: {len(common)}") + print(f"Paths only in Linen: {len(linen_only)}") + print(f"Paths only in NNX: {len(nnx_only)}") + + # Check for shape mismatches in common paths + shape_mismatches = [] + for path in common: + linen_shape = linen_normalized[path][1] + nnx_shape = nnx_normalized[path][1] + if linen_shape != nnx_shape: + shape_mismatches.append((path, linen_shape, nnx_shape)) + + if shape_mismatches: + print(f"\nShape mismatches: {len(shape_mismatches)}") + for path, linen_shape, nnx_shape in shape_mismatches: + print(f" {path}: Linen={linen_shape}, NNX={nnx_shape}") + else: + print("\nNo shape mismatches in common paths!") + + return linen_normalized, nnx_normalized, linen_only, nnx_only, common + + +def print_tree_structure(paths_dict, name, max_depth=3): + """Print a hierarchical view of the tree structure.""" + print(f"\n{'='*60}") + print(f"{name} Tree Structure (depth {max_depth}):") + print("=" * 60) + + # Build a tree representation + tree = {} + for norm_path, (_, shape, dtype) in paths_dict.items(): + parts = norm_path.split("/") + current = tree + for i, part in enumerate(parts[:-1]): + if i >= max_depth: + break + if part not in current: + current[part] = {} + current = current[part] + if len(parts) <= max_depth: + leaf_name = parts[-1] if parts else "root" + if isinstance(shape, tuple): + current[leaf_name] = f"{shape} {dtype}" + else: + current[leaf_name] = f"({shape})" + + def print_tree(d, indent=0): + for key, value in sorted(d.items()): + if isinstance(value, dict): + print(" " * indent + f"{key}/") + print_tree(value, indent + 1) + else: + print(" " * indent + f"{key}: {value}") + + print_tree(tree) + + +def main(): + parser = argparse.ArgumentParser(description="Compare Linen and NNX model tree structures") + parser.add_argument("--model", type=str, default="gemma2-2b", help="Model config to use (e.g., gemma2-2b, llama2-7b)") + parser.add_argument("--depth", type=int, default=4, help="Max depth for tree structure printout") + parser.add_argument("--verbose", action="store_true", help="Print detailed path information") + parser.add_argument("--show-rngs", action="store_true", help="Show RNG-related paths (hidden by default)") + args = parser.parse_args() + + # Initialize config + model_config = os.path.join(MAXTEXT_PKG_DIR, "configs", "models", f"{args.model}.yml") + if not os.path.exists(model_config): + print(f"Model config not found: {model_config}") + print("Available models:") + models_dir = os.path.join(MAXTEXT_PKG_DIR, "configs", "models") + for f in os.listdir(models_dir): + if f.endswith(".yml"): + print(f" {f[:-4]}") + return 1 + + print(f"Using model config: {args.model}") + + # Create config for Linen model (uses Linen Decoder) + cfg_linen = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + model_name=args.model, + per_device_batch_size=1.0, + run_name="tree_compare", + enable_checkpointing=False, + max_target_length=32, # Small for faster abstract model creation + attention="dot_product", + pure_nnx_decoder=False, # Use Linen Decoder for Linen model + ) + + # Create config for NNX model (uses NNX Decoder) + cfg_nnx = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + model_name=args.model, + per_device_batch_size=1.0, + run_name="tree_compare", + enable_checkpointing=False, + max_target_length=32, # Small for faster abstract model creation + attention="dot_product", + pure_nnx_decoder=True, # Use NNX Decoder for NNX model + ) + + # Create mesh (same for both) + devices_array = maxtext_utils.create_device_mesh(cfg_linen) + mesh = Mesh(devices_array, cfg_linen.mesh_axes) + + print(f"\nModel: {args.model}") + print(f"emb_dim: {cfg_linen.emb_dim}") + print(f"num_decoder_layers: {cfg_linen.num_decoder_layers}") + print(f"num_query_heads: {cfg_linen.num_query_heads}") + print(f"num_kv_heads: {cfg_linen.num_kv_heads}") + print(f"vocab_size: {cfg_linen.vocab_size}") + + # Create abstract models with their respective configs + linen_vars = create_linen_model_abstract(cfg_linen, mesh) + nnx_state = create_nnx_model_abstract(cfg_nnx, mesh) + + # Compare structures + hide_rngs = not args.show_rngs + linen_normalized, nnx_normalized, linen_only, nnx_only, common = compare_tree_structures( + linen_vars, nnx_state, hide_rngs=hide_rngs + ) + + # Print tree structures + print_tree_structure(linen_normalized, "Linen (normalized)", max_depth=args.depth) + print_tree_structure(nnx_normalized, "NNX (normalized)", max_depth=args.depth) + + # Print differences + if linen_only: + print(f"\n{'='*60}") + print("Paths ONLY in Linen:") + print("=" * 60) + for path in sorted(linen_only): + _, shape, dtype = linen_normalized[path] + print(f" {path}: {shape} {dtype}") + + if nnx_only: + print(f"\n{'='*60}") + print("Paths ONLY in NNX:") + print("=" * 60) + for path in sorted(nnx_only): + _, shape, dtype = nnx_normalized[path] + print(f" {path}: {shape} {dtype}") + + # Summary + print(f"\n{'='*60}") + print("SUMMARY") + print("=" * 60) + print(f"Model: {args.model}") + print(f"Total Linen paths: {len(linen_normalized)}") + print(f"Total NNX paths: {len(nnx_normalized)}") + print(f"Common paths: {len(common)}") + print(f"Linen-only paths: {len(linen_only)}") + print(f"NNX-only paths: {len(nnx_only)}") + + if len(linen_only) == 0 and len(nnx_only) == 0: + print("\n✓ All paths match between Linen and NNX!") + else: + print("\n✗ There are differences between Linen and NNX paths") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/unit/decoder_tree_structure_test.py b/tests/unit/decoder_tree_structure_test.py new file mode 100644 index 0000000000..4a6894026f --- /dev/null +++ b/tests/unit/decoder_tree_structure_test.py @@ -0,0 +1,403 @@ +# Copyright 2023-2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for verifying Linen and NNX decoder tree structure parity. + +This module tests that all supported models have identical parameter tree +structures between their Linen and NNX implementations. +""" + +import logging +import os +import sys +import unittest + +# Suppress verbose logging from MaxText modules +logging.getLogger("MaxText").setLevel(logging.WARNING) +logging.getLogger("jax").setLevel(logging.WARNING) + +# Suppress TF/XLA C++ warnings +os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2") + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from flax import linen as nn +from flax import nnx +import pytest + +from MaxText import maxtext_utils +from MaxText import pyconfig +from MaxText.common_types import MODEL_MODE_TRAIN +from MaxText.globals import MAXTEXT_PKG_DIR +from MaxText.layers import models +from MaxText.layers import quantizations + +# Get the actual MaxText package directory path by using the models module location +# MAXTEXT_PKG_DIR from globals.py returns a relative path, so we derive the absolute path +_MAXTEXT_PKG_DIR_ABS = os.path.dirname(os.path.abspath(models.__file__)) +_MAXTEXT_PKG_DIR_ABS = os.path.dirname(_MAXTEXT_PKG_DIR_ABS) # Go up from layers/ to MaxText/ + + +# All supported models for tree structure verification +SUPPORTED_MODELS = [ + # LLaMA 2 family + "llama2-7b", + "llama2-13b", + "llama2-70b", + # LLaMA 3 family + "llama3-8b", + "llama3-70b", + # LLaMA 3.1 family + "llama3.1-8b", + "llama3.1-70b", + "llama3.1-405b", + # LLaMA 3.3 family + "llama3.3-70b", + # Mistral family + "mistral-7b", + # Mixtral family + "mixtral-8x7b", + "mixtral-8x22b", + # DeepSeek 2 family + "deepseek2-16b", + "deepseek2-236b", + # DeepSeek 3 family + "deepseek3-671b", + "deepseek3-671b-2dfsdp", + "deepseek3-test", + "deepseek3-tiny", + # Kimi family + "kimi-k2-1t", + # Gemma family + "gemma-7b", + "gemma-2b", + # Gemma 2 family + "gemma2-2b", + "gemma2-9b", + "gemma2-27b", + # Gemma 3 family + "gemma3-4b", + "gemma3-12b", + "gemma3-27b", + # Qwen 3 family + "qwen3-0.6b", + "qwen3-4b", + "qwen3-4b-thinking-2507", + "qwen3-8b", + "qwen3-14b", + "qwen3-32b", + "qwen3-235b-a22b", + "qwen3-30b-a3b", + "qwen3-480b-a35b", + "qwen3-next-80b-a3b", + "qwen3-omni-30b-a3b", + # GPT-3 family + "gpt3-175b", + "gpt3-22b", + "gpt3-6b", + "gpt3-52k", + # GPT-OSS family + "gpt-oss-20b", + "gpt-oss-120b", + # LLaMA 4 family + "llama4-17b-16e", + "llama4-17b-128e", +] + + +def is_rng_path(path: str) -> bool: + """Check if a path is RNG-related.""" + return "/rngs/" in path or path.startswith("rngs/") + + +def extract_linen_paths(vars_dict): + """Extract paths from Linen variables dict using JAX tree utilities.""" + paths = [] + leaves_with_paths = jax.tree_util.tree_leaves_with_path(vars_dict) + + for path_parts, leaf in leaves_with_paths: + path_str = "" + for part in path_parts: + if hasattr(part, "key"): + if path_str: + path_str += "/" + str(part.key) + else: + path_str = str(part.key) + elif hasattr(part, "idx"): + path_str += f"[{part.idx}]" + elif isinstance(part, str): + if path_str: + path_str += "/" + part + else: + path_str = part + else: + if path_str: + path_str += "/" + str(part) + else: + path_str = str(part) + + if hasattr(leaf, "shape"): + paths.append((path_str, leaf.shape, str(leaf.dtype))) + else: + paths.append((path_str, type(leaf).__name__, "")) + + return paths + + +def extract_nnx_paths(state): + """Extract paths from NNX state using JAX tree utilities.""" + paths = [] + leaves_with_paths = jax.tree_util.tree_leaves_with_path(state) + + for path_parts, leaf in leaves_with_paths: + path_str = "" + for part in path_parts: + if hasattr(part, "key"): + if path_str: + path_str += "/" + str(part.key) + else: + path_str = str(part.key) + elif hasattr(part, "idx"): + path_str += f"[{part.idx}]" + elif isinstance(part, str): + if path_str: + path_str += "/" + part + else: + path_str = part + else: + if path_str: + path_str += "/" + str(part) + else: + path_str = str(part) + + if hasattr(leaf, "shape"): + paths.append((path_str, leaf.shape, str(leaf.dtype))) + elif hasattr(leaf, "value") and hasattr(leaf.value, "shape"): + paths.append((path_str, leaf.value.shape, str(leaf.value.dtype))) + else: + paths.append((path_str, type(leaf).__name__, "")) + + return paths + + +def normalize_path(path: str, is_linen: bool = False) -> str: + """Normalize a path for comparison. + + Linen format: params/params/decoder/layers/0/mlp/wi_0/kernel + NNX format: decoder/layers/0/mlp/wi_0/kernel + + This removes the double 'params' prefix from Linen paths. + """ + if is_linen and path.startswith("params/params/"): + path = path[len("params/params/") :] + elif is_linen and path.startswith("params/"): + path = path[len("params/") :] + return path + + +def create_linen_model_abstract(cfg, mesh): + """Create a Linen model and get its abstract parameter structure.""" + quant = quantizations.configure_quantization(cfg) + model = models.transformer_as_linen(config=cfg, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + + batch_size = cfg.global_batch_size_to_train_on + seq_len = cfg.max_target_length + + rng = jax.random.PRNGKey(0) + dummy_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + dummy_positions = jnp.stack([jnp.arange(seq_len, dtype=jnp.int32) for _ in range(batch_size)]) + dummy_segment_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + + def init_fn(): + return model.init( + {"params": rng, "aqt": rng, "dropout": rng}, + dummy_tokens, + dummy_positions, + dummy_segment_ids, + enable_dropout=False, + ) + + with mesh: + with nn.logical_axis_rules(cfg.logical_axis_rules): + abstract_vars = jax.eval_shape(init_fn) + + return abstract_vars + + +def create_nnx_model_abstract(cfg, mesh): + """Create an NNX model and get its abstract parameter structure.""" + quant = quantizations.configure_quantization(cfg) + + def create_model(): + rng = jax.random.PRNGKey(0) + params_rng, dropout_rng = jax.random.split(rng) + rngs = nnx.Rngs(params=params_rng, dropout=dropout_rng) + return models.Transformer(cfg, mesh, quant=quant, rngs=rngs, model_mode=MODEL_MODE_TRAIN) + + with mesh: + with nn.logical_axis_rules(cfg.logical_axis_rules): + abstract_model = nnx.eval_shape(create_model) + + _, abstract_state = nnx.split(abstract_model) + return abstract_state + + +def compare_tree_structures(linen_vars, nnx_state, hide_rngs: bool = True): + """Compare the tree structures of Linen and NNX models. + + Args: + linen_vars: Linen model variables (from model.init) + nnx_state: NNX model state (from nnx.split) + hide_rngs: If True, filter out RNG-related paths from comparison + + Returns: + Tuple of (linen_only, nnx_only, shape_mismatches) where: + - linen_only: Set of paths only in Linen + - nnx_only: Set of paths only in NNX + - shape_mismatches: List of (path, linen_shape, nnx_shape) tuples + """ + linen_paths = extract_linen_paths(linen_vars) + nnx_paths = extract_nnx_paths(nnx_state) + + if hide_rngs: + linen_paths = [(p, s, d) for p, s, d in linen_paths if not is_rng_path(p)] + nnx_paths = [(p, s, d) for p, s, d in nnx_paths if not is_rng_path(p)] + + # Normalize paths for comparison + linen_normalized = {} + for path, shape, dtype in linen_paths: + norm_path = normalize_path(path, is_linen=True) + linen_normalized[norm_path] = (path, shape, dtype) + + nnx_normalized = {} + for path, shape, dtype in nnx_paths: + norm_path = path + nnx_normalized[norm_path] = (path, shape, dtype) + + # Find matches and mismatches + linen_only = set(linen_normalized.keys()) - set(nnx_normalized.keys()) + nnx_only = set(nnx_normalized.keys()) - set(linen_normalized.keys()) + common = set(linen_normalized.keys()) & set(nnx_normalized.keys()) + + # Check for shape mismatches in common paths + shape_mismatches = [] + for path in common: + linen_shape = linen_normalized[path][1] + nnx_shape = nnx_normalized[path][1] + if linen_shape != nnx_shape: + shape_mismatches.append((path, linen_shape, nnx_shape)) + + return linen_only, nnx_only, shape_mismatches + + +class TestDecoderTreeStructure(unittest.TestCase): + """Test that Linen and NNX decoders have identical tree structures.""" + + def _check_model_config_exists(self, model_name: str) -> bool: + """Check if a model config file exists.""" + model_config = os.path.join(_MAXTEXT_PKG_DIR_ABS, "configs", "models", f"{model_name}.yml") + return os.path.exists(model_config) + + def _create_configs_and_mesh(self, model_name: str): + """Create Linen and NNX configs and mesh for a model.""" + # Create config for Linen model (uses Linen Decoder) + cfg_linen = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + model_name=model_name, + per_device_batch_size=1.0, + run_name="tree_compare_test", + enable_checkpointing=False, + max_target_length=32, + attention="dot_product", + pure_nnx_decoder=False, + ) + + # Create config for NNX model (uses NNX Decoder) + cfg_nnx = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + model_name=model_name, + per_device_batch_size=1.0, + run_name="tree_compare_test", + enable_checkpointing=False, + max_target_length=32, + attention="dot_product", + pure_nnx_decoder=True, + ) + + devices_array = maxtext_utils.create_device_mesh(cfg_linen) + mesh = Mesh(devices_array, cfg_linen.mesh_axes) + + return cfg_linen, cfg_nnx, mesh + + def verify_tree_structure_match(self, model_name: str): + """Verify that Linen and NNX tree structures match for a model.""" + if not self._check_model_config_exists(model_name): + self.skipTest(f"Model config not found: {model_name}") + + cfg_linen, cfg_nnx, mesh = self._create_configs_and_mesh(model_name) + + # Create abstract models + linen_vars = create_linen_model_abstract(cfg_linen, mesh) + nnx_state = create_nnx_model_abstract(cfg_nnx, mesh) + + # Compare structures (hide RNG paths by default) + linen_only, nnx_only, shape_mismatches = compare_tree_structures(linen_vars, nnx_state, hide_rngs=True) + + # Build error message if there are differences + error_messages = [] + + if linen_only: + error_messages.append(f"Paths only in Linen ({len(linen_only)}):") + for path in sorted(linen_only): + error_messages.append(f" {path}") + + if nnx_only: + error_messages.append(f"Paths only in NNX ({len(nnx_only)}):") + for path in sorted(nnx_only): + error_messages.append(f" {path}") + + if shape_mismatches: + error_messages.append(f"Shape mismatches ({len(shape_mismatches)}):") + for path, linen_shape, nnx_shape in shape_mismatches: + error_messages.append(f" {path}: Linen={linen_shape}, NNX={nnx_shape}") + + # Assert no differences + self.assertEqual( + len(linen_only), + 0, + f"Model {model_name}: Found paths only in Linen\n" + "\n".join(error_messages), + ) + self.assertEqual( + len(nnx_only), + 0, + f"Model {model_name}: Found paths only in NNX\n" + "\n".join(error_messages), + ) + self.assertEqual( + len(shape_mismatches), + 0, + f"Model {model_name}: Found shape mismatches\n" + "\n".join(error_messages), + ) + + +# Generate parametrized test methods for each supported model +@pytest.mark.parametrize("model_name", SUPPORTED_MODELS) +def test_linen_nnx_tree_structure_match(model_name): + """Test that Linen and NNX tree structures match for a model.""" + test_instance = TestDecoderTreeStructure() + test_instance.verify_tree_structure_match(model_name) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From a6fee98cb5a44b25a6d5d9ccf6a73c3ef64010bc Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 29 Jan 2026 17:26:21 +0000 Subject: [PATCH 3/5] refactor code --- tests/unit/decoder_tree_structure_test.py | 30 ++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/tests/unit/decoder_tree_structure_test.py b/tests/unit/decoder_tree_structure_test.py index 4a6894026f..84ba8d0509 100644 --- a/tests/unit/decoder_tree_structure_test.py +++ b/tests/unit/decoder_tree_structure_test.py @@ -207,6 +207,31 @@ def normalize_path(path: str, is_linen: bool = False) -> str: return path +def transpose_nnx_shape_for_scanned_layers(path: str, nnx_shape: tuple) -> tuple: + """Transpose NNX shape for scanned layers to match Linen's axis ordering. + + When scan_layers=True: + - NNX with nnx.vmap puts the layer dimension at axis 0 + - Linen with nn.scan puts the layer dimension at axis 1 + + For paths containing 'layers' with 2+ dimensions, we swap axes 0 and 1. + Example: NNX (32, 4096) -> (4096, 32) to match Linen + + Args: + path: The parameter path string + nnx_shape: The NNX parameter shape tuple + + Returns: + Transposed shape if applicable, otherwise original shape + """ + # Only transpose for layer parameters with 2+ dimensions + if "layers" in path and isinstance(nnx_shape, tuple) and len(nnx_shape) >= 2: + # Swap axes 0 and 1: (0, 1, 2, ...) -> (1, 0, 2, ...) + transposed = (nnx_shape[1], nnx_shape[0]) + nnx_shape[2:] + return transposed + return nnx_shape + + def create_linen_model_abstract(cfg, mesh): """Create a Linen model and get its abstract parameter structure.""" quant = quantizations.configure_quantization(cfg) @@ -296,7 +321,10 @@ def compare_tree_structures(linen_vars, nnx_state, hide_rngs: bool = True): for path in common: linen_shape = linen_normalized[path][1] nnx_shape = nnx_normalized[path][1] - if linen_shape != nnx_shape: + # Apply transpose for scanned layers (NNX vmap puts layer dim at axis 0, + # Linen scan puts it at axis 1) + nnx_shape_normalized = transpose_nnx_shape_for_scanned_layers(path, nnx_shape) + if linen_shape != nnx_shape_normalized: shape_mismatches.append((path, linen_shape, nnx_shape)) return linen_only, nnx_only, shape_mismatches From 161c44f680c56fafb434f827882eb43632104040 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 29 Jan 2026 17:44:05 +0000 Subject: [PATCH 4/5] refactor code --- src/MaxText/compare_linen_nnx_tree.py | 2 +- tests/unit/decoder_tree_structure_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/MaxText/compare_linen_nnx_tree.py b/src/MaxText/compare_linen_nnx_tree.py index be13505258..7acc00f3ed 100644 --- a/src/MaxText/compare_linen_nnx_tree.py +++ b/src/MaxText/compare_linen_nnx_tree.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2023–2025 Google LLC +# Copyright 2023–2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unit/decoder_tree_structure_test.py b/tests/unit/decoder_tree_structure_test.py index 84ba8d0509..0a9b5a3452 100644 --- a/tests/unit/decoder_tree_structure_test.py +++ b/tests/unit/decoder_tree_structure_test.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Google LLC +# Copyright 2023-2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From e4dcc6485455425f3136b6b921b068b695f8ede3 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 30 Jan 2026 01:38:33 +0000 Subject: [PATCH 5/5] refactor code --- tests/unit/test_raw_params_simulation.py | 136 +++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 tests/unit/test_raw_params_simulation.py diff --git a/tests/unit/test_raw_params_simulation.py b/tests/unit/test_raw_params_simulation.py new file mode 100644 index 0000000000..bd1a45b420 --- /dev/null +++ b/tests/unit/test_raw_params_simulation.py @@ -0,0 +1,136 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests demonstrating how to simulate raw_params loading with Orbax. + +raw_params is returned by load_state_if_possible() when loading from a +parameter-only checkpoint (via load_parameters_from_path). This is used +to merge loaded params into a freshly initialized training state. +""" + +import tempfile +import jax +import jax.numpy as jnp +from etils import epath +import orbax.checkpoint as ocp +import pytest + + +class TestRawParamsSimulation: + """Tests for simulating raw_params checkpoint loading.""" + + def test_save_and_load_params_only_checkpoint(self): + """Test saving and loading a params-only checkpoint (raw_params pattern).""" + # Create mock params (model params pytree) + params = { + "decoder": { + "layers": { + "mlp": {"kernel": jnp.ones((16, 32)), "bias": jnp.zeros((32,))}, + "attention": {"query": jnp.zeros((8, 16)), "key": jnp.ones((8, 16))}, + } + } + } + + with tempfile.TemporaryDirectory() as tmpdir: + ckpt_path = epath.Path(tmpdir) / "params_ckpt" + + # Save params-only checkpoint (mimics save_params_to_path) + ckptr = ocp.PyTreeCheckpointer() + ckptr.save(ckpt_path, {"params": params}, force=True) + + # Load params (mimics load_params_from_path -> raw_params) + abstract_params = jax.tree.map( + lambda x: ocp.utils.to_shape_dtype_struct(x, x.dtype), + params, + ) + + restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_params) + restored = ckptr.restore( + ckpt_path, + item={"params": abstract_params}, + restore_args={"params": restore_args}, + ) + raw_params = restored["params"] + + # Verify restored params match original + assert raw_params["decoder"]["layers"]["mlp"]["kernel"].shape == (16, 32) + assert raw_params["decoder"]["layers"]["mlp"]["bias"].shape == (32,) + assert jnp.allclose(raw_params["decoder"]["layers"]["mlp"]["kernel"], jnp.ones((16, 32))) + assert jnp.allclose(raw_params["decoder"]["layers"]["attention"]["query"], jnp.zeros((8, 16))) + + def test_merge_raw_params_into_state(self): + """Test merging raw_params into a fresh state (like maxtext_utils.py:1037-1038).""" + # Simulate fresh initialized params + fresh_params = { + "decoder": { + "layers": { + "mlp": {"kernel": jnp.zeros((16, 32))}, # Fresh init (zeros) + } + } + } + + # Simulate raw_params loaded from checkpoint + raw_params = { + "decoder": { + "layers": { + "mlp": {"kernel": jnp.ones((16, 32))}, # Loaded (ones) + } + } + } + + # Merge: state = state.replace(params=raw_params) + # In practice this replaces the params in a TrainState dataclass + merged_params = raw_params + + # Verify merge replaced fresh params with loaded params + assert jnp.allclose(merged_params["decoder"]["layers"]["mlp"]["kernel"], jnp.ones((16, 32))) + + def test_load_params_with_sharding(self): + """Test loading params with explicit sharding (single device case).""" + params = {"layer": {"weights": jnp.ones((8, 8))}} + + with tempfile.TemporaryDirectory() as tmpdir: + ckpt_path = epath.Path(tmpdir) / "sharded_params_ckpt" + + ckptr = ocp.PyTreeCheckpointer() + ckptr.save(ckpt_path, {"params": params}, force=True) + + # Create abstract params with sharding info + mesh = jax.sharding.Mesh(jax.devices(), ("data",)) + pspec = jax.sharding.PartitionSpec() # Replicated + sharding = jax.sharding.NamedSharding(mesh, pspec) + + def create_restore_args(x): + return ocp.type_handlers.ArrayRestoreArgs(sharding=sharding) + + abstract_params = jax.tree.map( + lambda x: ocp.utils.to_shape_dtype_struct(x, x.dtype), + params, + ) + restore_args = jax.tree.map(create_restore_args, abstract_params) + + restored = ckptr.restore( + ckpt_path, + item={"params": abstract_params}, + restore_args={"params": restore_args}, + ) + raw_params = restored["params"] + + assert raw_params["layer"]["weights"].shape == (8, 8) + assert jnp.allclose(raw_params["layer"]["weights"], jnp.ones((8, 8))) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])