Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion tests/unit/moe_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023–2025 Google LLC
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -460,6 +460,7 @@ def test_megablox(self):
megablox=True,
sparse_matmul=True,
per_device_batch_size=1,
max_target_length=128,
)

rng = jax.random.PRNGKey(1234)
Expand Down Expand Up @@ -488,6 +489,7 @@ def test_ragged_dot(self):
megablox=False,
sparse_matmul=True,
per_device_batch_size=1,
max_target_length=128,
)

rng = jax.random.PRNGKey(1234)
Expand Down Expand Up @@ -516,6 +518,7 @@ def test_dense(self):
megablox=False,
sparse_matmul=False,
per_device_batch_size=1,
max_target_length=128,
)

rng = jax.random.PRNGKey(2345)
Expand Down Expand Up @@ -545,6 +548,7 @@ def test_megablox_expert_parallelism(self):
sparse_matmul=True,
per_device_batch_size=4, # TODO(b/450900273): sharding error if pdbs=1
ici_expert_parallelism=4,
max_target_length=128,
)

rng = jax.random.PRNGKey(2345)
Expand Down Expand Up @@ -577,6 +581,7 @@ def test_moe_fsdp_two_stage_parallelism_tpu_only(self):
ici_fsdp_parallelism=2,
ici_fsdp_transpose_parallelism=2,
moe_fsdp_use_two_stage_all_gather=True,
max_target_length=128,
)

rng = jax.random.PRNGKey(2345)
Expand Down Expand Up @@ -652,6 +657,7 @@ def test_megablox_context_parallelism(self):
sparse_matmul=True,
per_device_batch_size=1,
ici_context_parallelism=4,
max_target_length=128,
)

rng = jax.random.PRNGKey(2345)
Expand Down Expand Up @@ -684,6 +690,7 @@ def test_megablox_expert_context_parallelism(self):
ici_context_parallelism=2,
ici_expert_parallelism=2,
packing=False,
max_target_length=128,
)

rng = jax.random.PRNGKey(2345)
Expand Down Expand Up @@ -715,6 +722,7 @@ def test_megablox_expert_tensor_parallelism(self):
per_device_batch_size=4,
ici_tensor_parallelism=2,
ici_expert_parallelism=2,
max_target_length=128,
)

rng = jax.random.PRNGKey(2345)
Expand Down
70 changes: 24 additions & 46 deletions tests/unit/train_compile_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023–2025 Google LLC
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -188,7 +188,7 @@ def test_sequence_parallelism(self):
"",
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5e-256",
"compile_topology=v5p-64",
"use_iota_embed=true",
"compile_topology_num_slices=1",
"ici_sequence_parallelism=16",
Expand Down Expand Up @@ -276,12 +276,12 @@ def test_remat_full(self):
"",
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5e-256",
"compile_topology=v6e-256",
"compile_topology_num_slices=1",
"per_device_batch_size=1",
"ici_fsdp_parallelism=16",
"ici_tensor_parallelism=16",
"max_target_length=2048",
"max_target_length=1024",
"fused_qkv=true",
"fused_mlp=true",
"remat_policy=full",
Expand Down Expand Up @@ -366,7 +366,7 @@ def test_moe_dropping_bf16(self):
"",
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v6e-256",
"compile_topology=v5p-64",
"use_iota_embed=true",
"compile_topology_num_slices=1",
"model_name=mixtral-8x7b",
Expand Down Expand Up @@ -457,7 +457,7 @@ def test_moe_dense_bf16(self):
"",
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v6e-256",
"compile_topology=v5p-64",
"use_iota_embed=true",
"compile_topology_num_slices=1",
"model_name=mixtral-8x7b",
Expand Down Expand Up @@ -503,7 +503,7 @@ def test_moe_pp_bf16(self):
"",
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v6e-256",
"compile_topology=v5p-64",
"use_iota_embed=true",
"compile_topology_num_slices=2",
"model_name=mixtral-8x7b",
Expand All @@ -527,10 +527,10 @@ def test_moe_deepseek_scanned_bf16(self):
"",
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5p-256",
"compile_topology=v5p-64",
"use_iota_embed=true",
"compile_topology_num_slices=1",
"model_name=deepseek3-671b",
"model_name=deepseek3-test",
"sparse_matmul=True",
"megablox=False",
"per_device_batch_size=2",
Expand All @@ -552,10 +552,10 @@ def test_moe_deepseek_unscanned_bf16(self):
"",
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5p-256",
"compile_topology=v5p-64",
"use_iota_embed=true",
"compile_topology_num_slices=1",
"model_name=deepseek3-671b",
"model_name=deepseek3-test",
"sparse_matmul=True",
"megablox=False",
"per_device_batch_size=1",
Expand All @@ -575,10 +575,10 @@ def test_moe_deepseek_with_device_limit(self):
"",
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5p-256",
"compile_topology=v5p-64",
"use_iota_embed=true",
"compile_topology_num_slices=1",
"model_name=deepseek3-671b",
"model_name=deepseek3-test",
"sparse_matmul=True",
"megablox=False",
"per_device_batch_size=1",
Expand All @@ -591,30 +591,6 @@ def test_moe_deepseek_with_device_limit(self):
)
)

@pytest.mark.cpu_only
def test_moe_deepseek_without_device_limit(self):
compiled_trainstep_file = "/tmp/test_moe_deepseek_without_device_limit.pickle"
train_compile_main(
(
"",
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5p-256",
"use_iota_embed=true",
"compile_topology_num_slices=1",
"model_name=deepseek3-671b",
"sparse_matmul=True",
"megablox=False",
"per_device_batch_size=1",
"max_target_length=1024",
"attention=flash",
"dtype=bfloat16",
"weight_dtype=bfloat16",
"n_routing_groups=-1",
"topk_routing_group=-1",
)
)

@pytest.mark.cpu_only
def test_moe_deepseek_pipeline_subset(self):
compiled_trainstep_file = "/tmp/test_moe_deepseek_pipeline_subset.pickle"
Expand All @@ -623,15 +599,15 @@ def test_moe_deepseek_pipeline_subset(self):
"",
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v6e-256",
"compile_topology=v5p-64",
"compile_topology_num_slices=8",
"use_iota_embed=true",
"model_name=deepseek3-671b",
"model_name=deepseek3-test",
"megablox=True",
"sparse_matmul=False",
"capacity_factor=1",
"per_device_batch_size=1",
"max_target_length=2048",
"max_target_length=1024",
"pipeline_parallel_layers=56",
"ici_expert_parallelism=16",
"dcn_pipeline_parallelism=8",
Expand All @@ -646,11 +622,11 @@ def test_pipeline_subset(self):
"",
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v6e-256",
"compile_topology=v5p-128",
"compile_topology_num_slices=8",
"use_iota_embed=true",
"per_device_batch_size=1",
"max_target_length=2048",
"max_target_length=1024",
"pipeline_parallel_layers=56",
"base_num_decoder_layers=61", # Remainder of 5 will fail when sharded incorrectly.
"ici_expert_parallelism=16",
Expand All @@ -666,15 +642,15 @@ def test_moe_llama4_17b_16e(self):
"",
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5p-256",
"compile_topology=v5p-128",
"compile_topology_num_slices=1",
"model_name=llama4-17b-16e",
"per_device_batch_size=1",
"max_target_length=1024",
"dtype=bfloat16",
"weight_dtype=bfloat16",
"scan_layers=True",
"ici_fsdp_parallelism=32",
"ici_fsdp_parallelism=16",
"ici_tensor_parallelism=4",
)
)
Expand All @@ -687,7 +663,7 @@ def test_moe_gpt_oss_20b_sparse_matmul(self):
"",
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5p-64",
"compile_topology=v5p-16",
"compile_topology_num_slices=1",
"model_name=gpt-oss-20b",
"per_device_batch_size=1",
Expand All @@ -709,7 +685,7 @@ def test_moe_gpt_oss_20b_dense_matmul(self):
"",
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5p-64",
"compile_topology=v5p-16",
"compile_topology_num_slices=1",
"model_name=gpt-oss-20b",
"per_device_batch_size=1",
Expand Down Expand Up @@ -767,6 +743,7 @@ def test_qwen3_next(self):
"compile_topology_num_slices=1",
"model_name=qwen3-next-80b-a3b",
"per_device_batch_size=1",
"max_target_length=1024",
)
)

Expand Down Expand Up @@ -811,5 +788,6 @@ def test_olmo3_7b(self):
"model_name=olmo3_7b",
"per_device_batch_size=1",
"scan_layers=True",
"max_target_length=1024",
)
)
Loading