diff --git a/.github/workflows/run_pathways_tests.yml b/.github/workflows/run_pathways_tests.yml
index 6ff4b7b355..08ab9eab32 100644
--- a/.github/workflows/run_pathways_tests.yml
+++ b/.github/workflows/run_pathways_tests.yml
@@ -98,7 +98,7 @@ jobs:
FINAL_PYTEST_MARKER="${{ inputs.pytest_marker }} and not scheduled_only"
fi
export MAXTEXT_REPO_ROOT=$(pwd)
- export MAXTEXT_ASSETS_ROOT=$(pwd)/src/MaxText/assets
+ export MAXTEXT_ASSETS_ROOT=$(pwd)/src/maxtext/assets
export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/tests/assets
export MAXTEXT_PKG_DIR=$(pwd)/src/MaxText
# TODO(b/454659463): Enable test_default_hlo_match after volume mount is supported.
diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml
index eae8b0f55c..7ad07d1c17 100644
--- a/.github/workflows/run_tests_against_package.yml
+++ b/.github/workflows/run_tests_against_package.yml
@@ -108,7 +108,7 @@ jobs:
fi
# TODO: Use package data for testing and remove the env vars
export MAXTEXT_REPO_ROOT=$(pwd)
- export MAXTEXT_ASSETS_ROOT=$(pwd)/src/MaxText/assets
+ export MAXTEXT_ASSETS_ROOT=$(pwd)/src/maxtext/assets
export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/tests/assets
export MAXTEXT_PKG_DIR=$(pwd)/src/MaxText
# omit this libtpu init args for gpu tests
diff --git a/.vscode/launch.json b/.vscode/launch.json
index c0d04607f2..76a2e505ec 100644
--- a/.vscode/launch.json
+++ b/.vscode/launch.json
@@ -15,7 +15,7 @@
"dataset_path=gs://test-maxtext-dataset",
"model_name=llama2-7b",
"load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_",
- "tokenizer_path=src/MaxText/assets/tokenizer.llama2",
+ "tokenizer_path=src/maxtext/assets/tokenizers/tokenizer.llama2",
"per_device_batch_size=8",
"max_prefill_predict_length=8",
"max_target_length=20",
@@ -70,7 +70,7 @@
"args": [
"src/MaxText/configs/base.yml",
"model_name=llama2-7b",
- "tokenizer_path=src/MaxText/assets/tokenizer.llama2",
+ "tokenizer_path=src/maxtext/assets/tokenizers/tokenizer.llama2",
"weight_dtype=bfloat16",
"scan_layers=false",
"attention=dot_product",
diff --git a/benchmarks/globals.py b/benchmarks/globals.py
index ba3a625b72..ab23984a7a 100644
--- a/benchmarks/globals.py
+++ b/benchmarks/globals.py
@@ -25,7 +25,7 @@
r if os.path.isdir(os.path.join(r := os.path.dirname(os.path.dirname(__file__)), ".git")) else MAXTEXT_PKG_DIR,
)
-# This is the assets root: with "tokenizer.gemma3"; &etc.
-MAXTEXT_ASSETS_ROOT = os.environ.get("MAXTEXT_ASSETS_ROOT", os.path.join(MAXTEXT_PKG_DIR, "assets"))
+# This is the assets root: with "tokenizers/"; &etc.
+MAXTEXT_ASSETS_ROOT = os.environ.get("MAXTEXT_ASSETS_ROOT", os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "assets"))
__all__ = ["MAXTEXT_ASSETS_ROOT", "MAXTEXT_PKG_DIR", "MAXTEXT_REPO_ROOT"]
diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py
index d1e8f28fff..4950c8f57b 100644
--- a/benchmarks/maxtext_trillium_model_configs.py
+++ b/benchmarks/maxtext_trillium_model_configs.py
@@ -544,7 +544,7 @@
"profiler": "xplane",
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "tfds",
- "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
+ "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
@@ -1280,7 +1280,7 @@
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,
"tokenizer_type": "tiktoken",
- "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer_llama3.tiktoken"),
+ "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"),
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
@@ -1336,7 +1336,7 @@
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,
"tokenizer_type": "tiktoken",
- "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer_llama3.tiktoken"),
+ "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"),
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
@@ -1517,7 +1517,7 @@
"megablox": False,
"sparse_matmul": False,
"capacity_factor": 1.25,
- "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.mistral-v1"),
+ "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.mistral-v1"),
},
xla_flags=(
xla_flags_library.MOE_VMEM_LIMIT_FLAG
@@ -1552,7 +1552,7 @@
"sparse_matmul": False,
"capacity_factor": 1.25,
"quantization": "int8",
- "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.mistral-v1"),
+ "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.mistral-v1"),
},
xla_flags=(
xla_flags_library.MOE_VMEM_LIMIT_FLAG
@@ -1593,7 +1593,7 @@
"megablox": False,
"sparse_matmul": False,
"capacity_factor": 1.25,
- "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.mistral-v3"),
+ "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.mistral-v3"),
"dtype": "bfloat16",
"weight_dtype": "bfloat16",
"allow_split_physical_axes": True,
@@ -1634,7 +1634,7 @@
"megablox": False,
"sparse_matmul": False,
"capacity_factor": 1.0,
- "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.mistral-v3"),
+ "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.mistral-v3"),
"dtype": "bfloat16",
"opt_type": "sgd",
"weight_dtype": "bfloat16",
@@ -1667,7 +1667,7 @@
"reuse_example_batch": 1,
"enable_checkpointing": False,
"profiler": "xplane",
- "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
+ "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
"sa_block_q": 2048,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
@@ -1700,7 +1700,7 @@
"reuse_example_batch": 1,
"enable_checkpointing": False,
"profiler": "xplane",
- "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
+ "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
"sa_block_q": 2048,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
@@ -1739,7 +1739,7 @@
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 2,
- "tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
+ "tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"),
"sa_block_q": 1024,
"sa_block_kv": 1024,
"sa_block_kv_compute": 1024,
@@ -1779,7 +1779,7 @@
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 2,
- "tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
+ "tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"),
"sa_block_q": 1024,
"sa_block_kv": 1024,
"sa_block_kv_compute": 1024,
@@ -1819,7 +1819,7 @@
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 2,
- "tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
+ "tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"),
"sa_block_q": 1024,
"sa_block_kv": 1024,
"sa_block_kv_compute": 1024,
@@ -1868,7 +1868,7 @@
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,
"tokenizer_type": "tiktoken",
- "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer_llama3.tiktoken"),
+ "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"),
"packing": False,
},
xla_flags=(
@@ -1933,7 +1933,7 @@
"sa_use_fused_bwd_kernel": True,
"sparse_matmul": False,
"capacity_factor": 1.5,
- "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.mistral-v1"),
+ "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.mistral-v1"),
"dtype": "bfloat16",
"weight_dtype": "bfloat16",
"opt_type": "sgd",
diff --git a/benchmarks/maxtext_v5e_model_configs.py b/benchmarks/maxtext_v5e_model_configs.py
index 445cdf0abc..1e977f533c 100644
--- a/benchmarks/maxtext_v5e_model_configs.py
+++ b/benchmarks/maxtext_v5e_model_configs.py
@@ -149,7 +149,7 @@
"remat_policy": "save_qkv_proj",
"max_target_length": 2048,
"use_iota_embed": True,
- "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
+ "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"reuse_example_batch": 1,
@@ -171,7 +171,7 @@
"remat_policy": "qkv_proj_offloaded",
"max_target_length": 2048,
"use_iota_embed": True,
- "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
+ "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"reuse_example_batch": 1,
diff --git a/benchmarks/maxtext_v5p_model_configs.py b/benchmarks/maxtext_v5p_model_configs.py
index ae9563fd10..f228b0f7fc 100644
--- a/benchmarks/maxtext_v5p_model_configs.py
+++ b/benchmarks/maxtext_v5p_model_configs.py
@@ -227,7 +227,7 @@
"remat_policy": "minimal",
"max_target_length": 4096,
"use_iota_embed": True,
- "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
+ "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"reuse_example_batch": 1,
diff --git a/docs/guides/data_input_pipeline/data_input_tfds.md b/docs/guides/data_input_pipeline/data_input_tfds.md
index a2df3dcae0..03c38e9838 100644
--- a/docs/guides/data_input_pipeline/data_input_tfds.md
+++ b/docs/guides/data_input_pipeline/data_input_tfds.md
@@ -16,5 +16,5 @@ eval_interval: 10000
eval_dataset_name: 'c4/en:3.0.1'
eval_split: 'validation'
# TFDS input pipeline only supports tokenizer in spm format
-tokenizer_path: 'src/MaxText/assets/tokenizer.llama2'
+tokenizer_path: 'src/maxtext/assets/tokenizers/tokenizer.llama2'
```
diff --git a/docs/tutorials/posttraining/multimodal.md b/docs/tutorials/posttraining/multimodal.md
index e845bd1ffe..bcf7a35b75 100644
--- a/docs/tutorials/posttraining/multimodal.md
+++ b/docs/tutorials/posttraining/multimodal.md
@@ -1,20 +1,21 @@
-
-
# Multimodal support
This document provides a guide to use the multimodal functionalities in MaxText including:
+
- **Checkpoint Conversion**: Convert a MaxText-compatible orbax checkpoint from HuggingFace.
- **Multimodal Decode**: Inference with text+images as input.
- **Supervised Fine-Tuning (SFT)**: Apply SFT to the model using a visual-question-answering dataset.
We also provide a [colab](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/multimodal_gemma3_demo.ipynb) for multimodal features demonstration. The following table provides a list of models and modalities we currently support:
-| Models | Input Modalities | Output Modalities |
-| :---- | :---- | :---- |
-| - Gemma3-4B/12B/27B
- Llama4-Scout/Maverick | Text, images | Text |
+
+| Models | Input Modalities | Output Modalities |
+| :--------------------------------------------- | :--------------- | :---------------- |
+| - Gemma3-4B/12B/27B
- Llama4-Scout/Maverick | Text, images | Text |
## Introduction
-Multimodal Large Language Models (LLMs) extend traditional text-only models by incorporating multiple input modalities such as images, audio, and video. For each non-text modality, the architecture typically follows a three-stage pipeline:
+Multimodal Large Language Models (LLMs) extend traditional text-only models by incorporating multiple input modalities such as images, audio, and video. For each non-text modality, the architecture typically follows a three-stage pipeline:
+
- **Data Preprocessing**: We apply modality-specific preprocessing steps to prepare the raw input data (e.g., image resizing and normalization), transforming them into a format which neural networks can understand.
- **Modality-Specific Encoders**: Modality-specific encoders will transform the preprocessed data into high-dimensional representations (e.g., vision transformers for images).
- **Projection and Merge**: Projection layers will map these modality-specific embeddings into the shared embedding space of the language model, usually aligned with the dimension of text embeddings. These projected embeddings are then merged with text token embeddings, allowing the unified model to process and reason over multiple modalities simultaneously within a single coherent framework.
@@ -22,12 +23,12 @@ Multimodal Large Language Models (LLMs) extend traditional text-only models by i

*Figure 1: Overview of multimodal dataflow in MaxText.*
-
## Checkpoint Conversion
Recently we have onboarded a new centralized tool for bidirectional checkpoint conversion between MaxText and HuggingFace ([README](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/README.md)).
Install pytorch:
+
```
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
```
@@ -58,7 +59,9 @@ python -m MaxText.utils.ckpt_scripts.llama4_ckpt_unscanned \
```
## Multimodal Decode
+
MaxText supports multimodal decoding, allowing you to input text with multiple images to get a text output. To use this feature, you need three main settings:
+
- `use_multimodal=True`: Initializes the multimodal preprocessing steps and network components.
- `prompt`: Specifies the position of image placeholder tokens in your input. If you don't manually place them, MaxText will automatically append the required placeholder (e.g., `` for Gemma3, `<|image|>` for Llama4). The exact placeholder is listed under the `image_placeholder` field in each model's configuration file.
- `image_path`: The path(s) to the image file(s) MaxText will load and process.
@@ -73,7 +76,7 @@ python -m MaxText.decode \
MaxText/configs/base.yml \
model_name=gemma3-4b \
hf_access_token=$HF_ACCESS_TOKEN \
- tokenizer_path=src/MaxText/assets/tokenizer.gemma3 \
+ tokenizer_path=src/maxtext/assets/tokenizers/tokenizer.gemma3 \
load_parameters_path=$MAXTEXT_CKPT_GCS_PATH/0/items \
per_device_batch_size=1 \
run_name=ht_test \
@@ -89,6 +92,7 @@ python -m MaxText.decode \
```
The decoding results will look like this:
+
```
Input `user
Describe image
@@ -123,7 +127,6 @@ Supervised Fine-Tuning (SFT) of multimodal LLMs in MaxText focuses specifically
Here, we use [ChartQA](https://huggingface.co/datasets/HuggingFaceM4/ChartQA) as an example to demonstrate SFT functionality:
-
```shell
export UNSCANNED_CKPT_PATH=... # either set to an already available MaxText ckpt or to the one we just converted in the previous step
python -m MaxText.sft_trainer \
@@ -148,14 +151,16 @@ python -m MaxText.sft_trainer \
```
## Other Recommendations
+
- **Setting appropriate prefill length**: To prevent truncation and ensure your full input (text + image) is processed, the prefill length should be set longer than the total combined length of your text tokens and image tokens. This combined length makes up the final sequence fed to the decoder. We recommend to estimate the combined sequence length from your full input and then add a buffer when setting your `max_prefill_predict_length` for decoding. Token estimation rules:
- - For text tokens, a good estimate is:
-
- $\text{Text Tokens} \approx 1.3 \times \text{Number of Words in Prompt}$.
- - For Gemma3, each image is resized to 896*896 and contributes 256 tokens:
-
- $\text{Total Tokens} \approx \text{Text Tokens} + \text{Number of Images} * 256$.
- - For Llama4 models, each image is dynamically tiled based on its size, with each resulting tile contributing 144 tokens:
-
- $\text{Total Tokens} \approx \text{Text Tokens} + 144 \times \sum_{i=1}^{N} \text{Number of Tiles of Image}_i$.
+ - For text tokens, a good estimate is:
+
+ $\text{Text Tokens} \approx 1.3 \times \text{Number of Words in Prompt}$.
+
+ - For Gemma3, each image is resized to 896\*896 and contributes 256 tokens:
+
+ $\text{Total Tokens} \approx \text{Text Tokens} + \text{Number of Images} * 256$.
+
+ - For Llama4 models, each image is dynamically tiled based on its size, with each resulting tile contributing 144 tokens:
+ $\text{Total Tokens} \approx \text{Text Tokens} + 144 \times \sum_{i=1}^{N} \text{Number of Tiles of Image}_i$.
diff --git a/end_to_end/gpu/a3/test_gemma3_logits.sh b/end_to_end/gpu/a3/test_gemma3_logits.sh
index f92eb2e00a..e5ee235c6c 100644
--- a/end_to_end/gpu/a3/test_gemma3_logits.sh
+++ b/end_to_end/gpu/a3/test_gemma3_logits.sh
@@ -44,5 +44,5 @@ python3 -m MaxText.utils.ckpt_scripts.convert_gemma3_chkpt --base_model_path ${C
export UNSCANNED_CKPT_PATH=gs://runner-maxtext-logs/unscanned_chkpt_2025-04-16-00-01/checkpoints/0/items
export NVTE_FUSED_ATTN=1
# # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
-python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} hardware=gpu attention=cudnn_flash_te per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0
+python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} hardware=gpu attention=cudnn_flash_te per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0
diff --git a/end_to_end/gpu/mixtral/test_8x7b.sh b/end_to_end/gpu/mixtral/test_8x7b.sh
index ece8f5f600..d665e8a0c6 100644
--- a/end_to_end/gpu/mixtral/test_8x7b.sh
+++ b/end_to_end/gpu/mixtral/test_8x7b.sh
@@ -31,7 +31,7 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT
attention=cudnn_flash_te capacity_factor=1.25 dtype=bfloat16 \
enable_checkpointing=false ici_expert_parallelism=-1 ici_fsdp_parallelism=1 \
max_target_length=1024 megablox=False per_device_batch_size=1 \
- reuse_example_batch=1 steps=5 tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 \
+ reuse_example_batch=1 steps=5 tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 \
weight_dtype=bfloat16 sparse_matmul=False packing=False
echo "Finished pre-training"
@@ -43,7 +43,7 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT
attention=cudnn_flash_te capacity_factor=1.25 dtype=bfloat16 \
ici_expert_parallelism=-1 ici_fsdp_parallelism=1 \
max_target_length=1024 megablox=False per_device_batch_size=1 \
- reuse_example_batch=1 steps=5 tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 \
+ reuse_example_batch=1 steps=5 tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 \
weight_dtype=bfloat16 sparse_matmul=False packing=False
echo "Finished fine-tuning"
@@ -55,7 +55,7 @@ echo "Finished fine-tuning"
# ici_expert_parallelism=8 ici_fsdp_parallelism=1 max_prefill_predict_length=11 \
# max_target_length=24 megablox=False per_device_batch_size=1 \
# prompt='"[INST] I love to [/INST]"' scan_layers=false \
-# tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1
+# tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1
# echo "Finished decoding"
diff --git a/end_to_end/tpu/gemma/2b/test_gemma.sh b/end_to_end/tpu/gemma/2b/test_gemma.sh
index a8b68795ab..8e10a322b5 100644
--- a/end_to_end/tpu/gemma/2b/test_gemma.sh
+++ b/end_to_end/tpu/gemma/2b/test_gemma.sh
@@ -39,17 +39,16 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/it
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state.
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to"
# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to"
# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning
export FINETUNE_RUN_NAME=runner_finetune_${idx}
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-2b checkpoint_period=5
-
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-2b checkpoint_period=5
# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma per_device_batch_size=1 run_name=runner_pretrain_${idx} max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-2b
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma per_device_batch_size=1 run_name=runner_pretrain_${idx} max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-2b
# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters.
# So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run.
@@ -58,10 +57,10 @@ export PARAM_RUN_NAME=param_chkpt_${idx}
python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-2b' force_unroll=true
# Now, run decoding on the checkpoint generated from our finetune run.
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to"
# We also test whether the forward pass logits match the golden logits for Gemma-2b
-python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2b per_device_batch_size=1 model_name=gemma-2b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false attention=dot_product --max_kl_div=0.015
+python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2b per_device_batch_size=1 model_name=gemma-2b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false attention=dot_product --max_kl_div=0.015
# We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance.
# This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 2B.
diff --git a/end_to_end/tpu/gemma/7b/2_test_gemma.sh b/end_to_end/tpu/gemma/7b/2_test_gemma.sh
index 7979a28af6..261e4fb6ab 100644
--- a/end_to_end/tpu/gemma/7b/2_test_gemma.sh
+++ b/end_to_end/tpu/gemma/7b/2_test_gemma.sh
@@ -39,17 +39,16 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items
export ASYNC_CHECKPOINTING=True # True so that the jax distributed system is initialized
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state.
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to"
# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=gemma-7b attention=dot_product prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=gemma-7b attention=dot_product prompt="I love to"
# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning
export FINETUNE_RUN_NAME=runner_finetune
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=gemma-7b checkpoint_period=5
-
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=gemma-7b checkpoint_period=5
# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-7b
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-7b
# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters.
# So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run.
@@ -58,7 +57,7 @@ export PARAM_RUN_NAME=param_chkpt
python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-7b' force_unroll=true
# Now, run decoding on the checkpoint generated from our finetune run.
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to"
# We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance.
# This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 7B.
diff --git a/end_to_end/tpu/gemma2/27b/2_test_gemma.sh b/end_to_end/tpu/gemma2/27b/2_test_gemma.sh
index 9f9d6a1ba5..c4429ddc60 100644
--- a/end_to_end/tpu/gemma2/27b/2_test_gemma.sh
+++ b/end_to_end/tpu/gemma2/27b/2_test_gemma.sh
@@ -39,11 +39,11 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state.
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-27b attention=dot_product prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-27b attention=dot_product prompt="I love to"
# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-27b attention=dot_product prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-27b attention=dot_product prompt="I love to"
# We also test whether the forward pass logits match the golden logits for Gemma2-27b
# to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
-python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2_27b per_device_batch_size=1 model_name=gemma2-27b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype='float32' --atol=1.0 --rtol=1.0 --max_kl_div=0.15
+python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2_27b per_device_batch_size=1 model_name=gemma2-27b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype='float32' --atol=1.0 --rtol=1.0 --max_kl_div=0.15
\ No newline at end of file
diff --git a/end_to_end/tpu/gemma2/2b/test_gemma2.sh b/end_to_end/tpu/gemma2/2b/test_gemma2.sh
index ba4e45530b..af86ea247b 100644
--- a/end_to_end/tpu/gemma2/2b/test_gemma2.sh
+++ b/end_to_end/tpu/gemma2/2b/test_gemma2.sh
@@ -41,17 +41,17 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/it
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state.
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-2b prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-2b prompt="I love to"
# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-2b prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-2b prompt="I love to"
# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning
export FINETUNE_RUN_NAME=runner_finetune_${idx}
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma2-2b checkpoint_period=5
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma2-2b checkpoint_period=5
# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma per_device_batch_size=1 run_name=runner_pretrain_${idx} max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma2-2b
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma per_device_batch_size=1 run_name=runner_pretrain_${idx} max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma2-2b
# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters.
# So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run.
@@ -60,8 +60,8 @@ export PARAM_RUN_NAME=param_chkpt_${idx}
python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma2-2b' force_unroll=true
# Now, run decoding on the checkpoint generated from our finetune run.
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-2b prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-2b prompt="I love to"
# We also test whether the forward pass logits match the golden logits for Gemma2-2b
# to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
-python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2_2b per_device_batch_size=1 model_name=gemma2-2b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype='float32' --atol=1.0 --rtol=1.0
+python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2_2b per_device_batch_size=1 model_name=gemma2-2b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype='float32' --atol=1.0 --rtol=1.0
\ No newline at end of file
diff --git a/end_to_end/tpu/gemma2/2b/test_gemma2_to_hf.sh b/end_to_end/tpu/gemma2/2b/test_gemma2_to_hf.sh
index 67f02a417f..0701ef2c46 100644
--- a/end_to_end/tpu/gemma2/2b/test_gemma2_to_hf.sh
+++ b/end_to_end/tpu/gemma2/2b/test_gemma2_to_hf.sh
@@ -15,7 +15,7 @@ set -ex
idx=$(date +%Y-%m-%d-%H-%M)
MODEL_NAME='gemma2-2b'
export MODEL_VARIATION='2b'
-TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"'/tokenizer.gemma'
+TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"'/tokenizer.gemma'
# Installing torch for deps in forward_pass_logit_checker.py
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
diff --git a/end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh b/end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh
index d7f9f521c3..32469bfe68 100644
--- a/end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh
+++ b/end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh
@@ -18,7 +18,7 @@ idx=$(date +%Y-%m-%d-%H-%M)
MODEL_NAME='gemma2-2b'
export MODEL_VARIATION='2b'
HF_GOLDEN_MODEL='google/gemma-2-2b'
-TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"'/tokenizer.gemma'
+TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"'/tokenizer.gemma'
# Installing torch for deps in forward_pass_logit_checker.py
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
diff --git a/end_to_end/tpu/gemma2/9b/2_test_gemma.sh b/end_to_end/tpu/gemma2/9b/2_test_gemma.sh
index dfd2c54b50..1db075e973 100644
--- a/end_to_end/tpu/gemma2/9b/2_test_gemma.sh
+++ b/end_to_end/tpu/gemma2/9b/2_test_gemma.sh
@@ -40,11 +40,11 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state.
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-9b attention=dot_product prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-9b attention=dot_product prompt="I love to"
# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-9b attention=dot_product prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-9b attention=dot_product prompt="I love to"
# We also test whether the forward pass logits match the golden logits for Gemma2-9b
# to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
-python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2_9b per_device_batch_size=1 model_name=gemma2-9b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype='float32' --atol=1.0 --rtol=1.0 --max_kl_div=0.15
\ No newline at end of file
+python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2_9b per_device_batch_size=1 model_name=gemma2-9b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype='float32' --atol=1.0 --rtol=1.0 --max_kl_div=0.15
\ No newline at end of file
diff --git a/end_to_end/tpu/gemma3/12b/test_gemma3.sh b/end_to_end/tpu/gemma3/12b/test_gemma3.sh
index 10a4e7372e..b9500dbd63 100644
--- a/end_to_end/tpu/gemma3/12b/test_gemma3.sh
+++ b/end_to_end/tpu/gemma3/12b/test_gemma3.sh
@@ -43,10 +43,10 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/it
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state.
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to"
# # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
-python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0
+python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0
# Finetune by using the scanned converted checkpoint by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. For Googlers, uncomment the line below if you want to use the pre-converted checkpoint.
# export CONVERTED_CHECKPOINT=gs://maxtext-model-checkpoints/gemma3-12b/2025-03-19-21-16/scanned/0/items
@@ -55,4 +55,4 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT
# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load_parameters_path
PRETRAIN_RUN_NAME=runner_pretrain_${idx}
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03
diff --git a/end_to_end/tpu/gemma3/27b/test_gemma3.sh b/end_to_end/tpu/gemma3/27b/test_gemma3.sh
index f3ddf8e74a..388522be6b 100644
--- a/end_to_end/tpu/gemma3/27b/test_gemma3.sh
+++ b/end_to_end/tpu/gemma3/27b/test_gemma3.sh
@@ -43,16 +43,16 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/it
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state.
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to"
# # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
-python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0
+python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0
# Finetune by using the scanned converted checkpoint by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. For Googlers, uncomment the line below if you want to use the pre-converted checkpoint.
# export CONVERTED_CHECKPOINT=gs://maxtext-model-checkpoints/gemma3-27b/2025-03-20-00-12/scanned/0/items
FINETUNE_RUN_NAME=runner_finetune_${idx}
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$FINETUNE_RUN_NAME steps=10 enable_checkpointing=true sharding_tolerance=0.03
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$FINETUNE_RUN_NAME steps=10 enable_checkpointing=true sharding_tolerance=0.03
# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load_parameters_path
PRETRAIN_RUN_NAME=runner_pretrain_${idx}
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03
\ No newline at end of file
diff --git a/end_to_end/tpu/gemma3/4b/test_gemma3.sh b/end_to_end/tpu/gemma3/4b/test_gemma3.sh
index f8da8ce5da..80007dad9a 100644
--- a/end_to_end/tpu/gemma3/4b/test_gemma3.sh
+++ b/end_to_end/tpu/gemma3/4b/test_gemma3.sh
@@ -43,16 +43,16 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/it
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state.
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to"
# # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
-python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0
+python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0
# Finetune by using the scanned converted checkpoint by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. For Googlers, uncomment the line below if you want to use the pre-converted checkpoint.
# export CONVERTED_CHECKPOINT=gs://maxtext-model-checkpoints/gemma3-4b/2025-03-18-19-03/scanned/0/items
FINETUNE_RUN_NAME=runner_finetune_${idx}
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$FINETUNE_RUN_NAME steps=10 enable_checkpointing=true sharding_tolerance=0.03
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$FINETUNE_RUN_NAME steps=10 enable_checkpointing=true sharding_tolerance=0.03
# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load_parameters_path
PRETRAIN_RUN_NAME=runner_pretrain_${idx}
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03
diff --git a/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh b/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh
index 7f6a33faa5..e8ce5ecd31 100644
--- a/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh
+++ b/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh
@@ -17,7 +17,7 @@ MODEL_NAME='gemma3-4b'
export MODEL_VARIATION='4b'
HF_TOKEN='' # Important!!! Save your hf access token here
HF_GOLDEN_MODEL='google/gemma-3-4b-pt'
-TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"'/tokenizer.gemma3'
+TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"'/tokenizer.gemma3'
# To convert the multimodal model, make sure the use_multimodal is set to be true
USE_MULTIMODAL=true
SCAN_LAYERS=false
diff --git a/end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh b/end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh
index b6b17ba7b3..a1d4fa727d 100644
--- a/end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh
+++ b/end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh
@@ -14,7 +14,7 @@ set -ex
idx=$(date +%Y-%m-%d-%H-%M)
MODEL_NAME='gemma3-4b'
export MODEL_VARIATION='4b'
-TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"'/tokenizer.gemma3'
+TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"'/tokenizer.gemma3'
# To convert the multimodal model, make sure the use_multimodal is set to be true
USE_MULTIMODAL=false
diff --git a/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh b/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh
index cdc570a745..c654c3d672 100644
--- a/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh
+++ b/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh
@@ -18,7 +18,7 @@ idx=$(date +%Y-%m-%d-%H-%M)
MODEL_NAME='gemma3-4b'
export MODEL_VARIATION='4b'
HF_GOLDEN_MODEL='google/gemma-3-4b-it'
-TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"'/tokenizer.gemma3'
+TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"'/tokenizer.gemma3'
# To convert the multimodal model, make sure the use_multimodal is set to be true
USE_MULTIMODAL=false
diff --git a/end_to_end/tpu/gemma3/Run_Gemma3.md b/end_to_end/tpu/gemma3/Run_Gemma3.md
index 8265afd07f..055ae353b2 100644
--- a/end_to_end/tpu/gemma3/Run_Gemma3.md
+++ b/end_to_end/tpu/gemma3/Run_Gemma3.md
@@ -25,7 +25,7 @@ We provide examples for checkpoint conversion and decoding/training/finetuning G
You can train from scratch to generate a new checkpoint. One example command to run pretraining Gemma3-4B model is as follows:
```sh
-python3 -m MaxText.train src/MaxText/configs/base.yml model_name=gemma3-4b base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=runner_pretrain_gemma3_4b steps=10 enable_checkpointing=false sharding_tolerance=0.03
+python3 -m MaxText.train src/MaxText/configs/base.yml model_name=gemma3-4b base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=runner_pretrain_gemma3_4b steps=10 enable_checkpointing=false sharding_tolerance=0.03
```
## Checkpoint Conversion
@@ -35,12 +35,12 @@ To obtain the Gemma3 model weights, follow the instructions provided on [Kaggle]
After the conversion, you will have a MaxText compatible checkpoint which allows you to fine-tune it with different datasets. One example command to fine-tune a Gemma3-4B model is as follows:
```
-python3 -m MaxText.train src/MaxText/configs/base.yml model_name=gemma3-4b base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=runner_finetune_gemma3_4b steps=10 enable_checkpointing=true sharding_tolerance=0.03
+python3 -m MaxText.train src/MaxText/configs/base.yml model_name=gemma3-4b base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=runner_finetune_gemma3_4b steps=10 enable_checkpointing=true sharding_tolerance=0.03
```
## Decoding
One example to use a converted checkpoint to decode with prompt "I love to":
```
-python3 -m MaxText.decode src/MaxText/configs/base.yml model_name=gemma3-4b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_decode_gemma3_4b max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false prompt="I love to"
+python3 -m MaxText.decode src/MaxText/configs/base.yml model_name=gemma3-4b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_decode_gemma3_4b max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false prompt="I love to"
```
\ No newline at end of file
diff --git a/end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh b/end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh
index e4a951beb9..357fb06823 100644
--- a/end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh
+++ b/end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh
@@ -39,17 +39,16 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state.
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
# We compare our decoded results by asserting with golden outputs using `autoregressive_decode_assert`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to teach."
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to teach."
# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to teach."
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to teach."
# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning
export FINETUNE_RUN_NAME=runner_finetune
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=false model_name=${MODEL_VARIATION} checkpoint_period=5
-
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=false model_name=${MODEL_VARIATION} checkpoint_period=5
# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION}
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION}
# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters.
# So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run.
@@ -58,4 +57,4 @@ export PARAM_RUN_NAME=param_chkpt
python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true
# Now, run decoding on the checkpoint generated from our finetune run.
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to"
diff --git a/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh b/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh
index 72ea1c212d..e7ea05b8f5 100644
--- a/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh
+++ b/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh
@@ -44,17 +44,16 @@ export ASYNC_CHECKPOINTING=true # True so that jax distributed system is initial
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state.
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to"
# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to"
# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning
export FINETUNE_RUN_NAME=runner_finetune
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=${MODEL_VARIATION} checkpoint_period=5
-
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=${MODEL_VARIATION} checkpoint_period=5
# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION}
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION}
# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters.
# So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run.
@@ -63,7 +62,7 @@ export PARAM_RUN_NAME=param_chkpt
python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true
# Now, run decoding on the checkpoint generated from our finetune run.
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to"
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to"
# We also test whether the forward pass logits match the golden logits for Llama2-70b
python3 -m tests.utils.forward_pass_logit_checker --atol=0.2 --rtol=0.2 "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=llama2-70b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false async_checkpointing=${ASYNC_CHECKPOINTING}
diff --git a/end_to_end/tpu/mistral/7b/test_mistral-7b.sh b/end_to_end/tpu/mistral/7b/test_mistral-7b.sh
index 94fcae05f4..1c0428b251 100644
--- a/end_to_end/tpu/mistral/7b/test_mistral-7b.sh
+++ b/end_to_end/tpu/mistral/7b/test_mistral-7b.sh
@@ -40,7 +40,7 @@ echo "Wrote MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN
export DATASET_PATH=gs://maxtext-dataset
# Run decoding with converted ckpt - matmul implementation
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${SCANNED_CHECKPOINT} run_name=scanned_decoding per_device_batch_size=1 model_name=mistral-7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=16 prompt='"[INST] I love to [/INST]"' attention=dot_product megablox=False sparse_matmul=False
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${SCANNED_CHECKPOINT} run_name=scanned_decoding per_device_batch_size=1 model_name=mistral-7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=16 prompt='"[INST] I love to [/INST]"' attention=dot_product megablox=False sparse_matmul=False
# Test whether the forward pass logits match the golden logits - matmul implementation
-python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mistral-7b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 megablox=False sparse_matmul=False --atol=3 --rtol=1 --token_size=4
+python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mistral-7b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 megablox=False sparse_matmul=False --atol=3 --rtol=1 --token_size=4
\ No newline at end of file
diff --git a/end_to_end/tpu/mixtral/8x22b/2_test_mixtral.sh b/end_to_end/tpu/mixtral/8x22b/2_test_mixtral.sh
index bd2bb9a0b7..6206ab2908 100644
--- a/end_to_end/tpu/mixtral/8x22b/2_test_mixtral.sh
+++ b/end_to_end/tpu/mixtral/8x22b/2_test_mixtral.sh
@@ -20,7 +20,7 @@ if [ -z "${BASE_OUTPUT_PATH}" ]; then
fi
export DATASET_PATH=gs://maxtext-dataset
-export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v3
+export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v3
# Run pre-training without load_parameters_path - megablox implementation
python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \
diff --git a/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh b/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh
index 7535b8e337..6bd1334583 100644
--- a/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh
+++ b/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh
@@ -35,21 +35,21 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned_ckpt/checkpoints/0/item
# Run decoding with converted ckpt - matmul implementation
# TODO(ranran): add decoding test for megablox implementation
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=24 prompt='"[INST] I love to [/INST]"' megablox=False sparse_matmul=False scan_layers=false
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=24 prompt='"[INST] I love to [/INST]"' megablox=False sparse_matmul=False scan_layers=false
# Run decoding with converted ckpt - dropping implementation
-python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=24 prompt='"[INST] I love to [/INST]"' megablox=False sparse_matmul=False scan_layers=false capacity_factor=1.25
+python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=24 prompt='"[INST] I love to [/INST]"' megablox=False sparse_matmul=False scan_layers=false capacity_factor=1.25
# Test whether the forward pass logits match the golden logits - matmul implementation
-python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mixtral-8x7b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=11 dtype=float32 megablox=False sparse_matmul=False scan_layers=false --token_size=4 --max_kl_div=3e-3
+python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mixtral-8x7b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=11 dtype=float32 megablox=False sparse_matmul=False scan_layers=false --token_size=4 --max_kl_div=3e-3
# To repeat duplicate tests, we have MoE unit test to verify outputs matching for matmul, megablox, and ragged_dot implementation at https://github.com/AI-Hypercomputer/maxtext/blob/5c4090b8d5713a1a25cab85df89b0ec9c9862635/MaxText/tests/unit/moe_test.py#L338-L411
# Run pre-training - megablox implementation
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=megablox_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=-1 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=megablox_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=-1 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16
# Run pre-training - matmul implementation
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=matmul_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=-1 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=matmul_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=-1 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False
# Run pre-training - dropping implementation
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=dropping_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=-1 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False capacity_factor=1
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=dropping_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=-1 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False capacity_factor=1
\ No newline at end of file
diff --git a/end_to_end/tpu/qwen/moe/qwen3-235b-a22b/1_test_qwen3_235b_a22b.sh b/end_to_end/tpu/qwen/moe/qwen3-235b-a22b/1_test_qwen3_235b_a22b.sh
index a1a1fdc9e1..3ba50596a8 100644
--- a/end_to_end/tpu/qwen/moe/qwen3-235b-a22b/1_test_qwen3_235b_a22b.sh
+++ b/end_to_end/tpu/qwen/moe/qwen3-235b-a22b/1_test_qwen3_235b_a22b.sh
@@ -42,7 +42,7 @@ echo "Against original HF model: ${HF_MODEL_PATH}"
# This command runs the core validation logic.
JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \
tokenizer_type=huggingface \
- tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/qwen3-tokenizer \
+ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/qwen3-tokenizer \
megablox=False \
sparse_matmul=False \
load_parameters_path=${MAXTEXT_CHECKPOINT_PATH} \
diff --git a/end_to_end/tpu/qwen/moe/qwen3-30b-a3b/1_test_qwen3_30b_a3b.sh b/end_to_end/tpu/qwen/moe/qwen3-30b-a3b/1_test_qwen3_30b_a3b.sh
index 96ad0a5160..8906ddab23 100644
--- a/end_to_end/tpu/qwen/moe/qwen3-30b-a3b/1_test_qwen3_30b_a3b.sh
+++ b/end_to_end/tpu/qwen/moe/qwen3-30b-a3b/1_test_qwen3_30b_a3b.sh
@@ -42,7 +42,7 @@ echo "Against original HF model: ${HF_MODEL_PATH}"
# This command runs the core validation logic.
JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \
tokenizer_type=huggingface \
- tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/qwen3-tokenizer \
+ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/qwen3-tokenizer \
megablox=False \
sparse_matmul=False \
load_parameters_path=${MAXTEXT_CHECKPOINT_PATH} \
diff --git a/end_to_end/tpu/qwen/moe/qwen3-480b-a35b/1_test_qwen3_480b_a35b.sh b/end_to_end/tpu/qwen/moe/qwen3-480b-a35b/1_test_qwen3_480b_a35b.sh
index 5e66667aab..f8e06a5aec 100644
--- a/end_to_end/tpu/qwen/moe/qwen3-480b-a35b/1_test_qwen3_480b_a35b.sh
+++ b/end_to_end/tpu/qwen/moe/qwen3-480b-a35b/1_test_qwen3_480b_a35b.sh
@@ -42,7 +42,7 @@ echo "Against original HF model: ${HF_MODEL_PATH}"
# This command runs the core validation logic.
JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \
tokenizer_type=huggingface \
- tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/qwen3-tokenizer \
+ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/qwen3-tokenizer \
megablox=False \
sparse_matmul=False \
load_parameters_path=${MAXTEXT_CHECKPOINT_PATH} \
diff --git a/end_to_end/tpu/qwen/moe/run_qwen_moe.md b/end_to_end/tpu/qwen/moe/run_qwen_moe.md
index 9bdcbd2ab7..d5f3bcb83c 100644
--- a/end_to_end/tpu/qwen/moe/run_qwen_moe.md
+++ b/end_to_end/tpu/qwen/moe/run_qwen_moe.md
@@ -55,7 +55,7 @@ python3 -m MaxText.train src/MaxText/configs/base.yml\
max_target_length=8192\
ici_fsdp_parallelism=256\
tokenizer_type=huggingface\
- tokenizer_path=src/MaxText/assets/qwen3-tokenizer
+ tokenizer_path=src/maxtext/assets/tokenizers/qwen3-tokenizer
```
@@ -70,7 +70,7 @@ To generate text with a trained model, use the `decode` command. The command bel
python3 -m MaxText.decode src/MaxText/configs/base.yml\
load_parameters_path=gs://your-gcs-bucket/qwen3_maxtext_ckpt/0/items\
tokenizer_type=huggingface\
- tokenizer_path=src/MaxText/assets/qwen3-tokenizer\
+ tokenizer_path=src/maxtext/assets/tokenizers/qwen3-tokenizer\
prompt="Today is a beautiful day to"\
model_name=\
per_device_batch_size=1\
diff --git a/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh b/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh
index 79e4650749..fab873a4dd 100644
--- a/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh
+++ b/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh
@@ -42,7 +42,7 @@ echo "Against original HF model: ${HF_MODEL_PATH}"
# This command runs the core validation logic.
JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \
tokenizer_type=huggingface \
- tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/qwen3-tokenizer \
+ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/qwen3-tokenizer \
megablox=False \
sparse_matmul=False \
load_parameters_path=${MAXTEXT_CHECKPOINT_PATH} \
diff --git a/end_to_end/tpu/qwen/next/run_qwen3_next.md b/end_to_end/tpu/qwen/next/run_qwen3_next.md
index 61612c3b31..a12a1bbbdf 100644
--- a/end_to_end/tpu/qwen/next/run_qwen3_next.md
+++ b/end_to_end/tpu/qwen/next/run_qwen3_next.md
@@ -47,7 +47,7 @@ python3 -m MaxText.train src/MaxText/configs/base.yml \
max_target_length=8192 \
ici_fsdp_parallelism=256 \
tokenizer_type=huggingface \
- tokenizer_path=src/MaxText/assets/qwen3-tokenizer
+ tokenizer_path=src/maxtext/assets/tokenizers/qwen3-tokenizer
```
diff --git a/end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh b/end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh
index cef32e0321..9c36b34c9c 100644
--- a/end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh
+++ b/end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh
@@ -44,7 +44,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_PKG_DIR:-${MA
# We also test whether the forward pass logits match the original HF model
# to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \
- tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/qwen3-tokenizer \
+ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/qwen3-tokenizer \
load_parameters_path=${CKPT_PATH} \
model_name=${MODEL_NAME} \
scan_layers=false \
diff --git a/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh b/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh
index ff2a34396c..b2ec20858b 100644
--- a/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh
+++ b/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh
@@ -17,7 +17,7 @@ idx=$(date +%Y-%m-%d-%H-%M)
MODEL_NAME='qwen3-4b'
export MODEL_VARIATION='4b'
HF_GOLDEN_MODEL='Qwen/Qwen3-4B'
-TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"'/qwen3-tokenizer'
+TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"'/qwen3-tokenizer'
# Installing torch for deps in forward_pass_logit_checker.py
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
diff --git a/end_to_end/tpu/test_decode_load_quantized_ckpt.sh b/end_to_end/tpu/test_decode_load_quantized_ckpt.sh
index f0b03047c4..8abfc5b480 100644
--- a/end_to_end/tpu/test_decode_load_quantized_ckpt.sh
+++ b/end_to_end/tpu/test_decode_load_quantized_ckpt.sh
@@ -24,7 +24,7 @@ else
cmd=''
fi
-export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2
+export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2
export MAX_PREFILL_PREDICT_LENGTH=128
export MAX_TARGET_LENGTH=256
export MODEL_NAME=${model}
diff --git a/end_to_end/tpu/test_decode_save_quantized_ckpt.sh b/end_to_end/tpu/test_decode_save_quantized_ckpt.sh
index a6afc59c9b..5cc91c8807 100644
--- a/end_to_end/tpu/test_decode_save_quantized_ckpt.sh
+++ b/end_to_end/tpu/test_decode_save_quantized_ckpt.sh
@@ -31,7 +31,7 @@ if [ "$model" = "llama2-70b" ]; then
fi
export MODEL_NAME=${model}
-export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2
+export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2
export LOAD_PARAMETERS_PATH=gs://inference-benchmarks/models/${MODEL_NAME}-chat/${checkpoint_ts}/param-only-decode-ckpt-maxtext/checkpoints/0/items
export MAX_PREFILL_PREDICT_LENGTH=128
export MAX_TARGET_LENGTH=256
diff --git a/end_to_end/tpu/test_dpo.sh b/end_to_end/tpu/test_dpo.sh
index fdd6a523ac..1fdd9c17ea 100644
--- a/end_to_end/tpu/test_dpo.sh
+++ b/end_to_end/tpu/test_dpo.sh
@@ -9,7 +9,7 @@ export GEMMA_2B_CKPT_PATH=$(gcloud storage ls gs://maxtext-gemma/gemma2/2b | sor
LOGS="gs://maxtext-external/logs"
# tfds pipeline
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/dpo.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma \
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/dpo.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma \
run_name="$RUN_NAME-tfds" model_name=gemma2-2b base_output_directory=${LOGS} \
load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \
per_device_batch_size=0.5 allow_split_physical_axes=True \
@@ -18,7 +18,7 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT
# grain pipeline
mkdir -p /tmp/anthropic_rlhf || true
gcloud storage cp -r gs://maxtext-dataset/dpo/anthropic_rlhf/array_record /tmp/anthropic_rlhf
-python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/dpo.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma \
+python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/dpo.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma \
run_name="$RUN_NAME-grain" model_name=gemma2-2b base_output_directory=${LOGS} \
load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \
dataset_type=grain grain_worker_count=16 \
diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml
index e06ec42033..7d97495c98 100644
--- a/src/MaxText/configs/base.yml
+++ b/src/MaxText/configs/base.yml
@@ -526,7 +526,7 @@ num_vocab_tiling: 1
# Tokenizer
vocab_size: 32_000 # powers of 2 for sharding
-tokenizer_path: "src/MaxText/assets/tokenizer.llama2"
+tokenizer_path: "src/maxtext/assets/tokenizers/tokenizer.llama2"
# tfds pipeline supports tokenizer_type: sentencepiece, huggingface, tiktoken
# grain pipeline supports tokenizer_type: sentencepiece, huggingface
# hf pipeline only supports huggingface type, and will ignore tokenizer_type flag
diff --git a/src/MaxText/configs/models/gpu/mixtral_8x7b.yml b/src/MaxText/configs/models/gpu/mixtral_8x7b.yml
index 5a08ffd38c..5fa58f066f 100644
--- a/src/MaxText/configs/models/gpu/mixtral_8x7b.yml
+++ b/src/MaxText/configs/models/gpu/mixtral_8x7b.yml
@@ -30,7 +30,7 @@ reuse_example_batch: 1
enable_checkpointing: False
megablox: False
scan_layers: False
-tokenizer_path: "/deps/src/MaxText/assets/tokenizer.mistral-v1"
+tokenizer_path: "/deps/src/maxtext/assets/tokenizers/tokenizer.mistral-v1"
profiler: "nsys"
capacity_factor: 1.0
max_segments_per_seq: 32
diff --git a/src/MaxText/configs/models/mixtral-8x22b.yml b/src/MaxText/configs/models/mixtral-8x22b.yml
index 31a2fdeacd..0d040bf48a 100644
--- a/src/MaxText/configs/models/mixtral-8x22b.yml
+++ b/src/MaxText/configs/models/mixtral-8x22b.yml
@@ -13,7 +13,7 @@
# limitations under the License.
# model config for mixtral-8x22b
-# tokenizer_path is assets/tokenizer.mistral-v3
+# tokenizer_path is assets/tokenizers/tokenizer.mistral-v3
base_emb_dim: 6144
base_num_query_heads: 48
diff --git a/src/MaxText/configs/models/mixtral-8x7b.yml b/src/MaxText/configs/models/mixtral-8x7b.yml
index c45031b5f8..91a7ab50bc 100644
--- a/src/MaxText/configs/models/mixtral-8x7b.yml
+++ b/src/MaxText/configs/models/mixtral-8x7b.yml
@@ -13,7 +13,7 @@
# limitations under the License.
# model config for mixtral-8x7b
-# tokenizer_path is assets/tokenizer.mistral-v1
+# tokenizer_path is assets/tokenizers/tokenizer.mistral-v1
base_emb_dim: 4096
base_num_query_heads: 32
diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py
index c1f109eccc..97344f2534 100644
--- a/src/MaxText/configs/types.py
+++ b/src/MaxText/configs/types.py
@@ -870,7 +870,7 @@ class Tokenizer(BaseModel):
vocab_size: int = Field(32_000, description="The size of the vocabulary.")
tokenizer_path: PathStr = Field(
- os.path.join("assets", "tokenizer.llama2"),
+ os.path.join("assets", "tokenizers", "tokenizer.llama2"),
description="Path to the tokenizer model file.",
)
tokenizer_type: TokenizerType = Field(TokenizerType.SENTENCEPIECE, description="The type of tokenizer.")
@@ -1831,8 +1831,8 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
filter(
os.path.exists,
(
- os.path.join(MAXTEXT_ASSETS_ROOT, os.path.basename(tokenizer_path)),
- os.path.join(MAXTEXT_ASSETS_ROOT, tokenizer_path),
+ os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", os.path.basename(tokenizer_path)),
+ os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", tokenizer_path),
),
),
tokenizer_path,
diff --git a/src/MaxText/configs/v5e/llama2_13b.sh b/src/MaxText/configs/v5e/llama2_13b.sh
index 50e7bc7f60..0604d97220 100644
--- a/src/MaxText/configs/v5e/llama2_13b.sh
+++ b/src/MaxText/configs/v5e/llama2_13b.sh
@@ -43,5 +43,5 @@ export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xl
python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=llama2-13b\
base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\
- tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 per_device_batch_size=8 remat_policy=qkv_proj_offloaded\
+ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 per_device_batch_size=8 remat_policy=qkv_proj_offloaded\
steps=15 enable_checkpointing=false use_iota_embed=true
diff --git a/src/MaxText/configs/v5e/llama2_70b.sh b/src/MaxText/configs/v5e/llama2_70b.sh
index d470b5d051..bf2cb73d62 100644
--- a/src/MaxText/configs/v5e/llama2_70b.sh
+++ b/src/MaxText/configs/v5e/llama2_70b.sh
@@ -43,5 +43,5 @@ export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xl
python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=llama2-70b\
base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\
- tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 per_device_batch_size=2 remat_policy=qkv_proj_offloaded\
+ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 per_device_batch_size=2 remat_policy=qkv_proj_offloaded\
steps=15 enable_checkpointing=false use_iota_embed=true
diff --git a/src/MaxText/configs/v5e/llama2_7b.sh b/src/MaxText/configs/v5e/llama2_7b.sh
index 72852a8d96..3fa110e03f 100644
--- a/src/MaxText/configs/v5e/llama2_7b.sh
+++ b/src/MaxText/configs/v5e/llama2_7b.sh
@@ -43,5 +43,5 @@ export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xl
python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=llama2-7b\
base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\
- tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 per_device_batch_size=4 remat_policy=save_qkv_proj\
+ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 per_device_batch_size=4 remat_policy=save_qkv_proj\
steps=15 enable_checkpointing=false use_iota_embed=true
\ No newline at end of file
diff --git a/src/MaxText/configs/v5p/llama2_70b.sh b/src/MaxText/configs/v5p/llama2_70b.sh
index 99878b3c3f..bf69b52cef 100644
--- a/src/MaxText/configs/v5p/llama2_70b.sh
+++ b/src/MaxText/configs/v5p/llama2_70b.sh
@@ -46,7 +46,7 @@ export LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion_fuse_all_gathe
python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=llama2-70b\
base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\
- tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 remat_policy=save_dot_except_mlpwi per_device_batch_size=4\
+ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 remat_policy=save_dot_except_mlpwi per_device_batch_size=4\
steps=30 enable_checkpointing=false use_iota_embed=true max_target_length=4096\
profiler=xplane skip_first_n_steps_for_profiler=10 profiler_steps=5 gcs_metrics=true\
dataset_type=$DATASET_TYPE reuse_example_batch=$REUSE_EXAMPLE_BATCH
diff --git a/src/MaxText/configs/v5p/llama2_7b.sh b/src/MaxText/configs/v5p/llama2_7b.sh
index edca887e4b..8d3e0a9206 100644
--- a/src/MaxText/configs/v5p/llama2_7b.sh
+++ b/src/MaxText/configs/v5p/llama2_7b.sh
@@ -46,7 +46,7 @@ fi
export LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=llama2-7b\
base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\
- tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 remat_policy=minimal per_device_batch_size=4\
+ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 remat_policy=minimal per_device_batch_size=4\
steps=30 enable_checkpointing=false use_iota_embed=true max_target_length=4096\
profiler=xplane skip_first_n_steps_for_profiler=10 profiler_steps=5 gcs_metrics=true\
dataset_type=$DATASET_TYPE reuse_example_batch=$REUSE_EXAMPLE_BATCH
diff --git a/src/MaxText/examples/demo_decoding.ipynb b/src/MaxText/examples/demo_decoding.ipynb
index 9b913318e9..5103eeedfc 100644
--- a/src/MaxText/examples/demo_decoding.ipynb
+++ b/src/MaxText/examples/demo_decoding.ipynb
@@ -146,6 +146,8 @@
"from huggingface_hub import login\n",
"\n",
"MAXTEXT_PKG_DIR = os.path.dirname(mt.__file__)\n",
+ "MAXTEXT_REPO_ROOT = os.path.dirname(os.path.dirname(MAXTEXT_PKG_DIR))\n",
+ "MAXTEXT_ASSETS_ROOT = os.path.join(MAXTEXT_REPO_ROOT, \"src\", \"maxtext\", \"assets\")\n",
"\n",
"nest_asyncio.apply()"
]
@@ -255,7 +257,7 @@
" run_name=\"test\",\n",
" max_target_length=4,\n",
" max_prefill_predict_length=4,\n",
- " tokenizer_path=f\"{MAXTEXT_PKG_DIR}/assets/qwen3-tokenizer\",\n",
+ " tokenizer_path=f\"{MAXTEXT_ASSETS_ROOT}/tokenizers/qwen3-tokenizer\",\n",
" load_parameters_path=f\"{MODEL_CHECKPOINT_PATH}/0/items\",\n",
" model_name=MODEL_NAME,\n",
" async_checkpointing=False,\n",
@@ -312,7 +314,7 @@
"outputs": [],
"source": [
"tokenizer = _input_pipeline_utils.get_tokenizer(\n",
- " f\"{MAXTEXT_PKG_DIR}/assets/qwen3-tokenizer\",\n",
+ " f\"{MAXTEXT_ASSETS_ROOT}/tokenizers/qwen3-tokenizer\",\n",
" \"huggingface\",\n",
" add_bos=True,\n",
" add_eos=False,\n",
@@ -416,7 +418,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
+ "display_name": "Python 3",
"language": "python",
"name": "python3"
},
@@ -430,7 +432,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.11"
+ "version": "3.10.12"
}
},
"nbformat": 4,
diff --git a/src/MaxText/examples/multimodal_gemma3_demo.ipynb b/src/MaxText/examples/multimodal_gemma3_demo.ipynb
index a664fe576b..b2453f9515 100644
--- a/src/MaxText/examples/multimodal_gemma3_demo.ipynb
+++ b/src/MaxText/examples/multimodal_gemma3_demo.ipynb
@@ -70,7 +70,10 @@
"import MaxText\n",
"\n",
"# Get the root directory of the MaxText\n",
- "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n",
+ "MAXTEXT_PKG_DIR = os.path.dirname(MaxText.__file__)\n",
+ "MAXTEXT_REPO_ROOT = os.path.dirname(os.path.dirname(MAXTEXT_PKG_DIR))\n",
+ "MAXTEXT_ASSETS_ROOT = os.path.join(MAXTEXT_REPO_ROOT, \"src\", \"maxtext\", \"assets\")\n",
+ "\n",
"\n",
"# Define model name\n",
"MODEL_NAME = \"gemma3-4b\"\n",
@@ -96,7 +99,7 @@
"outputs": [],
"source": [
"!python3 -m MaxText.utils.ckpt_conversion.to_maxtext \\\n",
- " $MAXTEXT_REPO_ROOT/configs/base.yml \\\n",
+ " $MAXTEXT_PKG_DIR/configs/base.yml \\\n",
" model_name=$MODEL_NAME \\\n",
" hf_access_token=$HF_TOKEN \\\n",
" base_output_directory=$MODEL_CHECKPOINT_PATH \\\n",
@@ -118,9 +121,9 @@
"outputs": [],
"source": [
"!python -m MaxText.decode \\\n",
- " $MAXTEXT_REPO_ROOT/configs/base.yml \\\n",
+ " $MAXTEXT_PKG_DIR/configs/base.yml \\\n",
" model_name=$MODEL_NAME \\\n",
- " tokenizer_path=assets/tokenizer.gemma3 \\\n",
+ " tokenizer_path=$MAXTEXT_ASSETS_ROOT/tokenizers/tokenizer.gemma3 \\\n",
" load_parameters_path=$MODEL_CHECKPOINT_PATH/0/items \\\n",
" per_device_batch_size=1 \\\n",
" run_name=ht_test max_prefill_predict_length=272 \\\n",
@@ -130,7 +133,7 @@
" scan_layers=false \\\n",
" use_multimodal=true \\\n",
" prompt='Describe image ' \\\n",
- " image_path=$MAXTEXT_REPO_ROOT/tests/assets/test_image.jpg \\\n",
+ " image_path=$MAXTEXT_PKG_DIR/tests/assets/test_image.jpg \\\n",
" attention='dot_product'"
]
},
@@ -162,7 +165,7 @@
"PER_DEVICE_BATCH_SIZE=1\n",
"\n",
"!python -m MaxText.sft_trainer \\\n",
- " $MAXTEXT_REPO_ROOT/configs/sft-vision-chartqa.yml \\\n",
+ " $MAXTEXT_PKG_DIR/configs/sft-vision-chartqa.yml \\\n",
" run_name=$WORKLOAD_NAME \\\n",
" model_name=$MODEL_NAME \\\n",
" tokenizer_path=$PRE_TRAINED_MODEL_TOKENIZER \\\n",
diff --git a/src/MaxText/experimental/agent/ckpt_conversion_agent/README.md b/src/MaxText/experimental/agent/ckpt_conversion_agent/README.md
index 42633bc99b..90f8eaf470 100644
--- a/src/MaxText/experimental/agent/ckpt_conversion_agent/README.md
+++ b/src/MaxText/experimental/agent/ckpt_conversion_agent/README.md
@@ -66,7 +66,7 @@ If a ground-truth version isn't available, you'll need to debug the conversion m
3. After the conversion is done, run a decode to check the correctness of the generated code.
Example command:
```bash
-python3 -m MaxText.decode src/MaxText/configs/base.yml model_name=gemma3-4b tokenizer_path=assets/tokenizer.gemma3 \
+python3 -m MaxText.decode src/MaxText/configs/base.yml model_name=gemma3-4b tokenizer_path=src/maxtext/assets/tokenizers/tokenizer.gemma3 \
load_parameters_path= per_device_batch_size=1 run_name=ht_test \
max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=true \
prompt='I love to' attention='dot_product'
@@ -76,7 +76,7 @@ If outputs are wrong, you can use jax.debug.print() to print the layer-wise mean
4. To further validate the converted checkpoint, we recommend to use the [forward_pass_logit_checker.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/README.md#verifying-conversion-correctness) to compare the original ckpt with the converted ckpt:
```bash
python3 -m tests.utils.forward_pass_logit_checker src/MaxText/configs/base.yml \
- tokenizer_path=assets/ \
+ tokenizer_path=assets/tokenizers/ \
load_parameters_path= \
model_name= \
scan_layers=false \
diff --git a/src/MaxText/globals.py b/src/MaxText/globals.py
index 3301f30176..547a1ae964 100644
--- a/src/MaxText/globals.py
+++ b/src/MaxText/globals.py
@@ -27,8 +27,8 @@
else MAXTEXT_PKG_DIR,
)
-# This is the assets root: with "tokenizer.gemma3"; &etc.
-MAXTEXT_ASSETS_ROOT = os.environ.get("MAXTEXT_ASSETS_ROOT", os.path.join(MAXTEXT_PKG_DIR, "assets"))
+# This is the assets root: with "tokenizers/"; &etc.
+MAXTEXT_ASSETS_ROOT = os.environ.get("MAXTEXT_ASSETS_ROOT", os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "assets"))
# This is the test assets root: with "golden_logits"; &etc.
MAXTEXT_TEST_ASSETS_ROOT = os.environ.get("MAXTEXT_TEST_ASSETS_ROOT", os.path.join(MAXTEXT_REPO_ROOT, "tests", "assets"))
diff --git a/src/MaxText/inference/gpu/microbenchmark_llama2-70b_h100-8.sh b/src/MaxText/inference/gpu/microbenchmark_llama2-70b_h100-8.sh
index d124b849de..095bbbc6e4 100755
--- a/src/MaxText/inference/gpu/microbenchmark_llama2-70b_h100-8.sh
+++ b/src/MaxText/inference/gpu/microbenchmark_llama2-70b_h100-8.sh
@@ -104,7 +104,7 @@ TF_FORCE_GPU_ALLOW_GROWTH=true \
XLA_PYTHON_CLIENT_MEM_FRACTION=0.94 \
python3 -m MaxText.inference_microbenchmark $MAXENGINE_CONFIG_FILEPATH \
base_output_directory=$BASE_OUTPUT_DIRECTORY \
- tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 \
+ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 \
model_name='llama2-70b' \
max_prefill_predict_length=$max_prefill_predict_length \
max_target_length=2048 \
diff --git a/src/MaxText/inference_mlperf/README.md b/src/MaxText/inference_mlperf/README.md
index 8731913fd1..f4d2566632 100644
--- a/src/MaxText/inference_mlperf/README.md
+++ b/src/MaxText/inference_mlperf/README.md
@@ -97,8 +97,8 @@ export SAVE_QUANT_PARAMS_PATH=gs://${USER}-bkt/quantized/llama2-70b-chat
```sh
# Set appropriate tokenizer path. For example, LLama2 models tokenizer.llama2. You can find
-# other tokenizers under src/MaxText/assets/ directory.
-export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"'/tokenizer.llama2'
+# other tokenizers under src/maxtext/assets/tokenizers directory.
+export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"'/tokenizer.llama2'
cd maxtext && \
python3 -m MaxText.decode src/MaxText/configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-70b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH}
```
@@ -120,7 +120,7 @@ export SAVE_QUANT_PARAMS_PATH=gs://${USER}-bkt/quantized/llama3.1-405b
2. Run the following maxtext script to generate and save an int8 quantized checkpoint
```sh
-export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken
+export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken
export MODEL_SIZE=llama3.1-405b
export QUANTIZE_TYPE=int8
diff --git a/src/MaxText/inference_mlperf/gpu/benchmarks_llama2-70b-h100_8.sh b/src/MaxText/inference_mlperf/gpu/benchmarks_llama2-70b-h100_8.sh
index 06d49143c2..c350647efa 100755
--- a/src/MaxText/inference_mlperf/gpu/benchmarks_llama2-70b-h100_8.sh
+++ b/src/MaxText/inference_mlperf/gpu/benchmarks_llama2-70b-h100_8.sh
@@ -81,7 +81,7 @@ if [[ -z ${CHECKPOINT} ]] ; then
fi
if [[ -z ${TOKENIZER_PATH} ]] ; then
- export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}/tokenizer.llama2"
+ export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}/tokenizer.llama2"
fi
if [ -z "$PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES" ];
diff --git a/src/MaxText/inference_mlperf/llama_offline_run.sh b/src/MaxText/inference_mlperf/llama_offline_run.sh
index 26b5be29ea..5a777144d1 100755
--- a/src/MaxText/inference_mlperf/llama_offline_run.sh
+++ b/src/MaxText/inference_mlperf/llama_offline_run.sh
@@ -59,7 +59,7 @@ if "$enable_batch_prefill"; then
fi
if [ -z "$TOKENIZER_PATH" ]; then
- TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2
+ TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2
fi
BATCH_STR=""
diff --git a/src/MaxText/inference_mlperf/mixtral_offline_run.sh b/src/MaxText/inference_mlperf/mixtral_offline_run.sh
index 9af36486a2..993a4d1144 100755
--- a/src/MaxText/inference_mlperf/mixtral_offline_run.sh
+++ b/src/MaxText/inference_mlperf/mixtral_offline_run.sh
@@ -52,7 +52,7 @@ if "$enable_profiler"; then
fi
if [ -z "$TOKENIZER_PATH" ]; then
- TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1
+ TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1
fi
BATCH_STR=""
diff --git a/src/MaxText/inference_mlperf/trillium/benchmarks_llama2-70b-trillium_2x4.sh b/src/MaxText/inference_mlperf/trillium/benchmarks_llama2-70b-trillium_2x4.sh
index 2a98616e20..24727b8eab 100644
--- a/src/MaxText/inference_mlperf/trillium/benchmarks_llama2-70b-trillium_2x4.sh
+++ b/src/MaxText/inference_mlperf/trillium/benchmarks_llama2-70b-trillium_2x4.sh
@@ -86,7 +86,7 @@ if [[ -z ${CHECKPOINT} ]] ; then
fi
if [[ -z ${TOKENIZER_PATH} ]] ; then
- export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}/tokenizer.llama2" # NOTE: you may need to change this path for your VM
+ export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}/tokenizer.llama2" # NOTE: you may need to change this path for your VM
fi
if [ -z "$PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES" ];
diff --git a/src/MaxText/inference_mlperf/trillium/microbenchmarks_llama2-70b-trillium_2x4.sh b/src/MaxText/inference_mlperf/trillium/microbenchmarks_llama2-70b-trillium_2x4.sh
index babf36ed62..39160c7889 100644
--- a/src/MaxText/inference_mlperf/trillium/microbenchmarks_llama2-70b-trillium_2x4.sh
+++ b/src/MaxText/inference_mlperf/trillium/microbenchmarks_llama2-70b-trillium_2x4.sh
@@ -57,7 +57,7 @@ echo
echo "LIBTPU_INIT_ARGS:${LIBTPU_INIT_ARGS}"
echo "XLA_FLAGS:${XLA_FLAGS}"
echo
-export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2
+export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2
export LOAD_PARAMETERS_PATH=gs://${USER}-bkt/checkpoints/quant_llama2-70b-chat/prod/int8_
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
diff --git a/src/MaxText/scratch_code/demo_from_config.ipynb b/src/MaxText/scratch_code/demo_from_config.ipynb
index d0f22a5dea..044bad8eb7 100644
--- a/src/MaxText/scratch_code/demo_from_config.ipynb
+++ b/src/MaxText/scratch_code/demo_from_config.ipynb
@@ -490,7 +490,7 @@
" base_num_query_heads=2,\n",
" base_num_kv_heads=2,\n",
" max_prefill_predict_length=4,\n",
- " # tokenizer_path=\"assets/llama3.1-tokenizer/\",\n",
+ " # tokenizer_path=\"assets/tokenizers/llama3.1-tokenizer/\",\n",
" # model_name=\"llama3.1-7b\",\n",
")\n",
"\n",
@@ -521,7 +521,7 @@
"from MaxText.globals import MAXTEXT_ASSETS_ROOT\n",
"\n",
"source_tokenizer = _input_pipeline_utils.get_tokenizer(\n",
- " os.path.join(MAXTEXT_ASSETS_ROOT, \"tokenizer_llama3.tiktoken\"),\n",
+ " os.path.join(MAXTEXT_ASSETS_ROOT, \"tokenizers\", \"tokenizer_llama3.tiktoken\"),\n",
" \"tiktoken\",\n",
" add_bos=True,\n",
" add_eos=False,\n",
diff --git a/src/MaxText/scratch_code/gemma_7b.sh b/src/MaxText/scratch_code/gemma_7b.sh
index 7fd4bf9e0d..3de03118d2 100644
--- a/src/MaxText/scratch_code/gemma_7b.sh
+++ b/src/MaxText/scratch_code/gemma_7b.sh
@@ -3,6 +3,6 @@ export M_PER_DEVICE_BATCH_SIZE=24
export M_MAX_PREFILL_PREDICT_LENGTH=1024
export M_MAX_TARGET_LENGTH=2048
-#python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false
+#python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false
-python3 -m MaxText.maxengine_server "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false
+python3 -m MaxText.maxengine_server "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false
\ No newline at end of file
diff --git a/src/MaxText/scratch_code/run_inference_microbenchmark.sh b/src/MaxText/scratch_code/run_inference_microbenchmark.sh
index 15cfcc8cab..8f2502864a 100644
--- a/src/MaxText/scratch_code/run_inference_microbenchmark.sh
+++ b/src/MaxText/scratch_code/run_inference_microbenchmark.sh
@@ -15,4 +15,4 @@ steps=10 \
scan_layers=false \
model_name=llama2-7b \
weight_dtype=bfloat16 \
-tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2
+tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2
diff --git a/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_hf.sh b/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_hf.sh
index bbe169f83d..e23f253ad4 100644
--- a/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_hf.sh
+++ b/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_hf.sh
@@ -10,7 +10,7 @@ DATE=$(date +%Y-%m-%d)
HF_CHECKPOINT_GCS_PATH="gs://maxtext-model-checkpoints/HuggingFace/gemma2-2b/${DATE}" # (optional)GCS path for HF model
MAXTEXT_CHECKPOINT_DIR="gs://maxtext-model-checkpoints/gemma2-2b-it/2025-02-20-18-01/unscanned/checkpoints/0/items"
LOCAL_HF_CHECKPOINT_DIR="/tmp/hf_gemma2-2b_output" # HF requires a local dir
-TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}/tokenizer.gemma"
+TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}/tokenizer.gemma"
MODEL_NAME="gemma2-2b"
PER_DEVICE_BATCH_SIZE=1
SCAN_LAYERS=false
@@ -22,7 +22,7 @@ echo "Starting Hugging Face model conversion for gemma2-2b..."
python3 -m MaxText.utils.ckpt_conversion.to_huggingface \
"${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}configs/base.yml" \
model_name="${MODEL_NAME}" \
- tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}/tokenizer.gemma" \
+ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}/tokenizer.gemma" \
load_parameters_path="${MAXTEXT_CHECKPOINT_DIR}" \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
max_prefill_predict_length=8 \
diff --git a/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_mt.sh b/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_mt.sh
index 03a38b650b..cf767e6ac9 100644
--- a/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_mt.sh
+++ b/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_mt.sh
@@ -11,7 +11,7 @@ MODEL_NAME="gemma2-2b"
# HF model id as golden model for verification
HF_MODEL_ID="google/gemma-2-2b-it"
# Tokenizer path for decoding
-TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}/tokenizer.gemma"
+TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}/tokenizer.gemma"
PER_DEVICE_BATCH_SIZE=1
ASYNC_CHECKPOINTING=false
diff --git a/src/MaxText/utils/ckpt_conversion/examples/convert_gemma3_to_hf.sh b/src/MaxText/utils/ckpt_conversion/examples/convert_gemma3_to_hf.sh
index 03b5eed175..e00ba56974 100644
--- a/src/MaxText/utils/ckpt_conversion/examples/convert_gemma3_to_hf.sh
+++ b/src/MaxText/utils/ckpt_conversion/examples/convert_gemma3_to_hf.sh
@@ -10,7 +10,7 @@ DATE=$(date +%Y-%m-%d)
HF_CHECKPOINT_GCS_PATH="gs://maxtext-model-checkpoints/HuggingFace/gemma3-4b/${DATE}" # (optional)GCS path for HF model
MAXTEXT_CHECKPOINT_DIR="gs://maxtext-model-checkpoints/gemma3-4b/2025-03-18-19-03/unscanned/checkpoints/0/items"
LOCAL_HF_CHECKPOINT_DIR="/tmp/hf_gemma3-4b_output" # HF requires a local dir
-TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}/tokenizer.gemma3"
+TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}/tokenizer.gemma3"
MODEL_NAME="gemma3-4b"
PER_DEVICE_BATCH_SIZE=1
SCAN_LAYERS=false
@@ -21,7 +21,7 @@ echo "Starting Hugging Face model conversion for gemma3-4b..."
python3 -m MaxText.utils.ckpt_conversion.to_huggingface \
"${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml" \
model_name="gemma3-4b" \
- tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}/tokenizer.gemma3" \
+ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}/tokenizer.gemma3" \
load_parameters_path="${MAXTEXT_CHECKPOINT_DIR}" \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
run_name="ht_test" \
diff --git a/src/MaxText/assets/qwen3-tokenizer/tokenizer.json b/src/maxtext/assets/tokenizers/qwen3-tokenizer/tokenizer.json
similarity index 100%
rename from src/MaxText/assets/qwen3-tokenizer/tokenizer.json
rename to src/maxtext/assets/tokenizers/qwen3-tokenizer/tokenizer.json
diff --git a/src/MaxText/assets/qwen3-tokenizer/tokenizer_config.json b/src/maxtext/assets/tokenizers/qwen3-tokenizer/tokenizer_config.json
similarity index 100%
rename from src/MaxText/assets/qwen3-tokenizer/tokenizer_config.json
rename to src/maxtext/assets/tokenizers/qwen3-tokenizer/tokenizer_config.json
diff --git a/src/MaxText/assets/tokenizer b/src/maxtext/assets/tokenizers/tokenizer.default
similarity index 100%
rename from src/MaxText/assets/tokenizer
rename to src/maxtext/assets/tokenizers/tokenizer.default
diff --git a/src/MaxText/assets/tokenizer.gemma b/src/maxtext/assets/tokenizers/tokenizer.gemma
similarity index 100%
rename from src/MaxText/assets/tokenizer.gemma
rename to src/maxtext/assets/tokenizers/tokenizer.gemma
diff --git a/src/MaxText/assets/tokenizer.gemma3 b/src/maxtext/assets/tokenizers/tokenizer.gemma3
similarity index 100%
rename from src/MaxText/assets/tokenizer.gemma3
rename to src/maxtext/assets/tokenizers/tokenizer.gemma3
diff --git a/src/MaxText/assets/tokenizer.llama2 b/src/maxtext/assets/tokenizers/tokenizer.llama2
similarity index 100%
rename from src/MaxText/assets/tokenizer.llama2
rename to src/maxtext/assets/tokenizers/tokenizer.llama2
diff --git a/src/MaxText/assets/tokenizer.mistral-v1 b/src/maxtext/assets/tokenizers/tokenizer.mistral-v1
similarity index 100%
rename from src/MaxText/assets/tokenizer.mistral-v1
rename to src/maxtext/assets/tokenizers/tokenizer.mistral-v1
diff --git a/src/MaxText/assets/tokenizer.mistral-v3 b/src/maxtext/assets/tokenizers/tokenizer.mistral-v3
similarity index 100%
rename from src/MaxText/assets/tokenizer.mistral-v3
rename to src/maxtext/assets/tokenizers/tokenizer.mistral-v3
diff --git a/src/MaxText/assets/tokenizer_llama3.tiktoken b/src/maxtext/assets/tokenizers/tokenizer_llama3.tiktoken
similarity index 100%
rename from src/MaxText/assets/tokenizer_llama3.tiktoken
rename to src/maxtext/assets/tokenizers/tokenizer_llama3.tiktoken
diff --git a/tests/inference/test_llama2_7b_bf16.sh b/tests/inference/test_llama2_7b_bf16.sh
index 672611932c..ce7d78aa29 100755
--- a/tests/inference/test_llama2_7b_bf16.sh
+++ b/tests/inference/test_llama2_7b_bf16.sh
@@ -5,7 +5,7 @@ args=(
"-m"
"MaxText.decode"
"${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml"
- "tokenizer_path=${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}/tokenizer.llama2"
+ "tokenizer_path=${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}/tokenizer.llama2"
"model_name=llama2-7b"
"load_parameters_path=gs://runner-maxtext-logs/direct_generate_param_only_checkpoint_2024-06-11-04-13/checkpoints/0/items/"
"checkpoint_is_quantized=false"
diff --git a/tests/inference/test_llama2_7b_int8.sh b/tests/inference/test_llama2_7b_int8.sh
index 50aa2c0dc9..ea3f8ebd89 100755
--- a/tests/inference/test_llama2_7b_int8.sh
+++ b/tests/inference/test_llama2_7b_int8.sh
@@ -5,7 +5,7 @@ args=(
"-m"
"MaxText.decode"
"${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml"
- "tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2"
+ "tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2"
"model_name=llama2-7b"
"load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_"
"checkpoint_is_quantized=true"
diff --git a/tests/integration/decode_tests.py b/tests/integration/decode_tests.py
index 3cb61b83f3..a70e3286f3 100644
--- a/tests/integration/decode_tests.py
+++ b/tests/integration/decode_tests.py
@@ -45,7 +45,7 @@ class DecodeTests(unittest.TestCase):
"ici_tensor_parallelism=4",
"max_target_length=128",
"per_device_batch_size=1",
- rf"tokenizer_path={os.path.join('src', MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
],
"int8": [ # tests decode with int8 quantization
None,
@@ -60,7 +60,7 @@ class DecodeTests(unittest.TestCase):
"per_device_batch_size=1",
"quantization=int8",
"quantize_kvcache=True",
- rf"tokenizer_path={os.path.join('src', MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
],
"pdb_lt_1": [ # tests decode with per_device_batch_size < 1
None,
@@ -73,7 +73,7 @@ class DecodeTests(unittest.TestCase):
"ici_tensor_parallelism=4",
"max_target_length=128",
"per_device_batch_size=.25",
- rf"tokenizer_path={os.path.join('src', MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
],
"decode_sampling": [
None,
@@ -88,7 +88,7 @@ class DecodeTests(unittest.TestCase):
"steps=10",
"async_checkpointing=False",
"model_name=gemma-2b",
- rf"tokenizer_path={os.path.join('src', MAXTEXT_ASSETS_ROOT, 'tokenizer.gemma')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.gemma')}",
"attention=dot_product",
"prompt=I love to",
"skip_jax_distributed_system=True",
diff --git a/tests/integration/generate_param_only_checkpoint_test.py b/tests/integration/generate_param_only_checkpoint_test.py
index 72777ed3ce..4c8bbdb3ff 100644
--- a/tests/integration/generate_param_only_checkpoint_test.py
+++ b/tests/integration/generate_param_only_checkpoint_test.py
@@ -127,7 +127,7 @@ def test_param_ckpt_generation_with_pre_generated_ckpt(capsys):
"""Tests the parameter-only checkpoint generation and decode flow with a pre-generated Gemma-2b model checkpoint."""
model_config = [
"model_name=gemma-2b",
- f"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.gemma')}",
+ f"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.gemma')}",
]
run_e2e_test_flow(
hardware="tpu",
diff --git a/tests/integration/gradient_accumulation_test.py b/tests/integration/gradient_accumulation_test.py
index 0fca7ac008..56131169bd 100644
--- a/tests/integration/gradient_accumulation_test.py
+++ b/tests/integration/gradient_accumulation_test.py
@@ -53,7 +53,7 @@ def test_grad_accumulate_same_loss(self):
"enable_goodput_recording=False",
"base_emb_dim=256",
"base_num_decoder_layers=4",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"steps=20",
]
# Run with gradient accumulation with accumulate_steps=10, per_device_batch=1 --> simulating per_device_batch=10
@@ -145,7 +145,7 @@ def test_sft_grad_accumulate_same_loss(self):
"enable_goodput_recording=False",
"base_emb_dim=256",
"base_num_decoder_layers=4",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"steps=3",
"gradient_accumulation_steps=2",
"use_sft=True",
diff --git a/tests/integration/simple_decoder_layer_test.py b/tests/integration/simple_decoder_layer_test.py
index ea1cedbfd7..5f7b0e24ba 100644
--- a/tests/integration/simple_decoder_layer_test.py
+++ b/tests/integration/simple_decoder_layer_test.py
@@ -38,7 +38,7 @@ def test_simple_decoder_layer(self):
"decoder_block=simple",
"enable_checkpointing=False",
"enable_goodput_recording=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"steps=3",
]
)
@@ -55,7 +55,7 @@ def test_mlp_decoder_layer(self):
"decoder_block=simple_mlp",
"enable_checkpointing=False",
"enable_goodput_recording=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"steps=3",
]
)
diff --git a/tests/integration/smoke/inference_microbenchmark_smoke_test.py b/tests/integration/smoke/inference_microbenchmark_smoke_test.py
index 3ae010542d..d5c5761e0b 100644
--- a/tests/integration/smoke/inference_microbenchmark_smoke_test.py
+++ b/tests/integration/smoke/inference_microbenchmark_smoke_test.py
@@ -35,7 +35,7 @@ def test(self):
[
None,
os.path.join(MAXTEXT_PKG_DIR, "configs", "tpu_smoke_test.yml"),
- rf"tokenizer_path={os.path.join('src', MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"ici_autoregressive_parallelism=-1",
"ici_fsdp_parallelism=1",
"max_prefill_predict_length=1024",
diff --git a/tests/integration/smoke/train_gpu_smoke_test.py b/tests/integration/smoke/train_gpu_smoke_test.py
index 80d1710770..357d858e1b 100644
--- a/tests/integration/smoke/train_gpu_smoke_test.py
+++ b/tests/integration/smoke/train_gpu_smoke_test.py
@@ -36,7 +36,7 @@ def test_tiny_config(self):
"run_name=runner_test",
r"dataset_path=gs://maxtext-dataset",
"enable_checkpointing=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"enable_goodput_recording=False",
"enable_checkpoint_cloud_logger=False",
"monitor_goodput=False",
diff --git a/tests/integration/smoke/train_int8_smoke_test.py b/tests/integration/smoke/train_int8_smoke_test.py
index dedf9d27c0..0ba43ff4a0 100644
--- a/tests/integration/smoke/train_int8_smoke_test.py
+++ b/tests/integration/smoke/train_int8_smoke_test.py
@@ -47,7 +47,7 @@ def test_tiny_config(self):
"steps=10",
"enable_checkpointing=False",
"quantization=int8",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"enable_goodput_recording=False",
"monitor_goodput=False",
"enable_checkpoint_cloud_logger=False",
diff --git a/tests/integration/smoke/train_smoke_test.py b/tests/integration/smoke/train_smoke_test.py
index b839232e60..9a3641b767 100644
--- a/tests/integration/smoke/train_smoke_test.py
+++ b/tests/integration/smoke/train_smoke_test.py
@@ -46,7 +46,7 @@ def test_tiny_config(self):
"dataset_type=synthetic",
"steps=10",
"enable_checkpointing=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"enable_goodput_recording=False",
"enable_checkpoint_cloud_logger=False",
"monitor_goodput=False",
@@ -74,7 +74,7 @@ def test_tiny_config_no_scan(self):
"dataset_type=synthetic",
"steps=10",
"enable_checkpointing=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"enable_goodput_recording=False",
"enable_checkpoint_cloud_logger=False",
"monitor_goodput=False",
@@ -104,7 +104,7 @@ def test_tiny_config_explicit_shardmode(self):
"steps=10",
"shard_mode=explicit",
"enable_checkpointing=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"enable_goodput_recording=False",
"enable_checkpoint_cloud_logger=False",
"monitor_goodput=False",
diff --git a/tests/integration/standalone_dl_ckpt_test.py b/tests/integration/standalone_dl_ckpt_test.py
index 64b92e6686..cee9c7ffe3 100644
--- a/tests/integration/standalone_dl_ckpt_test.py
+++ b/tests/integration/standalone_dl_ckpt_test.py
@@ -48,7 +48,7 @@ def test_standalone_dataloader(self):
"steps=100",
"enable_checkpointing=false",
"enable_goodput_recording=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
)
) # need to pass relative path to tokenizer
diff --git a/tests/integration/train_tests.py b/tests/integration/train_tests.py
index 41f8d8a7c9..16dfbf1822 100644
--- a/tests/integration/train_tests.py
+++ b/tests/integration/train_tests.py
@@ -35,7 +35,7 @@ class TrainTests(unittest.TestCase):
"steps=2",
"enable_checkpointing=False",
"enable_goodput_recording=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
],
"synthetic": [ # tests base config with synthetic dataset
None,
@@ -47,7 +47,7 @@ class TrainTests(unittest.TestCase):
"enable_checkpointing=False",
"enable_goodput_recording=False",
"dataset_type=synthetic",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
],
"pdb_lt_1": [ # tests base config with per_device_batch_size < 1
None,
@@ -60,7 +60,7 @@ class TrainTests(unittest.TestCase):
"enable_goodput_recording=False",
"per_device_batch_size=0.25",
"ici_tensor_parallelism=4",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
],
"tp_transpose": [ # tests base config with ici_tensor_transpose_parallelism=4
None,
@@ -71,7 +71,7 @@ class TrainTests(unittest.TestCase):
"steps=2",
"ici_tensor_transpose_parallelism=4",
"enable_goodput_recording=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
],
"int8": [ # tests base config with int8
None,
@@ -83,7 +83,7 @@ class TrainTests(unittest.TestCase):
"steps=2",
"enable_checkpointing=False",
"enable_goodput_recording=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
],
"fp8": [ # tests base config with fp8
None,
@@ -95,7 +95,7 @@ class TrainTests(unittest.TestCase):
"steps=2",
"enable_checkpointing=False",
"enable_goodput_recording=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
],
"nanoo_fp8": [ # tests base config with nanoo_fp8
None,
@@ -107,7 +107,7 @@ class TrainTests(unittest.TestCase):
"steps=2",
"enable_checkpointing=False",
"enable_goodput_recording=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
],
"te_fp8_delayedscaling": [ # tests base config with te_fp8_delayedscaling
None,
@@ -119,7 +119,7 @@ class TrainTests(unittest.TestCase):
"steps=2",
"enable_checkpointing=False",
"enable_goodput_recording=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
],
"te_fp8_currentscaling": [ # tests base config with te_fp8_currentscaling
None,
@@ -131,7 +131,7 @@ class TrainTests(unittest.TestCase):
"steps=2",
"enable_checkpointing=False",
"enable_goodput_recording=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
],
"te_mxfp8": [ # tests base config with te_mxfp8
None,
@@ -143,7 +143,7 @@ class TrainTests(unittest.TestCase):
"steps=2",
"enable_checkpointing=False",
"enable_goodput_recording=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
],
"dropout": [ # tests base config with dropout
None,
@@ -157,7 +157,7 @@ class TrainTests(unittest.TestCase):
"max_target_length=128",
"per_device_batch_size=1",
"dropout_rate=0.02",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
],
"hf_input_pipeline": [ # test for train.py with TFDS c4, using HF input pipeline
None,
@@ -294,7 +294,7 @@ def test_gpu_cudnn_flash_te(self):
"enable_goodput_recording=False",
"attention=cudnn_flash_te",
"packing=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
train_main(cudnn_flash_te)
@@ -317,7 +317,7 @@ def test_gpu_context_parallelism(self):
"context_parallel_strategy=all_gather",
"context_parallel_load_balance=True",
"packing=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
train_main(context_parallel)
@@ -338,7 +338,7 @@ def test_gpu_tensor_parallelism(self):
"ici_fsdp_parallelism=-1",
"ici_tensor_parallelism=2",
"packing=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
train_main(tensor_parallel)
@@ -358,7 +358,7 @@ def test_gpu_optimizer_offload(self):
"dataset_type=synthetic",
"enable_checkpointing=False",
"enable_goodput_recording=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
train_main(optimizer_offload)
@@ -379,7 +379,7 @@ def test_gpu_parameter_offload(self):
"dataset_type=synthetic",
"enable_checkpointing=False",
"enable_goodput_recording=False",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
train_main(parameter_offload)
@@ -397,7 +397,7 @@ def test_gpu_cudnn_flash_jax(self):
"attention=cudnn_flash_jax",
"packing=False",
"shardy=False", # The cudnn kernel is not compatible with shardy, see (b/425746362).
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
train_main(cudnn_flash_jax)
@@ -429,7 +429,7 @@ def test_tpu_zero1_gradient_accumulation(self):
"shard_optimizer_over_data=True",
"shard_mode=explicit",
"decoder_block=llama2",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
train_main(zero1_ga)
@@ -460,7 +460,7 @@ def test_gpu_zero1_gradient_accumulation(self):
"gradient_accumulation_steps=8",
"shard_optimizer_over_data=True",
"override_model_config=True",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
train_main(zero1_ga)
@@ -484,7 +484,7 @@ def test_gpu_packed_attention(self):
"attention=cudnn_flash_te",
"ici_fsdp_parallelism=-1",
"packing=True",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
train_main(packed_attention)
@@ -509,7 +509,7 @@ def test_gpu_ring_attention(self):
"context_parallel_strategy=ring",
"packing=False",
"hardware=gpu",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
train_main(ring_attention)
diff --git a/tests/integration/vision_encoder_test.py b/tests/integration/vision_encoder_test.py
index 289f053145..b6a576cc31 100644
--- a/tests/integration/vision_encoder_test.py
+++ b/tests/integration/vision_encoder_test.py
@@ -51,7 +51,7 @@ class VisionEncoderEmbeddingTest(unittest.TestCase):
None,
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
"model_name=gemma3-4b",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.gemma3')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.gemma3')}",
"use_multimodal=True",
"run_name=runner_test",
f"load_parameters_path={DEFAULT_LOAD_PARAMETERS_PATH}",
diff --git a/tests/unit/grain_data_processing_test.py b/tests/unit/grain_data_processing_test.py
index 406a01fd8e..a4429c837b 100644
--- a/tests/unit/grain_data_processing_test.py
+++ b/tests/unit/grain_data_processing_test.py
@@ -54,7 +54,7 @@ def setUp(self):
grain_train_files=os.path.join(
temp_dir, "gcsfuse", "array-record", "c4", "en", "3.0.1", "c4-train.array_record*"
),
- tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer"),
+ tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"),
enable_checkpointing=False,
)
self.mesh_shape_1d = (len(jax.devices()),)
@@ -130,7 +130,7 @@ def setUp(self):
base_output_directory="gs://max-experiments/",
dataset_type="grain",
grain_train_files=grain_train_files,
- tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer"),
+ tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"),
enable_checkpointing=False,
)
self.mesh_shape_1d = (len(jax.devices()),)
@@ -168,7 +168,7 @@ def setUp(self):
base_output_directory="gs://max-experiments/",
dataset_type="grain",
grain_train_mixture_config_path=self.mixture_config_path,
- tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer"),
+ tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"),
enable_checkpointing=False,
)
self.mesh_shape_1d = (len(jax.devices()),)
@@ -203,7 +203,7 @@ def setUp(self):
temp_dir, "gcsfuse", "array-record", "c4", "en", "3.0.1", "c4-train.array_record*"
),
grain_worker_count=-1, # Enable auto-tuning
- tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer"),
+ tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"),
enable_checkpointing=False,
)
self.mesh_shape_1d = (len(jax.devices()),)
@@ -250,7 +250,7 @@ def setUp(self):
temp_dir, "gcsfuse", "array-record", "c4", "en", "3.0.1", "c4-train.array_record*"
),
grain_packing_type="best_fit", # Use best_fit packing
- tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer"),
+ tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"),
enable_checkpointing=False,
)
self.mesh_shape_1d = (len(jax.devices()),)
@@ -288,7 +288,7 @@ def setUp(self):
grain_train_files=os.path.join(temp_dir, "gcsfuse", "hf", "c4", "c4-train-00000-of-01637.parquet"),
grain_worker_count=1,
grain_per_worker_buffer_size=1,
- tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer"),
+ tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"),
enable_checkpointing=False,
)
self.mesh_shape_1d = (len(jax.devices()),)
diff --git a/tests/unit/pipeline_parallelism_test.py b/tests/unit/pipeline_parallelism_test.py
index 3f5a8a6704..2bbd9e921c 100644
--- a/tests/unit/pipeline_parallelism_test.py
+++ b/tests/unit/pipeline_parallelism_test.py
@@ -303,7 +303,7 @@ def test_full_train_circular(self):
"ici_pipeline_parallelism=4",
"num_layers_per_pipeline_stage=2",
"num_pipeline_microbatches=8",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"scan_layers_per_stage=False", # We see better performance only scanning the pipeline iterations.
]
)
@@ -353,7 +353,7 @@ def test_full_train_non_circular(self):
"ici_pipeline_parallelism=4",
"num_layers_per_pipeline_stage=8",
"num_pipeline_microbatches=8",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"scan_layers_per_stage=False", # We see better performance only scanning the pipeline iterations.
]
)
@@ -387,7 +387,7 @@ def test_subset_layers(self):
"num_pipeline_repeats=2",
"pipeline_parallel_layers=8",
"num_pipeline_microbatches=8",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"scan_layers_per_stage=False", # We see better performance only scanning the pipeline iterations.
]
)
@@ -417,7 +417,7 @@ def test_full_train_fp8(self):
"enable_checkpointing=False",
"enable_goodput_recording=False",
"ici_pipeline_parallelism=4",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"quantization=fp8",
"scan_layers_per_stage=False",
"attention=dot_product",
@@ -449,7 +449,7 @@ def test_full_train_nanoo_fp8(self):
"enable_checkpointing=False",
"enable_goodput_recording=False",
"ici_pipeline_parallelism=4",
- rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
+ rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"quantization=nanoo_fp8",
"scan_layers_per_stage=False",
"attention=dot_product",
diff --git a/tests/unit/tfds_data_processing_test.py b/tests/unit/tfds_data_processing_test.py
index f3f515e567..12e1372dce 100644
--- a/tests/unit/tfds_data_processing_test.py
+++ b/tests/unit/tfds_data_processing_test.py
@@ -43,7 +43,7 @@ def setUp(self):
data_sharding=["data"],
base_output_directory="gs://max-experiments/",
dataset_path="gs://maxtext-dataset/",
- tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer"),
+ tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"),
enable_checkpointing=False,
eval_interval=10,
)
diff --git a/tests/unit/tokenizer_test.py b/tests/unit/tokenizer_test.py
index c76ceba2b6..100c0076a7 100644
--- a/tests/unit/tokenizer_test.py
+++ b/tests/unit/tokenizer_test.py
@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-""" Tests for tokenizer
-"""
+"""Tests for tokenizer"""
import numpy as np
from MaxText import train_tokenizer
@@ -40,7 +39,10 @@ def setUpClass(cls):
vocab_model_name = "test_tokenizer"
cls.tokenizer_path = os.path.join(assets_path, vocab_model_name)
cls.source_tokenizer = _input_pipeline_utils.get_tokenizer(
- os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer"), "sentencepiece", add_bos=False, add_eos=False
+ os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"),
+ "sentencepiece",
+ add_bos=False,
+ add_eos=False,
)
os.environ["TFDS_DATA_DIR"] = dataset_path
read_config = tfds.ReadConfig(
@@ -81,7 +83,7 @@ def setUpClass(cls):
dataset_name = "c4/en:3.0.1"
dataset_path = "gs://maxtext-dataset"
cls.source_tokenizer = _input_pipeline_utils.get_tokenizer(
- os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer_llama3.tiktoken"),
+ os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"),
"tiktoken",
add_bos=False,
add_eos=False,
@@ -112,16 +114,16 @@ class HFTokenizerTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
source = "gs://maxtext-gemma/huggingface/gemma2-2b"
- destination = os.path.join(MAXTEXT_ASSETS_ROOT, "")
+ destination = os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers")
subprocess.run(
["gcloud", "storage", "cp", "-R", source, destination],
check=True,
)
cls.hf_tokenizer = _input_pipeline_utils.get_tokenizer(
- os.path.join(MAXTEXT_ASSETS_ROOT, "gemma2-2b"), "huggingface", add_bos=False, add_eos=False
+ os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "gemma2-2b"), "huggingface", add_bos=False, add_eos=False
)
cls.sp_tokenizer = _input_pipeline_utils.get_tokenizer(
- os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.gemma"), "sentencepiece", add_bos=False, add_eos=False
+ os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.gemma"), "sentencepiece", add_bos=False, add_eos=False
)
@pytest.mark.tpu_only