Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
460 changes: 460 additions & 0 deletions src/MaxText/compare_linen_nnx_tree.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,7 @@ subslice_shape: ""

# NNX
enable_nnx: false
pure_nnx_decoder: false

################################## Qwen3-Next Specific Configs ##################################
# Kernel size for the 1D convolution in the Gated Delta Net
Expand Down
1 change: 1 addition & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ class HardwareAndMesh(BaseModel):
enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.")
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")


class LayoutAndSharding(BaseModel):
Expand Down
22 changes: 12 additions & 10 deletions src/MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from MaxText import max_utils
from MaxText.layers import nnx_wrappers
from MaxText.layers.decoders import Decoder
from MaxText.layers.nnx_decoders import NNXDecoder, decoder_as_linen
from MaxText.layers.embeddings import Embed, embed_as_linen
from MaxText.layers.encoders import VisionEncoder, vision_encoder_as_linen, AudioEncoder, audio_encoder_as_linen
from MaxText.layers.quantizations import AqtQuantization as Quant
Expand Down Expand Up @@ -86,7 +87,13 @@ def setup(self):
)
self.vision_encoder = vision_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_multimodal else None
self.audio_encoder = audio_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_audio else None
self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
if cfg.pure_nnx_decoder:
self.decoder = decoder_as_linen(
config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=nnx.Rngs(0)
)
else:
self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)

# If MTP is enabled via config, set up the MTP block.
if self.config.mtp_num_layers > 0:
# Get the list of layer blueprints for the current model.
Expand Down Expand Up @@ -334,9 +341,11 @@ def __init__(
)
self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_multimodal else None
self.audio_encoder = AudioEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_audio else None
if cfg.pure_nnx_decoder:
self.decoder = NNXDecoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs)
else:
self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)

decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs)
self.hidden_states = None

batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode)
Expand All @@ -362,13 +371,6 @@ def __init__(
else:
dummy_attention_metadata = None

self.decoder.lazy_init(
shared_embedding=self.token_embedder,
decoder_input_tokens=dummy_decoder_input_tokens,
decoder_positions=dummy_decoder_positions,
attention_metadata=dummy_attention_metadata,
)

# If MTP is enabled via config, set up the MTP block.
if self.config.mtp_num_layers > 0:
# Get the list of layer blueprints for the current model.
Expand Down
22 changes: 16 additions & 6 deletions src/MaxText/layers/multi_token_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,22 @@ def __init__(
rngs=rngs,
)
# Use MODEL_MODE_TRAIN for initialization; runtime model_mode is passed dynamically.
mtp_transformer_layer = transformer_layer_module(
config=cfg,
mesh=mesh,
model_mode=MODEL_MODE_TRAIN,
name=f"mtp_{k}_transformer_layer",
)
if cfg.pure_nnx_decoder:
mtp_transformer_layer = transformer_layer_module(
config=cfg,
mesh=mesh,
model_mode=MODEL_MODE_TRAIN,
name=f"mtp_{k}_transformer_layer",
rngs=rngs,
)
else:
mtp_transformer_layer = transformer_layer_module(
config=cfg,
mesh=mesh,
model_mode=MODEL_MODE_TRAIN,
name=f"mtp_{k}_transformer_layer",
)

self.transformer_layer = nnx_wrappers.ToNNX(mtp_transformer_layer, rngs=rngs)

# ToNNX requires explicit initialization with sample inputs for proper parameter setup.
Expand Down
Loading
Loading