diff --git a/.gitmodules b/.gitmodules
index 4b188d6bb1..e531c95507 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -7,3 +7,7 @@
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
+[submodule "3rdparty/nccl"]
+ path = 3rdparty/nccl
+ url = https://github.com/NVIDIA/nccl.git
+ branch = v2.30u1
diff --git a/3rdparty/nccl b/3rdparty/nccl
new file mode 160000
index 0000000000..146496ac88
--- /dev/null
+++ b/3rdparty/nccl
@@ -0,0 +1 @@
+Subproject commit 146496ac881bc504ed1a52be0ae7b707ce41e706
diff --git a/build_tools/jax.py b/build_tools/jax.py
index 5d9276b5e6..d3dcc7b453 100644
--- a/build_tools/jax.py
+++ b/build_tools/jax.py
@@ -103,16 +103,54 @@ def setup_jax_extension(
setup_mpi_flags(include_dirs, cxx_flags)
+ libraries = []
+ submod_lib_dir = None
+ submod_nccl_inc = None
+
if bool(int(os.getenv("NVTE_WITH_CUBLASMP", 0))):
cxx_flags.append("-DNVTE_WITH_CUBLASMP")
+ # NCCL EP is on by default. Set NVTE_BUILD_WITH_NCCL_EP=0 to skip it.
+ build_with_nccl_ep = bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1")))
+ if build_with_nccl_ep:
+ cxx_flags.append("-DNVTE_WITH_NCCL_EP")
+ # Headers + libs come from the in-tree 3rdparty/nccl submodule build
+ # (auto-produced by setup.py).
+ libraries = ["nccl", "nccl_ep"]
+ # NCCL EP requires SM>=90 (Hopper+).
+ archs_env = os.getenv("NVTE_CUDA_ARCHS", "")
+ for a in archs_env.split(";"):
+ a_num = "".join(c for c in a if c.isdigit())
+ if a_num and int(a_num) < 90:
+ raise RuntimeError(
+ f"NCCL EP requires CUDA arch >= 90 (Hopper or newer); got '{a}' in"
+ " NVTE_CUDA_ARCHS."
+ )
+ submod_root = (common_header_files / ".." / "3rdparty" / "nccl").resolve()
+ submod_ep_inc = submod_root / "contrib" / "nccl_ep" / "include"
+ if not (submod_ep_inc / "nccl_ep.h").exists():
+ raise RuntimeError(
+ f"NCCL EP header not found at {submod_ep_inc}/nccl_ep.h. "
+ "Run `git submodule update --init --recursive` to checkout 3rdparty/nccl."
+ )
+ include_dirs.append(submod_ep_inc)
+ submod_lib_dir = submod_root / "build" / "lib"
+ submod_nccl_inc = submod_root / "build" / "include"
+
# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension
- return Pybind11Extension(
+ ext = Pybind11Extension(
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args=cxx_flags,
- libraries=["nccl"],
+ libraries=libraries,
)
+ if submod_lib_dir is not None:
+ ext.library_dirs.append(str(submod_lib_dir))
+ ext.runtime_library_dirs.append(str(submod_lib_dir))
+ # Prefer submodule's nccl.h when present (matches the C++ side).
+ if (submod_nccl_inc / "nccl.h").exists():
+ ext.include_dirs.insert(0, str(submod_nccl_inc))
+ return ext
diff --git a/docs/examples/jax/media/jax_moe_native_vs_te_flow.drawio b/docs/examples/jax/media/jax_moe_native_vs_te_flow.drawio
new file mode 100644
index 0000000000..446fb340d8
--- /dev/null
+++ b/docs/examples/jax/media/jax_moe_native_vs_te_flow.drawio
@@ -0,0 +1,43 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/docs/examples/jax/media/jax_moe_native_vs_te_flow.svg b/docs/examples/jax/media/jax_moe_native_vs_te_flow.svg
new file mode 100644
index 0000000000..13b5d65183
--- /dev/null
+++ b/docs/examples/jax/media/jax_moe_native_vs_te_flow.svg
@@ -0,0 +1,120 @@
+
diff --git a/docs/examples/jax/moe.out b/docs/examples/jax/moe.out
new file mode 100644
index 0000000000..e79009eedf
--- /dev/null
+++ b/docs/examples/jax/moe.out
@@ -0,0 +1,12 @@
+# Numbers below were captured on 4x NVIDIA GB200.
+# Native JAX BF16 uses the ragged A2A baseline in single-process 4-GPU mode.
+# TE BF16 uses NCCL EP in 4-process mode with one GPU per process.
+
+# MOE_OUTPUT_START
+native JAX BF16 ragged A2A:
+Mean time: 17.085 ms
+
+TE _MoEBlock BF16 with NCCL EP:
+TE _MoEBlock BF16 output: shape=(8, 2048, 1024), dtype=bfloat16
+Mean time: 3.156 ms
+# MOE_OUTPUT_END
diff --git a/docs/examples/jax/moe.py b/docs/examples/jax/moe.py
new file mode 100644
index 0000000000..b93080e0ec
--- /dev/null
+++ b/docs/examples/jax/moe.py
@@ -0,0 +1,461 @@
+# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+"""JAX: BF16 Mixture-of-Experts with TransformerEngine.
+
+Companion source for ``moe.rst``. Code blocks between ``# MOE_*_START`` /
+``# MOE_*_END`` markers are pulled into the RST via ``literalinclude``.
+
+Run as a script to exercise the example end-to-end:
+
+ python docs/examples/jax/moe.py
+ python docs/examples/jax/moe.py --num-process=4 --process-id=0
+
+Launch one process for each ``process-id`` in ``[0, 4)``.
+
+The TransformerEngine path uses NCCL-backed EP and therefore requires a
+multi-process launch with one GPU per process. Both the native baseline and
+TransformerEngine path run in BF16; the current ``_MoEBlock`` wrapper uses
+no-op quantizer sets.
+"""
+
+# MOE_IMPORTS_START
+from dataclasses import dataclass
+from typing import Any
+import os
+import sys
+
+import jax
+import jax.numpy as jnp
+from flax.linen import partitioning as nn_partitioning
+from jax.experimental import mesh_utils
+from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
+
+from moe_native import NativeMoEBlock
+
+# MOE_IMPORTS_END
+
+
+# MOE_CONFIG_START
+EP_AXIS = "ep"
+FSDP_AXIS = "fsdp"
+EP_SIZE = 2
+FSDP_SIZE = 2
+
+NUM_EXPERTS = 8
+TOPK = 2
+BATCH = 8
+SEQ = 2048
+HIDDEN = 1024
+INTERMEDIATE = 4096
+DTYPE = jnp.bfloat16
+
+LOGICAL_AXIS_RULES = (
+ ("exp", EP_AXIS),
+ ("embed", FSDP_AXIS),
+ ("mlp", None),
+ ("batch", (EP_AXIS, FSDP_AXIS)),
+)
+# MOE_CONFIG_END
+
+
+@dataclass
+class DemoState:
+ mesh: Mesh
+ mesh_resource: Any
+ native_model: NativeMoEBlock
+ te_model: Any
+ variables: Any
+ x: jax.Array
+ dy: jax.Array
+
+
+def _ensure_writable_triton_cache():
+ import tempfile
+
+ os.environ.setdefault(
+ "TRITON_CACHE_DIR",
+ os.path.join(tempfile.gettempdir(), "transformer_engine_triton_cache"),
+ )
+
+
+def _register_te_ffi_targets():
+ _ensure_writable_triton_cache()
+ import transformer_engine.jax.cpp_extensions # noqa: F401
+
+
+# MOE_MESH_SETUP_START
+def _read_mp_options():
+ num_process = int(os.environ.get("MP_NUM_PROCESS", "0") or "0")
+ process_id = int(os.environ.get("MP_PROCESS_ID", "0") or "0")
+ for i, arg in enumerate(sys.argv):
+ if arg.startswith("--num-process="):
+ num_process = int(arg.split("=", 1)[1])
+ elif arg == "--num-process" and i + 1 < len(sys.argv):
+ num_process = int(sys.argv[i + 1])
+ elif arg.startswith("--process-id="):
+ process_id = int(arg.split("=", 1)[1])
+ elif arg == "--process-id" and i + 1 < len(sys.argv):
+ process_id = int(sys.argv[i + 1])
+ return num_process, process_id
+
+
+def maybe_initialize_distributed():
+ num_process, process_id = _read_mp_options()
+ if num_process <= 1:
+ return
+ coordinator = os.environ.get("TE_EP_MOE_COORDINATOR_ADDRESS", "127.0.0.1:13457")
+ jax.distributed.initialize(
+ coordinator_address=coordinator,
+ num_processes=num_process,
+ process_id=process_id,
+ local_device_ids=process_id,
+ )
+
+
+def build_ep_fsdp_mesh():
+ from transformer_engine.jax.sharding import MeshResource
+
+ required_devices = EP_SIZE * FSDP_SIZE
+ if len(jax.devices()) < required_devices:
+ raise RuntimeError(
+ f"MoE tutorial requires {required_devices} GPUs; only {len(jax.devices())} visible"
+ )
+
+ devices = mesh_utils.create_device_mesh(
+ (FSDP_SIZE, EP_SIZE),
+ devices=jax.devices()[:required_devices],
+ )
+ mesh = Mesh(devices, axis_names=(FSDP_AXIS, EP_AXIS))
+ mesh_resource = MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS)
+ return mesh, mesh_resource
+
+
+# MOE_MESH_SETUP_END
+
+
+# MOE_MODEL_SETUP_START
+def build_models(mesh, *, hidden=HIDDEN, intermediate=INTERMEDIATE):
+ _ensure_writable_triton_cache()
+
+ from transformer_engine.jax.flax import _MoEBlock as TEMoEBlock
+
+ native_model = NativeMoEBlock(
+ mesh=mesh,
+ num_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOPK,
+ intermediate_size=intermediate,
+ ep_axis=EP_AXIS,
+ data_parallelism_axes=(FSDP_AXIS,),
+ dtype=DTYPE,
+ )
+ te_model = TEMoEBlock(
+ num_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOPK,
+ intermediate_size=intermediate,
+ data_parallelism_axes=(FSDP_AXIS,),
+ apply_topk_weights_early=True,
+ dtype=DTYPE,
+ )
+ return native_model, te_model
+
+
+# MOE_MODEL_SETUP_END
+
+
+# MOE_INPUTS_SETUP_START
+def make_inputs(*, batch=BATCH, seq=SEQ, hidden=HIDDEN):
+ key = jax.random.PRNGKey(0)
+ k_init, k_x, k_dy = jax.random.split(key, 3)
+ x = jax.random.normal(k_x, (batch, seq, hidden), dtype=DTYPE)
+ dy = jax.random.normal(k_dy, (batch, seq, hidden), dtype=DTYPE)
+ return k_init, x, dy
+
+
+def shard_inputs_and_variables(mesh, variables, x, dy):
+ input_sharding = NamedSharding(mesh, P((FSDP_AXIS, EP_AXIS), None, None))
+ gate_sharding = NamedSharding(mesh, P())
+ expert_sharding = NamedSharding(mesh, P(EP_AXIS, None, None))
+
+ params = variables["params"]
+ sharded_params = {
+ "gate_kernel": jax.device_put(params["gate_kernel"], gate_sharding),
+ "wi_0": jax.device_put(params["wi_0"], expert_sharding),
+ "wi_1": jax.device_put(params["wi_1"], expert_sharding),
+ "wo": jax.device_put(params["wo"], expert_sharding),
+ }
+ return {
+ "variables": {**variables, "params": sharded_params},
+ "x": jax.device_put(x, input_sharding),
+ "dy": jax.device_put(dy, input_sharding),
+ }
+
+
+# MOE_INPUTS_SETUP_END
+
+
+def _recv_capacity_per_rank(batch, seq):
+ num_procs = jax.process_count()
+ dp_size = num_procs // EP_SIZE
+ num_local_experts = NUM_EXPERTS // EP_SIZE
+ natural_recv_pr = (batch // dp_size) * seq * TOPK
+ slots_per_expert = (natural_recv_pr + num_local_experts - 1) // num_local_experts
+ return num_local_experts * slots_per_expert
+
+
+def bootstrap_te_ep(mesh, mesh_resource, *, batch=BATCH, seq=SEQ, hidden=HIDDEN):
+ from transformer_engine.jax.ep import ep_bootstrap
+ from transformer_engine.jax.moe import record_ep_bootstrap_signature_for_moe
+ from transformer_engine.jax.sharding import global_shard_guard
+
+ world_size = jax.process_count()
+ max_tokens_per_rank = (batch // world_size) * seq
+ recv_capacity_per_rank = _recv_capacity_per_rank(batch, seq)
+
+ with jax.set_mesh(mesh), global_shard_guard(mesh_resource):
+ ep_bootstrap(
+ world_size=world_size,
+ rank=jax.process_index(),
+ ep_size=EP_SIZE,
+ num_experts=NUM_EXPERTS,
+ max_tokens_per_rank=max_tokens_per_rank,
+ recv_capacity_per_rank=recv_capacity_per_rank,
+ hidden_dim=hidden,
+ allow_handle_mem_reloc=True,
+ max_token_dtype=DTYPE,
+ )
+ record_ep_bootstrap_signature_for_moe(
+ num_experts=NUM_EXPERTS,
+ max_tokens_per_rank=max_tokens_per_rank,
+ recv_capacity_per_rank=recv_capacity_per_rank,
+ hidden_dim=hidden,
+ ep_size=EP_SIZE,
+ )
+
+
+def _te_apply(te_model):
+ def apply_fn(variables, x, **kwargs):
+ out, _ = te_model.apply(variables, x, **kwargs)
+ return out
+
+ return apply_fn
+
+
+def setup_demo(*, batch=BATCH, seq=SEQ, hidden=HIDDEN, intermediate=INTERMEDIATE):
+ from transformer_engine.jax.sharding import global_shard_guard
+
+ mesh, mesh_resource = build_ep_fsdp_mesh()
+ bootstrap_te_ep(mesh, mesh_resource, batch=batch, seq=seq, hidden=hidden)
+ native_model, te_model = build_models(mesh, hidden=hidden, intermediate=intermediate)
+ k_init, x, dy = make_inputs(batch=batch, seq=seq, hidden=hidden)
+
+ with jax.set_mesh(mesh), global_shard_guard(mesh_resource), nn_partitioning.axis_rules(
+ LOGICAL_AXIS_RULES
+ ):
+ variables = jax.jit(native_model.init)(k_init, x)
+ variables = jax.jit(native_model.init)(k_init, x)
+ jax.block_until_ready(jax.tree_util.tree_leaves(variables))
+ sharded = shard_inputs_and_variables(mesh, variables, x, dy)
+ return DemoState(
+ mesh=mesh,
+ mesh_resource=mesh_resource,
+ native_model=native_model,
+ te_model=te_model,
+ variables=sharded["variables"],
+ x=sharded["x"],
+ dy=sharded["dy"],
+ )
+
+
+def setup_te_demo(*, batch=BATCH, seq=SEQ, hidden=HIDDEN, intermediate=INTERMEDIATE):
+ from transformer_engine.jax.sharding import global_shard_guard
+
+ mesh, mesh_resource = build_ep_fsdp_mesh()
+ bootstrap_te_ep(mesh, mesh_resource, batch=batch, seq=seq, hidden=hidden)
+ _, te_model = build_models(mesh, hidden=hidden, intermediate=intermediate)
+ k_init, x, dy = make_inputs(batch=batch, seq=seq, hidden=hidden)
+
+ with jax.set_mesh(mesh), global_shard_guard(mesh_resource), nn_partitioning.axis_rules(
+ LOGICAL_AXIS_RULES
+ ):
+ variables = jax.jit(te_model.init)(k_init, x)
+ jax.block_until_ready(jax.tree_util.tree_leaves(variables))
+ sharded = shard_inputs_and_variables(mesh, variables, x, dy)
+ return DemoState(
+ mesh=mesh,
+ mesh_resource=mesh_resource,
+ native_model=None,
+ te_model=te_model,
+ variables=sharded["variables"],
+ x=sharded["x"],
+ dy=sharded["dy"],
+ )
+
+
+def te_moe_supported():
+ try:
+ import importlib
+
+ _ensure_writable_triton_cache()
+
+ import transformer_engine.jax # noqa: F401
+
+ transformer_engine_jax = sys.modules["transformer_engine_jax"]
+ flax_mod = importlib.import_module("transformer_engine.jax.flax")
+ getattr(flax_mod, "_MoEBlock")
+ if jax.process_count() < EP_SIZE * FSDP_SIZE:
+ return (
+ False,
+ (
+ "TE EP requires a multi-process launch with one GPU per process; "
+ f"got process_count={jax.process_count()}"
+ ),
+ )
+ if jax.local_device_count() != 1:
+ return (
+ False,
+ (
+ "TE EP requires one local GPU per process; "
+ f"got local_device_count={jax.local_device_count()}"
+ ),
+ )
+ if transformer_engine_jax.get_device_compute_capability(0) < 100:
+ return False, "TE MoE grouped GEMM currently requires Blackwell (sm_100+)"
+ except Exception as exc: # pylint: disable=broad-exception-caught
+ return False, str(exc)
+ return True, ""
+
+
+def compare_forward(demo):
+ from transformer_engine.jax.sharding import global_shard_guard
+
+ te_apply = _te_apply(demo.te_model)
+ with jax.set_mesh(demo.mesh), global_shard_guard(
+ demo.mesh_resource
+ ), nn_partitioning.axis_rules(LOGICAL_AXIS_RULES):
+ native_out = jax.jit(demo.native_model.apply)(demo.variables, demo.x)
+ te_out = jax.jit(te_apply)(demo.variables, demo.x)
+ native_out, te_out = jax.block_until_ready((native_out, te_out))
+
+ max_abs = jnp.max(jnp.abs(native_out.astype(jnp.float32) - te_out.astype(jnp.float32)))
+ print(f"max |native BF16 - TE BF16|: {float(max_abs):.4f}")
+ return native_out, te_out
+
+
+# MOE_CORRECTNESS_START
+def run_te_forward(demo):
+ from transformer_engine.jax.sharding import global_shard_guard
+
+ te_apply = _te_apply(demo.te_model)
+ with jax.set_mesh(demo.mesh), global_shard_guard(
+ demo.mesh_resource
+ ), nn_partitioning.axis_rules(LOGICAL_AXIS_RULES):
+ te_out = jax.jit(te_apply)(demo.variables, demo.x)
+ te_out.block_until_ready()
+
+ print(f"TE _MoEBlock BF16 output: shape={te_out.shape}, dtype={te_out.dtype}")
+ return te_out
+
+
+# MOE_CORRECTNESS_END
+
+
+# MOE_BENCH_START
+def _block_until_ready_tree(tree):
+ leaves = jax.tree_util.tree_leaves(tree)
+ if leaves:
+ jax.block_until_ready(leaves)
+
+
+def _time_fwd_bwd(apply_fn, demo, *, warmup_iters=5, timing_iters=10):
+ import time
+
+ autocast_kwargs = {"enabled": False, "mesh_resource": demo.mesh_resource}
+
+ def loss_fn(variables, inp, grad_target):
+ import transformer_engine.jax as te
+
+ with te.autocast(**autocast_kwargs):
+ out = apply_fn(variables, inp)
+ return jnp.vdot(out, grad_target)
+
+ train_step = jax.jit(jax.value_and_grad(loss_fn, argnums=(0, 1)))
+
+ for _ in range(warmup_iters):
+ _block_until_ready_tree(train_step(demo.variables, demo.x, demo.dy))
+
+ start = time.perf_counter()
+ for _ in range(timing_iters):
+ _block_until_ready_tree(train_step(demo.variables, demo.x, demo.dy))
+ return (time.perf_counter() - start) * 1000.0 / timing_iters
+
+
+def run_benchmarks(demo, *, warmup_iters=5, timing_iters=10):
+ from transformer_engine.jax.sharding import global_shard_guard
+
+ te_apply = _te_apply(demo.te_model)
+ with jax.set_mesh(demo.mesh), global_shard_guard(
+ demo.mesh_resource
+ ), nn_partitioning.axis_rules(LOGICAL_AXIS_RULES):
+ print("native JAX BF16:")
+ native_ms = _time_fwd_bwd(
+ demo.native_model.apply,
+ demo,
+ warmup_iters=warmup_iters,
+ timing_iters=timing_iters,
+ )
+ print(f"Mean time: {native_ms:.3f} ms")
+
+ print("\nTE _MoEBlock BF16:")
+ te_ms = _time_fwd_bwd(
+ te_apply,
+ demo,
+ warmup_iters=warmup_iters,
+ timing_iters=timing_iters,
+ )
+ print(f"Mean time: {te_ms:.3f} ms")
+ return native_ms, te_ms
+
+
+def run_te_benchmark(demo, *, warmup_iters=5, timing_iters=10):
+ from transformer_engine.jax.sharding import global_shard_guard
+
+ te_apply = _te_apply(demo.te_model)
+ with jax.set_mesh(demo.mesh), global_shard_guard(
+ demo.mesh_resource
+ ), nn_partitioning.axis_rules(LOGICAL_AXIS_RULES):
+ print("TE _MoEBlock BF16:")
+ te_ms = _time_fwd_bwd(
+ te_apply,
+ demo,
+ warmup_iters=warmup_iters,
+ timing_iters=timing_iters,
+ )
+ print(f"Mean time: {te_ms:.3f} ms")
+ return te_ms
+
+
+# MOE_BENCH_END
+
+
+def main():
+ _register_te_ffi_targets()
+ maybe_initialize_distributed()
+
+ if len(jax.devices()) < EP_SIZE * FSDP_SIZE:
+ print(f"[skipped: need {EP_SIZE * FSDP_SIZE} GPUs for EP=2/FSDP=2]")
+ return
+
+ te_supported, te_reason = te_moe_supported()
+ if not te_supported:
+ print(f"[skipped TE comparison: {te_reason}]")
+ return
+
+ demo = setup_te_demo()
+ run_te_forward(demo)
+ run_te_benchmark(demo)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/examples/jax/moe.rst b/docs/examples/jax/moe.rst
new file mode 100644
index 0000000000..33ef33bbbd
--- /dev/null
+++ b/docs/examples/jax/moe.rst
@@ -0,0 +1,189 @@
+..
+ Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+ See LICENSE for license information.
+
+JAX: BF16 Mixture-of-Experts with TransformerEngine
+===================================================
+
+This document walks through replacing a native JAX/Flax expert-parallel MoE
+block with TransformerEngine's experimental Flax ``_MoEBlock``.
+
+**Baseline.** The reference path is pure JAX/Flax BF16. It uses
+``jax.lax.ragged_all_to_all`` for expert-parallel token exchange and
+``jax.lax.ragged_dot`` for the grouped expert FFNs. The low-level ragged
+all-to-all setup lives in ``moe_native.py`` so the snippets below stay focused
+on model-level code.
+
+**TransformerEngine path.** This tutorial uses ``_MoEBlock`` in BF16 with
+NCCL-backed TE EP and the wrapper's current no-op quantizer sets. TE EP replaces
+the tutorial's previous TE-side ragged A2A exchange with ``tex.ep_dispatch`` and
+``tex.ep_combine`` over NCCL EP. Quantized MoE recipes are intentionally out of
+scope here.
+
+`<- Back to the JAX integration overview <../te_jax_integration.html>`_
+
+The forward path below summarizes the data flow for the native baseline and the
+TE replacement.
+
+.. figure:: media/jax_moe_native_vs_te_flow.svg
+ :alt: Side-by-side forward data flow for native JAX and TransformerEngine JAX MoE blocks.
+ :align: center
+ :width: 100%
+
+ Forward data flow for the tutorial's BF16 MoE block. The native baseline
+ keeps JAX ``ragged_all_to_all`` and ``ragged_dot``. TE keeps the same sharded
+ inputs and weights, but routes through TE fused router, NCCL EP
+ dispatch/combine, and grouped GEMM primitives while keeping dispatch, expert
+ compute, and combine inside one MoE VJP.
+
+1. Baseline: native JAX BF16 EP MoE
+-----------------------------------
+
+The example uses a 2x2 mesh: expert parallelism on ``ep`` and FSDP-style batch
+parallelism on ``fsdp``. The batch dimension is sharded over both axes, and
+expert weights are sharded over ``ep``. TE EP requires ``ep`` to be the inner
+axis and currently runs in multi-process mode with one GPU per process.
+
+.. literalinclude:: moe.py
+ :language: python
+ :start-after: # MOE_IMPORTS_START
+ :end-before: # MOE_IMPORTS_END
+
+.. literalinclude:: moe.py
+ :language: python
+ :start-after: # MOE_CONFIG_START
+ :end-before: # MOE_CONFIG_END
+
+.. literalinclude:: moe.py
+ :language: python
+ :start-after: # MOE_MESH_SETUP_START
+ :end-before: # MOE_MESH_SETUP_END
+
+The native baseline is exposed as a normal Flax module. Its implementation in
+``moe_native.py`` performs softmax top-k routing, forward
+``ragged_all_to_all`` over ``ep``, local source-major to expert-major chunk
+reordering, a concatenated ``wi_0|wi_1`` ``ragged_dot`` input projection,
+activation, ``wo`` ``ragged_dot`` output projection, reverse
+``ragged_all_to_all``, and weighted token combine.
+
+2. TransformerEngine ``_MoEBlock``
+----------------------------------
+
+The TE replacement registers the same gate and expert parameter names as the
+baseline, then delegates routing, dispatch, grouped FFN, combine,
+expert-parallel collectives, and VJP to ``transformer_engine.jax.moe.moe``.
+On this branch, the TE-side expert exchange is NCCL EP: ``_MoEBlock`` calls
+``tex.ep_dispatch`` before the grouped FFNs and ``tex.ep_combine`` after them.
+The native baseline remains unchanged and continues to use
+``jax.lax.ragged_all_to_all`` for the comparison numbers.
+
+``_MoEBlock`` is intentionally underscore-prefixed while the API stabilizes. Use
+it as an experimental integration point.
+
+.. literalinclude:: moe.py
+ :language: python
+ :start-after: # MOE_MODEL_SETUP_START
+ :end-before: # MOE_MODEL_SETUP_END
+
+.. literalinclude:: moe.py
+ :language: python
+ :start-after: # MOE_INPUTS_SETUP_START
+ :end-before: # MOE_INPUTS_SETUP_END
+
+3. TE EP smoke check
+--------------------
+
+The direct script path initializes the TE EP communicator, creates the
+``_MoEBlock`` variables, runs a BF16 forward pass, and reports the output shape
+and dtype.
+
+.. literalinclude:: moe.py
+ :language: python
+ :start-after: # MOE_CORRECTNESS_START
+ :end-before: # MOE_CORRECTNESS_END
+
+The native ragged A2A baseline remains in ``moe_native.py`` and is used for the
+baseline timings below. Because the native ragged A2A path runs in
+single-process 4-GPU mode while TE EP runs in one-process-per-GPU mode, the
+benchmark sweep times the two paths separately.
+
+4. Performance comparison
+-------------------------
+
+``run_te_benchmark`` runs a blocking JIT-compiled forward+backward loop with
+warmup. Even though quantization is disabled, the benchmark passes the active
+``MeshResource`` through TE's autocast context so ``_MoEBlock`` can resolve the
+``ep`` axis. The TE block folds top-k weights into the per-expert FFN
+intermediate with ``apply_topk_weights_early=True``; this is mathematically
+equivalent for the BF16 path because the down projection is linear.
+
+.. literalinclude:: moe.py
+ :language: python
+ :start-after: # MOE_BENCH_START
+ :end-before: # MOE_BENCH_END
+
+Run the full example with:
+
+.. code-block:: bash
+
+ for i in 0 1 2 3; do
+ python docs/examples/jax/moe.py --num-process=4 --process-id=$i > proc_$i.log 2>&1 &
+ done
+ wait
+
+Measured on four NVIDIA GB200 GPUs with the default tutorial shape
+``batch=8``, ``seq=2048``, ``hidden=1024``, ``intermediate=4096``,
+``num_experts=8``, and ``topk=2``:
+
+.. csv-table::
+ :header: "Path", "Mean fwd+bwd time", "Relative time"
+ :widths: 35, 25, 25
+
+ "Native JAX BF16 ragged A2A", "17.085 ms", "1.00x"
+ "TE ``_MoEBlock`` BF16 with NCCL EP", "3.156 ms", "0.18x"
+
+For this no-op-quantizer BF16 configuration, TE EP measured ``5.41x`` the
+native ragged A2A baseline throughput on this tutorial shape.
+
+A larger-shape sweep with the same blocking timing loop found TE EP ahead for
+each shape tried. The native column uses the unchanged ragged A2A baseline; the
+TE column uses NCCL EP. The default shape appears in both tables; the values
+differ slightly because the standalone tutorial run and sweep were timed
+separately.
+
+.. csv-table::
+ :header: "Batch", "Seq", "Hidden", "Intermediate", "Native BF16", "TE BF16", "TE speedup"
+ :widths: 10, 10, 12, 16, 16, 16, 14
+
+ "8", "1024", "1024", "4096", "8.543 ms", "2.075 ms", "4.12x"
+ "8", "2048", "1024", "4096", "17.085 ms", "3.217 ms", "5.31x"
+ "8", "4096", "1024", "4096", "38.811 ms", "5.349 ms", "7.26x"
+ "16", "2048", "1024", "4096", "39.194 ms", "5.355 ms", "7.32x"
+ "8", "1024", "2048", "8192", "19.329 ms", "4.110 ms", "4.70x"
+ "8", "2048", "2048", "8192", "42.505 ms", "6.254 ms", "6.80x"
+ "16", "2048", "2048", "8192", "88.134 ms", "10.542 ms", "8.36x"
+
+The result depends on token distribution, hidden size, intermediate size, and
+the target stack.
+
+.. raw:: html
+
+
+ Output:
+
+
+.. container:: program-output
+
+ .. literalinclude:: moe.out
+ :language: text
+ :start-after: # MOE_OUTPUT_START
+ :end-before: # MOE_OUTPUT_END
+
+Next steps
+----------
+
+* `Dense GEMMs `_: quantizing a single ``flax.linen.Dense`` GEMM.
+* `Collective GEMM `_: further speedups by communicating
+ between devices inside the GEMM.
+* `<- Hub <../te_jax_integration.html>`_
diff --git a/docs/examples/jax/moe_native.py b/docs/examples/jax/moe_native.py
new file mode 100644
index 0000000000..cf1f6db7d8
--- /dev/null
+++ b/docs/examples/jax/moe_native.py
@@ -0,0 +1,405 @@
+# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+"""Native JAX/Flax MoE baseline used by ``moe.rst``.
+
+This file intentionally contains the lower-level reference mechanics so the
+tutorial can focus on model-level code. It does not import TransformerEngine:
+the router, expert-parallel ragged all-to-all, local ragged chunk reorder, and
+ragged expert matmuls are implemented with JAX and Flax only.
+"""
+
+import inspect
+from functools import partial
+from typing import Any, Callable, Optional, Tuple
+
+import jax
+import jax.numpy as jnp
+from flax import linen as nn
+from jax.sharding import PartitionSpec as P
+
+
+def _forward_a2a_params(
+ all_tokens_per_expert: jnp.ndarray,
+ shard_id: jnp.ndarray,
+ num_ep: int,
+) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
+ """Build ``ragged_all_to_all`` offsets/sizes for dispatch."""
+ num_experts = all_tokens_per_expert.shape[1]
+ experts_per_shard = num_experts // num_ep
+
+ local_tokens_per_expert = jax.lax.dynamic_slice(
+ all_tokens_per_expert,
+ start_indices=(shard_id, 0),
+ slice_sizes=(1, num_experts),
+ ).squeeze(0)
+ local_by_destination = local_tokens_per_expert.reshape(num_ep, experts_per_shard)
+ send_sizes = jnp.sum(local_by_destination, axis=1).astype(jnp.int32)
+ input_offsets = jnp.concatenate([jnp.array([0], dtype=jnp.int32), jnp.cumsum(send_sizes)[:-1]])
+
+ local_expert_start = shard_id * experts_per_shard
+ local_expert_columns = jax.lax.dynamic_slice(
+ all_tokens_per_expert,
+ start_indices=(0, local_expert_start),
+ slice_sizes=(num_ep, experts_per_shard),
+ )
+ recv_sizes = jnp.sum(local_expert_columns, axis=1).astype(jnp.int32)
+
+ sends_to_destination = jnp.sum(
+ all_tokens_per_expert.reshape(num_ep, num_ep, experts_per_shard),
+ axis=2,
+ ).astype(jnp.int32)
+ cumulative = jnp.cumsum(
+ jnp.concatenate(
+ [jnp.zeros((1, num_ep), dtype=jnp.int32), sends_to_destination],
+ axis=0,
+ ),
+ axis=0,
+ )
+ output_offsets = jax.lax.dynamic_slice(
+ cumulative,
+ start_indices=(shard_id, 0),
+ slice_sizes=(1, num_ep),
+ ).squeeze(0)
+
+ return input_offsets, send_sizes, output_offsets, recv_sizes
+
+
+def _reverse_a2a_params(
+ all_tokens_per_expert: jnp.ndarray,
+ shard_id: jnp.ndarray,
+ num_ep: int,
+) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
+ """Build ``ragged_all_to_all`` offsets/sizes for combine."""
+ num_experts = all_tokens_per_expert.shape[1]
+ experts_per_shard = num_experts // num_ep
+ local_expert_start = shard_id * experts_per_shard
+
+ local_expert_columns = jax.lax.dynamic_slice(
+ all_tokens_per_expert,
+ start_indices=(0, local_expert_start),
+ slice_sizes=(num_ep, experts_per_shard),
+ )
+ send_sizes = jnp.sum(local_expert_columns, axis=1).astype(jnp.int32)
+ input_offsets = jnp.concatenate([jnp.array([0], dtype=jnp.int32), jnp.cumsum(send_sizes)[:-1]])
+
+ local_tokens_per_expert = jax.lax.dynamic_slice(
+ all_tokens_per_expert,
+ start_indices=(shard_id, 0),
+ slice_sizes=(1, num_experts),
+ ).squeeze(0)
+ local_by_destination = local_tokens_per_expert.reshape(num_ep, experts_per_shard)
+ recv_sizes = jnp.sum(local_by_destination, axis=1).astype(jnp.int32)
+
+ forward_sends_to = jnp.sum(
+ all_tokens_per_expert.reshape(num_ep, num_ep, experts_per_shard),
+ axis=2,
+ ).astype(jnp.int32)
+ reverse_sends_to = jnp.transpose(forward_sends_to)
+ cumulative = jnp.cumsum(
+ jnp.concatenate(
+ [jnp.zeros((1, num_ep), dtype=jnp.int32), reverse_sends_to],
+ axis=0,
+ ),
+ axis=0,
+ )
+ output_offsets = jax.lax.dynamic_slice(
+ cumulative,
+ start_indices=(shard_id, 0),
+ slice_sizes=(1, num_ep),
+ ).squeeze(0)
+
+ return input_offsets, send_sizes, output_offsets, recv_sizes
+
+
+def _reorder_ragged_chunks(
+ x: jnp.ndarray,
+ chunk_sizes: jnp.ndarray,
+ source_order: jnp.ndarray,
+ target_order: jnp.ndarray,
+) -> jnp.ndarray:
+ """Reorder a fixed-size ragged buffer from one chunk order to another."""
+ source_sizes = chunk_sizes[source_order]
+ source_starts = jnp.concatenate(
+ [jnp.array([0], dtype=jnp.int32), jnp.cumsum(source_sizes)[:-1]]
+ )
+ source_ends = source_starts + source_sizes
+
+ target_sizes = chunk_sizes[target_order]
+ target_starts_by_position = jnp.concatenate(
+ [jnp.array([0], dtype=jnp.int32), jnp.cumsum(target_sizes)[:-1]]
+ )
+ target_position_by_chunk = jnp.argsort(target_order)
+ target_start_by_chunk = target_starts_by_position[target_position_by_chunk]
+
+ rows = jnp.arange(x.shape[0], dtype=jnp.int32)
+ in_source_chunk = (rows[:, None] >= source_starts[None, :]) & (
+ rows[:, None] < source_ends[None, :]
+ )
+ valid = jnp.any(in_source_chunk, axis=1)
+ source_position = jnp.argmax(in_source_chunk, axis=1)
+ chunk_id = source_order[source_position]
+ row_in_chunk = rows - source_starts[source_position]
+ target_rows = target_start_by_chunk[chunk_id] + row_in_chunk
+ target_rows = jnp.where(valid, target_rows, 0)
+
+ updates = jnp.where(valid[:, None], x, jnp.zeros_like(x))
+ return jnp.zeros_like(x).at[target_rows].add(updates)
+
+
+def _route_tokens(
+ x_2d: jnp.ndarray,
+ gate_kernel: jnp.ndarray,
+ num_experts_per_tok: int,
+) -> Tuple[jnp.ndarray, jnp.ndarray]:
+ """Softmax top-k router matching the tutorial's default TE path."""
+ logits = x_2d.astype(jnp.float32) @ gate_kernel.astype(jnp.float32)
+ probs = jax.nn.softmax(logits, axis=-1)
+ weights, experts = jax.lax.top_k(probs, num_experts_per_tok)
+ weights = weights / jnp.sum(weights, axis=-1, keepdims=True)
+ return experts.astype(jnp.int32), weights.astype(x_2d.dtype)
+
+
+def _native_moe_local(
+ captured: dict,
+ *,
+ ep_axis: str,
+ num_experts: int,
+ num_experts_per_tok: int,
+ recv_buffer_rows: int,
+ dtype: jnp.dtype,
+) -> jnp.ndarray:
+ """One shard of the native EP MoE forward pass."""
+ x = captured["x"]
+ gate_kernel = captured["gate_kernel"]
+ wi_0 = captured["wi_0"]
+ wi_1 = captured["wi_1"]
+ wo = captured["wo"]
+
+ batch, sequence, hidden = x.shape
+ tokens = batch * sequence
+ x_2d = x.reshape(tokens, hidden)
+
+ selected_experts, routing_weights = _route_tokens(x_2d, gate_kernel, num_experts_per_tok)
+ flat_experts = selected_experts.reshape(-1)
+ flat_token_ids = jnp.repeat(jnp.arange(tokens, dtype=jnp.int32), num_experts_per_tok)
+ flat_weights = routing_weights.reshape(-1)
+
+ sort_order = jnp.argsort(flat_experts, stable=True)
+ sorted_experts = flat_experts[sort_order]
+ sorted_x = x_2d[flat_token_ids][sort_order]
+ tokens_per_expert = jnp.bincount(
+ sorted_experts,
+ length=num_experts,
+ minlength=num_experts,
+ ).astype(jnp.int32)
+
+ shard_id = jax.lax.axis_index(ep_axis)
+ num_ep = jax.lax.psum(1, ep_axis)
+ experts_per_shard = num_experts // num_ep
+
+ all_tokens_per_expert = jax.lax.all_gather(
+ tokens_per_expert[None, :],
+ axis_name=ep_axis,
+ axis=0,
+ tiled=True,
+ )
+
+ in_off, send_sz, out_off, recv_sz = _forward_a2a_params(all_tokens_per_expert, shard_id, num_ep)
+ x_recv = jax.lax.ragged_all_to_all(
+ sorted_x,
+ jnp.zeros((recv_buffer_rows, hidden), dtype=sorted_x.dtype),
+ in_off,
+ send_sz,
+ out_off,
+ recv_sz,
+ axis_name=ep_axis,
+ )
+
+ local_expert_start = shard_id * experts_per_shard
+ local_counts_by_source = jax.lax.dynamic_slice(
+ all_tokens_per_expert,
+ start_indices=(0, local_expert_start),
+ slice_sizes=(num_ep, experts_per_shard),
+ ).astype(jnp.int32)
+ local_chunk_sizes = local_counts_by_source.reshape(-1)
+ source_major_order = jnp.arange(num_ep * experts_per_shard, dtype=jnp.int32)
+ expert_major_order = source_major_order.reshape(num_ep, experts_per_shard).T.reshape(-1)
+ local_group_sizes = jnp.sum(local_counts_by_source, axis=0).astype(jnp.int32)
+
+ x_expert_major = _reorder_ragged_chunks(
+ x_recv,
+ local_chunk_sizes,
+ source_major_order,
+ expert_major_order,
+ )
+ wi_combined = jnp.concatenate([wi_0, wi_1], axis=-1)
+ hidden_combined = jax.lax.ragged_dot(x_expert_major, wi_combined, local_group_sizes)
+ hidden_0, hidden_1 = jnp.split(hidden_combined, 2, axis=-1)
+ activated = jax.nn.silu(hidden_0) * hidden_1
+ expert_output = jax.lax.ragged_dot(activated, wo, local_group_sizes).astype(dtype)
+
+ source_major_output = _reorder_ragged_chunks(
+ expert_output,
+ local_chunk_sizes,
+ expert_major_order,
+ source_major_order,
+ )
+ in_off, send_sz, out_off, recv_sz = _reverse_a2a_params(all_tokens_per_expert, shard_id, num_ep)
+ returned = jax.lax.ragged_all_to_all(
+ source_major_output,
+ jnp.zeros_like(sorted_x),
+ in_off,
+ send_sz,
+ out_off,
+ recv_sz,
+ axis_name=ep_axis,
+ )
+
+ unsorted = jnp.zeros_like(returned).at[sort_order].set(returned)
+ token_outputs = unsorted.reshape(tokens, num_experts_per_tok, hidden)
+ weighted = token_outputs * flat_weights.reshape(tokens, num_experts_per_tok, 1)
+ return jnp.sum(weighted, axis=1).reshape(batch, sequence, hidden).astype(dtype)
+
+
+def native_moe_ep(
+ x: jnp.ndarray,
+ gate_kernel: jnp.ndarray,
+ wi_0: jnp.ndarray,
+ wi_1: jnp.ndarray,
+ wo: jnp.ndarray,
+ *,
+ mesh: Any,
+ ep_axis: str,
+ data_parallelism_axes: Tuple[str, ...],
+ num_experts: int,
+ num_experts_per_tok: int,
+ dtype: jnp.dtype,
+) -> jnp.ndarray:
+ """Run the native BF16 EP MoE baseline on an active JAX mesh."""
+ if num_experts % mesh.shape[ep_axis] != 0:
+ raise ValueError(
+ f"num_experts={num_experts} must be divisible by EP size={mesh.shape[ep_axis]}"
+ )
+
+ if data_parallelism_axes:
+ batch_axis = (ep_axis, *data_parallelism_axes)
+ else:
+ batch_axis = ep_axis
+
+ dp_size = 1
+ for axis in data_parallelism_axes:
+ dp_size *= mesh.shape[axis]
+
+ batch, sequence, _ = x.shape
+ required_batch_multiple = mesh.shape[ep_axis] * dp_size
+ if batch % required_batch_multiple != 0:
+ raise ValueError(f"batch={batch} must be divisible by ep*dp={required_batch_multiple}")
+
+ recv_buffer_rows = (batch // dp_size) * sequence * num_experts_per_tok
+ captured = {
+ "x": x,
+ "gate_kernel": gate_kernel,
+ "wi_0": wi_0,
+ "wi_1": wi_1,
+ "wo": wo,
+ }
+ in_specs = (
+ {
+ "x": P(batch_axis, None, None),
+ "gate_kernel": P(),
+ "wi_0": P(ep_axis, None, None),
+ "wi_1": P(ep_axis, None, None),
+ "wo": P(ep_axis, None, None),
+ },
+ )
+
+ body = partial(
+ _native_moe_local,
+ ep_axis=ep_axis,
+ num_experts=num_experts,
+ num_experts_per_tok=num_experts_per_tok,
+ recv_buffer_rows=recv_buffer_rows,
+ dtype=dtype,
+ )
+ shard_map_kwargs = {
+ "mesh": mesh,
+ "in_specs": in_specs,
+ "out_specs": P(batch_axis, None, None),
+ }
+ shard_map_params = inspect.signature(jax.shard_map).parameters
+ if "check_rep" in shard_map_params:
+ shard_map_kwargs["check_rep"] = False
+ elif "check_vma" in shard_map_params:
+ shard_map_kwargs["check_vma"] = False
+
+ return jax.shard_map(body, **shard_map_kwargs)(captured)
+
+
+class NativeMoEBlock(nn.Module):
+ """Native JAX/Flax BF16 EP MoE block used as the tutorial baseline."""
+
+ mesh: Any
+ num_experts: int = 8
+ num_experts_per_tok: int = 2
+ intermediate_size: int = 2048
+ ep_axis: str = "ep"
+ data_parallelism_axes: Tuple[str, ...] = ("fsdp",)
+ dtype: jnp.dtype = jnp.bfloat16
+ kernel_init: Optional[Callable] = None
+
+ def __post_init__(self):
+ if self.kernel_init is None:
+ object.__setattr__(
+ self,
+ "kernel_init",
+ nn.initializers.variance_scaling(
+ 1.0,
+ "fan_in",
+ "truncated_normal",
+ dtype=self.dtype,
+ ),
+ )
+ super().__post_init__()
+
+ @nn.compact
+ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
+ hidden = x.shape[-1]
+ gate_kernel = self.param(
+ "gate_kernel",
+ self.kernel_init,
+ (hidden, self.num_experts),
+ self.dtype,
+ )
+ wi_0 = self.param(
+ "wi_0",
+ self.kernel_init,
+ (self.num_experts, hidden, self.intermediate_size),
+ self.dtype,
+ )
+ wi_1 = self.param(
+ "wi_1",
+ self.kernel_init,
+ (self.num_experts, hidden, self.intermediate_size),
+ self.dtype,
+ )
+ wo = self.param(
+ "wo",
+ self.kernel_init,
+ (self.num_experts, self.intermediate_size, hidden),
+ self.dtype,
+ )
+ return native_moe_ep(
+ x,
+ gate_kernel,
+ wi_0,
+ wi_1,
+ wo,
+ mesh=self.mesh,
+ ep_axis=self.ep_axis,
+ data_parallelism_axes=self.data_parallelism_axes,
+ num_experts=self.num_experts,
+ num_experts_per_tok=self.num_experts_per_tok,
+ dtype=self.dtype,
+ )
diff --git a/docs/examples/jax/test_moe.py b/docs/examples/jax/test_moe.py
new file mode 100644
index 0000000000..24840af85e
--- /dev/null
+++ b/docs/examples/jax/test_moe.py
@@ -0,0 +1,133 @@
+# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+"""Pytest entry points for ``moe.py``.
+
+Run with:
+
+ pytest -v docs/examples/jax/test_moe.py
+
+The tutorial uses a 2x2 EP/FSDP mesh, so tests skip when fewer than four GPUs
+are visible. TransformerEngine MoE tests also skip when the installed TE build
+does not expose the experimental ``_MoEBlock`` or when hardware support is
+missing.
+"""
+
+import importlib
+import os
+import sys
+import tempfile
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pytest
+
+
+requires_4gpu = pytest.mark.skipif(len(jax.devices()) < 4, reason="needs 4 GPUs")
+
+
+os.environ.setdefault(
+ "TRITON_CACHE_DIR",
+ os.path.join(tempfile.gettempdir(), "transformer_engine_triton_cache"),
+)
+
+
+def _te_moe_available():
+ try:
+ import transformer_engine.jax # noqa: F401
+
+ mod = importlib.import_module("transformer_engine.jax.flax")
+ getattr(mod, "_MoEBlock")
+ transformer_engine_jax = sys.modules["transformer_engine_jax"]
+
+ if transformer_engine_jax.get_device_compute_capability(0) < 100:
+ return False, "TE MoE grouped GEMM requires Blackwell (sm_100+)"
+ if jax.process_count() < 4:
+ return False, "TE EP requires a multiprocess launch"
+ if jax.local_device_count() != 1:
+ return False, "TE EP requires one local GPU per process"
+ except Exception as exc: # pylint: disable=broad-exception-caught
+ return False, str(exc)
+ return True, ""
+
+
+_te_supported, _te_reason = _te_moe_available()
+requires_te_moe = pytest.mark.skipif(not _te_supported, reason=_te_reason)
+
+
+def _small_native_state():
+ from jax.experimental import mesh_utils
+ from jax.sharding import Mesh
+ from moe import EP_AXIS, FSDP_AXIS, NativeMoEBlock
+
+ mesh = Mesh(
+ mesh_utils.create_device_mesh((2, 2), devices=jax.devices()[:4]),
+ (EP_AXIS, FSDP_AXIS),
+ )
+ model = NativeMoEBlock(
+ mesh=mesh,
+ num_experts=8,
+ num_experts_per_tok=2,
+ intermediate_size=64,
+ ep_axis=EP_AXIS,
+ data_parallelism_axes=(FSDP_AXIS,),
+ dtype=jnp.bfloat16,
+ )
+ x = jax.random.normal(jax.random.PRNGKey(1), (4, 16, 32), dtype=jnp.bfloat16)
+ dy = jax.random.normal(jax.random.PRNGKey(2), x.shape, dtype=jnp.bfloat16)
+ return mesh, model, x, dy
+
+
+@requires_4gpu
+def test_native_baseline_runs():
+ mesh, model, x, _ = _small_native_state()
+ with jax.set_mesh(mesh):
+ variables = jax.jit(model.init)(jax.random.PRNGKey(0), x)
+ out = jax.jit(model.apply)(variables, x)
+ out.block_until_ready()
+
+ assert out.shape == x.shape
+ assert out.dtype == x.dtype
+ assert np.all(np.isfinite(np.asarray(out)))
+
+
+@requires_4gpu
+def test_native_baseline_grads_are_finite():
+ mesh, model, x, dy = _small_native_state()
+
+ def loss_fn(variables, x):
+ return jnp.vdot(model.apply(variables, x), dy)
+
+ with jax.set_mesh(mesh):
+ variables = jax.jit(model.init)(jax.random.PRNGKey(0), x)
+ grads = jax.jit(jax.grad(loss_fn))(variables, x)
+ jax.block_until_ready(jax.tree_util.tree_leaves(grads))
+
+ for name in ("gate_kernel", "wi_0", "wi_1", "wo"):
+ grad = np.asarray(grads["params"][name])
+ assert np.all(np.isfinite(grad)), f"{name} grad has NaN/Inf"
+ assert np.any(grad != 0.0), f"{name} grad is identically zero"
+
+
+@requires_4gpu
+@requires_te_moe
+def test_te_moe_matches_native_shape_and_dtype():
+ import moe
+
+ demo = moe.setup_demo(batch=4, seq=16, hidden=32, intermediate=64)
+ native_out, te_out = moe.compare_forward(demo)
+
+ assert native_out.shape == te_out.shape == demo.x.shape
+ assert native_out.dtype == te_out.dtype == demo.x.dtype
+ assert np.all(np.isfinite(np.asarray(te_out)))
+
+
+@requires_4gpu
+@requires_te_moe
+def test_benchmark_entrypoint_runs():
+ import moe
+
+ demo = moe.setup_demo(batch=4, seq=16, hidden=32, intermediate=64)
+ moe.run_benchmarks(demo, warmup_iters=1, timing_iters=1)
diff --git a/docs/examples/te_jax_integration.rst b/docs/examples/te_jax_integration.rst
index a15a10e0b3..492d1c21b5 100644
--- a/docs/examples/te_jax_integration.rst
+++ b/docs/examples/te_jax_integration.rst
@@ -24,6 +24,9 @@ Pick a topic
* - `Dense GEMMs `_
- **Available**
- ``nn.Dense`` → quantized GEMM; single-GPU speedup; multi-GPU speedup;
+ * - `Mixture-of-Experts `_
+ - **Available**
+ - Native BF16 EP MoE → experimental ``_MoEBlock``; BF16 performance;
* - `Collective GEMMs `_
- *Coming soon*
-
@@ -90,6 +93,7 @@ Conventions used across these documents
:hidden:
jax/dense
+ jax/moe
jax/collective_gemm
jax/attention
jax/expert_parallelism
diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py
new file mode 100644
index 0000000000..dae3710526
--- /dev/null
+++ b/examples/jax/ep/ep_moe.py
@@ -0,0 +1,395 @@
+# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+"""End-to-end MoE example: dispatch -> batched expert linear -> combine, fwd + bwd.
+
+One process per GPU. Run via run_test_ep.sh.
+"""
+
+import argparse
+import sys
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from jax.sharding import Mesh, NamedSharding, PartitionSpec
+
+from transformer_engine.jax.ep import ep_bootstrap, ep_make_handle, ep_dispatch, ep_combine
+from transformer_engine.jax.sharding import MeshResource, global_shard_guard
+
+
+# ── Setup ───────────────────────────────────────────────────────────────────
+
+
+def _parse_args():
+ p = argparse.ArgumentParser(description="TE-JAX EP MoE example (fwd + bwd)")
+ p.add_argument("--coordinator-address", required=True)
+ p.add_argument("--process-id", type=int, required=True)
+ p.add_argument("--num-processes", type=int, required=True)
+ p.add_argument("--num-tokens", type=int, default=8, help="Per-rank token count.")
+ p.add_argument("--top-k", type=int, default=2)
+ p.add_argument("--hidden", type=int, default=32)
+ p.add_argument("--hidden-out", type=int, default=32)
+ p.add_argument(
+ "--num-experts",
+ type=int,
+ default=None,
+ help="Total experts across the EP group. Default: num_processes.",
+ )
+ p.add_argument("--dp-size", type=int, default=None, help="Default: num_procs // ep_size.")
+ p.add_argument(
+ "--check",
+ action="store_true",
+ default=True,
+ help="Verify fwd+bwd against a single-rank numpy reference.",
+ )
+ return p.parse_args()
+
+
+def _distributed_init(args):
+ jax.distributed.initialize(
+ coordinator_address=args.coordinator_address,
+ num_processes=args.num_processes,
+ process_id=args.process_id,
+ local_device_ids=[args.process_id],
+ )
+ assert (
+ jax.local_device_count() == 1
+ ), f"EP example requires 1 GPU per process; got {jax.local_device_count()}"
+
+
+def _build_mesh_and_resource(args):
+ """Pick a (2, 2) mesh by default. Override via --dp-size."""
+ n = args.num_processes
+ if n < 4:
+ raise ValueError(f"num_processes ({n}) must be >= 4 for NCCL EP")
+ if args.dp_size is None:
+ if n != 4:
+ raise ValueError(
+ f"default mesh expects exactly 4 ranks (got {n}); pass --dp-size to override"
+ )
+ args.dp_size = 2
+ assert n % args.dp_size == 0, f"num_processes={n} not divisible by dp_size={args.dp_size}"
+ args.ep_size = n // args.dp_size
+ if args.num_experts is None:
+ args.num_experts = args.num_processes
+ assert args.num_experts % args.ep_size == 0
+ args.num_local_experts = args.num_experts // args.ep_size
+ args.recv_capacity_per_rank = args.ep_size * args.num_tokens * args.top_k
+
+ devs = np.asarray(jax.devices()).reshape(args.dp_size, args.ep_size)
+ mesh = Mesh(devs, ("dp", "ep"))
+ mr = MeshResource(dp_resource="dp", ep_resource="ep")
+ return mesh, mr
+
+
+def _make_routing(dp_color, num_tokens, top_k, num_experts, num_local_experts):
+ """Deterministic routing: topk_idx[t, k] = (dp_color*NLE + t*K + k) % E."""
+ topk_idx = np.empty((num_tokens, top_k), dtype=np.int32)
+ for t in range(num_tokens):
+ for k in range(top_k):
+ topk_idx[t, k] = (dp_color * num_local_experts + t * top_k + k) % num_experts
+ return topk_idx
+
+
+def _make_inputs(args):
+ """Build 3D ``[B, S, H]`` arrays sharded ``(("dp","ep"), None, None)``.
+
+ B = num_processes (sharded across the compound (dp,ep) axis so each rank
+ holds one slot); S = args.num_tokens. Global numpy views (rank-0
+ reference) are kept 2D for the legacy reference implementation.
+ """
+ T, K, H, H_out = args.num_tokens, args.top_k, args.hidden, args.hidden_out
+ E = args.num_experts
+ dp_size = args.dp_size
+ ep_size = args.ep_size
+ num_procs = args.num_processes
+ dp_color = args.process_id // ep_size
+
+ rng_dp = np.random.default_rng(seed=42 + dp_color)
+ tokens_np = (rng_dp.standard_normal((T, H), dtype=np.float32) * 0.5).astype(np.float32)
+ topk_idx_np = _make_routing(dp_color, T, K, E, args.num_local_experts)
+ w_np = np.full((T, K), 1.0 / K, dtype=np.float32)
+
+ tokens_global_np = np.concatenate(
+ [
+ (
+ np.random.default_rng(seed=42 + c).standard_normal((T, H), dtype=np.float32) * 0.5
+ ).astype(np.float32)
+ for c in range(dp_size)
+ ],
+ axis=0,
+ )
+ topk_idx_global_np = np.concatenate(
+ [_make_routing(c, T, K, E, args.num_local_experts) for c in range(dp_size)], axis=0
+ )
+ w_global_np = np.full((dp_size * T, K), 1.0 / K, dtype=np.float32)
+
+ # Same seed on every rank → identical kernel array everywhere.
+ rng = np.random.default_rng(seed=42)
+ kernels_np = (rng.standard_normal((E, H, H_out), dtype=np.float32) * (1.0 / np.sqrt(H))).astype(
+ np.float32
+ )
+
+ # Each rank contributes one [1, T, ...] slab; the global shape is
+ # [num_procs, T, ...] sharded on the first dim across (dp, ep).
+ mesh = args.mesh
+ dpep_spec = NamedSharding(mesh, PartitionSpec(("dp", "ep"), None, None))
+ tokens = jax.make_array_from_process_local_data(
+ dpep_spec, tokens_np[None, :, :].astype(np.float32), (num_procs, T, H)
+ ).astype(jnp.bfloat16)
+ topk_idx = jax.make_array_from_process_local_data(
+ dpep_spec, topk_idx_np[None, :, :], (num_procs, T, K)
+ )
+ topk_w = jax.make_array_from_process_local_data(dpep_spec, w_np[None, :, :], (num_procs, T, K))
+ kernels = jnp.asarray(kernels_np, dtype=jnp.bfloat16)
+ return (
+ tokens_global_np,
+ topk_idx_global_np,
+ w_global_np,
+ kernels_np,
+ tokens,
+ topk_idx,
+ topk_w,
+ kernels,
+ )
+
+
+# ── MoE step ────────────────────────────────────────────────────────────────
+
+
+def _batched_expert_linear(recv_tokens, kernels, num_local_experts, dp_size, ep_size):
+ """Per-expert linear. ``recv_tokens`` is 3D ``[num_procs, recv_pr, H]``
+ (compound (dp,ep) leading); ``kernels`` is 4D ``[ep_size, NLE, H, H_out]``,
+ broadcast over the dp axis. Output matches ``recv_tokens``' 3D layout
+ with ``H_out`` in place of ``H``."""
+ num_procs, recv_pr, H = recv_tokens.shape
+ H_out = kernels.shape[-1]
+ slots_per_expert = recv_pr // num_local_experts
+ # [num_procs, recv_pr, H] -> [dp, ep, NLE, slots, H]
+ grouped = recv_tokens.reshape(dp_size, ep_size, num_local_experts, slots_per_expert, H)
+ # Contract H; batch over (ep, NLE) which are present on both sides.
+ out = jax.lax.dot_general(
+ grouped,
+ kernels.astype(grouped.dtype),
+ dimension_numbers=(((4,), (2,)), ((1, 2), (0, 1))),
+ )
+ # Output dim order from dot_general: batch dims first, then remaining lhs, rhs.
+ # batch=(ep,NLE), lhs_remaining=(dp,slots), rhs_remaining=(H_out,)
+ # → shape [ep, NLE, dp, slots, H_out]. Permute to [dp, ep, NLE, slots, H_out].
+ out = jnp.transpose(out, (2, 0, 1, 3, 4))
+ return out.reshape(num_procs, recv_pr, H_out)
+
+
+def _moe_step(args, topk_idx, tokens, topk_w, kernels):
+ """Jit'd MoE step: dispatch -> batched per-expert linear -> combine.
+
+ Inputs are 3D ``[B, S, H]`` with the first dim compound-sharded across
+ ``("dp","ep")``. Combine returns the same 3D shape.
+ """
+ B = args.num_processes
+ S = args.num_tokens
+ NLE = args.num_local_experts
+ dp_size, ep_size = args.dp_size, args.ep_size
+ mesh = args.mesh
+ in_spec = PartitionSpec(("dp", "ep"), None, None) # [B, S, ...]
+ ep3 = PartitionSpec(("dp", "ep"), None, None) # [num_procs, recv_pr, H]
+ ep2 = PartitionSpec(("dp", "ep"), None) # [num_procs, recv_pr]
+ # Kernels are EP-replicated across dp colors; shard only the ep-rank axis.
+ kernel_spec = PartitionSpec("ep", None, None, None)
+
+ kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:])
+ ep_handle = ep_make_handle(args.top_k, dispatch_output_per_expert_alignment=16)
+
+ @jax.jit
+ def step(topk_idx, tokens, topk_w, local_kernels):
+ topk_idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(mesh, in_spec))
+ tokens = jax.lax.with_sharding_constraint(tokens, NamedSharding(mesh, in_spec))
+ topk_w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(mesh, in_spec))
+ local_kernels = jax.lax.with_sharding_constraint(
+ local_kernels, NamedSharding(mesh, kernel_spec)
+ )
+ recv_tokens, recv_topk_w, handle_mem, _tc = ep_dispatch(
+ ep_handle, topk_idx, tokens, topk_w, args.recv_capacity_per_rank
+ )
+ recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3))
+ recv_topk_w = jax.lax.with_sharding_constraint(recv_topk_w, NamedSharding(mesh, ep2))
+ expert_out = _batched_expert_linear(recv_tokens, local_kernels, NLE, dp_size, ep_size)
+ expert_out = jax.lax.with_sharding_constraint(expert_out, NamedSharding(mesh, ep3))
+ return ep_combine(
+ ep_handle,
+ handle_mem,
+ _tc,
+ expert_out,
+ recv_topk_w,
+ num_local_tokens=(B, S),
+ out_sharding=(("dp", "ep"), None, None),
+ )
+
+ return step(topk_idx, tokens, topk_w, kernels)
+
+
+# ── Reference (numerical check) ─────────────────────────────────────────────
+
+
+def _reference_moe(tokens, topk_idx, topk_w, kernels):
+ """Single-rank dense MoE reference. tokens [T, H], output [T, H_out]."""
+ T, K = topk_idx.shape
+ H_out = kernels.shape[-1]
+ out = np.zeros((T, H_out), dtype=np.float32)
+ for t in range(T):
+ tok = tokens[t].astype(np.float32)
+ for k in range(K):
+ e = int(topk_idx[t, k])
+ out[t] += float(topk_w[t, k]) * (tok @ kernels[e].astype(np.float32))
+ return out
+
+
+def _reference_grad(tokens, topk_idx, topk_w, kernels):
+ """d/dtokens of 0.5 * sum(ref_out**2) — used by --check to validate bwd."""
+ T, K = topk_idx.shape
+ H = tokens.shape[-1]
+ ref_out = _reference_moe(tokens, topk_idx, topk_w, kernels)
+ grad = np.zeros((T, H), dtype=np.float32)
+ for t in range(T):
+ mixed = np.zeros_like(kernels[0])
+ for k in range(K):
+ mixed = mixed + float(topk_w[t, k]) * kernels[int(topk_idx[t, k])]
+ grad[t] = ref_out[t] @ mixed.T
+ return ref_out, grad
+
+
+# ── Main ────────────────────────────────────────────────────────────────────
+
+
+def main():
+ args = _parse_args()
+ _distributed_init(args)
+
+ dev = jax.local_devices()[0]
+ cap = getattr(dev, "compute_capability", None)
+ if cap is not None:
+ major, minor = (int(x) for x in str(cap).split("."))
+ if major * 10 + minor < 90:
+ print(f"[ep_moe] SKIPPED: NCCL EP requires SM>=90 (got SM{major}{minor})")
+ return
+
+ args.mesh, args.mr = _build_mesh_and_resource(args)
+
+ with args.mesh, global_shard_guard(args.mr):
+ ep_bootstrap(
+ world_size=args.num_processes,
+ rank=args.process_id,
+ ep_size=args.ep_size,
+ num_experts=args.num_experts,
+ max_tokens_per_rank=args.num_tokens,
+ recv_capacity_per_rank=args.recv_capacity_per_rank,
+ hidden_dim=args.hidden,
+ # XLA reallocates handle_mem between JIT executables.
+ allow_handle_mem_reloc=True,
+ )
+
+ (
+ tokens_global_np,
+ topk_idx_global_np,
+ w_global_np,
+ kernels_np,
+ tokens,
+ topk_idx,
+ topk_w,
+ kernels,
+ ) = _make_inputs(args)
+
+ def loss_fn(toks, idx, w, kern):
+ out = _moe_step(args, idx, toks, w, kern)
+ return 0.5 * (out.astype(jnp.float32) ** 2).sum(), out
+
+ (loss, out_fwd), grad_tokens = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))(
+ tokens, topk_idx, topk_w, kernels
+ )
+ grad_tokens.block_until_ready()
+ out_fwd.block_until_ready()
+
+ if args.process_id == 0:
+ print(
+ f"[ep_moe] loss={float(loss):.4f} grad_tokens.shape={grad_tokens.shape} "
+ f"dp={args.dp_size} ep={args.ep_size} "
+ f"num_experts={args.num_experts} recv_pr={args.recv_capacity_per_rank}"
+ )
+
+ if args.check:
+
+ def _norm(spec, ndim):
+ return tuple(spec) + (None,) * (ndim - len(spec))
+
+ # JAX may collapse a size-1 mesh axis: when dp_size==1 the spec can
+ # appear as ``(("dp","ep"),...)`` or ``("ep",...)``. Accept both.
+ if args.dp_size > 1:
+ acceptable_specs = ((("dp", "ep"), None, None),)
+ else:
+ acceptable_specs = ((("dp", "ep"), None, None), ("ep", None, None))
+ assert (
+ _norm(out_fwd.sharding.spec, out_fwd.ndim) in acceptable_specs
+ ), f"out_fwd.sharding.spec={out_fwd.sharding.spec} (expected one of {acceptable_specs})"
+ assert _norm(grad_tokens.sharding.spec, grad_tokens.ndim) in acceptable_specs, (
+ f"grad_tokens.sharding.spec={grad_tokens.sharding.spec}"
+ f" (expected one of {acceptable_specs})"
+ )
+
+ replicated = NamedSharding(args.mesh, jax.sharding.PartitionSpec())
+ out_global = jax.jit(lambda x: jax.lax.with_sharding_constraint(x, replicated))(out_fwd)
+ grad_global = jax.jit(lambda x: jax.lax.with_sharding_constraint(x, replicated))(
+ grad_tokens
+ )
+ out_global.block_until_ready()
+ grad_global.block_until_ready()
+
+ ref_out, ref_grad = _reference_grad(
+ tokens_global_np, topk_idx_global_np, w_global_np, kernels_np
+ )
+ ref_loss = 0.5 * float((ref_out.astype(np.float32) ** 2).sum())
+ # 3D global ``[num_procs, S, H]`` with num_procs = dp * ep. Each EP
+ # column in a DP color sees identical inputs (and produces identical
+ # outputs), so collapse the ep dim to one replica before flattening
+ # to 2D against the dp-only reference.
+ dp_size, ep_size = args.dp_size, args.ep_size
+ global_out = (
+ np.asarray(out_global.addressable_shards[0].data.astype(jnp.float32))
+ .reshape(dp_size, ep_size, -1, ref_out.shape[-1])[:, 0]
+ .reshape(-1, ref_out.shape[-1])
+ )
+ global_grad = (
+ np.asarray(grad_global.addressable_shards[0].data.astype(jnp.float32))
+ .reshape(dp_size, ep_size, -1, ref_grad.shape[-1])[:, 0]
+ .reshape(-1, ref_grad.shape[-1])
+ )
+ if args.process_id == 0:
+ fwd_diff = np.abs(global_out - ref_out)
+ grad_diff = np.abs(global_grad - ref_grad)
+ print(
+ f"[ep_moe] DEBUG loss={float(loss):.4f} ref_loss(global)={ref_loss:.4f} "
+ f"ratio={float(loss) / max(ref_loss, 1e-9):.4f} (expected ~1.0)"
+ )
+ print(f"[ep_moe] DEBUG fwd max-abs-diff per row: {fwd_diff.max(axis=1)}")
+ print(f"[ep_moe] DEBUG grad max-abs-diff per row: {grad_diff.max(axis=1)}")
+ np.testing.assert_allclose(
+ global_out,
+ ref_out,
+ rtol=5e-2,
+ atol=5e-2,
+ err_msg=f"rank {args.process_id}: fwd mismatch",
+ )
+ np.testing.assert_allclose(
+ global_grad,
+ ref_grad,
+ rtol=5e-2,
+ atol=5e-2,
+ err_msg=f"rank {args.process_id}: bwd mismatch",
+ )
+ if args.process_id == 0:
+ print(f"[ep_moe] --check PASSED (ref_out.sum()={float(ref_out.sum()):.4f})")
+
+
+if __name__ == "__main__":
+ main()
+ sys.exit(0)
diff --git a/examples/jax/ep/run_test_ep.sh b/examples/jax/ep/run_test_ep.sh
new file mode 100755
index 0000000000..55b958f146
--- /dev/null
+++ b/examples/jax/ep/run_test_ep.sh
@@ -0,0 +1,85 @@
+# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+#!/bin/bash
+
+NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
+
+if [ "${NUM_GPUS}" -lt 4 ]; then
+ echo "NCCL EP requires at least 4 GPUs (found ${NUM_GPUS}); SKIPPING."
+ exit 0
+fi
+# Default mesh is (2, 2); use exactly 4 ranks even on larger boxes.
+NUM_GPUS="${NVTE_EP_NUM_RANKS:-4}"
+
+: ${TE_PATH:=/opt/transformerengine}
+: ${XML_LOG_DIR:=/logs}
+mkdir -p "$XML_LOG_DIR"
+
+# NCCL EP requires NVLink P2P among ranks on the node.
+echo "*** Checking NVLINK support ***"
+NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1)
+NVLINK_EXIT_CODE=$?
+if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] \
+ || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then
+ echo "NVLINK is not supported on this platform — EP example requires NVLINK; SKIPPING"
+ exit 0
+fi
+echo "NVLINK support detected"
+
+SCRIPT="$TE_PATH/examples/jax/ep/ep_moe.py"
+export PYTHONPATH="${TE_PATH}${PYTHONPATH:+:${PYTHONPATH}}"
+COORD="${COORD:-127.0.0.1:12345}"
+TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-300}"
+
+XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
+ --xla_gpu_graph_min_graph_size=1"
+export XLA_FLAGS="${XLA_BASE_FLAGS}"
+
+# Stage NCCL EP JIT cubins on tmpfs to keep build/iteration fast.
+: ${NCCL_EP_JIT_CACHE_DIR:="${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)"}
+export NCCL_EP_JIT_CACHE_DIR
+mkdir -p "$NCCL_EP_JIT_CACHE_DIR"
+
+echo
+echo "*** Executing ep_moe.py across $NUM_GPUS GPUs ***"
+
+PIDS=()
+cleanup() {
+ for pid in "${PIDS[@]}"; do
+ kill -0 "$pid" 2>/dev/null && kill -KILL "$pid" 2>/dev/null || true
+ done
+}
+trap cleanup EXIT INT TERM
+
+EXTRA_ARGS=${EXTRA_ARGS:-"--check"}
+
+for ((i=1; i "stdout_rank_${i}.txt" 2>&1 &
+ PIDS+=($!)
+done
+timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \
+ python -u "$SCRIPT" \
+ --coordinator-address "$COORD" --process-id "0" --num-processes "$NUM_GPUS" \
+ $EXTRA_ARGS 2>&1 | tee stdout_rank_0.txt
+wait
+
+HAS_FAILURE=0
+if grep -qE "FAILED|Traceback|ERROR" stdout_rank_0.txt; then
+ echo "... ep_moe FAILED"
+ HAS_FAILURE=1
+elif ! grep -qE "\[ep_moe\]" stdout_rank_0.txt; then
+ echo "... ep_moe INVALID (rank 0 produced no summary line)"
+ for ((i=1; i/dev/null
+ done
+ HAS_FAILURE=1
+else
+ echo "... ep_moe PASSED"
+fi
+rm -f stdout_rank_*.txt
+exit $HAS_FAILURE
diff --git a/qa/L1_cpp_distributed/test.sh b/qa/L1_cpp_distributed/test.sh
index 8d767a4efb..7e5ce2cf0d 100755
--- a/qa/L1_cpp_distributed/test.sh
+++ b/qa/L1_cpp_distributed/test.sh
@@ -14,4 +14,7 @@ if [[ $(nvidia-smi --list-gpus | wc -l) -ge 4 ]]; then
cmake -GNinja -S. -Bbuild
cmake --build build
mpirun --allow-run-as-root --np 4 --oversubscribe ./build/test_comm_gemm
+
+ # EP suites; runner self-skips on pre-Hopper GPUs.
+ bash ./run_test_ep.sh 4 ./build
fi
diff --git a/setup.py b/setup.py
index ec277b6349..34a3abfd99 100644
--- a/setup.py
+++ b/setup.py
@@ -83,6 +83,34 @@ def setup_common_extension() -> CMakeExtension:
cusolvermp_dir = os.getenv("CUSOLVERMP_HOME", "/usr")
cmake_flags.append(f"-DCUSOLVERMP_DIR={cusolvermp_dir}")
+ # NCCL EP: on by default; auto-disabled if no arch >= 90.
+ # Set NVTE_BUILD_WITH_NCCL_EP=0/1 to force off/on.
+ nccl_ep_env = os.getenv("NVTE_BUILD_WITH_NCCL_EP")
+ explicit_nccl_ep = nccl_ep_env is not None
+ build_with_nccl_ep = bool(int(nccl_ep_env)) if explicit_nccl_ep else True
+
+ if build_with_nccl_ep:
+ arch_tokens = [a.strip() for a in str(archs or "").split(";") if a.strip()]
+ has_hopper_or_newer = any(t.lower() == "native" for t in arch_tokens) or any(
+ int(t.rstrip("af")) >= 90 for t in arch_tokens if t.rstrip("af").isdigit()
+ )
+ if not has_hopper_or_newer:
+ if explicit_nccl_ep:
+ raise RuntimeError(
+ "NVTE_BUILD_WITH_NCCL_EP=1 requires at least one CUDA arch >= 90 in "
+ f"NVTE_CUDA_ARCHS (got '{archs}'). Add '90' or unset NVTE_BUILD_WITH_NCCL_EP."
+ )
+ print(
+ "[NCCL EP] No CUDA arch >= 90 in NVTE_CUDA_ARCHS"
+ f" ('{archs}'); auto-disabling NCCL EP (nvte_ep_* will throw at runtime)."
+ )
+ build_with_nccl_ep = False
+
+ if build_with_nccl_ep:
+ build_nccl_ep_submodule()
+ else:
+ cmake_flags.append("-DNVTE_WITH_NCCL_EP=OFF")
+
# Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
if nvte_cmake_extra_args:
@@ -128,6 +156,109 @@ def setup_requirements() -> Tuple[List[str], List[str]]:
return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]
+def _discover_nccl_home() -> str:
+ """Resolve NCCL_HOME: honor env var, else probe well-known prefixes, else ldconfig."""
+ env_home = os.environ.get("NCCL_HOME")
+ if env_home:
+ if (Path(env_home) / "include" / "nccl.h").exists():
+ return env_home
+ print(
+ f"[NCCL EP] WARNING: NCCL_HOME='{env_home}' is set but "
+ f"'{env_home}/include/nccl.h' was not found; falling back to system probes."
+ )
+
+ lib_names = ("libnccl.so", "libnccl.so.2")
+ # Include Debian/Ubuntu multiarch subdirs (e.g. lib/aarch64-linux-gnu).
+ lib_subdirs = ("lib", "lib64", "lib/aarch64-linux-gnu", "lib/x86_64-linux-gnu")
+ for cand in ("/opt/nvidia/nccl", "/usr/local/nccl", "/usr"):
+ p = Path(cand)
+ if (p / "include" / "nccl.h").exists() and any(
+ (p / sub / name).exists() for sub in lib_subdirs for name in lib_names
+ ):
+ return str(p)
+
+ try:
+ out = subprocess.check_output(["ldconfig", "-p"], stderr=subprocess.DEVNULL).decode()
+ for line in out.splitlines():
+ if "libnccl.so" in line and "=>" in line:
+ lib_path = Path(line.split("=>")[-1].strip())
+ # Walk upward so multiarch layouts (.../lib//libnccl.so)
+ # resolve to the prefix that contains include/nccl.h.
+ for root in (lib_path.parent.parent, lib_path.parent.parent.parent):
+ if (root / "include" / "nccl.h").exists():
+ return str(root)
+ except (subprocess.CalledProcessError, FileNotFoundError):
+ pass
+
+ raise RuntimeError(
+ "Could not locate NCCL core (nccl.h + libnccl.so). Set NCCL_HOME to the install prefix."
+ )
+
+
+def build_nccl_ep_submodule() -> str:
+ """Build libnccl_ep.so from the 3rdparty/nccl submodule.
+
+ NCCL EP is on by default; the system NCCL core (libnccl.so) supplies the
+ headers and runtime symbols. Returns the submodule build directory.
+ """
+ nccl_root = current_file_path / "3rdparty" / "nccl"
+ if not (nccl_root / "Makefile").exists():
+ raise RuntimeError(
+ f"NCCL submodule not found at {nccl_root}. "
+ "Run `git submodule update --init --recursive`."
+ )
+
+ build_dir = nccl_root / "build"
+ nccl_ep_lib = build_dir / "lib" / "libnccl_ep.so"
+
+ archs = cuda_archs() or "90"
+ arch_list = []
+ for a in str(archs).split(";"):
+ a = a.strip().rstrip("af")
+ if a and a.isdigit() and int(a) >= 90:
+ arch_list.append(a)
+ if not arch_list:
+ arch_list = ["90"]
+ gencode = " ".join(f"-gencode=arch=compute_{a},code=sm_{a}" for a in arch_list)
+
+ nproc = os.cpu_count() or 8
+ env = os.environ.copy()
+ env["NVCC_GENCODE"] = gencode
+ # NCCL EP needs the core NCCL headers + libnccl.so; write NCCL EP build
+ # outputs to the submodule's local build/ tree.
+ nccl_home = _discover_nccl_home()
+ env["NCCL_HOME"] = nccl_home
+ env["NCCL_EP_BUILDDIR"] = str(build_dir)
+
+ if not nccl_ep_lib.exists():
+ print(f"[NCCL EP] Building libnccl_ep.so (gencode='{gencode}')")
+ subprocess.check_call(
+ ["make", "-j", str(nproc), "-C", "contrib/nccl_ep", "lib"],
+ cwd=str(nccl_root),
+ env=env,
+ )
+
+ # TE's CMake expects nccl.h under 3rdparty/nccl/build/include/ for its
+ # version check. Mirror the top-level host headers from the system NCCL
+ # install — DON'T mirror nccl_device/ because the submodule ships its own
+ # newer copy at src/include/nccl_device/ with device-side templates that
+ # conflict with older system versions, and the JIT include path picks the
+ # submodule's.
+ nccl_include = build_dir / "include"
+ nccl_include.mkdir(parents=True, exist_ok=True)
+ for cand in (Path(nccl_home) / "include", Path("/usr/include")):
+ p = Path(cand)
+ if (p / "nccl.h").exists():
+ for name in ("nccl.h", "nccl_net.h", "nccl_tuner.h"):
+ src = p / name
+ dst = nccl_include / name
+ if src.exists() and not dst.exists():
+ dst.symlink_to(src)
+ break
+
+ return str(build_dir)
+
+
def git_check_submodules() -> None:
"""
Attempt to checkout git submodules automatically during setup.
diff --git a/tests/cpp_distributed/CMakeLists.txt b/tests/cpp_distributed/CMakeLists.txt
index 44ad7c7384..5a68dbd69d 100644
--- a/tests/cpp_distributed/CMakeLists.txt
+++ b/tests/cpp_distributed/CMakeLists.txt
@@ -55,10 +55,31 @@ target_include_directories(test_comm_gemm PRIVATE ${test_comm_gemm_INCLUDES})
find_package(CUDAToolkit REQUIRED)
find_package(OpenMP REQUIRED)
find_package(MPI REQUIRED)
+
+# ── NCCL library ──────────────────────────────────────────────────────────────
+# Search order: NCCL_HOME env → 3rdparty/nccl submodule build → system paths.
+set(NCCL_SUBMODULE_BUILD "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/build")
find_library(NCCL_LIB
NAMES nccl libnccl
- PATH_SUFFIXES lib
+ HINTS $ENV{NCCL_HOME}/lib ${NCCL_SUBMODULE_BUILD}/lib
+ PATH_SUFFIXES lib lib64
REQUIRED)
+
+# NCCL headers: prefer submodule build output (has the handle_init API),
+# then submodule src, then system (CUDA toolkit).
+set(NCCL_SUBMODULE_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/build/include")
+set(NCCL_SUBMODULE_SRC_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/src/include")
+if(EXISTS "${NCCL_SUBMODULE_INCLUDE}/nccl.h")
+ set(NCCL_INCLUDE_DIR "${NCCL_SUBMODULE_INCLUDE}")
+elseif(EXISTS "${NCCL_SUBMODULE_SRC_INCLUDE}/nccl.h")
+ set(NCCL_INCLUDE_DIR "${NCCL_SUBMODULE_SRC_INCLUDE}")
+elseif(DEFINED ENV{NCCL_HOME})
+ set(NCCL_INCLUDE_DIR "$ENV{NCCL_HOME}/include")
+endif()
+if(DEFINED NCCL_INCLUDE_DIR)
+ target_include_directories(test_comm_gemm PRIVATE ${NCCL_INCLUDE_DIR})
+endif()
+
list(APPEND test_comm_gemm_LINKER_LIBS
CUDA::cuda_driver
CUDA::cudart
@@ -74,3 +95,58 @@ target_compile_options(test_comm_gemm PRIVATE -O2 -fopenmp)
include(GoogleTest)
gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600)
+
+# ── EP distributed tests ──────────────────────────────────────────────────────
+# Launched via mpirun; ncclUniqueId exchange uses MPI_Bcast (see test_ep_common.h).
+# Headers + libs come from the in-tree 3rdparty/nccl submodule build.
+set(NCCL_EP_SUBMODULE_ROOT
+ "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl")
+find_library(NCCL_EP_LIB
+ NAMES nccl_ep libnccl_ep
+ HINTS ${NCCL_EP_SUBMODULE_ROOT}/build/lib
+ NO_DEFAULT_PATH
+ REQUIRED)
+
+set(NCCL_EP_INCLUDE_DIR "${NCCL_EP_SUBMODULE_ROOT}/contrib/nccl_ep/include")
+if(NOT EXISTS "${NCCL_EP_INCLUDE_DIR}/nccl_ep.h")
+ message(FATAL_ERROR
+ "NCCL EP header not found at ${NCCL_EP_INCLUDE_DIR}/nccl_ep.h. "
+ "Run `git submodule update --init --recursive` to checkout 3rdparty/nccl.")
+endif()
+message(STATUS "EP test: NCCL EP headers: ${NCCL_EP_INCLUDE_DIR}")
+
+# Collect NCCL include dirs shared by all EP test targets (nccl_ep.h + nccl.h).
+set(EP_TEST_NCCL_INCLUDES ${NCCL_EP_INCLUDE_DIR})
+if(DEFINED NCCL_INCLUDE_DIR)
+ list(APPEND EP_TEST_NCCL_INCLUDES ${NCCL_INCLUDE_DIR})
+ message(STATUS "EP test: NCCL headers: ${NCCL_INCLUDE_DIR}")
+endif()
+
+set(EP_TEST_COMMON_INCLUDES
+ ${EP_TEST_NCCL_INCLUDES}
+ ${MPI_CXX_INCLUDE_PATH}
+ ../../transformer_engine/common/include
+ ../../transformer_engine/common
+ ${CMAKE_CURRENT_SOURCE_DIR})
+
+# nvrtc must follow TE_LIB so symbols referenced from libtransformer_engine.so
+# (loaded via dlopen in Python; not in its DT_NEEDED) resolve through nvrtc.
+set(EP_TEST_COMMON_LIBS
+ CUDA::cuda_driver
+ CUDA::cudart
+ GTest::gtest
+ ${TE_LIB}
+ CUDA::nvrtc
+ ${NCCL_LIB}
+ ${NCCL_EP_LIB}
+ MPI::MPI_CXX
+ OpenMP::OpenMP_CXX)
+
+# ── EP distributed tests (per-op + full pipeline + zero-copy symm) ───────────
+add_executable(test_ep test_ep.cu ../cpp/test_common.cu)
+target_include_directories(test_ep PRIVATE ${EP_TEST_COMMON_INCLUDES})
+target_link_libraries(test_ep PUBLIC ${EP_TEST_COMMON_LIBS})
+
+# Do NOT use gtest_discover_tests — these binaries require multi-process
+# launch via run_test_ep.sh, not direct single-process execution.
+message(STATUS "EP distributed tests enabled: ${NCCL_EP_LIB}")
diff --git a/tests/cpp_distributed/run_test_ep.sh b/tests/cpp_distributed/run_test_ep.sh
new file mode 100755
index 0000000000..13e86fa02d
--- /dev/null
+++ b/tests/cpp_distributed/run_test_ep.sh
@@ -0,0 +1,54 @@
+#!/usr/bin/env bash
+# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+#
+# Run TE EP distributed unit tests via mpirun. Each MPI rank pins to one GPU
+# (rank % device_count) and exchanges ncclUniqueId through MPI_Bcast.
+#
+# Usage:
+# bash run_test_ep.sh [num_gpus] [build_dir]
+#
+# Defaults:
+# num_gpus = number of GPUs visible to nvidia-smi
+# build_dir = /build
+#
+# Environment variables:
+# GTEST_FILTER — forwarded to all processes (e.g., "EPPipelineTest.*")
+# MPIRUN — override the mpirun binary (default: mpirun)
+# MPIRUN_EXTRA — extra flags forwarded to mpirun
+
+set -euo pipefail
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+BUILD_DIR="${2:-${SCRIPT_DIR}/build}"
+NUM_GPUS="${1:-$(nvidia-smi -L 2>/dev/null | wc -l)}"
+MPIRUN="${MPIRUN:-mpirun}"
+
+# Skip cleanly on pre-Hopper: NCCL EP requires SM>=90.
+MIN_SM=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \
+ | awk -F. 'NR==1 || ($1*10+$2) 0 && MIN_SM < 90 )); then
+ echo "NCCL EP requires SM>=90 (lowest visible GPU is SM${MIN_SM}); SKIPPING."
+ exit 0
+fi
+
+TEST_BIN="${BUILD_DIR}/test_ep"
+if [[ ! -x "${TEST_BIN}" ]]; then
+ echo "ERROR: binary not found: ${TEST_BIN}"
+ echo "Build: cd ${SCRIPT_DIR} && mkdir -p build && cd build && cmake .. && make"
+ exit 1
+fi
+
+if (( NUM_GPUS < 2 )); then
+ echo "EP Tests: requires at least 2 GPUs, found ${NUM_GPUS}. Skipping."
+ exit 0
+fi
+
+GTEST_ARGS="${GTEST_FILTER:+--gtest_filter=${GTEST_FILTER}}"
+
+echo "=== EP Tests ==="
+echo " GPUs: ${NUM_GPUS} Binary: ${TEST_BIN}"
+echo
+
+"${MPIRUN}" -n "${NUM_GPUS}" ${MPIRUN_EXTRA:-} "${TEST_BIN}" ${GTEST_ARGS}
diff --git a/tests/cpp_distributed/test_ep.cu b/tests/cpp_distributed/test_ep.cu
new file mode 100644
index 0000000000..bcf4ca3c98
--- /dev/null
+++ b/tests/cpp_distributed/test_ep.cu
@@ -0,0 +1,805 @@
+/*************************************************************************
+ * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+/*
+ * EP pipeline tests: smallest-scope first.
+ *
+ * EPDispatchTest/PrepareAndDispatch — exact recv values + per-expert counts
+ * EPCombineTest/Combine — round-trip: out == top_k * tokens
+ * EPCombineBwdTest/CombineBwdCheck — exact grad_expert values
+ * EPDispatchBwdTest/DispatchBwdCheck — exact grad_tokens
+ * EPDispatchBwdGradWeightsTest/RoundTrip — exact per-(t, k) grad_topk_weights
+ * EPPipelineTest/FullForwardBackward — fwd + bwd NaN/Inf check
+ *
+ * Routing: token t on rank r → expert (r * num_local_experts + t * top_k + k) % num_experts
+ * Token values: rank r, token t → all hidden dims = (r+1)*0.01 + t*0.001
+ *
+ * Closed-form expected values:
+ * dispatch recv: multiset of source-token values routed to this rank's experts
+ * combine: result[t] == top_k * tokens[t]
+ * combine_bwd: grad_expert[slot] == d_result[t] (no weighting)
+ * dispatch_bwd: grad_tokens[t] == top_k * d_result[t]
+ */
+
+#include "test_ep_common.h"
+
+#include
+#include
+#include
+#include
+
+// ── Deterministic routing helpers ─────────────────────────────────────────────
+
+// Token value for (rank, t): (rank * num_tokens + t + 1) / 256. Step 1/256 is
+// bf16-exact and unique across (rank, t) when rank * num_tokens + t < 256.
+static inline float token_value(int rank, int t, int num_tokens) {
+ return static_cast(rank * num_tokens + t + 1) * (1.0f / 256.0f);
+}
+
+// Per-element host-side conversion helpers used by templated test code.
+inline float tok_to_float(nv_bfloat16 v) { return __bfloat162float(v); }
+inline float tok_to_float(__half v) { return __half2float(v); }
+inline float tok_to_float(float v) { return v; }
+
+template T tok_from_float(float v);
+template <> inline nv_bfloat16 tok_from_float(float v) { return __float2bfloat16(v); }
+template <> inline __half tok_from_float<__half> (float v) { return __float2half(v); }
+template <> inline float tok_from_float (float v) { return v; }
+
+template
+static std::vector generate_tokens(int rank, int num_tokens, int hidden_dim) {
+ std::vector v(num_tokens * hidden_dim);
+ for (int t = 0; t < num_tokens; ++t) {
+ T val = tok_from_float(token_value(rank, t, num_tokens));
+ for (int h = 0; h < hidden_dim; ++h)
+ v[t * hidden_dim + h] = val;
+ }
+ return v;
+}
+
+static std::vector expected_token_counts(
+ int recv_rank, int num_processes, int num_tokens, int top_k,
+ int num_experts, int num_local_experts) {
+ int base = recv_rank * num_local_experts;
+ std::vector cnt(num_local_experts, 0);
+ for (int src = 0; src < num_processes; ++src) {
+ auto idx = routing_balanced(src, num_tokens, top_k, num_experts, num_local_experts);
+ for (int t = 0; t < num_tokens; ++t)
+ for (int k = 0; k < top_k; ++k) {
+ int64_t e = idx[t * top_k + k];
+ if (e >= base && e < base + num_local_experts) ++cnt[e - base];
+ }
+ }
+ return cnt;
+}
+
+static std::vector expected_recv_values_sorted(
+ int recv_rank, int num_processes, int num_tokens, int top_k,
+ int num_experts, int num_local_experts) {
+ int base = recv_rank * num_local_experts;
+ std::vector vals;
+ for (int src = 0; src < num_processes; ++src) {
+ auto idx = routing_balanced(src, num_tokens, top_k, num_experts, num_local_experts);
+ for (int t = 0; t < num_tokens; ++t)
+ for (int k = 0; k < top_k; ++k) {
+ int64_t e = idx[t * top_k + k];
+ if (e >= base && e < base + num_local_experts) {
+ float raw = token_value(src, t, num_tokens);
+ vals.push_back(__bfloat162float(__float2bfloat16(raw)));
+ }
+ }
+ }
+ std::sort(vals.begin(), vals.end());
+ return vals;
+}
+
+// 2^-5 relative tolerance for BF16 (matches mantissa precision with margin),
+// plus a small atol floor for near-zero expected values.
+static constexpr float kBf16Rtol = 1.0f / 32.0f;
+static constexpr float kBf16Atol = 1e-3f;
+static float bf16_tol(float magnitude) {
+ return kBf16Atol + kBf16Rtol * std::fabs(magnitude);
+}
+
+template
+static bool check_no_nan_inf(const T* dev, int count, const char* name) {
+ std::vector h(count);
+ cudaMemcpy(h.data(), dev, count * sizeof(T), cudaMemcpyDeviceToHost);
+ for (int i = 0; i < count; ++i) {
+ float v = tok_to_float(h[i]);
+ if (std::isnan(v) || std::isinf(v)) {
+ fprintf(stderr, "Rank %d: %s in %s[%d]\n",
+ g_process_id, std::isnan(v) ? "NaN" : "Inf", name, i);
+ return false;
+ }
+ }
+ return true;
+}
+
+// ── Forward buffer set with RAII ──────────────────────────────────────────────
+
+template
+struct EPBuffers {
+ // Forward
+ DevBuf topk_idx;
+ DevBuf topk_weights;
+ DevBuf tokens;
+ DevBuf token_counts;
+ DevBuf handle_mem;
+ DevBuf recv_tokens;
+ DevBuf recv_topk_weights;
+ DevBuf result;
+ // Backward
+ DevBuf grad_result;
+ DevBuf grad_expert;
+ DevBuf grad_tokens;
+ DevBuf g_recv_topk_weights;
+ DevBuf grad_topk_weights;
+
+ uint64_t handle_id = 0;
+ size_t handle_mem_size = 0;
+ size_t recv_capacity = 0;
+ int top_k_ = 0;
+
+ void alloc(int num_tokens, int top_k, int hidden_dim, int num_local_experts,
+ int ep_size, int max_tokens_per_rank, size_t alignment = 0) {
+ top_k_ = top_k;
+ recv_capacity = static_cast(ep_size) * max_tokens_per_rank * 2;
+
+ topk_idx.alloc(num_tokens * top_k);
+ topk_weights.alloc(num_tokens * top_k);
+ tokens.alloc(num_tokens * hidden_dim);
+ token_counts.alloc(num_local_experts);
+ recv_tokens.alloc(recv_capacity * hidden_dim);
+ recv_topk_weights.alloc(recv_capacity);
+ result.alloc(num_tokens * hidden_dim);
+
+ NVTEEpLayerConfig cfg{num_local_experts, top_k, alignment};
+ handle_id = nvte_ep_register_layer(cfg, &handle_mem_size);
+ handle_mem.alloc(handle_mem_size);
+
+ grad_result.alloc(num_tokens * hidden_dim);
+ grad_expert.alloc(recv_capacity * hidden_dim);
+ grad_tokens.alloc(num_tokens * hidden_dim);
+ g_recv_topk_weights.alloc(recv_capacity);
+ grad_topk_weights.alloc(num_tokens * top_k);
+ }
+};
+
+// Bundled NVTETensor views over an EPBuffers, with the shapes the EP C API
+// expects.
+template
+struct EPTensors {
+ TensorWrapper topk_idx, topk_weights, token_counts, handle_mem, tokens;
+ TensorWrapper recv_tokens, recv_topk_weights, result;
+ TensorWrapper grad_result, grad_expert, grad_tokens;
+ TensorWrapper g_recv_topk_weights, grad_topk_weights;
+
+ EPTensors(EPBuffers& b, int num_tokens, int top_k, int hidden_dim,
+ int num_local_experts) {
+ constexpr DType kTokDType = test::TypeInfo::dtype;
+ using Shape = std::vector;
+ topk_idx = TensorWrapper(b.topk_idx.get(),
+ Shape{(size_t)num_tokens, (size_t)top_k}, DType::kInt64);
+ topk_weights = TensorWrapper(b.topk_weights.get(),
+ Shape{(size_t)num_tokens, (size_t)top_k}, DType::kFloat32);
+ token_counts = TensorWrapper(b.token_counts.get(),
+ Shape{(size_t)num_local_experts}, DType::kInt32);
+ handle_mem = TensorWrapper(b.handle_mem.get(),
+ Shape{b.handle_mem_size}, DType::kByte);
+ tokens = TensorWrapper(b.tokens.get(),
+ Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType);
+ recv_tokens = TensorWrapper(b.recv_tokens.get(),
+ Shape{b.recv_capacity, (size_t)hidden_dim}, kTokDType);
+ recv_topk_weights = TensorWrapper(b.recv_topk_weights.get(),
+ Shape{b.recv_capacity}, DType::kFloat32);
+ result = TensorWrapper(b.result.get(),
+ Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType);
+ grad_result = TensorWrapper(b.grad_result.get(),
+ Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType);
+ grad_expert = TensorWrapper(b.grad_expert.get(),
+ Shape{b.recv_capacity, (size_t)hidden_dim}, kTokDType);
+ grad_tokens = TensorWrapper(b.grad_tokens.get(),
+ Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType);
+ g_recv_topk_weights = TensorWrapper(b.g_recv_topk_weights.get(),
+ Shape{b.recv_capacity}, DType::kFloat32);
+ grad_topk_weights = TensorWrapper(b.grad_topk_weights.get(),
+ Shape{(size_t)num_tokens, (size_t)top_k}, DType::kFloat32);
+ }
+};
+
+// ── Shared fixture base ───────────────────────────────────────────────────────
+
+class EpOpTestBase : public ::testing::Test {
+ protected:
+ int ep_size_, num_experts_, num_local_experts_, hidden_dim_;
+ int max_tokens_per_rank_, top_k_, num_tokens_;
+
+ void SetUp() override {
+ if (g_sm_major < 9)
+ GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)";
+ ASSERT_GE(g_num_processes, 2);
+ ASSERT_TRUE(g_ep_initialized);
+
+ ep_size_ = g_ep_size;
+ num_experts_ = g_num_experts;
+ num_local_experts_ = num_experts_ / ep_size_;
+ hidden_dim_ = g_hidden_dim;
+ max_tokens_per_rank_ = g_max_tokens_per_rank;
+ top_k_ = 2;
+ num_tokens_ = 32;
+ }
+
+ template
+ void upload_inputs(EPBuffers& buf, int rank = -1) {
+ if (rank < 0) rank = g_process_id;
+ auto h_idx = routing_balanced(rank, num_tokens_, top_k_,
+ num_experts_, num_local_experts_);
+ std::vector h_w(num_tokens_ * top_k_, 1.0f / top_k_);
+ auto h_tok = generate_tokens(rank, num_tokens_, hidden_dim_);
+
+ NVTE_CHECK_CUDA(cudaMemcpy(buf.topk_idx.get(), h_idx.data(),
+ h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice));
+ NVTE_CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(),
+ h_w.size() * sizeof(float), cudaMemcpyHostToDevice));
+ NVTE_CHECK_CUDA(cudaMemcpy(buf.tokens.get(), h_tok.data(),
+ h_tok.size() * sizeof(T), cudaMemcpyHostToDevice));
+ }
+
+ NVTEEpLayerConfig layer_config(size_t alignment = 0) const {
+ return NVTEEpLayerConfig{num_local_experts_, top_k_, alignment};
+ }
+
+ // NVTE_CHECK_CUDA (fprintf+exit) so this non-void helper stays legal.
+ template
+ int read_total_recv(const EPBuffers& buf) const {
+ std::vector cnt(num_local_experts_);
+ NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(),
+ num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost));
+ int total = 0;
+ for (int c : cnt) total += c;
+ return total;
+ }
+};
+
+// =============================================================================
+// EPDispatchTest: exact recv values and per-expert counts.
+// =============================================================================
+
+class EPDispatchTest : public EpOpTestBase {};
+
+TEST_F(EPDispatchTest, PrepareAndDispatch) {
+ EPBuffers<> buf;
+ buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_,
+ ep_size_, max_tokens_per_rank_);
+ upload_inputs(buf);
+ EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_);
+
+ NVTE_CHECK_CUDA(cudaMemset(buf.recv_tokens.get(), 0, buf.recv_tokens.bytes()));
+
+ cudaStream_t stream;
+ NVTE_CHECK_CUDA(cudaStreamCreate(&stream));
+
+ uint64_t handle_id = buf.handle_id;
+ ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream));
+ ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(),
+ t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(),
+ NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{},
+ t.recv_topk_weights.data(), NVTECommWindow{}, stream));
+ NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
+
+ // 1. Per-expert counts.
+ std::vector got_counts(num_local_experts_);
+ NVTE_CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.token_counts.get(),
+ num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost));
+ auto exp_counts = expected_token_counts(g_process_id, g_num_processes, num_tokens_, top_k_,
+ num_experts_, num_local_experts_);
+ int total_recv = 0;
+ for (int i = 0; i < num_local_experts_; ++i) {
+ EXPECT_EQ(got_counts[i], exp_counts[i]) << "local expert " << i;
+ total_recv += exp_counts[i];
+ }
+ ASSERT_LE(total_recv, static_cast(buf.recv_capacity))
+ << "total_recv exceeded recv_capacity — overflow would corrupt downstream memory";
+
+ // 2. Recv values: read only the filled prefix per local-expert zone, not the
+ // whole recv buffer — avoids false positives from legitimate-zero token values.
+ std::vector h_recv(buf.recv_capacity * hidden_dim_);
+ NVTE_CHECK_CUDA(cudaMemcpy(h_recv.data(), buf.recv_tokens.get(),
+ h_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost));
+
+ std::vector got_vals;
+ got_vals.reserve(total_recv);
+ size_t slot = 0;
+ for (int e = 0; e < num_local_experts_; ++e) {
+ for (int i = 0; i < got_counts[e]; ++i) {
+ got_vals.push_back(__bfloat162float(h_recv[slot * hidden_dim_]));
+ ++slot;
+ }
+ }
+ std::sort(got_vals.begin(), got_vals.end());
+
+ auto exp_vals = expected_recv_values_sorted(g_process_id, g_num_processes, num_tokens_,
+ top_k_, num_experts_, num_local_experts_);
+
+ ASSERT_EQ(got_vals.size(), exp_vals.size());
+ for (size_t i = 0; i < exp_vals.size(); ++i)
+ EXPECT_NEAR(got_vals[i], exp_vals[i], bf16_tol(exp_vals[i]))
+ << "recv value mismatch at sorted index " << i;
+
+ // 3. recv_topk_weights: every filled slot must equal the per-token weight (1/top_k).
+ std::vector h_w(buf.recv_capacity);
+ NVTE_CHECK_CUDA(cudaMemcpy(h_w.data(), buf.recv_topk_weights.get(),
+ h_w.size() * sizeof(float), cudaMemcpyDeviceToHost));
+ const float exp_w = 1.0f / static_cast(top_k_);
+ for (int i = 0; i < total_recv; ++i)
+ EXPECT_NEAR(h_w[i], exp_w, 1e-6f) << "recv_topk_weights[" << i << "]";
+
+ if (g_process_id == 0)
+ printf(" PrepareAndDispatch: passed (recv=%d, values + weights exact)\n", total_recv);
+
+ NVTE_CHECK_CUDA(cudaStreamDestroy(stream));
+}
+
+// =============================================================================
+// EPCombineTest: round-trip identity expert → result == top_k * tokens.
+// =============================================================================
+
+class EPCombineTest : public EpOpTestBase {};
+
+TEST_F(EPCombineTest, Combine) {
+ EPBuffers<> buf;
+ buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_,
+ ep_size_, max_tokens_per_rank_);
+ upload_inputs(buf);
+ EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_);
+
+ cudaStream_t stream;
+ NVTE_CHECK_CUDA(cudaStreamCreate(&stream));
+
+ uint64_t handle_id = buf.handle_id;
+ ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream));
+ ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(),
+ t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(),
+ NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{},
+ t.recv_topk_weights.data(), NVTECommWindow{}, stream));
+ ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.recv_tokens.data(), NVTECommWindow{},
+ t.result.data(), stream));
+ NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
+
+ std::vector h_result(num_tokens_ * hidden_dim_);
+ NVTE_CHECK_CUDA(cudaMemcpy(h_result.data(), buf.result.get(),
+ h_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost));
+ auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_);
+ for (int tok = 0; tok < num_tokens_; ++tok) {
+ float exp = __bfloat162float(h_tok[tok * hidden_dim_]) * static_cast(top_k_);
+ for (int p = 0; p < hidden_dim_; ++p) {
+ float got = __bfloat162float(h_result[tok * hidden_dim_ + p]);
+ EXPECT_NEAR(got, exp, bf16_tol(exp))
+ << "token " << tok << " rank " << g_process_id << " hidden " << p;
+ }
+ }
+
+ if (g_process_id == 0)
+ printf(" Combine: passed (result == top_k * tokens)\n");
+
+ NVTE_CHECK_CUDA(cudaStreamDestroy(stream));
+}
+
+// =============================================================================
+// EPCombineBwdTest: filled slots in grad_expert == d_result (unweighted).
+// =============================================================================
+
+class EPCombineBwdTest : public EpOpTestBase {};
+
+TEST_F(EPCombineBwdTest, CombineBwdCheck) {
+ EPBuffers<> buf;
+ buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_,
+ ep_size_, max_tokens_per_rank_);
+ upload_inputs(buf);
+ EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_);
+
+ cudaStream_t stream;
+ NVTE_CHECK_CUDA(cudaStreamCreate(&stream));
+
+ uint64_t handle_id = buf.handle_id;
+ ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream));
+ ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(),
+ t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(),
+ NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{},
+ t.recv_topk_weights.data(), NVTECommWindow{}, stream));
+ ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.recv_tokens.data(), NVTECommWindow{},
+ t.result.data(), stream));
+
+ std::vector h_grad_r(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f));
+ NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad_r.data(),
+ h_grad_r.size() * sizeof(nv_bfloat16),
+ cudaMemcpyHostToDevice, stream));
+ NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream));
+
+ ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_result.data(), NVTECommWindow{},
+ t.grad_expert.data(), NVTECommWindow{}, stream));
+ NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
+
+ int total_recv = read_total_recv(buf);
+
+ std::vector cnt(num_local_experts_);
+ NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(),
+ num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost));
+ std::vector h_ge(buf.recv_capacity * hidden_dim_);
+ NVTE_CHECK_CUDA(cudaMemcpy(h_ge.data(), buf.grad_expert.get(),
+ h_ge.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost));
+
+ // Walk filled slots by per-expert zone (no v != 0 heuristic).
+ const float kExpGrad = 0.1f;
+ size_t slot = 0;
+ int filled = 0;
+ for (int e = 0; e < num_local_experts_; ++e) {
+ for (int i = 0; i < cnt[e]; ++i) {
+ for (int p = 0; p < hidden_dim_; ++p) {
+ float v = __bfloat162float(h_ge[slot * hidden_dim_ + p]);
+ EXPECT_NEAR(v, kExpGrad, bf16_tol(kExpGrad))
+ << "grad_expert expert " << e << " slot " << i
+ << " (linear " << slot << ") hidden " << p;
+ }
+ ++filled; ++slot;
+ }
+ }
+ EXPECT_EQ(filled, total_recv);
+
+ if (g_process_id == 0)
+ printf(" CombineBwdCheck: passed (filled=%d)\n", filled);
+
+ NVTE_CHECK_CUDA(cudaStreamDestroy(stream));
+}
+
+// =============================================================================
+// EPDispatchBwdTest: grad_tokens == top_k * d_result.
+// =============================================================================
+
+class EPDispatchBwdTest : public EpOpTestBase {};
+
+TEST_F(EPDispatchBwdTest, DispatchBwdCheck) {
+ EPBuffers<> buf;
+ buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_,
+ ep_size_, max_tokens_per_rank_);
+ upload_inputs(buf);
+ EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_);
+
+ cudaStream_t stream;
+ NVTE_CHECK_CUDA(cudaStreamCreate(&stream));
+
+ uint64_t handle_id = buf.handle_id;
+ ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream));
+ ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(),
+ t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(),
+ NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{},
+ t.recv_topk_weights.data(), NVTECommWindow{}, stream));
+ ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.recv_tokens.data(), NVTECommWindow{},
+ t.result.data(), stream));
+
+ std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f));
+ NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(),
+ h_grad.size() * sizeof(nv_bfloat16),
+ cudaMemcpyHostToDevice, stream));
+ NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream));
+ NVTE_CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream));
+ NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream));
+
+ ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_result.data(), NVTECommWindow{},
+ t.grad_expert.data(), NVTECommWindow{}, stream));
+ ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_expert.data(), NVTECommWindow{},
+ t.g_recv_topk_weights.data(), NVTECommWindow{},
+ t.grad_tokens.data(), t.grad_topk_weights.data(), stream));
+ NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
+
+ std::vector h_gt(num_tokens_ * hidden_dim_);
+ NVTE_CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(),
+ h_gt.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost));
+ const float kExpGrad = static_cast(top_k_) * 0.1f;
+ for (int tok = 0; tok < num_tokens_; ++tok)
+ for (int p = 0; p < hidden_dim_; ++p)
+ EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_ + p]), kExpGrad,
+ bf16_tol(kExpGrad))
+ << "grad_tokens token " << tok << " hidden " << p;
+
+ if (g_process_id == 0)
+ printf(" DispatchBwdCheck: passed (grad_tokens == %.2f)\n", kExpGrad);
+
+ NVTE_CHECK_CUDA(cudaStreamDestroy(stream));
+}
+
+// =============================================================================
+// EPDispatchBwdGradWeightsTest: round-trip per-(t, k) weights.
+// =============================================================================
+
+class EPDispatchBwdGradWeightsTest : public EpOpTestBase {};
+
+TEST_F(EPDispatchBwdGradWeightsTest, RoundTrip) {
+ EPBuffers<> buf;
+ buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_,
+ ep_size_, max_tokens_per_rank_);
+ upload_inputs(buf);
+ EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_);
+
+ // Distinct per-(rank, t, k) weights so each slot carries a unique value.
+ std::vector h_w(num_tokens_ * top_k_);
+ for (int tok = 0; tok < num_tokens_; ++tok)
+ for (int k = 0; k < top_k_; ++k)
+ h_w[tok * top_k_ + k] = 0.1f + 0.01f * tok + 0.001f * k +
+ 0.0001f * (g_process_id + 1);
+ NVTE_CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(),
+ h_w.size() * sizeof(float), cudaMemcpyHostToDevice));
+
+ cudaStream_t stream;
+ NVTE_CHECK_CUDA(cudaStreamCreate(&stream));
+
+ uint64_t handle_id = buf.handle_id;
+ ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream));
+ NVTE_CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0,
+ buf.recv_topk_weights.bytes(), stream));
+ ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(),
+ t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(),
+ NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{},
+ t.recv_topk_weights.data(), NVTECommWindow{}, stream));
+
+ // Sentinel: NaN so any (t, k) the bwd kernel fails to write is immediately visible.
+ std::vector h_nan(num_tokens_ * top_k_,
+ std::numeric_limits::quiet_NaN());
+ NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_topk_weights.get(), h_nan.data(),
+ h_nan.size() * sizeof(float),
+ cudaMemcpyHostToDevice, stream));
+ NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream));
+
+ // g_recv_topk_weights := recv_topk_weights (the round-trip input).
+ auto g_recv_t = TensorWrapper(buf.recv_topk_weights.get(),
+ std::vector{buf.recv_capacity}, DType::kFloat32);
+ ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_expert.data(),
+ NVTECommWindow{}, g_recv_t.data(), NVTECommWindow{},
+ t.grad_tokens.data(), t.grad_topk_weights.data(), stream));
+ NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
+
+ std::vector h_grad_w(num_tokens_ * top_k_);
+ NVTE_CHECK_CUDA(cudaMemcpy(h_grad_w.data(), buf.grad_topk_weights.get(),
+ h_grad_w.size() * sizeof(float), cudaMemcpyDeviceToHost));
+
+ const float kTol = 1e-5f;
+ int errs = 0, k0_eq_k1 = 0;
+ for (int tok = 0; tok < num_tokens_; ++tok) {
+ for (int k = 0; k < top_k_; ++k) {
+ float got = h_grad_w[tok * top_k_ + k];
+ float exp = h_w[tok * top_k_ + k];
+ if (std::isnan(got) || std::fabs(got - exp) > kTol) {
+ if (errs < 8)
+ fprintf(stderr, "Rank %d: grad_topk_weights[%d, %d]: got %.6f, expected %.6f\n",
+ g_process_id, tok, k, got, exp);
+ ++errs;
+ }
+ }
+ if (top_k_ >= 2 &&
+ std::fabs(h_grad_w[tok * top_k_ + 0] - h_grad_w[tok * top_k_ + 1]) < 1e-7f)
+ ++k0_eq_k1;
+ }
+ EXPECT_EQ(errs, 0);
+ EXPECT_EQ(k0_eq_k1, 0) << "per-token-average regression: grad[t, 0] == grad[t, 1]";
+
+ if (g_process_id == 0 && errs == 0 && k0_eq_k1 == 0)
+ printf(" RoundTrip: passed (%d (t, k) gradients)\n", num_tokens_ * top_k_);
+
+ NVTE_CHECK_CUDA(cudaStreamDestroy(stream));
+}
+
+// =============================================================================
+// Integrated FwdBwd: NaN/Inf check end-to-end.
+// =============================================================================
+
+class EPPipelineTest : public EpOpTestBase, public ::testing::WithParamInterface {
+ protected:
+ template
+ void run_full_forward_backward() {
+ EPBuffers buf;
+ buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_,
+ ep_size_, max_tokens_per_rank_);
+ upload_inputs(buf);
+ EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_);
+
+ cudaStream_t stream;
+ NVTE_CHECK_CUDA(cudaStreamCreate(&stream));
+
+ uint64_t handle_id = buf.handle_id;
+ ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream));
+ ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(),
+ t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(),
+ NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{},
+ t.recv_topk_weights.data(), NVTECommWindow{}, stream));
+ ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.recv_tokens.data(), NVTECommWindow{},
+ t.result.data(), stream));
+
+ std::vector h_grad(num_tokens_ * hidden_dim_, tok_from_float(0.1f));
+ NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(),
+ h_grad.size() * sizeof(Tok),
+ cudaMemcpyHostToDevice, stream));
+ NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream));
+ NVTE_CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream));
+ NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream));
+
+ ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_result.data(), NVTECommWindow{},
+ t.grad_expert.data(), NVTECommWindow{}, stream));
+ ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_expert.data(), NVTECommWindow{},
+ t.g_recv_topk_weights.data(), NVTECommWindow{},
+ t.grad_tokens.data(), t.grad_topk_weights.data(), stream));
+ NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
+
+ ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result"));
+ ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens"));
+
+ NVTE_CHECK_CUDA(cudaStreamDestroy(stream));
+ }
+};
+
+TEST_P(EPPipelineTest, FullForwardBackward) {
+ const DType dtype = GetParam();
+ // NCCL EP backend currently asserts ncclBfloat16 in ncclEpDispatch
+ // (contrib/nccl_ep/nccl_ep.cc); skip FP16/FP32 until the backend supports them.
+ if (dtype != DType::kBFloat16) {
+ GTEST_SKIP() << test::typeName(dtype) << " not yet supported by NCCL EP backend";
+ }
+ switch (dtype) {
+ case DType::kBFloat16: run_full_forward_backward(); break;
+ case DType::kFloat16: run_full_forward_backward<__half> (); break;
+ case DType::kFloat32: run_full_forward_backward (); break;
+ default: FAIL() << "unsupported token dtype " << static_cast(dtype);
+ }
+ if (g_process_id == 0)
+ printf(" FullForwardBackward[%s]: passed\n", test::typeName(dtype).c_str());
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ Dtypes, EPPipelineTest,
+ ::testing::Values(DType::kBFloat16, DType::kFloat16, DType::kFloat32),
+ [](const ::testing::TestParamInfo& info) {
+ return test::typeName(info.param);
+ });
+
+// =============================================================================
+// EPZeroCopyTest: dispatch/combine with NCCL symmetric-memory windows attached
+// to payload tensors (zero-copy fast path via ncclEpTensorCreateFromWindow).
+// Symm-mem requirements per spec: input&output of Dispatch, input of Combine,
+// input&output of Combine bwd, input of Dispatch bwd.
+// =============================================================================
+
+namespace {
+
+// Caller-owned ncclMemAlloc'd buffer with a registered symmetric window.
+// Frees in destructor (deregister + ncclMemFree). Non-copyable, move-only.
+struct SymmBuf {
+ void* ptr = nullptr;
+ size_t bytes = 0;
+ ncclWindow_t win = nullptr;
+
+ SymmBuf() = default;
+ SymmBuf(const SymmBuf&) = delete;
+ SymmBuf& operator=(const SymmBuf&) = delete;
+ SymmBuf(SymmBuf&& o) noexcept : ptr(o.ptr), bytes(o.bytes), win(o.win) {
+ o.ptr = nullptr; o.win = nullptr; o.bytes = 0;
+ }
+ ~SymmBuf() {
+ if (win) ncclCommWindowDeregister(g_ep_comm, win);
+ if (ptr) ncclMemFree(ptr);
+ }
+
+ void alloc(size_t n_bytes) {
+ bytes = n_bytes;
+ NVTE_CHECK_NCCL(ncclMemAlloc(&ptr, bytes));
+ NVTE_CHECK_CUDA(cudaMemset(ptr, 0, bytes));
+ NVTE_CHECK_NCCL(ncclCommWindowRegister(g_ep_comm, ptr, bytes, &win,
+ NCCL_WIN_COLL_SYMMETRIC));
+ }
+};
+
+// Build an NVTECommWindow descriptor pointing at a SymmBuf's window (offset 0).
+static inline NVTECommWindow symm_window(const SymmBuf& b) {
+ return NVTECommWindow{b.win, /*offset=*/0};
+}
+
+} // namespace
+
+class EPZeroCopyTest : public EpOpTestBase {};
+
+// Identity round-trip with symm-mem on dispatch i/o + combine input. Bit-exact
+// vs HBM reference (same routing, same input).
+TEST_F(EPZeroCopyTest, IdentityAllSymm) {
+ // HBM reference run.
+ EPBuffers<> ref_buf;
+ ref_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_,
+ ep_size_, max_tokens_per_rank_);
+ upload_inputs(ref_buf);
+ EPTensors<> ref_t(ref_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_);
+
+ cudaStream_t stream;
+ NVTE_CHECK_CUDA(cudaStreamCreate(&stream));
+
+ uint64_t ref_hid = ref_buf.handle_id;
+ ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{ref_hid, ref_t.handle_mem.data()}, ref_t.topk_idx.data(), ref_t.token_counts.data(), /*alignment=*/0, stream));
+ ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{ref_hid, ref_t.handle_mem.data()}, ref_t.topk_idx.data(),
+ ref_t.tokens.data(), NVTECommWindow{}, ref_t.topk_weights.data(),
+ NVTECommWindow{}, ref_t.recv_tokens.data(), NVTECommWindow{},
+ ref_t.recv_topk_weights.data(), NVTECommWindow{}, stream));
+ ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{ref_hid, ref_t.handle_mem.data()}, ref_t.recv_tokens.data(), NVTECommWindow{},
+ ref_t.result.data(), stream));
+ NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
+
+ std::vector ref_recv(ref_buf.recv_capacity * hidden_dim_);
+ std::vector ref_result(num_tokens_ * hidden_dim_);
+ NVTE_CHECK_CUDA(cudaMemcpy(ref_recv.data(), ref_buf.recv_tokens.get(),
+ ref_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost));
+ NVTE_CHECK_CUDA(cudaMemcpy(ref_result.data(), ref_buf.result.get(),
+ ref_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost));
+
+ // Symm-mem run: tokens, recv_tokens, combine_input (== recv_tokens) all symm.
+ EPBuffers<> sym_buf; // alloc all buffers except the symm ones.
+ sym_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_,
+ ep_size_, max_tokens_per_rank_);
+ upload_inputs(sym_buf);
+
+ SymmBuf sym_tokens, sym_recv;
+ sym_tokens.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16));
+ sym_recv .alloc(sym_buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16));
+
+ // Stage same tokens into the symm-mem input.
+ auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_);
+ NVTE_CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(),
+ h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice));
+
+ EPTensors<> sym_t(sym_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_);
+ // Replace the tokens/recv_tokens views with ones pointing at the symm buffers.
+ sym_t.tokens = TensorWrapper(sym_tokens.ptr,
+ std::vector{(size_t)num_tokens_, (size_t)hidden_dim_}, DType::kBFloat16);
+ sym_t.recv_tokens = TensorWrapper(sym_recv.ptr,
+ std::vector{sym_buf.recv_capacity, (size_t)hidden_dim_}, DType::kBFloat16);
+
+ uint64_t sym_hid = sym_buf.handle_id;
+ ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{sym_hid, sym_t.handle_mem.data()}, sym_t.topk_idx.data(), sym_t.token_counts.data(), /*alignment=*/0, stream));
+ ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{sym_hid, sym_t.handle_mem.data()}, sym_t.topk_idx.data(),
+ sym_t.tokens.data(), symm_window(sym_tokens),
+ sym_t.topk_weights.data(), NVTECommWindow{},
+ sym_t.recv_tokens.data(), symm_window(sym_recv),
+ sym_t.recv_topk_weights.data(), NVTECommWindow{}, stream));
+ ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{sym_hid, sym_t.handle_mem.data()}, sym_t.recv_tokens.data(),
+ symm_window(sym_recv), sym_t.result.data(), stream));
+ NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
+
+ std::vector sym_recv_host(sym_buf.recv_capacity * hidden_dim_);
+ std::vector sym_result(num_tokens_ * hidden_dim_);
+ NVTE_CHECK_CUDA(cudaMemcpy(sym_recv_host.data(), sym_recv.ptr,
+ sym_recv_host.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost));
+ NVTE_CHECK_CUDA(cudaMemcpy(sym_result.data(), sym_buf.result.get(),
+ sym_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost));
+
+ // Compare per filled recv slot (HBM ref vs symm) and full result.
+ int total_recv = read_total_recv(sym_buf);
+ for (int i = 0; i < total_recv * hidden_dim_; ++i)
+ ASSERT_EQ(__bfloat162float(sym_recv_host[i]), __bfloat162float(ref_recv[i]))
+ << "recv mismatch at " << i;
+ for (size_t i = 0; i < sym_result.size(); ++i)
+ ASSERT_EQ(__bfloat162float(sym_result[i]), __bfloat162float(ref_result[i]))
+ << "result mismatch at " << i;
+
+ if (g_process_id == 0)
+ printf(" IdentityAllSymm: passed (recv_slots=%d, bit-exact vs HBM)\n", total_recv);
+
+ NVTE_CHECK_CUDA(cudaStreamDestroy(stream));
+}
+
+
+// ── main ──────────────────────────────────────────────────────────────────────
+
+int main(int argc, char* argv[]) {
+ if (!ep_bootstrap(argc, argv)) return 0;
+ int ret = RUN_ALL_TESTS();
+ ep_teardown();
+ return ret;
+}
diff --git a/tests/cpp_distributed/test_ep_common.h b/tests/cpp_distributed/test_ep_common.h
new file mode 100644
index 0000000000..135a39416e
--- /dev/null
+++ b/tests/cpp_distributed/test_ep_common.h
@@ -0,0 +1,184 @@
+/*************************************************************************
+ * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+/*
+ * Shared TE EP test infrastructure. Include once per TU; ep_bootstrap() in
+ * each test binary's main() populates process-level globals.
+ * Defaults: 4 experts/rank, hidden_dim=256, max_tokens_per_rank=64.
+ */
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include "../cpp/test_common.h"
+#include "util/logging.h"
+
+using transformer_engine::DType;
+using transformer_engine::TensorWrapper;
+
+#define CHECK_MPI(expr) \
+ do { \
+ int _err_mpi = (expr); \
+ NVTE_CHECK(_err_mpi == MPI_SUCCESS, "MPI error: ", _err_mpi); \
+ } while (false)
+
+// ── Process-level state ───────────────────────────────────────────────────────
+
+static int g_process_id = -1;
+static int g_num_processes = -1;
+
+static int g_sm_major = -1; // set by ep_bootstrap; -1 until then
+static int g_ep_size = -1;
+static int g_num_experts = -1;
+static int g_hidden_dim = 256;
+static int g_max_tokens_per_rank = 64;
+static NVTEDType g_max_token_dtype = kNVTEFloat32; // staging-buffer sizing
+static bool g_ep_initialized = false;
+static ncclComm_t g_ep_comm = nullptr; // owned by harness, destroyed in ep_teardown
+
+// RAII owner for a cudaMalloc'd device buffer; element-count API on top of
+// test::CudaPtr.
+template
+struct DevBuf {
+ test::CudaPtr ptr;
+ size_t count = 0;
+
+ DevBuf() = default;
+ explicit DevBuf(size_t n) { alloc(n); }
+
+ void alloc(size_t n) {
+ count = n;
+ ptr = (n > 0) ? test::cuda_alloc(n * sizeof(T)) : test::CudaPtr{};
+ }
+ void reset() {
+ ptr.reset();
+ count = 0;
+ }
+
+ T* get() const { return ptr.get(); }
+ size_t bytes() const { return count * sizeof(T); }
+};
+
+// ── Shared routing helper ─────────────────────────────────────────────────────
+
+// Balanced round-robin routing: token t on rank r maps top_k experts to
+// (r * num_local_experts + t * top_k + k) % num_experts
+static inline std::vector routing_balanced(
+ int rank, int num_tokens, int top_k, int num_experts, int num_local_experts) {
+ std::vector idx(num_tokens * top_k);
+ for (int t = 0; t < num_tokens; ++t)
+ for (int k = 0; k < top_k; ++k)
+ idx[t * top_k + k] = (rank * num_local_experts + t * top_k + k) % num_experts;
+ return idx;
+}
+
+// ── ncclUniqueId exchange via MPI ─────────────────────────────────────────────
+
+static void exchange_unique_id(ncclUniqueId* uid) {
+ if (g_process_id == 0) NVTE_CHECK_NCCL(ncclGetUniqueId(uid));
+ CHECK_MPI(MPI_Bcast(uid, sizeof(*uid), MPI_BYTE, 0, MPI_COMM_WORLD));
+}
+
+// ── CLI parsing ───────────────────────────────────────────────────────────────
+
+static void ep_parse_args(int argc, char* argv[]) {
+ for (int i = 1; i < argc; ++i) {
+ std::string a(argv[i]);
+ if (a.rfind("--max-token-dtype=", 0) == 0)
+ g_max_token_dtype = static_cast(std::stoi(a.substr(18)));
+ }
+}
+
+// ── Bootstrap / teardown ──────────────────────────────────────────────────────
+
+// Returns false if the binary should exit without running tests (wrong SM, etc.).
+static bool ep_bootstrap(int argc, char* argv[]) {
+ int mpi_initialized = 0;
+ MPI_Initialized(&mpi_initialized);
+ if (!mpi_initialized) CHECK_MPI(MPI_Init(&argc, &argv));
+ CHECK_MPI(MPI_Comm_rank(MPI_COMM_WORLD, &g_process_id));
+ CHECK_MPI(MPI_Comm_size(MPI_COMM_WORLD, &g_num_processes));
+
+ ep_parse_args(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+
+ int device_count;
+ cudaGetDeviceCount(&device_count);
+ cudaSetDevice(g_process_id % device_count);
+
+ int device, major;
+ cudaGetDevice(&device);
+ cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device);
+ g_sm_major = major;
+ if (major < 9) {
+ if (g_process_id == 0)
+ printf("SKIP: EP requires SM_90+ (device is SM_%d0)\n", major);
+ return false;
+ }
+ if (g_num_processes < 2) {
+ if (g_process_id == 0)
+ printf("SKIP: at least 2 processes required\n");
+ return false;
+ }
+
+ g_ep_size = g_num_processes;
+ g_num_experts = g_ep_size * 4; // 4 experts per rank
+
+ ncclUniqueId uid{};
+ exchange_unique_id(&uid);
+
+ NVTEEpGroupConfig group_config{};
+ group_config.ep_size = g_ep_size;
+ group_config.num_experts = g_num_experts;
+ group_config.max_tokens_per_rank = g_max_tokens_per_rank;
+ // Worst-case for top_k fan-out: ep_size * max_tokens_per_rank * 2.
+ group_config.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2;
+ group_config.hidden_dim = g_hidden_dim;
+ group_config.max_token_dtype = g_max_token_dtype;
+
+ NVTE_CHECK_NCCL(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id));
+ nvte_ep_initialize(static_cast(g_ep_comm), group_config);
+
+ if (g_process_id == 0) {
+ printf("EP initialized: ep_size=%d num_experts=%d "
+ "hidden_dim=%d max_tokens_per_rank=%d\n",
+ g_ep_size, g_num_experts, g_hidden_dim, g_max_tokens_per_rank);
+ }
+
+ g_ep_initialized = true;
+ return true;
+}
+
+// Tear down in dependency order: backend's ep_group reads from ep_comm,
+// so destroy the group first, then the comm.
+static void ep_teardown() {
+ if (g_ep_initialized) {
+ nvte_ep_shutdown();
+ if (g_ep_comm != nullptr) {
+ ncclCommDestroy(g_ep_comm);
+ g_ep_comm = nullptr;
+ }
+ g_ep_initialized = false;
+ }
+ int finalized = 0;
+ MPI_Finalized(&finalized);
+ if (!finalized) MPI_Finalize();
+}
diff --git a/tests/jax/multi_process_launch_ep.sh b/tests/jax/multi_process_launch_ep.sh
new file mode 100755
index 0000000000..a37ffc2952
--- /dev/null
+++ b/tests/jax/multi_process_launch_ep.sh
@@ -0,0 +1,67 @@
+# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+#!/bin/bash
+
+SCRIPT_NAMES="${SCRIPT_NAMES:-test_multi_process_ep.py}"
+TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-180}"
+
+
+XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
+ --xla_gpu_graph_min_graph_size=1"
+
+export XLA_FLAGS="${XLA_BASE_FLAGS}"
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"
+export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}"
+
+NUM_RUNS=$(nvidia-smi -L | wc -l)
+
+if [ "${NUM_RUNS}" -lt 4 ]; then
+ echo "NCCL EP requires at least 4 GPUs (found ${NUM_RUNS}); SKIPPING."
+ exit 0
+fi
+# Default test mesh is (2, 2); use exactly 4 ranks even on larger boxes.
+NUM_RUNS="${NVTE_TEST_EP_NUM_RANKS:-4}"
+
+OVERALL_RET=0
+
+for SCRIPT_NAME in $SCRIPT_NAMES; do
+ echo "=== Running ${SCRIPT_NAME} ==="
+ for ((i=1; i stdout_rank_${i}.txt 2>&1 &
+ done
+
+ timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \
+ python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS 2>&1 | tee stdout_multi_process.txt
+
+ wait
+
+ RET=0
+ if grep -q "FAILED" stdout_multi_process.txt; then
+ RET=1
+ fi
+ # Treat missing test summary on rank 0 as hang/crash rather than silent success.
+ if ! grep -qE "Ran [0-9]+ test|^OK$|PASSED" stdout_multi_process.txt; then
+ echo "ERROR: rank 0 produced no test summary for ${SCRIPT_NAME} — likely a hang or early crash."
+ echo " NCCL EP requires NVLS multicast; check NCCL_DEBUG=INFO output."
+ RET=1
+ fi
+ if [ "$RET" -ne 0 ]; then
+ for ((i=1; i/dev/null || echo "(no log)"
+ done
+ fi
+
+ rm -f stdout_multi_process.txt stdout_rank_*.txt
+ if [ "$RET" -ne 0 ]; then
+ OVERALL_RET=1
+ fi
+done
+
+exit "$OVERALL_RET"
diff --git a/tests/jax/run_te_ep_moe.sh b/tests/jax/run_te_ep_moe.sh
new file mode 100755
index 0000000000..32d5f21956
--- /dev/null
+++ b/tests/jax/run_te_ep_moe.sh
@@ -0,0 +1,122 @@
+#!/usr/bin/env bash
+# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+#
+# Multiprocess (one-GPU-per-process) launcher for the TE-EP MoE custom_vjp
+# test suite. Forks one pytest invocation per visible GPU, passing each
+# its own --num-process=N --process-id=i, and waits for all of them. Each
+# child calls jax.distributed.initialize(..., local_device_ids=process_id)
+# so each Python process only sees its one GPU as a local device and the
+# participating processes form a global (ep, fsdp) mesh.
+
+set -euo pipefail
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+TE_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
+TEST_FILE="$TE_ROOT/tests/jax/test_te_ep_moe.py"
+PYTEST_INI="$TE_ROOT/tests/jax/pytest.ini"
+
+NUM_GPUS="${NUM_GPUS:-$(nvidia-smi -L | wc -l)}"
+if [ "$NUM_GPUS" -lt 4 ]; then
+ echo "[run_te_ep_moe.sh] need >=4 GPUs (got $NUM_GPUS); aborting" >&2
+ exit 1
+fi
+
+export XLA_PYTHON_CLIENT_PREALLOCATE="${XLA_PYTHON_CLIENT_PREALLOCATE:-false}"
+export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.5}"
+export TE_EP_MOE_COORDINATOR_ADDRESS="${TE_EP_MOE_COORDINATOR_ADDRESS:-127.0.0.1:13457}"
+
+echo "============================================================"
+echo "TE-EP MoE MULTIPROCESS test (one process per GPU, ${NUM_GPUS} GPUs)"
+echo " test file : $TEST_FILE"
+echo " coordinator : $TE_EP_MOE_COORDINATOR_ADDRESS"
+echo " XLA_PYTHON_CLIENT_PREALLOCATE: $XLA_PYTHON_CLIENT_PREALLOCATE"
+echo " XLA_PYTHON_CLIENT_MEM_FRACTION: $XLA_PYTHON_CLIENT_MEM_FRACTION"
+echo "============================================================"
+
+if [ -n "${TE_EP_MOE_MP_LOG_DIR:-}" ]; then
+ LOG_DIR="$TE_EP_MOE_MP_LOG_DIR"
+ mkdir -p "$LOG_DIR"
+else
+ LOG_DIR=$(mktemp -d -t te_ep_moe_mp_XXXXXX)
+fi
+echo "Per-process logs: $LOG_DIR"
+
+PIDS=()
+
+cleanup() {
+ for pid in "${PIDS[@]:-}"; do
+ if kill -0 "$pid" 2>/dev/null; then
+ kill -TERM "$pid" 2>/dev/null || true
+ fi
+ done
+ sleep 1
+ for pid in "${PIDS[@]:-}"; do
+ if kill -0 "$pid" 2>/dev/null; then
+ kill -KILL "$pid" 2>/dev/null || true
+ fi
+ done
+}
+trap cleanup EXIT INT TERM
+
+for i in $(seq 0 $((NUM_GPUS - 1))); do
+ LOG_FILE="$LOG_DIR/proc_${i}.log"
+ PYTEST_CMD=(
+ python3 -m pytest -c "$PYTEST_INI"
+ "$TEST_FILE"
+ -p no:typeguard
+ -v -s
+ --num-process="$NUM_GPUS"
+ --process-id="$i"
+ )
+ if [ "$i" -eq 0 ]; then
+ echo "=== Live output from process 0 ==="
+ "${PYTEST_CMD[@]}" 2>&1 | tee "$LOG_FILE" &
+ else
+ "${PYTEST_CMD[@]}" > "$LOG_FILE" 2>&1 &
+ fi
+ PIDS+=("$!")
+done
+
+EXITS=()
+for pid in "${PIDS[@]}"; do
+ if wait "$pid"; then
+ EXITS+=("0")
+ else
+ EXITS+=("$?")
+ fi
+done
+
+echo
+echo "============================================================"
+echo "Per-process exit codes:"
+for i in "${!EXITS[@]}"; do
+ echo " proc $i -> ${EXITS[$i]}"
+done
+
+# Treat exit 0 (pass) and exit 5 (pytest "no tests collected", which the
+# file emits via pytest.skip(allow_module_level=True) on pre-Blackwell
+# GPUs) as success.
+FAILED=0
+for e in "${EXITS[@]}"; do
+ if [ "$e" != "0" ] && [ "$e" != "5" ]; then
+ FAILED=1
+ break
+ fi
+done
+
+echo
+if [ "$FAILED" -eq 0 ]; then
+ echo "[run_te_ep_moe.sh] all processes PASSED"
+ if [ -z "${TE_EP_MOE_MP_LOG_DIR:-}" ]; then
+ rm -rf "$LOG_DIR"
+ fi
+ exit 0
+fi
+
+echo "[run_te_ep_moe.sh] at least one process FAILED"
+echo " retaining logs at $LOG_DIR for diagnosis"
+echo " process 0 tail:"
+tail -20 "$LOG_DIR/proc_0.log" 2>/dev/null || true
+exit 1
diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py
new file mode 100644
index 0000000000..1472ab49fe
--- /dev/null
+++ b/tests/jax/test_multi_process_ep.py
@@ -0,0 +1,761 @@
+# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+"""Multi-process unit tests for the TE-JAX Expert Parallelism (EP) primitives.
+
+Default mesh is (dp=2, ep=2); override via ``NVTE_TEST_EP_MESH=DPxEP``.
+Coverage:
+
+ - ``ep_bootstrap`` rejects when ``ep_resource`` is unset.
+ - Individual primitives (``ep_prepare``, ``ep_dispatch_fwd``, ``ep_combine_fwd``)
+ round-trip an identity expert → output ≈ tokens.
+ - ``ep_dispatch`` custom_vjp: ``grad_tokens ≈ TOP_K · tokens`` (closed form).
+ - ``ep_combine`` custom_vjp: ``max|grad_eo| ≈ eo_const / TOP_K`` (closed form).
+ - ``ep_dispatch`` custom_vjp: exact per-(t, k) ``grad_topk_weights`` under
+ skewed upstream gradients (no k-axis averaging).
+ - HLO reshard guard: compile-only, no XLA collectives outside the EP FFI.
+
+Launch via tests/jax/multi_process_launch_ep.sh (one process per GPU).
+"""
+
+import os
+import sys
+import unittest
+
+import jax
+import jax.experimental.multihost_utils as jmu
+import jax.numpy as jnp
+import numpy as np
+from jax.sharding import Mesh, NamedSharding, PartitionSpec
+
+from transformer_engine.jax.sharding import MeshResource, global_shard_guard
+from transformer_engine.jax.ep import ep_bootstrap, ep_make_handle, ep_dispatch, ep_combine
+from transformer_engine.jax.cpp_extensions.ep import (
+ ep_prepare,
+ ep_dispatch_fwd,
+ ep_combine_fwd,
+)
+
+
+# ── Test config ─────────────────────────────────────────────────────────────
+# NCCL EP requires NUM_LOCAL_EXPERTS*ep % 4 == 0 (TMA alignment in
+# device/hybridep_adapter.cu:511). With NUM_LOCAL_EXPERTS=2, ep must be even.
+
+NUM_LOCAL_EXPERTS = 2 # per-rank → num_experts = NLE * EP
+HIDDEN_DIM = 32
+TOP_K = 2
+TOKENS_PER_DP_SHARD = 4 # per device along dp
+
+
+def _factor_dp_ep(num_procs):
+ """Default to a (2, 2) mesh. Override via ``NVTE_TEST_EP_MESH=DPxEP``.
+
+ NUM_LOCAL_EXPERTS*ep must be a multiple of 4 for NCCL EP's TMA alignment.
+ """
+ override = os.environ.get("NVTE_TEST_EP_MESH")
+ if override:
+ dp_str, ep_str = override.lower().split("x")
+ dp, ep = int(dp_str), int(ep_str)
+ if dp * ep != num_procs:
+ raise ValueError(
+ f"NVTE_TEST_EP_MESH={override!r} does not multiply to num_procs={num_procs}"
+ )
+ if (NUM_LOCAL_EXPERTS * ep) % 4 != 0:
+ raise ValueError(
+ f"NUM_LOCAL_EXPERTS*ep ({NUM_LOCAL_EXPERTS}*{ep}) must be a multiple of 4 "
+ "for NCCL EP TMA alignment"
+ )
+ return dp, ep
+ if num_procs != 4:
+ raise ValueError(
+ f"default mesh expects exactly 4 ranks (got {num_procs}); set "
+ "NVTE_TEST_EP_MESH=DPxEP to override"
+ )
+ return 2, 2
+
+
+def _build_mesh(dp, ep):
+ devs = np.asarray(jax.devices()).reshape(dp, ep)
+ return Mesh(devs, ("dp", "ep"))
+
+
+def _local_device_sm():
+ """Return SM major*10+minor of the first local CUDA device, or None."""
+ try:
+ dev = jax.local_devices()[0]
+ cap = getattr(dev, "compute_capability", None)
+ if cap is None:
+ return None
+ major, minor = (int(x) for x in str(cap).split("."))
+ return major * 10 + minor
+ except Exception:
+ return None
+
+
+class TestEP(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ sm = _local_device_sm()
+ if sm is not None and sm < 90:
+ raise unittest.SkipTest(f"NCCL EP requires SM>=90 (got SM{sm})")
+ cls.num_procs = jax.process_count()
+ cls.rank = jax.process_index()
+ cls.dp, cls.ep = _factor_dp_ep(cls.num_procs)
+ cls.num_experts = NUM_LOCAL_EXPERTS * cls.ep
+ # recv_capacity is per-DP-group (NCCL EP comms isolated per DP color).
+ # Under PartitionSpec(("dp","ep"), None) each EP group sees
+ # T_global/dp = TOKENS_PER_DP_SHARD tokens total; pad for routing skew.
+ T_per_ep_group = TOKENS_PER_DP_SHARD
+ active_experts = min(cls.num_experts, T_per_ep_group * TOP_K)
+ overconc = cls.num_experts // active_experts
+ cls.recv_capacity_per_rank = (
+ NUM_LOCAL_EXPERTS * max(T_per_ep_group * TOP_K, 16) * overconc * 2
+ )
+ cls.mesh = _build_mesh(cls.dp, cls.ep)
+ cls.mr = MeshResource(dp_resource="dp", ep_resource="ep")
+ with cls.mesh, global_shard_guard(cls.mr):
+ ep_bootstrap(
+ world_size=cls.num_procs,
+ rank=cls.rank,
+ ep_size=cls.ep,
+ num_experts=cls.num_experts,
+ max_tokens_per_rank=TOKENS_PER_DP_SHARD,
+ recv_capacity_per_rank=cls.recv_capacity_per_rank,
+ hidden_dim=HIDDEN_DIM,
+ # XLA reallocates handle_mem between JIT executables.
+ allow_handle_mem_reloc=True,
+ )
+ # One handle key shared by all single-layer tests below.
+ cls.hk = ep_make_handle(TOP_K)
+
+ # ── Bootstrap precondition ────────────────────────────────────────────
+
+ def test_bootstrap_rejects_missing_ep_axis(self):
+ """ep_bootstrap raises when MeshResource has no ep_resource."""
+ with self.mesh, global_shard_guard(MeshResource()):
+ with self.assertRaisesRegex(ValueError, "ep_resource"):
+ ep_bootstrap(
+ world_size=self.num_procs,
+ rank=self.rank,
+ ep_size=self.ep,
+ num_experts=self.num_experts,
+ max_tokens_per_rank=TOKENS_PER_DP_SHARD,
+ recv_capacity_per_rank=self.recv_capacity_per_rank,
+ hidden_dim=HIDDEN_DIM,
+ )
+
+ # ── Helpers ───────────────────────────────────────────────────────────
+
+ def _make_identity_inputs(self, nonuniform=False):
+ """Identity routing + uniform weights — combined output ≈ tokens.
+
+ ``nonuniform=False``: ``(t*TOP_K+k) % E`` (round-robin, near-balanced).
+ ``nonuniform=True``: ``top1=0`` for every token, ``top2=1+(t%(E-1))`` —
+ expert 0 absorbs the entire batch while the others split the second
+ slot evenly. Exercises a skewed per-expert load.
+ """
+ T_global = TOKENS_PER_DP_SHARD * self.dp
+ E = self.num_experts
+ topk_idx = np.empty((T_global, TOP_K), dtype=np.int32)
+ if nonuniform:
+ assert TOP_K == 2, "non-uniform pattern assumes top_k=2"
+ for t in range(T_global):
+ topk_idx[t, 0] = 0
+ topk_idx[t, 1] = 1 + (t % (E - 1))
+ else:
+ for t in range(T_global):
+ for k in range(TOP_K):
+ topk_idx[t, k] = (t * TOP_K + k) % E
+ topk_idx = jnp.asarray(topk_idx)
+ topk_weights = jnp.full((T_global, TOP_K), 1.0 / TOP_K, dtype=jnp.float32)
+ tokens = jnp.asarray(
+ np.linspace(0.1, 0.9, T_global * HIDDEN_DIM, dtype=np.float32).reshape(
+ T_global, HIDDEN_DIM
+ ),
+ dtype=jnp.bfloat16,
+ )
+ return T_global, topk_idx, tokens, topk_weights
+
+ def _make_random_inputs(self, seed=42, nonuniform=True):
+ """Random tokens + skewed top-2 routing (top1=0 always; top2 varies).
+
+ Non-uniform load by default — guarantees expert 0 receives every token
+ while the rest of the experts split the second slot. Use
+ ``nonuniform=False`` for a balanced (t%E, (t+1)%E) pattern.
+ """
+ T_dp = TOKENS_PER_DP_SHARD * self.dp
+ E = self.num_experts
+ rng = np.random.default_rng(seed=seed)
+ tokens = jnp.asarray(
+ rng.standard_normal((T_dp, HIDDEN_DIM), dtype=np.float32) * 0.5,
+ dtype=jnp.bfloat16,
+ )
+ topk_idx_np = np.empty((T_dp, TOP_K), dtype=np.int32)
+ if nonuniform:
+ assert TOP_K == 2, "non-uniform pattern assumes top_k=2"
+ for t in range(T_dp):
+ topk_idx_np[t, 0] = 0
+ topk_idx_np[t, 1] = 1 + (t % (E - 1))
+ else:
+ for t in range(T_dp):
+ a, b = t % E, (t + 1) % E
+ topk_idx_np[t, 0], topk_idx_np[t, 1] = (a, b) if a < b else (b, a)
+ topk_idx = jnp.asarray(topk_idx_np)
+ topk_weights = jnp.asarray(np.full((T_dp, TOP_K), 1.0 / TOP_K, dtype=np.float32))
+ return T_dp, tokens, topk_idx, topk_weights
+
+ # ── Individual primitives (cpp_extensions level) ──────────────────────
+
+ def test_two_handles_distinct_ids(self):
+ """Two ``ep_make_handle`` calls must yield distinct ``handle_id``s;
+ distinct logical layers cannot share a HandleEntry. Verified through a
+ jit so each ``ep_prepare`` bind path is exercised."""
+ _T, topk_idx, _tokens, _w = self._make_identity_inputs()
+ ka, kb = ep_make_handle(TOP_K), ep_make_handle(TOP_K)
+ dp_spec = PartitionSpec(("dp", "ep"), None)
+ with self.mesh, global_shard_guard(self.mr):
+ idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec))
+
+ @jax.jit
+ def run(idx):
+ _tc_a, ha = ep_prepare(idx, ka)
+ _tc_b, hb = ep_prepare(idx, kb)
+ return ha, hb
+
+ hm_a, hm_b = run(idx_s)
+ hm_a.block_until_ready()
+ hm_b.block_until_ready()
+ self.assertNotEqual(ka.handle_id, kb.handle_id)
+
+ def test_two_layer_dispatch_no_handle_aliasing(self):
+ """Two ep_dispatch calls in one jit with distinct ``EpHandle``s must
+ not clobber each other's routing state. Different inputs per layer with
+ identity routing + uniform weights => both recv buffers must independently
+ identity-round-trip via ep_combine."""
+ T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False)
+ tokens_b = (tokens.astype(jnp.float32) * -1.0 + 0.25).astype(tokens.dtype)
+ ka, kb = ep_make_handle(TOP_K), ep_make_handle(TOP_K)
+ dp_spec = PartitionSpec(("dp", "ep"), None)
+ ep_spec_3d = PartitionSpec(("dp", "ep"), None, None)
+ ep_spec_2d = PartitionSpec(("dp", "ep"), None)
+ with self.mesh, global_shard_guard(self.mr):
+ idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec))
+ ta = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec))
+ tb = jax.lax.with_sharding_constraint(tokens_b, NamedSharding(self.mesh, dp_spec))
+ w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec))
+
+ def one_layer(hk, idx, toks, w_):
+ recv_t, recv_w, hm, tc = ep_dispatch(hk, idx, toks, w_, self.recv_capacity_per_rank)
+ recv_t = jax.lax.with_sharding_constraint(
+ recv_t, NamedSharding(self.mesh, ep_spec_3d)
+ )
+ recv_w = jax.lax.with_sharding_constraint(
+ recv_w, NamedSharding(self.mesh, ep_spec_2d)
+ )
+ return ep_combine(
+ hk, hm, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None)
+ )
+
+ @jax.jit
+ def run(idx, ta_, tb_, w_):
+ return one_layer(ka, idx, ta_, w_), one_layer(kb, idx, tb_, w_)
+
+ out_a, out_b = run(idx_s, ta, tb, w)
+ out_a.block_until_ready()
+ out_b.block_until_ready()
+ out_a_g = jmu.process_allgather(out_a, tiled=True)
+ out_b_g = jmu.process_allgather(out_b, tiled=True)
+
+ self.assertNotEqual(ka.handle_id, kb.handle_id)
+ if self.rank == 0:
+ np.testing.assert_allclose(
+ np.asarray(out_a_g.astype(jnp.float32)),
+ np.asarray(tokens.astype(jnp.float32)),
+ atol=5e-2,
+ rtol=5e-2,
+ )
+ np.testing.assert_allclose(
+ np.asarray(out_b_g.astype(jnp.float32)),
+ np.asarray(tokens_b.astype(jnp.float32)),
+ atol=5e-2,
+ rtol=5e-2,
+ )
+
+ def test_primitive_prepare(self):
+ """ep_prepare returns the expected shapes and a valid handle id."""
+ T_global, topk_idx, _tokens, _w = self._make_identity_inputs()
+ del T_global
+ dp_spec = PartitionSpec(("dp", "ep"), None)
+ with self.mesh, global_shard_guard(self.mr):
+ idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec))
+
+ @jax.jit
+ def run(idx):
+ tc, hm = ep_prepare(idx, self.hk)
+ return tc, hm
+
+ tc, hm = run(idx_s)
+ tc.block_until_ready()
+ self.assertEqual(tc.shape, (self.dp * self.ep, NUM_LOCAL_EXPERTS))
+ self.assertEqual(hm.shape[0], self.dp * self.ep)
+ self.assertGreater(hm.shape[1], 0)
+
+ def _run_identity_round_trip(self, nonuniform):
+ T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=nonuniform)
+ dp_spec = PartitionSpec(("dp", "ep"), None)
+ with self.mesh, global_shard_guard(self.mr):
+ idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec))
+ tok_s = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec))
+ w_s = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec))
+
+ ep_spec_3d = PartitionSpec(("dp", "ep"), None, None)
+ ep_spec_2d = PartitionSpec(("dp", "ep"), None)
+
+ @jax.jit
+ def run(idx, toks, w):
+ _tc, hm = ep_prepare(idx, self.hk)
+ recv_t, recv_w = ep_dispatch_fwd(
+ self.hk, hm, idx, toks, w, self.recv_capacity_per_rank
+ )
+ recv_t = jax.lax.with_sharding_constraint(
+ recv_t, NamedSharding(self.mesh, ep_spec_3d)
+ )
+ recv_w = jax.lax.with_sharding_constraint(
+ recv_w, NamedSharding(self.mesh, ep_spec_2d)
+ )
+ # Apply the weighted hadamard inline (combine FFI is unweighted).
+ mask = (recv_w != 0).astype(jnp.float32)[..., None]
+ weighted = (recv_t.astype(jnp.float32) * recv_w[..., None] * mask).astype(
+ recv_t.dtype
+ )
+ weighted = jax.lax.with_sharding_constraint(
+ weighted, NamedSharding(self.mesh, ep_spec_3d)
+ )
+ out = ep_combine_fwd(
+ self.hk,
+ hm,
+ weighted,
+ T_global,
+ out_partition_spec=(("dp", "ep"), None),
+ )
+ return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec))
+
+ out = run(idx_s, tok_s, w_s)
+ out.block_until_ready()
+ # Allgather so the rank-0 numpy comparison sees the full global tensor.
+ out_global = jmu.process_allgather(out, tiled=True)
+
+ # Identity expert + uniform weights → out ≈ tokens (rank-0 check).
+ if self.rank == 0:
+ np.testing.assert_allclose(
+ np.asarray(out_global.astype(jnp.float32)),
+ np.asarray(tokens.astype(jnp.float32)),
+ atol=5e-2,
+ rtol=5e-2,
+ )
+
+ def test_primitive_dispatch_combine_identity_uniform(self):
+ """Round-robin routing → identity round-trip via the primitive layer."""
+ self._run_identity_round_trip(nonuniform=False)
+
+ def test_primitive_dispatch_combine_identity_nonuniform(self):
+ """Skewed routing (top1=0 always) → identity round-trip via the primitive layer."""
+ self._run_identity_round_trip(nonuniform=True)
+
+ def test_primitive_dispatch_combine_identity_bwd_uniform(self):
+ """Bwd through identity round-trip: ∇(0.5 ||out||²) w.r.t. tokens ≈ tokens.
+
+ Identity routing + uniform top-k weights ⇒ dispatch∘combine is the
+ identity, so loss = 0.5||tokens||² and ∇_tokens loss = tokens.
+ """
+ T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False)
+ dp_spec = PartitionSpec(("dp", "ep"), None)
+ ep_spec_3d = PartitionSpec(("dp", "ep"), None, None)
+ ep_spec_2d = PartitionSpec(("dp", "ep"), None)
+
+ with self.mesh, global_shard_guard(self.mr):
+
+ def loss_fn(toks):
+ toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec))
+ idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec))
+ w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec))
+ recv_t, recv_w, hm, tc = ep_dispatch(
+ self.hk, idx, toks, w, self.recv_capacity_per_rank
+ )
+ recv_t = jax.lax.with_sharding_constraint(
+ recv_t, NamedSharding(self.mesh, ep_spec_3d)
+ )
+ recv_w = jax.lax.with_sharding_constraint(
+ recv_w, NamedSharding(self.mesh, ep_spec_2d)
+ )
+ out = ep_combine(
+ self.hk, hm, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None)
+ )
+ return 0.5 * (out.astype(jnp.float32) ** 2).sum()
+
+ grad = jax.jit(jax.grad(loss_fn))(tokens)
+ grad.block_until_ready()
+ grad_global = jmu.process_allgather(grad, tiled=True)
+
+ if self.rank == 0:
+ np.testing.assert_allclose(
+ np.asarray(grad_global.astype(jnp.float32)),
+ np.asarray(tokens.astype(jnp.float32)),
+ atol=5e-2,
+ rtol=5e-2,
+ )
+
+ def test_dispatch_combine_3d_input_output(self):
+ """3D input ``[B, S, H]`` sharded on the first dim only —
+ ``(("dp","ep"), None, None)`` here — dispatch accepts the rank-3 shape
+ and combine returns a matching 3D ``[B, S, H]`` output. End-to-end
+ round trip recovers the original tokens under identity routing +
+ uniform top-k weights."""
+ T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False)
+ # B is sharded across all (dp*ep) ranks; S held in one piece per rank.
+ B, S, H = T_global, 1, tokens.shape[-1]
+ tokens_3d = tokens.reshape(B, S, H)
+ topk_idx_3d = topk_idx.reshape(B, S, -1)
+ topk_w_3d = topk_w.reshape(B, S, -1)
+ spec_3d = PartitionSpec(("dp", "ep"), None, None)
+ out_spec_3d = (("dp", "ep"), None, None)
+ with self.mesh, global_shard_guard(self.mr):
+ idx_s = jax.lax.with_sharding_constraint(topk_idx_3d, NamedSharding(self.mesh, spec_3d))
+ tok_s = jax.lax.with_sharding_constraint(tokens_3d, NamedSharding(self.mesh, spec_3d))
+ w_s = jax.lax.with_sharding_constraint(topk_w_3d, NamedSharding(self.mesh, spec_3d))
+
+ ep_t = PartitionSpec(("dp", "ep"), None, None)
+ ep_w = PartitionSpec(("dp", "ep"), None)
+
+ @jax.jit
+ def run(idx, toks, w):
+ recv_t, recv_w, hm, _tc = ep_dispatch(
+ self.hk, idx, toks, w, self.recv_capacity_per_rank
+ )
+ recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t))
+ recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w))
+ out = ep_combine(
+ self.hk,
+ hm,
+ _tc,
+ recv_t,
+ recv_w,
+ num_local_tokens=(B, S),
+ out_sharding=out_spec_3d,
+ )
+ return out
+
+ out = run(idx_s, tok_s, w_s)
+ out.block_until_ready()
+ out_global = jmu.process_allgather(out, tiled=True)
+
+ if self.rank == 0:
+ self.assertEqual(out_global.shape, (B, S, H))
+ np.testing.assert_allclose(
+ np.asarray(out_global.astype(jnp.float32)),
+ np.asarray(tokens_3d.astype(jnp.float32)),
+ atol=5e-2,
+ rtol=5e-2,
+ )
+
+ def test_dispatch_combine_dp_only_first_dim(self):
+ """Input sharded ``("dp", None)`` (no ep on leading) — dispatch must
+ accept it. JAX SPMD slices the missing ep axis locally so the kernel
+ still sees ``T/(dp*ep)`` tokens per rank."""
+ T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False)
+ dp_only = PartitionSpec("dp", None)
+ with self.mesh, global_shard_guard(self.mr):
+ idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_only))
+ tok_s = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_only))
+ w_s = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_only))
+
+ ep_t = PartitionSpec(("dp", "ep"), None, None)
+ ep_w = PartitionSpec(("dp", "ep"), None)
+
+ @jax.jit
+ def run(idx, toks, w):
+ recv_t, recv_w, hm, _tc = ep_dispatch(
+ self.hk, idx, toks, w, self.recv_capacity_per_rank
+ )
+ recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t))
+ recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w))
+ out = ep_combine(
+ self.hk,
+ hm,
+ _tc,
+ recv_t,
+ recv_w,
+ num_local_tokens=T_global,
+ out_sharding=(("dp", "ep"), None),
+ )
+ return out
+
+ out = run(idx_s, tok_s, w_s)
+ out.block_until_ready()
+ out_global = jmu.process_allgather(out, tiled=True)
+
+ if self.rank == 0:
+ np.testing.assert_allclose(
+ np.asarray(out_global.astype(jnp.float32)),
+ np.asarray(tokens.astype(jnp.float32)),
+ atol=5e-2,
+ rtol=5e-2,
+ )
+
+ # ── Custom-VJP tests ─────────────────────────────────────────────────
+
+ def test_dispatch_vjp_fwd_bwd(self):
+ """ep_dispatch fwd + jax.grad w.r.t. tokens.
+
+ Identity routing + loss = 0.5||recv_tokens||² ⇒ each token appears
+ TOP_K times in recv_tokens (all routes fit recv_capacity), so
+ grad_tokens = TOP_K * tokens (closed form).
+ """
+ T_global, topk_idx, tokens, topk_w = self._make_identity_inputs()
+ del T_global
+ dp_spec = PartitionSpec(("dp", "ep"), None)
+ ep_spec_3d = PartitionSpec(("dp", "ep"), None, None)
+
+ with self.mesh, global_shard_guard(self.mr):
+
+ def loss_fn(toks):
+ toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec))
+ idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec))
+ w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec))
+ recv_tokens, _recv_w, _hm, _tc = ep_dispatch(
+ self.hk, idx, toks, w, self.recv_capacity_per_rank
+ )
+ recv_tokens = jax.lax.with_sharding_constraint(
+ recv_tokens, NamedSharding(self.mesh, ep_spec_3d)
+ )
+ return 0.5 * (recv_tokens.astype(jnp.float32) ** 2).sum()
+
+ loss, grad_tokens = jax.jit(jax.value_and_grad(loss_fn))(tokens)
+ grad_tokens.block_until_ready()
+ grad_global = jmu.process_allgather(grad_tokens, tiled=True)
+
+ self.assertTrue(np.isfinite(float(loss)))
+ self.assertEqual(grad_tokens.shape, tokens.shape)
+ if self.rank == 0:
+ np.testing.assert_allclose(
+ np.asarray(grad_global.astype(jnp.float32)),
+ np.asarray(tokens.astype(jnp.float32)) * float(TOP_K),
+ atol=5e-2,
+ rtol=5e-2,
+ )
+
+ def test_combine_vjp_fwd_bwd(self):
+ """ep_combine fwd + jax.grad w.r.t. expert_out.
+
+ Identity routing + constant eo=c + uniform topk_w ⇒ combined[t] = c
+ (sum_k topk_w = 1) and grad_eo[e, s, h] = recv_w[e, s] * c at filled
+ slots — so max|grad_eo| ≈ c / TOP_K.
+ """
+ T_global, topk_idx, tokens, topk_w = self._make_identity_inputs()
+ eo_const = 0.5
+ expert_out = jnp.full(
+ (self.dp * self.ep, self.recv_capacity_per_rank, HIDDEN_DIM),
+ eo_const,
+ dtype=jnp.bfloat16,
+ )
+ dp_spec = PartitionSpec(("dp", "ep"), None)
+ ep_spec_3d = PartitionSpec(("dp", "ep"), None, None)
+
+ with self.mesh, global_shard_guard(self.mr):
+
+ def loss_fn(eo):
+ eo = jax.lax.with_sharding_constraint(eo, NamedSharding(self.mesh, ep_spec_3d))
+ toks = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec))
+ idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec))
+ w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec))
+ _recv_tokens, recv_w, hm, tc = ep_dispatch(
+ self.hk, idx, toks, w, self.recv_capacity_per_rank
+ )
+ recv_w = jax.lax.with_sharding_constraint(
+ recv_w, NamedSharding(self.mesh, PartitionSpec(("dp", "ep"), None))
+ )
+ combined = ep_combine(self.hk, hm, tc, eo, recv_w, T_global)
+ # Pin combined to dp-sharded so autodiff transpose feeds
+ # ep_combine_bwd a per-shard cotangent.
+ combined = jax.lax.with_sharding_constraint(
+ combined, NamedSharding(self.mesh, dp_spec)
+ )
+ return 0.5 * (combined.astype(jnp.float32) ** 2).sum()
+
+ loss, grad_eo = jax.jit(jax.value_and_grad(loss_fn))(expert_out)
+ grad_eo.block_until_ready()
+
+ self.assertTrue(np.isfinite(float(loss)))
+ self.assertEqual(grad_eo.shape, expert_out.shape)
+ for shard in grad_eo.addressable_shards:
+ arr = np.asarray(shard.data.astype(jnp.float32))
+ self.assertTrue(np.all(np.isfinite(arr)))
+ self.assertGreater(arr.max(), 0.0, "grad_eo has no positive entry on filled slots")
+ np.testing.assert_allclose(
+ arr.max(),
+ eo_const / float(TOP_K),
+ atol=5e-2,
+ rtol=5e-2,
+ )
+
+ def test_dispatch_bwd_exact_per_k_topk_weights(self):
+ """Distinct per-(t, k) upstream grads ⇒ grad[t, 0] != grad[t, 1] for all t.
+
+ Guards against a regression where the bwd would average across the k
+ axis (per-token mean instead of per-slot exact recovery).
+ """
+ T_dp, tokens, topk_idx, topk_w = self._make_random_inputs()
+ dp_spec = PartitionSpec(("dp", "ep"), None)
+
+ with self.mesh, global_shard_guard(self.mr):
+
+ def loss_fn(idx_in, tok_in, w_in):
+ idx_in = jax.lax.with_sharding_constraint(idx_in, NamedSharding(self.mesh, dp_spec))
+ tok_in = jax.lax.with_sharding_constraint(tok_in, NamedSharding(self.mesh, dp_spec))
+ w_in = jax.lax.with_sharding_constraint(w_in, NamedSharding(self.mesh, dp_spec))
+ _recv_t, recv_w, _h, _tc = ep_dispatch(
+ self.hk, idx_in, tok_in, w_in, self.recv_capacity_per_rank
+ )
+ # Per-slot index scale ⇒ each slot's contribution differs.
+ scale = jnp.asarray(
+ np.arange(recv_w.size, dtype=np.float32).reshape(recv_w.shape) + 1.0
+ )
+ return jnp.sum(recv_w * scale)
+
+ grad_topk_w = jax.jit(jax.grad(loss_fn, argnums=2))(topk_idx, tokens, topk_w)
+ grad_topk_w.block_until_ready()
+ grad_global = jmu.process_allgather(grad_topk_w, tiled=True)
+
+ if self.rank == 0:
+ grad_np = np.asarray(grad_global).astype(np.float32)
+ mismatch = sum(int(abs(grad_np[t, 0] - grad_np[t, 1]) < 1e-6) for t in range(T_dp))
+ self.assertEqual(
+ mismatch,
+ 0,
+ f"Expected grad[t, 0] != grad[t, 1] for all {T_dp} tokens under skewed "
+ f"upstream scaling; got {mismatch} tokens with grad[t, 0] == grad[t, 1].",
+ )
+
+ # ── HLO reshard guard ────────────────────────────────────────────────
+ # Compile-only: assert XLA inserts no cross-device collectives outside
+ # the EP FFI. EP-axis flux is carried by the FFI itself.
+
+ def test_z_no_unexpected_reshard_in_hlo_fwd(self):
+ """Compiled fwd HLO must not insert XLA collectives outside the EP FFI."""
+ T_dp, tokens, topk_idx, topk_w = self._make_random_inputs()
+ dp_spec = PartitionSpec(("dp", "ep"), None)
+ ep_spec_3d = PartitionSpec(("dp", "ep"), None, None)
+ ep_spec_2d = PartitionSpec(("dp", "ep"), None)
+
+ with self.mesh, global_shard_guard(self.mr):
+
+ @jax.jit
+ def run(idx, toks, w):
+ idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec))
+ toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec))
+ w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec))
+ recv_t, recv_w, hm, tc = ep_dispatch(
+ self.hk, idx, toks, w, self.recv_capacity_per_rank
+ )
+ recv_t = jax.lax.with_sharding_constraint(
+ recv_t, NamedSharding(self.mesh, ep_spec_3d)
+ )
+ recv_w = jax.lax.with_sharding_constraint(
+ recv_w, NamedSharding(self.mesh, ep_spec_2d)
+ )
+ out = ep_combine(
+ self.hk, hm, tc, recv_t, recv_w, T_dp, out_sharding=(("dp", "ep"), None)
+ )
+ return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec))
+
+ compiled = run.lower(topk_idx, tokens, topk_w).compile()
+ hlo = compiled.as_text()
+ # Match instruction names; "all-gather-start" and "all-gather-done"
+ # bracket a single async all-gather.
+ for op in ("all-gather-start", "all-to-all", "collective-permute"):
+ self.assertEqual(hlo.count(op), 0, f"unexpected XLA {op} in fwd HLO:\n{hlo}")
+ # XLA drops trailing-None entries from the spec; compare as a tuple.
+ # JAX collapses size-1 mesh axes, so dp=1 reduces ("dp","ep") to "ep".
+ expected = (("dp", "ep"),) if self.dp > 1 else ("ep",)
+ self.assertEqual(tuple(compiled.output_shardings.spec), expected)
+
+ def test_z_no_unexpected_reshard_in_hlo_bwd(self):
+ """Compiled bwd HLO must not insert XLA collectives outside the EP FFI."""
+ T_dp, tokens, topk_idx, topk_w = self._make_random_inputs()
+ rng = np.random.default_rng(seed=44)
+ expert_out = jnp.asarray(
+ rng.standard_normal(
+ (self.dp * self.ep, self.recv_capacity_per_rank, HIDDEN_DIM), dtype=np.float32
+ )
+ * 0.5,
+ dtype=jnp.bfloat16,
+ )
+ dp_spec = PartitionSpec(("dp", "ep"), None)
+ ep_spec_3d = PartitionSpec(("dp", "ep"), None, None)
+ ep_spec_2d = PartitionSpec(("dp", "ep"), None)
+
+ with self.mesh, global_shard_guard(self.mr):
+
+ def fwd(eo, toks, idx, w):
+ eo = jax.lax.with_sharding_constraint(eo, NamedSharding(self.mesh, ep_spec_3d))
+ toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec))
+ idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec))
+ w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec))
+ _rt, rw, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank)
+ rw = jax.lax.with_sharding_constraint(rw, NamedSharding(self.mesh, ep_spec_2d))
+ combined = ep_combine(
+ self.hk, hm, tc, eo, rw, T_dp, out_sharding=(("dp", "ep"), None)
+ )
+ return jax.lax.with_sharding_constraint(combined, NamedSharding(self.mesh, dp_spec))
+
+ # jax.vjp + pinned cotangent feeds ep_combine_bwd/ep_dispatch_bwd
+ # the expected sharding without relying on XLA-transpose propagation.
+ def bwd_only(eo, toks, idx, w, g):
+ _y, vjp_fn = jax.vjp(fwd, eo, toks, idx, w)
+ g = jax.lax.with_sharding_constraint(g, NamedSharding(self.mesh, dp_spec))
+ grads = vjp_fn(g)
+ return (
+ jax.lax.with_sharding_constraint(
+ grads[0], NamedSharding(self.mesh, ep_spec_3d)
+ ),
+ jax.lax.with_sharding_constraint(grads[1], NamedSharding(self.mesh, dp_spec)),
+ )
+
+ g_seed = jnp.ones((T_dp, HIDDEN_DIM), dtype=jnp.bfloat16)
+ compiled = (
+ jax.jit(bwd_only).lower(expert_out, tokens, topk_idx, topk_w, g_seed).compile()
+ )
+ hlo = compiled.as_text()
+ for op in ("all-gather-start", "all-to-all", "collective-permute"):
+ self.assertEqual(hlo.count(op), 0, f"unexpected XLA {op} in bwd HLO:\n{hlo}")
+
+
+# ── Entry point ──────────────────────────────────────────────────────────────
+
+
+if __name__ == "__main__":
+ if len(sys.argv) < 4:
+ print("Usage: python test_multi_process_ep.py ")
+ sys.exit(1)
+
+ coord_addr = sys.argv[1]
+ proc_id = int(sys.argv[2])
+ num_procs = int(sys.argv[3])
+
+ jax.distributed.initialize(
+ coordinator_address=coord_addr,
+ num_processes=num_procs,
+ process_id=proc_id,
+ local_device_ids=[proc_id],
+ )
+
+ loader = unittest.TestLoader()
+ target = os.environ.get("TARGET_TEST")
+ if target:
+ name = target.split(".")[-1]
+ suite = loader.loadTestsFromName(name, TestEP)
+ else:
+ suite = loader.loadTestsFromTestCase(TestEP)
+ runner = unittest.TextTestRunner(verbosity=2)
+ result = runner.run(suite)
+ sys.exit(0 if result.wasSuccessful() else 1)
diff --git a/tests/jax/test_te_ep_moe.py b/tests/jax/test_te_ep_moe.py
new file mode 100644
index 0000000000..febae4bd2e
--- /dev/null
+++ b/tests/jax/test_te_ep_moe.py
@@ -0,0 +1,791 @@
+# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+"""Multi-process (one-GPU-per-process) tests for the TE-EP MoE custom_vjp.
+
+The launcher ``tests/jax/run_te_ep_moe.sh`` forks one pytest process per
+visible GPU (mirroring ``run_multiprocess_moe_vjp.sh``). Each process binds
+to exactly one device via
+``jax.distributed.initialize(..., local_device_ids=process_id)``; the
+participating processes form a global ``(ep, fsdp)`` mesh through JAX's
+distributed runtime.
+
+How to run
+----------
+
+You typically do NOT invoke pytest on this file directly -- use the
+launcher, which passes ``--num-process=N --process-id=i`` to each
+forked process. Driving it directly with only one process will skip
+every test because :func:`jax.distributed.initialize` requires
+multiple participants, and the TE EP NCCL primitives require at
+least four ranks.
+
+ bash tests/jax/run_te_ep_moe.sh
+
+What this suite covers
+----------------------
+
+This file is the TE-EP-only successor to ``test_moe_vjp.py`` and
+``test_multiprocess_moe_vjp.py``. Each test exercises one MoE-block
+run and bundles every check that single run supports — shape, dtype,
+finiteness AND numerical parity vs a pure-JAX reference. Variations
+on the block are pytest parametrize values rather than separate test
+classes:
+
+* ``test_forward`` covers the forward across a curated set of
+ configurations (apply_topk_weights_early on/off, align_size=0/128,
+ softmax/sigmoid scoring, optional expert_bias). Each config asserts
+ shape, dtype, finiteness and numerical parity vs the reference in
+ one run.
+* ``test_backward`` mirrors that for gradients.
+* ``TestTeEpMoeAuxLoss`` covers the second return value end-to-end
+ (returned + parity + aux-only grad propagates to gate + combined
+ main+aux grads stay finite) in two consolidated tests.
+* ``TestTeEpMoEBlockFlax`` exercises the Flax wrapper with the same
+ parity reference.
+* ``TestZZZTeEpMoeBootstrap`` verifies the per-process NCCL bootstrap
+ rejects a mismatched signature.
+
+FP8 / MXFP8 recipes are deferred — the ``quantizer_sets`` plumbing
+has not yet been re-wired across the TE-EP ``shard_map`` boundary
+(see ``.pr3036-review/INTEGRATION_DESIGN.md``).
+"""
+
+import os
+
+os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
+os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.5")
+
+import sys
+from functools import partial
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pytest
+
+from jax.experimental import mesh_utils
+from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
+from flax.linen import partitioning as nn_partitioning
+
+
+def _init_distributed(num_process: int, process_id: int) -> bool:
+ """Initialize jax.distributed for this pytest process.
+
+ Returns True on a real multi-process launch, False otherwise so
+ the module can fast-skip when pytest collects it without the
+ launcher.
+ """
+ if num_process <= 1:
+ return False
+ coord = os.environ.get("TE_EP_MOE_COORDINATOR_ADDRESS", "127.0.0.1:13457")
+ jax.distributed.initialize(
+ coordinator_address=coord,
+ num_processes=num_process,
+ process_id=process_id,
+ local_device_ids=process_id,
+ )
+ assert jax.local_device_count() == 1, "one GPU per process is required for TE EP"
+ assert (
+ jax.device_count() == num_process
+ ), f"global device_count {jax.device_count()} != num_process {num_process}"
+ return True
+
+
+def _read_mp_options():
+ num = int(os.environ.get("MP_NUM_PROCESS", "0") or "0")
+ pid = int(os.environ.get("MP_PROCESS_ID", "0") or "0")
+ for i, a in enumerate(sys.argv):
+ if a.startswith("--num-process="):
+ num = int(a.split("=", 1)[1])
+ elif a == "--num-process" and i + 1 < len(sys.argv):
+ num = int(sys.argv[i + 1])
+ elif a.startswith("--process-id="):
+ pid = int(a.split("=", 1)[1])
+ elif a == "--process-id" and i + 1 < len(sys.argv):
+ pid = int(sys.argv[i + 1])
+ return num, pid
+
+
+_MP_NUM_PROCESS, _MP_PROCESS_ID = _read_mp_options()
+_MP_ACTIVE = _init_distributed(_MP_NUM_PROCESS, _MP_PROCESS_ID)
+
+if not _MP_ACTIVE:
+ pytest.skip(
+ "test_te_ep_moe.py requires the multiprocess launcher (run_te_ep_moe.sh). Skipping.",
+ allow_module_level=True,
+ )
+
+from transformer_engine_jax import get_device_compute_capability
+
+# Grouped GEMM in the MoE custom_vjp requires Blackwell (sm_100+). The
+# TE EP NCCL primitives themselves need SM>=90, but the FFN body uses
+# grouped_gemm, so the file as a whole gates on sm_100+.
+if get_device_compute_capability(0) < 100:
+ pytest.skip(
+ "MoE TE EP tests require Blackwell (sm_100+) for grouped GEMM",
+ allow_module_level=True,
+ )
+
+from transformer_engine.jax.flax import _MoEBlock as MoEBlock
+from transformer_engine.jax.moe import moe, record_ep_bootstrap_signature_for_moe
+from transformer_engine.jax.ep import ep_bootstrap
+from transformer_engine.jax.sharding import MeshResource, global_shard_guard
+
+
+# -----------------------------------------------------------------------------
+# Mesh / shape config
+# -----------------------------------------------------------------------------
+
+EP_AXIS = "ep"
+FSDP_AXIS = "fsdp"
+EP_SIZE = 2
+assert (
+ jax.device_count() % EP_SIZE == 0
+), f"device_count {jax.device_count()} must be divisible by EP_SIZE={EP_SIZE}"
+FSDP_SIZE = jax.device_count() // EP_SIZE
+NUM_DEVICES_REQUIRED = EP_SIZE * FSDP_SIZE
+
+LOGICAL_AXIS_RULES = (
+ ("exp", EP_AXIS),
+ ("embed", FSDP_AXIS),
+ ("mlp", None),
+ ("batch", (EP_AXIS, FSDP_AXIS)),
+)
+
+# Small shapes so the parity tests stay tight on bf16. The block still
+# has all four ranks participating in dispatch/combine.
+DTYPE = jnp.bfloat16
+BATCH = EP_SIZE * FSDP_SIZE * 2 # 8 on 4-GPU, 16 on 8-GPU
+SEQ = 32
+HIDDEN = 64
+INTER = 128
+NUM_EXPERTS = 8
+TOPK = 2
+
+# bf16 grouped_gemm + softmax-topk + ep all-to-all stack drifts ~1e-1 vs a
+# fp32 numpy reference. Keep these tight enough to catch real bugs but
+# loose enough to absorb expected bf16 rounding.
+FWD_ATOL = 5e-2
+FWD_RTOL = 5e-2
+GRAD_FFN_ATOL = 1e-1
+GRAD_FFN_RTOL = 1e-1
+GRAD_GATE_ATOL = 5e-1
+GRAD_GATE_RTOL = 5e-1
+
+# Two TE EP runs that should be bitwise-equal modulo XLA fusion order
+# (align_size rounding, etc.).
+TE_TO_TE_ATOL = 5e-3
+TE_TO_TE_RTOL = 5e-3
+
+# Aux loss is computed in float32 from the SAME logits as the routing
+# path. Numerical drift between TE-EP and the reference is dominated by
+# the bf16-rounded softmax inside the topk kernel.
+AUX_ATOL = 1e-3
+AUX_RTOL = 1e-3
+
+
+# -----------------------------------------------------------------------------
+# Fixtures
+# -----------------------------------------------------------------------------
+
+
+def _compute_worst_case_recv_pr():
+ """Worst-case per-rank recv buffer across every config in _CONFIGS.
+
+ Bootstrap reserves NCCL EP buffers; per-call recv_pr <= bootstrap
+ recv_pr is fine. We size with the largest align_size in _CONFIGS so
+ the align128 config still fits the same singleton bootstrap.
+ """
+ num_procs = jax.device_count()
+ dp_size = num_procs // EP_SIZE
+ num_local_experts = NUM_EXPERTS // EP_SIZE
+ natural_recv_pr = (BATCH // dp_size) * SEQ * TOPK
+ natural_spe = (natural_recv_pr + num_local_experts - 1) // num_local_experts
+ worst_align = 128
+ worst_spe = ((natural_spe + worst_align - 1) // worst_align) * worst_align
+ return num_local_experts * worst_spe
+
+
+@pytest.fixture(scope="module")
+def mesh():
+ if jax.device_count() < NUM_DEVICES_REQUIRED:
+ pytest.skip(
+ f"Need >={NUM_DEVICES_REQUIRED} devices for ep={EP_SIZE} x fsdp={FSDP_SIZE};"
+ f" have {jax.device_count()}"
+ )
+ # ``ep`` must be the inner axis: ``ep_bootstrap`` forms NCCL EP groups
+ # from consecutive global ranks via ``dp_color = rank // ep_size``, so
+ # only an (outer_fsdp, inner_ep) device layout groups ranks correctly.
+ devices = mesh_utils.create_device_mesh((FSDP_SIZE, EP_SIZE))
+ mesh_obj = Mesh(devices, axis_names=(FSDP_AXIS, EP_AXIS))
+
+ num_procs = jax.process_count()
+ max_tokens_per_rank = (BATCH // num_procs) * SEQ
+ recv_capacity_per_rank = _compute_worst_case_recv_pr()
+
+ # Eager bootstrap: ep_bootstrap does a host-side NCCL UID allgather
+ # and cannot run from inside jax.jit. Sized to the worst-case recv_pr
+ # across _CONFIGS so every parametrized config is bootstrap-compatible.
+ with mesh_obj, global_shard_guard(MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS)):
+ ep_bootstrap(
+ world_size=num_procs,
+ rank=jax.process_index(),
+ ep_size=EP_SIZE,
+ num_experts=NUM_EXPERTS,
+ max_tokens_per_rank=max_tokens_per_rank,
+ recv_capacity_per_rank=recv_capacity_per_rank,
+ hidden_dim=HIDDEN,
+ allow_handle_mem_reloc=True,
+ max_token_dtype=DTYPE,
+ )
+ record_ep_bootstrap_signature_for_moe(
+ num_experts=NUM_EXPERTS,
+ max_tokens_per_rank=max_tokens_per_rank,
+ recv_capacity_per_rank=recv_capacity_per_rank,
+ hidden_dim=HIDDEN,
+ ep_size=EP_SIZE,
+ )
+ return mesh_obj
+
+
+# -----------------------------------------------------------------------------
+# Pure-JAX reference MoE (no EP). Mirrors the exact math of TE's fused
+# router primitive (see tests/jax/test_fused_router.py for the same
+# reference applied to the standalone router kernel):
+#
+# softmax + post-softmax (use_pre_softmax=False, the default):
+# 1. top_k by raw logits
+# 2. softmax over just the K selected logits (so weights sum to 1)
+#
+# sigmoid + optional expert_bias:
+# 1. scores = sigmoid(logits)
+# 2. top_k by (scores + expert_bias) [bias only steers selection]
+# 3. weights = scores at top_k positions, normalized when K > 1
+#
+# Then for both:
+# * weights *= scaling_factor (we leave scaling_factor=1.0 in this
+# suite, matching _make_block's default).
+# * per-expert FFN: silu(layer_w0) * layer_w1 → wo.
+# -----------------------------------------------------------------------------
+
+
+@partial(
+ jax.jit,
+ static_argnames=(
+ "num_experts",
+ "num_experts_per_tok",
+ "aux_loss_coeff",
+ "score_function",
+ ),
+)
+def _pure_jax_moe_reference(
+ x,
+ gate_kernel,
+ wi_0,
+ wi_1,
+ wo,
+ expert_bias=None,
+ *,
+ num_experts,
+ num_experts_per_tok,
+ aux_loss_coeff: float = 0.0,
+ score_function: str = "softmax",
+):
+ B, S, H = x.shape
+ T = B * S
+ K = num_experts_per_tok
+ x_2d = x.reshape(T, H)
+
+ gate_kernel_cast = gate_kernel.astype(x.dtype)
+ logits = (x_2d @ gate_kernel_cast).astype(jnp.float32) # [T, E]
+
+ if score_function == "softmax":
+ # use_pre_softmax=False: topk on raw logits, then softmax over K.
+ top_logits, top_indices = jax.lax.top_k(logits, k=K)
+ weights = jax.nn.softmax(top_logits, axis=-1) # [T, K], sums to 1
+ elif score_function == "sigmoid":
+ scores = jax.nn.sigmoid(logits) # [T, E]
+ if expert_bias is not None and expert_bias.shape != (0,):
+ scores_for_routing = scores + expert_bias.astype(jnp.float32)[None, :]
+ _, top_indices = jax.lax.top_k(scores_for_routing, k=K)
+ weights = jnp.take_along_axis(scores, top_indices, axis=-1)
+ else:
+ weights, top_indices = jax.lax.top_k(scores, k=K)
+ # Sigmoid weights are normalized when K > 1 (matches the kernel).
+ if K > 1:
+ weights = weights / (weights.sum(axis=-1, keepdims=True) + 1e-20)
+ else:
+ raise ValueError(f"Unsupported score_function={score_function!r}")
+
+ routing_weights_full = jnp.zeros((T, num_experts), dtype=jnp.float32)
+ routing_weights_full = routing_weights_full.at[jnp.arange(T)[:, None], top_indices].set(weights)
+
+ # FFN. ``apply_topk_weights_early`` is a fusion knob that doesn't
+ # change the math (wo is linear), so the reference is identical for
+ # both placements.
+ layer_w0 = jnp.einsum("th,ehm->tem", x_2d, wi_0)
+ layer_w1 = jnp.einsum("th,ehm->tem", x_2d, wi_1)
+ intermediate = jax.nn.silu(layer_w0.astype(jnp.float32)) * layer_w1.astype(jnp.float32)
+ intermediate = intermediate.astype(x.dtype)
+ expert_out = jnp.einsum("tem,emh->teh", intermediate, wo) # [T, E, H]
+ output_2d = jnp.einsum("te,teh->th", routing_weights_full.astype(x.dtype), expert_out)
+ output = output_2d.reshape(B, S, H).astype(x.dtype)
+
+ if aux_loss_coeff > 0.0:
+ # tex.fused_moe_aux_loss formula (matches the same
+ # reference_aux_loss helper from test_fused_router.py). The
+ # "aux scores" use the same score_function but always with
+ # K-normalised sigmoid (when sigmoid) / plain softmax (when
+ # softmax) — see tex.fused_topk_with_score_function_fwd with
+ # compute_aux_scores=True.
+ if score_function == "softmax":
+ aux_scores = jax.nn.softmax(logits, axis=-1)
+ else: # sigmoid
+ aux_scores = jax.nn.sigmoid(logits)
+ if K > 1:
+ aux_scores = aux_scores / (aux_scores.sum(axis=-1, keepdims=True) + 1e-20)
+ routing_map = (routing_weights_full > 0).astype(jnp.int32)
+ tokens_per_expert = jnp.sum(routing_map, axis=0) # [E]
+ sum_probs_per_expert = jnp.sum(aux_scores, axis=0) # [E]
+ aux_loss = (num_experts * aux_loss_coeff / (K * (T**2))) * jnp.sum(
+ sum_probs_per_expert * tokens_per_expert.astype(jnp.float32)
+ )
+ aux_loss = aux_loss.astype(x.dtype)
+ else:
+ aux_loss = jnp.zeros((), dtype=x.dtype)
+ return output, aux_loss
+
+
+# -----------------------------------------------------------------------------
+# Helpers
+# -----------------------------------------------------------------------------
+
+
+def _make_block(
+ *,
+ apply_topk_weights_early=False,
+ align_size=0,
+ aux_loss_coeff=0.0,
+ use_expert_bias=False,
+ score_function="softmax",
+ bias_init=None,
+):
+ kwargs = dict(
+ num_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOPK,
+ intermediate_size=INTER,
+ data_parallelism_axes=(FSDP_AXIS,),
+ apply_topk_weights_early=apply_topk_weights_early,
+ align_size=align_size,
+ aux_loss_coeff=aux_loss_coeff,
+ use_expert_bias=use_expert_bias,
+ score_function=score_function,
+ dtype=DTYPE,
+ )
+ # Custom bias_init lets tests inject a non-zero expert_bias without
+ # poking variables['params'] post-init.
+ if bias_init is not None:
+ kwargs["bias_init"] = bias_init
+ return MoEBlock(**kwargs)
+
+
+def _strong_expert_bias_init(key, shape, dtype):
+ """Half +5, half -5 — large enough to force topk onto the +ve half."""
+ del key
+ n = shape[0]
+ return jnp.concatenate(
+ [
+ jnp.full((n // 2,), 5.0, dtype=dtype),
+ jnp.full((n - n // 2,), -5.0, dtype=dtype),
+ ]
+ )
+
+
+def _shard_inputs(x, mesh):
+ # Match the layout moe.py re-pins to: outer dp axes, then ep innermost.
+ return jax.lax.with_sharding_constraint(
+ x, NamedSharding(mesh, P((FSDP_AXIS, EP_AXIS), None, None))
+ )
+
+
+def _ctx(mesh):
+ """Combined mesh + global_shard_guard + axis_rules context."""
+
+ class _Combo:
+ def __enter__(self_inner):
+ self_inner._m = mesh.__enter__()
+ self_inner._gs = global_shard_guard(
+ MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS)
+ )
+ self_inner._gs.__enter__()
+ self_inner._ar = nn_partitioning.axis_rules(LOGICAL_AXIS_RULES)
+ self_inner._ar.__enter__()
+ return self_inner._m
+
+ def __exit__(self_inner, *args):
+ self_inner._ar.__exit__(*args)
+ self_inner._gs.__exit__(*args)
+ mesh.__exit__(*args)
+
+ return _Combo()
+
+
+def _init_apply(block, mesh, x, key):
+ with _ctx(mesh):
+ x_sh = _shard_inputs(x, mesh)
+ variables = jax.jit(block.init)(key, x_sh)
+ jax.block_until_ready(jax.tree_util.tree_leaves(variables)[0])
+ output, aux = jax.jit(block.apply)(variables, x_sh)
+ jax.block_until_ready(output)
+ return variables, output, aux
+
+
+def _grad_step(block, variables, mesh, x, *, include_aux=False):
+ """Run jax.grad of mean(out^2) [+ aux if include_aux] vs params."""
+ with _ctx(mesh):
+ x_sh = _shard_inputs(x, mesh)
+
+ def loss_fn(variables, x):
+ output, aux = block.apply(variables, x)
+ loss = jnp.mean(output.astype(jnp.float32) ** 2)
+ if include_aux and aux is not None:
+ loss = loss + aux.astype(jnp.float32)
+ return loss
+
+ grads = jax.jit(jax.grad(loss_fn))(variables, x_sh)
+ jax.block_until_ready(jax.tree_util.tree_leaves(grads)[0])
+ return grads
+
+
+def _grad_aux_only(block, variables, mesh, x):
+ """Jit'd grad of just the aux loss scalar — proves it reaches the
+ gate even when no main-output contribution is present."""
+ with _ctx(mesh):
+ x_sh = _shard_inputs(x, mesh)
+
+ def aux_only(variables, x):
+ _, aux = block.apply(variables, x)
+ return aux.astype(jnp.float32)
+
+ grads = jax.jit(jax.grad(aux_only))(variables, x_sh)
+ jax.block_until_ready(jax.tree_util.tree_leaves(grads)[0])
+ return grads
+
+
+def _unwrap(x):
+ return x.value if hasattr(x, "value") else x
+
+
+def _to_global_numpy(arr, mesh):
+ """Replicate a sharded JAX array onto every rank and return as numpy.
+
+ Triggers an all-gather inside JIT. The resulting addressable_data(0)
+ contains the full global array on every process, so we can run the
+ pure-JAX reference and compare against it from any process.
+ """
+ rep = NamedSharding(mesh, P())
+ with mesh:
+ full = jax.jit(lambda a: jax.lax.with_sharding_constraint(a, rep))(arr)
+ full.block_until_ready()
+ return np.asarray(jax.device_get(full.addressable_data(0)))
+
+
+def _params_global_numpy(variables, mesh):
+ """Pull every entry of variables['params'] to a replicated numpy array."""
+ params = variables["params"]
+ return {name: _to_global_numpy(_unwrap(p), mesh) for name, p in params.items()}
+
+
+def _make_inputs(key):
+ """Generate a globally-identical input tensor on every process."""
+ return jax.random.normal(key, (BATCH, SEQ, HIDDEN), dtype=DTYPE)
+
+
+# -----------------------------------------------------------------------------
+# Tests
+# -----------------------------------------------------------------------------
+
+
+# -----------------------------------------------------------------------------
+# Parametrize variants exercised by both the forward and the backward
+# parity tests. Each config is one MoE-block configuration the suite
+# wants covered; the test body checks shape, dtype, finiteness AND
+# numerical parity vs the same pure-JAX reference (which understands
+# the same set of knobs).
+# -----------------------------------------------------------------------------
+
+_CONFIGS = [
+ pytest.param(
+ dict(score_function="softmax"),
+ id="softmax",
+ ),
+ pytest.param(
+ dict(score_function="softmax", apply_topk_weights_early=True),
+ id="softmax-topk-early",
+ ),
+ pytest.param(
+ dict(score_function="softmax", align_size=128),
+ id="softmax-align128",
+ ),
+ pytest.param(
+ dict(score_function="sigmoid"),
+ id="sigmoid",
+ ),
+ pytest.param(
+ dict(score_function="sigmoid", use_expert_bias=True),
+ id="sigmoid-bias-zero",
+ ),
+ pytest.param(
+ dict(
+ score_function="sigmoid",
+ use_expert_bias=True,
+ bias_init=_strong_expert_bias_init,
+ ),
+ id="sigmoid-bias-strong",
+ ),
+]
+
+
+def _reference_kwargs_from_config(config, params_np):
+ """Pick out the reference-relevant pieces of a parametrize config."""
+ return dict(
+ score_function=config.get("score_function", "softmax"),
+ expert_bias=(
+ jnp.asarray(params_np["expert_bias"]) if config.get("use_expert_bias", False) else None
+ ),
+ )
+
+
+class TestTeEpMoeForward:
+ """Per-config forward correctness in a single run: shape, dtype,
+ finiteness AND numerical parity vs the pure-JAX reference."""
+
+ @pytest.mark.parametrize("config", _CONFIGS)
+ def test_forward(self, mesh, config):
+ block = _make_block(**config)
+ x = _make_inputs(jax.random.PRNGKey(0))
+ variables, output, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(1))
+
+ # Shape / dtype / finiteness (cheap; on the local shard).
+ assert output.shape == x.shape
+ assert output.dtype == x.dtype
+ out_local = np.asarray(jax.device_get(output.addressable_data(0)))
+ assert np.all(np.isfinite(out_local)), "output has NaN/Inf"
+ assert aux is None, "aux_loss should be None when aux_loss_coeff == 0"
+
+ # Numerical parity (replicated global view -> single rank's numpy).
+ params_np = _params_global_numpy(variables, mesh)
+ x_np = np.asarray(jax.device_get(x))
+ out_te_np = _to_global_numpy(output, mesh)
+
+ out_ref, _ = _pure_jax_moe_reference(
+ jnp.asarray(x_np),
+ jnp.asarray(params_np["gate_kernel"]),
+ jnp.asarray(params_np["wi_0"]),
+ jnp.asarray(params_np["wi_1"]),
+ jnp.asarray(params_np["wo"]),
+ num_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOPK,
+ **_reference_kwargs_from_config(config, params_np),
+ )
+ np.testing.assert_allclose(
+ out_te_np.astype(np.float32),
+ np.asarray(jax.device_get(out_ref)).astype(np.float32),
+ atol=FWD_ATOL,
+ rtol=FWD_RTOL,
+ err_msg=f"forward parity breach for config={config}",
+ )
+
+
+class TestTeEpMoeBackward:
+ """Per-config backward correctness in a single run: per-tensor
+ grads finite, non-zero AND parity vs the pure-JAX reference."""
+
+ @pytest.mark.parametrize("config", _CONFIGS)
+ def test_backward(self, mesh, config):
+ block = _make_block(**config)
+ x = _make_inputs(jax.random.PRNGKey(2))
+ variables, _, _ = _init_apply(block, mesh, x, jax.random.PRNGKey(3))
+ grads_te = _grad_step(block, variables, mesh, x)
+
+ # Reference grads via jax.grad over the pure-JAX MoE with the
+ # same config.
+ params_np = _params_global_numpy(variables, mesh)
+ x_np = np.asarray(jax.device_get(x))
+ ref_kwargs = _reference_kwargs_from_config(config, params_np)
+ ref_expert_bias = ref_kwargs.pop("expert_bias")
+
+ def loss_fn(params, x):
+ out, _ = _pure_jax_moe_reference(
+ x,
+ params["gate_kernel"],
+ params["wi_0"],
+ params["wi_1"],
+ params["wo"],
+ ref_expert_bias,
+ num_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOPK,
+ **ref_kwargs,
+ )
+ return jnp.mean(out.astype(jnp.float32) ** 2)
+
+ grads_ref = jax.jit(jax.grad(loss_fn))(
+ {k: jnp.asarray(v) for k, v in params_np.items() if k != "expert_bias"},
+ jnp.asarray(x_np),
+ )
+ grads_ref_np = {k: np.asarray(jax.device_get(v)) for k, v in grads_ref.items()}
+
+ for name in ("gate_kernel", "wi_0", "wi_1", "wo"):
+ # Per-tensor: finite + non-zero + parity in one pass.
+ g_te = _to_global_numpy(_unwrap(grads_te["params"][name]), mesh)
+ assert np.all(np.isfinite(g_te)), f"{name} grad has NaN/Inf [config={config}]"
+ assert np.any(g_te != 0.0), f"{name} grad identically zero [config={config}]"
+ atol, rtol = (
+ (GRAD_GATE_ATOL, GRAD_GATE_RTOL)
+ if name == "gate_kernel"
+ else (GRAD_FFN_ATOL, GRAD_FFN_RTOL)
+ )
+ np.testing.assert_allclose(
+ g_te.astype(np.float32),
+ grads_ref_np[name].astype(np.float32),
+ atol=atol,
+ rtol=rtol,
+ err_msg=f"grad parity breach on {name} [config={config}]",
+ )
+
+
+class TestTeEpMoeAuxLoss:
+ """Aux-loss path. Consolidated into:
+ * ``test_aux_loss``: one run that checks the returned scalar's
+ shape / dtype / finiteness / magnitude AND numerical parity vs the
+ reference AND that the aux-only bwd propagates to gate_kernel.
+ * ``test_combined_loss_grads``: one run for joint main+aux bwd
+ finite + non-zero per tensor.
+ """
+
+ def test_aux_loss(self, mesh):
+ coeff = 1e-2
+ block = _make_block(aux_loss_coeff=coeff)
+ x = _make_inputs(jax.random.PRNGKey(20))
+ variables, _, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(21))
+
+ # Shape / dtype / finiteness / magnitude.
+ assert aux is not None, "aux_loss should be returned when coeff > 0"
+ assert aux.shape == (), f"aux_loss must be 0-d scalar, got {aux.shape}"
+ assert aux.dtype == DTYPE, f"aux_loss dtype {aux.dtype} != {DTYPE}"
+ aux_np = _to_global_numpy(aux, mesh)
+ assert np.isfinite(aux_np), "aux_loss is NaN/Inf"
+ assert abs(float(aux_np)) < 1e2, f"aux_loss looks unreasonable: {aux_np}"
+
+ # Numerical parity vs the reference.
+ params_np = _params_global_numpy(variables, mesh)
+ x_np = np.asarray(jax.device_get(x))
+ _, aux_ref = _pure_jax_moe_reference(
+ jnp.asarray(x_np),
+ jnp.asarray(params_np["gate_kernel"]),
+ jnp.asarray(params_np["wi_0"]),
+ jnp.asarray(params_np["wi_1"]),
+ jnp.asarray(params_np["wo"]),
+ num_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOPK,
+ aux_loss_coeff=coeff,
+ )
+ np.testing.assert_allclose(
+ float(aux_np),
+ float(jax.device_get(aux_ref)),
+ atol=AUX_ATOL,
+ rtol=AUX_RTOL,
+ )
+
+ # Aux-only bwd must propagate to gate_kernel — proves the
+ # fused_moe_aux_loss_bwd → topk(compute_aux_scores)_bwd chain is
+ # wired.
+ aux_grads = _grad_aux_only(block, variables, mesh, x)
+ g_gate = np.asarray(
+ jax.device_get(_unwrap(aux_grads["params"]["gate_kernel"]).addressable_data(0))
+ )
+ assert np.all(np.isfinite(g_gate)), "gate grad NaN/Inf under aux-only loss"
+ assert np.any(g_gate != 0.0), "aux bwd should propagate to gate_kernel"
+
+ def test_combined_loss_grads(self, mesh):
+ """Joint main + aux loss bwd: per-tensor finite + non-zero in
+ one pass."""
+ block = _make_block(aux_loss_coeff=1e-2)
+ x = _make_inputs(jax.random.PRNGKey(22))
+ variables, _, _ = _init_apply(block, mesh, x, jax.random.PRNGKey(23))
+ grads = _grad_step(block, variables, mesh, x, include_aux=True)
+ for name in ("gate_kernel", "wi_0", "wi_1", "wo"):
+ g_local = np.asarray(jax.device_get(_unwrap(grads["params"][name]).addressable_data(0)))
+ assert np.all(np.isfinite(g_local)), f"{name} grad NaN/Inf under main+aux"
+ assert np.any(g_local != 0.0), f"{name} grad zero under main+aux"
+
+
+class TestTeEpMoEBlockFlax:
+ """Flax wrapper end-to-end in one run: shape/dtype/finiteness on the
+ forward, numerical parity vs the same reference, and per-tensor
+ grad finiteness + non-zeroness."""
+
+ def test_init_apply_parity(self, mesh):
+ block = _make_block()
+ x = _make_inputs(jax.random.PRNGKey(12))
+ variables, output, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(13))
+
+ assert aux is None
+ assert output.shape == x.shape
+ assert output.dtype == x.dtype
+ out_local = np.asarray(jax.device_get(output.addressable_data(0)))
+ assert np.all(np.isfinite(out_local))
+
+ params_np = _params_global_numpy(variables, mesh)
+ x_np = np.asarray(jax.device_get(x))
+ out_te_np = _to_global_numpy(output, mesh)
+ out_ref, _ = _pure_jax_moe_reference(
+ jnp.asarray(x_np),
+ jnp.asarray(params_np["gate_kernel"]),
+ jnp.asarray(params_np["wi_0"]),
+ jnp.asarray(params_np["wi_1"]),
+ jnp.asarray(params_np["wo"]),
+ num_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOPK,
+ )
+ np.testing.assert_allclose(
+ out_te_np.astype(np.float32),
+ np.asarray(jax.device_get(out_ref)).astype(np.float32),
+ atol=FWD_ATOL,
+ rtol=FWD_RTOL,
+ )
+
+ grads = _grad_step(block, variables, mesh, x)
+ for name in ("gate_kernel", "wi_0", "wi_1", "wo"):
+ g_local = np.asarray(jax.device_get(_unwrap(grads["params"][name]).addressable_data(0)))
+ assert np.all(np.isfinite(g_local)), f"{name} grad NaN/Inf"
+ assert np.any(g_local != 0.0), f"{name} grad zero"
+
+
+# Keep the bootstrap-signature test last in the module (the "ZZZ" prefix
+# ensures pytest's alphabetic class ordering picks it last): it
+# intentionally mismatches the NCCL EP bootstrap signature, which
+# permanently taints the per-process bootstrap cache for the rest of
+# the file.
+class TestZZZTeEpMoeBootstrap:
+ """Per-process NCCL bootstrap re-bootstrap rejection."""
+
+ def test_bootstrap_signature_mismatch_raises(self, mesh):
+ block_a = _make_block()
+ x_a = _make_inputs(jax.random.PRNGKey(14))
+ _init_apply(block_a, mesh, x_a, jax.random.PRNGKey(15))
+
+ # Different hidden dim → different bootstrap signature.
+ bigger_hidden = HIDDEN * 2
+ x_b = jax.random.normal(jax.random.PRNGKey(16), (BATCH, SEQ, bigger_hidden), dtype=DTYPE)
+ block_b = MoEBlock(
+ num_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOPK,
+ intermediate_size=INTER,
+ data_parallelism_axes=(FSDP_AXIS,),
+ dtype=DTYPE,
+ )
+ with pytest.raises(ValueError, match="bootstrapped"):
+ _init_apply(block_b, mesh, x_b, jax.random.PRNGKey(17))
diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt
index 8f96432ed8..18c4af7b09 100644
--- a/transformer_engine/common/CMakeLists.txt
+++ b/transformer_engine/common/CMakeLists.txt
@@ -437,6 +437,96 @@ if (NVTE_WITH_CUSOLVERMP)
message(STATUS "Using cuSolverMp at: ${CUSOLVERMP_DIR}")
endif()
+# ── NCCL EP (on by default, HT mode only) ─────────────────────────────────
+# Set -DNVTE_WITH_NCCL_EP=OFF (or NVTE_BUILD_WITH_NCCL_EP=0 in setup.py) to
+# skip NCCL EP entirely — useful on older images whose system NCCL is below
+# the 2.30.4 EP minimum.
+option(NVTE_WITH_NCCL_EP "Build NCCL EP into libtransformer_engine.so" ON)
+if(NVTE_WITH_NCCL_EP)
+# SM>=90 and NCCL>=2.30.4 are gated at runtime in EPBackend::initialize.
+# ── NCCL EP headers ────────────────────────────────────────────────────────
+# Headers + libs are produced by the in-tree 3rdparty/nccl submodule build
+# (auto-built by setup.py via build_nccl_ep_submodule).
+set(NCCL_EP_SUBMODULE_ROOT
+ "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl")
+set(NCCL_EP_INCLUDE_DIR "${NCCL_EP_SUBMODULE_ROOT}/contrib/nccl_ep/include")
+if(NOT EXISTS "${NCCL_EP_INCLUDE_DIR}/nccl_ep.h")
+ message(FATAL_ERROR
+ "NCCL EP header not found at ${NCCL_EP_INCLUDE_DIR}/nccl_ep.h. "
+ "Run `git submodule update --init --recursive` to checkout 3rdparty/nccl.")
+endif()
+message(STATUS "NCCL EP headers: ${NCCL_EP_INCLUDE_DIR}")
+
+# ── libnccl_ep.so ──────────────────────────────────────────────────────────
+set(NCCL_EP_LIB_DIR "${NCCL_EP_SUBMODULE_ROOT}/build/lib")
+find_library(NCCL_EP_LIB
+ NAMES nccl_ep libnccl_ep
+ HINTS ${NCCL_EP_LIB_DIR}
+ NO_DEFAULT_PATH
+ REQUIRED)
+
+# ── NCCL + GIN headers ─────────────────────────────────────────────────────
+# libnccl.so and all GIN headers (ncclGin.h, ncclWindow_t, ncclDevComm_t)
+# ship with the base CUDA Toolkit OR the 3rdparty/nccl submodule build
+# (preferred when present; auto-built by setup.py via build_nccl_ep_submodule).
+if(NOT NCCL_LIB)
+ find_library(NCCL_LIB
+ NAMES nccl libnccl
+ HINTS ${NCCL_EP_LIB_DIR} ${CUDAToolkit_LIBRARY_DIR}
+ PATH_SUFFIXES lib lib64
+ REQUIRED)
+endif()
+
+set(NCCL_SUBMODULE_INCLUDE
+ "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/build/include")
+if(EXISTS "${NCCL_SUBMODULE_INCLUDE}/nccl.h")
+ set(NCCL_INCLUDE_DIRS_FOR_TE ${NCCL_SUBMODULE_INCLUDE})
+else()
+ set(NCCL_INCLUDE_DIRS_FOR_TE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
+endif()
+
+# Diagnostic: log detected NCCL header version (minimum enforced at runtime).
+find_file(_nvte_nccl_header_path nccl.h
+ PATHS ${NCCL_INCLUDE_DIRS_FOR_TE} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
+ NO_DEFAULT_PATH)
+if(_nvte_nccl_header_path)
+ file(READ "${_nvte_nccl_header_path}" _nvte_nccl_h)
+ string(REGEX MATCH "#define[ \t]+NCCL_MAJOR[ \t]+([0-9]+)" _ "${_nvte_nccl_h}")
+ set(_nvte_nccl_major "${CMAKE_MATCH_1}")
+ string(REGEX MATCH "#define[ \t]+NCCL_MINOR[ \t]+([0-9]+)" _ "${_nvte_nccl_h}")
+ set(_nvte_nccl_minor "${CMAKE_MATCH_1}")
+ string(REGEX MATCH "#define[ \t]+NCCL_PATCH[ \t]+([0-9]+)" _ "${_nvte_nccl_h}")
+ set(_nvte_nccl_patch "${CMAKE_MATCH_1}")
+ if(_nvte_nccl_major AND _nvte_nccl_minor AND _nvte_nccl_patch)
+ message(STATUS "NCCL header: ${_nvte_nccl_header_path} (version ${_nvte_nccl_major}.${_nvte_nccl_minor}.${_nvte_nccl_patch})")
+ endif()
+endif()
+
+target_include_directories(transformer_engine PRIVATE
+ ${NCCL_EP_INCLUDE_DIR}
+ ${NCCL_INCLUDE_DIRS_FOR_TE}) # covers nccl.h + nccl_device/
+
+target_link_libraries(transformer_engine PUBLIC
+ ${NCCL_EP_LIB}
+ ${NCCL_LIB})
+
+# Embed rpath so the installed wheel finds libnccl_ep.so at runtime.
+# libnccl.so is already on the system via the Toolkit — no rpath needed for it.
+set_target_properties(transformer_engine PROPERTIES
+ INSTALL_RPATH "$ORIGIN;${NCCL_EP_LIB_DIR}")
+
+target_sources(transformer_engine PRIVATE
+ ep/ep_backend.cpp
+ ep/ep_api.cpp)
+
+message(STATUS "NCCL EP enabled: ${NCCL_EP_LIB}")
+message(STATUS "NCCL EP include: ${NCCL_EP_INCLUDE_DIR}")
+else()
+ # NCCL EP off: export throwing nvte_ep_* stubs so framework bindings link.
+ target_sources(transformer_engine PRIVATE ep/ep_api_stub.cpp)
+ message(STATUS "NCCL EP disabled (NVTE_WITH_NCCL_EP=OFF) — using nvte_ep_* stubs")
+endif()
+
# Number of philox4x32 rounds for stochastic rounding (build-time constant).
set(NVTE_BUILD_NUM_PHILOX_ROUNDS_STR $ENV{NVTE_BUILD_NUM_PHILOX_ROUNDS})
if (NOT NVTE_BUILD_NUM_PHILOX_ROUNDS_STR)
diff --git a/transformer_engine/common/ep/ep_api.cpp b/transformer_engine/common/ep/ep_api.cpp
new file mode 100644
index 0000000000..89d8b38607
--- /dev/null
+++ b/transformer_engine/common/ep/ep_api.cpp
@@ -0,0 +1,76 @@
+/*************************************************************************
+ * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+/*! \file ep_api.cpp
+ * \brief nvte_ep_* C API: thin delegations to the EPBackend singleton.
+ */
+
+#include
+#include
+
+#include "../common.h"
+#include "../util/logging.h"
+#include "ep_backend.h"
+
+using transformer_engine::ep::EPBackend;
+
+void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config) {
+ NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null");
+ EPBackend::initialize(static_cast(ep_comm), group_config);
+}
+
+void nvte_ep_shutdown(void) { EPBackend::shutdown(); }
+
+uint64_t nvte_ep_register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size) {
+ NVTE_CHECK(handle_mem_size != nullptr, "handle_mem_size must not be null");
+ return EPBackend::get().register_layer(layer_config, handle_mem_size);
+}
+
+void nvte_ep_prepare(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor token_counts,
+ size_t dispatch_output_per_expert_alignment, cudaStream_t stream) {
+ void* mem_ptr = nvte_tensor_data(handle.mem);
+ NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null");
+ EPBackend::get().prepare(handle.id, topk_idx, token_counts, mem_ptr,
+ dispatch_output_per_expert_alignment, stream);
+}
+
+void nvte_ep_dispatch(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor tokens,
+ NVTECommWindow tokens_win, NVTETensor topk_weights,
+ NVTECommWindow topk_weights_win, NVTETensor recv_tokens,
+ NVTECommWindow recv_tokens_win, NVTETensor recv_topk_weights,
+ NVTECommWindow recv_topk_weights_win, cudaStream_t stream) {
+ void* mem_ptr = nvte_tensor_data(handle.mem);
+ NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null");
+ EPBackend::get().dispatch(handle.id, mem_ptr, topk_idx, tokens, tokens_win, topk_weights,
+ topk_weights_win, recv_tokens, recv_tokens_win, recv_topk_weights,
+ recv_topk_weights_win, stream);
+}
+
+void nvte_ep_combine(NVTEEpHandle handle, NVTETensor expert_out, NVTECommWindow expert_out_win,
+ NVTETensor result, cudaStream_t stream) {
+ void* mem_ptr = nvte_tensor_data(handle.mem);
+ NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null");
+ EPBackend::get().combine(handle.id, mem_ptr, expert_out, expert_out_win, result, stream);
+}
+
+void nvte_ep_dispatch_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win,
+ NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win,
+ NVTETensor grad_tokens, NVTETensor grad_topk_weights,
+ cudaStream_t stream) {
+ void* mem_ptr = nvte_tensor_data(handle.mem);
+ NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null");
+ EPBackend::get().dispatch_bwd(handle.id, mem_ptr, grad, grad_win, g_recv_topk_weights,
+ g_recv_topk_weights_win, grad_tokens, grad_topk_weights, stream);
+}
+
+void nvte_ep_combine_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win,
+ NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win,
+ cudaStream_t stream) {
+ void* mem_ptr = nvte_tensor_data(handle.mem);
+ NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null");
+ EPBackend::get().combine_bwd(handle.id, mem_ptr, grad, grad_win, grad_expert_out,
+ grad_expert_out_win, stream);
+}
diff --git a/transformer_engine/common/ep/ep_api_stub.cpp b/transformer_engine/common/ep/ep_api_stub.cpp
new file mode 100644
index 0000000000..fe4127d87d
--- /dev/null
+++ b/transformer_engine/common/ep/ep_api_stub.cpp
@@ -0,0 +1,61 @@
+/*************************************************************************
+ * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+/*! \file ep_api_stub.cpp
+ * \brief Throwing nvte_ep_* stubs compiled when NVTE_WITH_NCCL_EP=OFF.
+ */
+
+#include
+
+#include "../util/logging.h"
+
+namespace {
+[[noreturn]] void ep_not_built() {
+ NVTE_ERROR(
+ "NCCL EP is not built into this TransformerEngine. Rebuild TE with "
+ "NVTE_BUILD_WITH_NCCL_EP=1 and CUDA arch >= 90 (e.g. NVTE_CUDA_ARCHS=\"90\").");
+}
+} // namespace
+
+void nvte_ep_initialize(void* /*ep_comm*/, NVTEEpGroupConfig /*group_config*/) { ep_not_built(); }
+
+void nvte_ep_shutdown(void) {}
+
+uint64_t nvte_ep_register_layer(NVTEEpLayerConfig /*layer_config*/, size_t* /*handle_mem_size*/) {
+ ep_not_built();
+}
+
+void nvte_ep_prepare(NVTEEpHandle /*handle*/, NVTETensor /*topk_idx*/, NVTETensor /*token_counts*/,
+ size_t /*dispatch_output_per_expert_alignment*/, cudaStream_t /*stream*/) {
+ ep_not_built();
+}
+
+void nvte_ep_dispatch(NVTEEpHandle /*handle*/, NVTETensor /*topk_idx*/, NVTETensor /*tokens*/,
+ NVTECommWindow /*tokens_win*/, NVTETensor /*topk_weights*/,
+ NVTECommWindow /*topk_weights_win*/, NVTETensor /*recv_tokens*/,
+ NVTECommWindow /*recv_tokens_win*/, NVTETensor /*recv_topk_weights*/,
+ NVTECommWindow /*recv_topk_weights_win*/, cudaStream_t /*stream*/) {
+ ep_not_built();
+}
+
+void nvte_ep_combine(NVTEEpHandle /*handle*/, NVTETensor /*expert_out*/,
+ NVTECommWindow /*expert_out_win*/, NVTETensor /*result*/,
+ cudaStream_t /*stream*/) {
+ ep_not_built();
+}
+
+void nvte_ep_dispatch_bwd(NVTEEpHandle /*handle*/, NVTETensor /*grad*/, NVTECommWindow /*grad_win*/,
+ NVTETensor /*g_recv_topk_weights*/,
+ NVTECommWindow /*g_recv_topk_weights_win*/, NVTETensor /*grad_tokens*/,
+ NVTETensor /*grad_topk_weights*/, cudaStream_t /*stream*/) {
+ ep_not_built();
+}
+
+void nvte_ep_combine_bwd(NVTEEpHandle /*handle*/, NVTETensor /*grad*/, NVTECommWindow /*grad_win*/,
+ NVTETensor /*grad_expert_out*/, NVTECommWindow /*grad_expert_out_win*/,
+ cudaStream_t /*stream*/) {
+ ep_not_built();
+}
diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp
new file mode 100644
index 0000000000..a5ae99b089
--- /dev/null
+++ b/transformer_engine/common/ep/ep_backend.cpp
@@ -0,0 +1,513 @@
+/*************************************************************************
+ * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+/*! \file ep_backend.cpp
+ * \brief EPBackend implementation. See ep_backend.h for the op flow.
+ */
+
+#include "ep_backend.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "../common.h"
+#include "../util/cuda_runtime.h"
+#include "../util/logging.h"
+
+namespace transformer_engine {
+namespace ep {
+
+namespace {
+
+// Build a by-value ncclEpTensor_t descriptor. `sizes` is caller-owned and must
+// outlive any NCCL EP call that consumes the descriptor.
+inline ncclEpTensor_t make_tensor(void* data, unsigned int ndim, ncclDataType_t datatype,
+ size_t* sizes) {
+ ncclEpTensor_t t = NCCL_EP_TENSOR_INIT;
+ t.ndim = ndim;
+ t.datatype = datatype;
+ t.data = data;
+ t.sizes = sizes;
+ return t;
+}
+
+// Payload descriptor: prefer the symmem window when set, else fall back to the
+// NVTETensor's raw device pointer.
+inline ncclEpTensor_t make_payload_tensor(const NVTETensor t, const NVTECommWindow& win,
+ unsigned int ndim, ncclDataType_t datatype,
+ size_t* sizes) {
+ ncclEpTensor_t desc = NCCL_EP_TENSOR_INIT;
+ desc.ndim = ndim;
+ desc.datatype = datatype;
+ desc.sizes = sizes;
+ if (win.window != nullptr) {
+ desc.win_hdl = win.window;
+ desc.win_offset = win.offset;
+ } else {
+ desc.data = nvte_tensor_data(t);
+ NVTE_CHECK(desc.data != nullptr, "payload tensor data must not be null");
+ }
+ return desc;
+}
+
+} // namespace
+
+// ---------------------------------------------------------------------------
+// Singleton + bootstrap
+// ---------------------------------------------------------------------------
+
+EPBackend& EPBackend::instance() {
+ static EPBackend inst;
+ return inst;
+}
+
+EPBackend& EPBackend::get() {
+ EPBackend& inst = instance();
+ NVTE_CHECK(inst.initialized_, "EPBackend not initialized. Call nvte_ep_initialize() first.");
+ return inst;
+}
+
+void EPBackend::validate_config(const NVTEEpGroupConfig& config) {
+ NVTE_CHECK(config.ep_size > 0, "ep_size must be positive, got ", config.ep_size);
+ NVTE_CHECK(config.num_experts > 0, "num_experts must be positive, got ", config.num_experts);
+ NVTE_CHECK(config.max_tokens_per_rank > 0, "max_tokens_per_rank must be positive, got ",
+ config.max_tokens_per_rank);
+ NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ",
+ config.max_recv_tokens_per_rank);
+ NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim);
+ NVTE_CHECK(config.max_token_dtype >= 0 && config.max_token_dtype < kNVTENumTypes,
+ "max_token_dtype out of range, got ", static_cast(config.max_token_dtype));
+ const size_t elem_bytes = typeToSize(static_cast(config.max_token_dtype));
+ NVTE_CHECK(config.hidden_dim * elem_bytes >= 16,
+ "hidden_dim * sizeof(max_token_dtype) must be >= 16 (NCCL EP 16B row alignment); "
+ "got hidden_dim=",
+ config.hidden_dim, ", element_bytes=", elem_bytes);
+ NVTE_CHECK(config.num_experts % config.ep_size == 0, "num_experts (", config.num_experts,
+ ") must be divisible by ep_size (", config.ep_size, ")");
+ NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ",
+ config.max_num_sms);
+
+ int device, major;
+ NVTE_CHECK_CUDA(cudaGetDevice(&device));
+ NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
+ NVTE_CHECK(major >= 9,
+ "NCCL EP requires SM_90+ (Hopper or later), "
+ "but current device has compute capability ",
+ major, ".x");
+
+ // NCCL EP needs CUDA multicast (NVLS); init hangs without it.
+ NVTE_CHECK(cuda::supports_multicast(device),
+ "NCCL EP requires CUDA multicast (NVLS) support on device ", device,
+ " but CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED reports 0.");
+}
+
+void EPBackend::initialize(ncclComm_t ep_comm, NVTEEpGroupConfig config) {
+ EPBackend& inst = instance();
+ std::lock_guard lock(inst.mutex_);
+ NVTE_CHECK(!inst.initialized_, "EP already initialized. Call initialize only once per process.");
+ NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null");
+
+ // Runtime gate: NCCL >= 2.30.4 (matches the submodule pin).
+ constexpr int kMinNcclVersion = 23004;
+ int nccl_version = 0;
+ NVTE_CHECK_NCCL(ncclGetVersion(&nccl_version));
+ NVTE_CHECK(nccl_version >= kMinNcclVersion, "NCCL EP requires NCCL >= 2.30.4, found ",
+ nccl_version / 10000, ".", (nccl_version / 100) % 100, ".", nccl_version % 100,
+ " at runtime.");
+
+ validate_config(config);
+
+ int comm_size = 0;
+ NVTE_CHECK_NCCL(ncclCommCount(ep_comm, &comm_size));
+ NVTE_CHECK(comm_size == config.ep_size, "ep_comm size (", comm_size, ") must equal ep_size (",
+ config.ep_size, "). Pass the EP sub-communicator, not the world comm.");
+
+ inst.init(ep_comm, config);
+}
+
+void EPBackend::shutdown() {
+ EPBackend& inst = instance();
+ std::lock_guard lock(inst.mutex_);
+ if (!inst.initialized_) return;
+ for (auto& kv : inst.handles_) {
+ if (kv.second.cached_handle != nullptr) {
+ ncclEpHandleDestroy(kv.second.cached_handle);
+ kv.second.cached_handle = nullptr;
+ kv.second.cached_handle_mem = nullptr;
+ }
+ }
+ inst.handles_.clear();
+ // ncclEpGroupDestroy reads from ep_comm_; destroy group while comm is still alive.
+ if (inst.ep_group_ != nullptr) {
+ ncclEpGroupDestroy(inst.ep_group_);
+ inst.ep_group_ = nullptr;
+ }
+ inst.ep_comm_ = nullptr; // borrowed — caller destroys
+ inst.initialized_ = false;
+}
+
+// ---------------------------------------------------------------------------
+// Helpers
+// ---------------------------------------------------------------------------
+
+ncclDataType_t EPBackend::nvte_dtype_to_nccl(NVTEDType dtype) {
+ switch (dtype) {
+ case kNVTEFloat32:
+ return ncclFloat32;
+ case kNVTEFloat16:
+ return ncclFloat16;
+ case kNVTEBFloat16:
+ return ncclBfloat16;
+ case kNVTEInt32:
+ return ncclInt32;
+ case kNVTEInt64:
+ return ncclInt64;
+ case kNVTEByte:
+ return ncclUint8;
+ case kNVTEFloat8E4M3:
+ return ncclFloat8e4m3;
+ case kNVTEFloat8E5M2:
+ return ncclFloat8e5m2;
+ default:
+ NVTE_ERROR("Unsupported NVTEDType for NCCL EP conversion: ", static_cast(dtype));
+ }
+ return ncclFloat32; // unreachable
+}
+
+// Open a fresh ncclEpHandle over handle_mem. Caller (or cache) owns the result.
+ncclEpHandle_t EPBackend::open_handle(void* handle_mem, size_t handle_mem_size, int num_topk,
+ size_t dispatch_output_per_expert_alignment) {
+ size_t hm_sizes[1] = {handle_mem_size};
+ ncclEpTensor_t routing_desc = make_tensor(handle_mem, 1, ncclUint8, hm_sizes);
+ ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT;
+ hcfg.dispatch_output_per_expert_alignment = dispatch_output_per_expert_alignment;
+ ncclEpHandle_t handle;
+ NVTE_CHECK_NCCL(ncclEpInitHandle(&handle, ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, num_topk,
+ &routing_desc));
+ return handle;
+}
+
+// ---------------------------------------------------------------------------
+// Lifecycle
+// ---------------------------------------------------------------------------
+
+// Static-dtor teardown: skip NCCL calls (CUDA context / borrowed ep_comm_ may
+// already be gone) and release in-memory state only.
+EPBackend::~EPBackend() {
+ std::lock_guard lock(mutex_);
+ if (!initialized_) return;
+ handles_.clear();
+ ep_group_ = nullptr;
+ ep_comm_ = nullptr;
+ initialized_ = false;
+}
+
+void EPBackend::init(ncclComm_t ep_comm, NVTEEpGroupConfig group_config) {
+ NVTE_CHECK(!initialized_, "EPBackend already initialized");
+
+ group_config_ = group_config;
+
+ ncclEpGroupConfig_t cfg = NCCL_EP_GROUP_CONFIG_INIT;
+ cfg.algorithm = NCCL_EP_ALGO_HIGH_THROUGHPUT;
+ cfg.num_experts = static_cast(group_config.num_experts);
+ cfg.max_dispatch_tokens_per_rank = static_cast(group_config.max_tokens_per_rank);
+ const size_t elem_bytes = typeToSize(static_cast(group_config.max_token_dtype));
+ cfg.max_token_bytes = static_cast(group_config.hidden_dim * elem_bytes);
+ cfg.rdma_buffer_size = NCCL_EP_AUTO;
+ cfg.num_qp_per_rank = NCCL_EP_AUTO;
+ cfg.num_channels = NCCL_EP_AUTO;
+ cfg.max_num_sms = group_config.max_num_sms > 0
+ ? static_cast(group_config.max_num_sms)
+ : NCCL_EP_AUTO;
+ // Must be > 0; NCCL EP errors out on 0.
+ cfg.max_recv_tokens_per_rank = static_cast(group_config.max_recv_tokens_per_rank);
+
+ NVTE_CHECK_NCCL(ncclEpCreateGroup(&ep_group_, ep_comm, &cfg));
+
+ ep_comm_ = ep_comm;
+
+ initialized_ = true;
+}
+
+// ---------------------------------------------------------------------------
+// Per-handle_id config cache
+// ---------------------------------------------------------------------------
+
+uint64_t EPBackend::insert_new_entry(size_t handle_mem_size, int top_k, size_t alignment) {
+ if (handle_cache_cap_ == 0) {
+ const char* cap_env = std::getenv("NVTE_EP_HANDLE_CACHE_SIZE");
+ handle_cache_cap_ = (cap_env != nullptr) ? std::max(1, std::atoi(cap_env)) : 8192;
+ }
+ NVTE_CHECK(handles_.size() < handle_cache_cap_, "EP handle cache full (", handle_cache_cap_,
+ " entries). Raise via NVTE_EP_HANDLE_CACHE_SIZE.");
+ uint64_t id = next_handle_id_.fetch_add(1, std::memory_order_relaxed);
+ handles_.emplace(id, HandleEntry{handle_mem_size, alignment, top_k});
+ return id;
+}
+
+EPBackend::HandleEntry& EPBackend::lookup_config(uint64_t handle_id) {
+ auto it = handles_.find(handle_id);
+ NVTE_CHECK(it != handles_.end(), "ep op on handle_id=", handle_id,
+ " with no cached config — call ep_prepare first.");
+ return it->second;
+}
+
+ncclEpHandle_t EPBackend::get_or_open_handle(HandleEntry& cfg, void* handle_mem) {
+ if (cfg.cached_handle != nullptr && cfg.cached_handle_mem == handle_mem) {
+ return cfg.cached_handle;
+ }
+ if (cfg.cached_handle != nullptr) {
+ NVTE_CHECK(group_config_.allow_handle_mem_reloc != 0,
+ "EP handle_mem relocated for cached handle (old=",
+ reinterpret_cast(cfg.cached_handle_mem),
+ ", new=", reinterpret_cast(handle_mem),
+ "). Set NVTEEpGroupConfig.allow_handle_mem_reloc=1 to allow rebuild.");
+ ncclEpHandleDestroy(cfg.cached_handle);
+ cfg.cached_handle = nullptr;
+ cfg.cached_handle_mem = nullptr;
+ }
+ ncclEpHandle_t h = open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment);
+ cfg.cached_handle = h;
+ cfg.cached_handle_mem = handle_mem;
+ return h;
+}
+
+// ---------------------------------------------------------------------------
+// Per-step operations
+// ---------------------------------------------------------------------------
+
+uint64_t EPBackend::register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size) {
+ NVTE_CHECK(initialized_, "EPBackend not initialized");
+ NVTE_CHECK(layer_config.top_k > 0, "NVTEEpLayerConfig.top_k must be > 0");
+ NVTE_CHECK(handle_mem_size != nullptr, "handle_mem_size must not be null");
+ ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT;
+ hcfg.dispatch_output_per_expert_alignment = layer_config.dispatch_output_per_expert_alignment;
+ size_t hm_size = 0;
+ NVTE_CHECK_NCCL(ncclEpHandleMemSize(ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, &hm_size,
+ layer_config.top_k));
+ *handle_mem_size = hm_size;
+ std::lock_guard lock(mutex_);
+ return insert_new_entry(hm_size, layer_config.top_k,
+ layer_config.dispatch_output_per_expert_alignment);
+}
+
+void EPBackend::prepare(uint64_t handle_id, const NVTETensor topk_idx, NVTETensor token_counts,
+ void* handle_mem, size_t dispatch_output_per_expert_alignment,
+ cudaStream_t stream) {
+ NVTE_CHECK(initialized_, "EPBackend not initialized");
+ NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null");
+
+ NVTEShape idx_shape = nvte_tensor_shape(topk_idx);
+ void* idx_data = nvte_tensor_data(topk_idx);
+ NVTE_CHECK(idx_data != nullptr, "topk_idx data must not be null");
+
+ const size_t num_tokens = idx_shape.data[0];
+ const size_t top_k = idx_shape.ndim > 1 ? idx_shape.data[1] : 1;
+ const size_t num_local_experts =
+ static_cast(group_config_.num_experts / group_config_.ep_size);
+
+ size_t idx_sizes[2] = {num_tokens, top_k};
+ ncclEpTensor_t nccl_topk_idx = make_tensor(idx_data, 2, ncclInt64, idx_sizes);
+
+ // ncclEpUpdateHandle writes per-expert counts via expert_counters.
+ size_t cnt_sizes[1] = {num_local_experts};
+ ncclEpTensor_t token_counts_desc;
+ void* token_counts_data = (token_counts != nullptr) ? nvte_tensor_data(token_counts) : nullptr;
+ if (token_counts_data != nullptr) {
+ token_counts_desc = make_tensor(token_counts_data, 1, ncclInt32, cnt_sizes);
+ }
+ ncclEpLayoutInfo_t layout_info = NCCL_EP_LAYOUT_INFO_INIT;
+ layout_info.expert_counters = (token_counts_data != nullptr) ? &token_counts_desc : nullptr;
+
+ std::lock_guard lock(mutex_);
+ HandleEntry& cfg = lookup_config(handle_id);
+ NVTE_CHECK(cfg.alignment == dispatch_output_per_expert_alignment,
+ "ep_prepare: alignment mismatch for handle_id=", handle_id, " (cached=", cfg.alignment,
+ ", got=", dispatch_output_per_expert_alignment, ")");
+ ncclEpHandle_t h = get_or_open_handle(cfg, handle_mem);
+ NVTE_CHECK_NCCL(ncclEpUpdateHandle(h, &nccl_topk_idx, &layout_info, stream));
+}
+
+void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor topk_idx,
+ const NVTETensor tokens, const NVTECommWindow& tokens_win,
+ const NVTETensor topk_weights, const NVTECommWindow& topk_weights_win,
+ NVTETensor recv_tokens, const NVTECommWindow& recv_tokens_win,
+ NVTETensor recv_topk_weights, const NVTECommWindow& recv_topk_weights_win,
+ cudaStream_t stream) {
+ NVTE_CHECK(initialized_, "EPBackend not initialized");
+ NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null");
+
+ NVTEShape tok_shape = nvte_tensor_shape(tokens);
+ NVTEDType tok_dtype = nvte_tensor_type(tokens);
+ NVTE_CHECK(typeToSize(static_cast(tok_dtype)) <=
+ typeToSize(static_cast(group_config_.max_token_dtype)),
+ "tokens dtype (", static_cast(tok_dtype), ") wider than group max_token_dtype (",
+ static_cast(group_config_.max_token_dtype), ")");
+
+ const size_t num_tokens = tok_shape.data[0];
+ const size_t hidden_dim = tok_shape.data[1];
+
+ size_t tok_sizes[2] = {num_tokens, hidden_dim};
+ ncclEpTensor_t nccl_tokens_in =
+ make_payload_tensor(tokens, tokens_win, 2, nvte_dtype_to_nccl(tok_dtype), tok_sizes);
+
+ const bool is_forward = (topk_weights != nullptr);
+
+ // Routing is cached in handle_mem by ep_prepare; dispatch only needs
+ // topk_weights to reconstruct the sparse-to-dense prob map.
+ size_t weights_in_sizes[2] = {0, 0};
+ ncclEpTensor_t nccl_topk_weights_in;
+ if (is_forward) {
+ NVTE_CHECK(topk_idx != nullptr, "topk_idx required in forward dispatch");
+ NVTEShape idx_shape = nvte_tensor_shape(topk_idx);
+ const size_t top_k = idx_shape.ndim > 1 ? idx_shape.data[1] : 1;
+ weights_in_sizes[0] = num_tokens;
+ weights_in_sizes[1] = top_k;
+ nccl_topk_weights_in =
+ make_payload_tensor(topk_weights, topk_weights_win, 2, ncclFloat32, weights_in_sizes);
+ }
+
+ NVTEShape recv_shape = nvte_tensor_shape(recv_tokens);
+ NVTEDType recv_dtype = nvte_tensor_type(recv_tokens);
+ NVTE_CHECK(typeToSize(static_cast(recv_dtype)) <=
+ typeToSize(static_cast(group_config_.max_token_dtype)),
+ "recv_tokens dtype (", static_cast(recv_dtype),
+ ") wider than group max_token_dtype (",
+ static_cast(group_config_.max_token_dtype), ")");
+
+ size_t recv_sizes[2] = {recv_shape.data[0], recv_shape.data[1]};
+ ncclEpTensor_t nccl_tokens_out = make_payload_tensor(recv_tokens, recv_tokens_win, 2,
+ nvte_dtype_to_nccl(recv_dtype), recv_sizes);
+
+ size_t weights_out_sizes[1] = {recv_shape.data[0]};
+ ncclEpTensor_t nccl_topk_weights_out;
+ if (is_forward) {
+ NVTE_CHECK(recv_topk_weights != nullptr,
+ "recv_topk_weights must not be null in forward dispatch");
+ NVTEShape recv_w_shape = nvte_tensor_shape(recv_topk_weights);
+ NVTE_CHECK(recv_w_shape.ndim == 1, "recv_topk_weights must be 1D [recv_capacity]");
+ nccl_topk_weights_out = make_payload_tensor(recv_topk_weights, recv_topk_weights_win, 1,
+ ncclFloat32, weights_out_sizes);
+ }
+
+ ncclEpDispatchInputs_t in_struct = NCCL_EP_DISPATCH_INPUTS_INIT;
+ in_struct.tokens = &nccl_tokens_in;
+ in_struct.topk_weights = is_forward ? &nccl_topk_weights_in : nullptr;
+
+ ncclEpDispatchOutputs_t out_struct = NCCL_EP_DISPATCH_OUTPUTS_INIT;
+ out_struct.tokens = &nccl_tokens_out;
+ out_struct.topk_weights = is_forward ? &nccl_topk_weights_out : nullptr;
+
+ ncclEpDispatchConfig_t dispatch_cfg = NCCL_EP_DISPATCH_CONFIG_INIT;
+ dispatch_cfg.pass_direction = is_forward ? NCCL_EP_FWD_PASS : NCCL_EP_BWD_PASS;
+
+ std::lock_guard lock(mutex_);
+ HandleEntry& cfg = lookup_config(handle_id);
+ ncclEpHandle_t h = get_or_open_handle(cfg, handle_mem);
+ NVTE_CHECK_NCCL(ncclEpDispatch(h, &in_struct, &out_struct,
+ /*layout_info=*/nullptr, &dispatch_cfg, stream));
+}
+
+void EPBackend::combine(uint64_t handle_id, void* handle_mem, const NVTETensor expert_out,
+ const NVTECommWindow& expert_out_win, NVTETensor result,
+ cudaStream_t stream) {
+ NVTE_CHECK(initialized_, "EPBackend not initialized");
+ NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null");
+
+ NVTEShape exp_shape = nvte_tensor_shape(expert_out);
+ NVTEDType exp_dtype = nvte_tensor_type(expert_out);
+
+ size_t exp_sizes[2] = {exp_shape.data[0], exp_shape.data[1]};
+ ncclEpTensor_t nccl_expert_in =
+ make_payload_tensor(expert_out, expert_out_win, 2, nvte_dtype_to_nccl(exp_dtype), exp_sizes);
+
+ NVTEShape res_shape = nvte_tensor_shape(result);
+ void* res_data = nvte_tensor_data(result);
+ NVTEDType res_dtype = nvte_tensor_type(result);
+ NVTE_CHECK(res_data != nullptr, "result data must not be null");
+
+ size_t res_sizes[2] = {res_shape.data[0], res_shape.data[1]};
+ ncclEpTensor_t nccl_result_out =
+ make_tensor(res_data, 2, nvte_dtype_to_nccl(res_dtype), res_sizes);
+
+ ncclEpCombineInputs_t in_struct = NCCL_EP_COMBINE_INPUTS_INIT;
+ in_struct.tokens = &nccl_expert_in;
+
+ ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT;
+ out_struct.tokens = &nccl_result_out;
+
+ std::lock_guard lock(mutex_);
+ HandleEntry& cfg = lookup_config(handle_id);
+ ncclEpHandle_t h = get_or_open_handle(cfg, handle_mem);
+ NVTE_CHECK_NCCL(ncclEpCombine(h, &in_struct, &out_struct, /*config=*/nullptr, stream));
+}
+
+void EPBackend::dispatch_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad,
+ const NVTECommWindow& grad_win, const NVTETensor g_recv_topk_weights,
+ const NVTECommWindow& g_recv_topk_weights_win, NVTETensor grad_tokens,
+ NVTETensor grad_topk_weights, cudaStream_t stream) {
+ NVTE_CHECK(initialized_, "EPBackend not initialized");
+ NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null");
+
+ NVTEShape g_shape = nvte_tensor_shape(grad);
+ NVTEDType g_dtype = nvte_tensor_type(grad);
+ size_t g_sizes[2] = {g_shape.data[0], g_shape.data[1]};
+ ncclEpTensor_t nccl_tok_in =
+ make_payload_tensor(grad, grad_win, 2, nvte_dtype_to_nccl(g_dtype), g_sizes);
+
+ // g_recv_topk_weights must be 1D [recv_capacity] — caller flattens.
+ NVTEShape gw_shape = nvte_tensor_shape(g_recv_topk_weights);
+ NVTE_CHECK(gw_shape.ndim == 1,
+ "g_recv_topk_weights must be 1D [recv_capacity]; caller must flatten leading dims");
+ size_t gw_sizes[1] = {gw_shape.data[0]};
+ ncclEpTensor_t nccl_w_in =
+ make_payload_tensor(g_recv_topk_weights, g_recv_topk_weights_win, 1, ncclFloat32, gw_sizes);
+
+ NVTEShape gt_shape = nvte_tensor_shape(grad_tokens);
+ void* gt_data = nvte_tensor_data(grad_tokens);
+ NVTE_CHECK(gt_data != nullptr, "grad_tokens data must not be null");
+ size_t gt_sizes[2] = {gt_shape.data[0], gt_shape.data[1]};
+ ncclEpTensor_t nccl_tok_out = make_tensor(gt_data, 2, nvte_dtype_to_nccl(g_dtype), gt_sizes);
+
+ NVTEShape gtw_shape = nvte_tensor_shape(grad_topk_weights);
+ void* gtw_data = nvte_tensor_data(grad_topk_weights);
+ NVTE_CHECK(gtw_data != nullptr, "grad_topk_weights data must not be null");
+ NVTE_CHECK(gtw_shape.ndim == 2, "grad_topk_weights must be 2D [T, top_k]");
+ size_t gtw_sizes[2] = {gtw_shape.data[0], gtw_shape.data[1]};
+ ncclEpTensor_t nccl_w_out = make_tensor(gtw_data, 2, ncclFloat32, gtw_sizes);
+
+ ncclEpCombineInputs_t in_struct = NCCL_EP_COMBINE_INPUTS_INIT;
+ in_struct.tokens = &nccl_tok_in;
+ in_struct.topk_weights = &nccl_w_in;
+
+ ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT;
+ out_struct.tokens = &nccl_tok_out;
+ out_struct.topk_weights = &nccl_w_out;
+
+ ncclEpCombineConfig_t cfg = NCCL_EP_COMBINE_CONFIG_INIT;
+ cfg.pass_direction = NCCL_EP_BWD_PASS;
+
+ std::lock_guard lock(mutex_);
+ HandleEntry& entry = lookup_config(handle_id);
+ ncclEpHandle_t h = get_or_open_handle(entry, handle_mem);
+ NVTE_CHECK_NCCL(ncclEpCombine(h, &in_struct, &out_struct, &cfg, stream));
+}
+
+void EPBackend::combine_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad,
+ const NVTECommWindow& grad_win, NVTETensor grad_expert_out,
+ const NVTECommWindow& grad_expert_out_win, cudaStream_t stream) {
+ // Backward of combine = reverse-direction dispatch.
+ dispatch(handle_id, handle_mem, /*topk_idx=*/nullptr, grad, grad_win, /*topk_weights=*/nullptr,
+ /*topk_weights_win=*/NVTECommWindow{}, grad_expert_out, grad_expert_out_win,
+ /*recv_topk_weights=*/nullptr, /*recv_topk_weights_win=*/NVTECommWindow{}, stream);
+}
+
+} // namespace ep
+} // namespace transformer_engine
diff --git a/transformer_engine/common/ep/ep_backend.h b/transformer_engine/common/ep/ep_backend.h
new file mode 100644
index 0000000000..e82c974c3f
--- /dev/null
+++ b/transformer_engine/common/ep/ep_backend.h
@@ -0,0 +1,122 @@
+/*************************************************************************
+ * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+/*! \file ep_backend.h
+ * \brief Internal NCCL EP singleton; not part of the public API.
+ *
+ * Per handle_id the cache stores config only (no device pointers), so
+ * handle_mem may be relocated between ops. Cap: NVTE_EP_HANDLE_CACHE_SIZE
+ * (default 8192); overflow throws.
+ */
+
+#ifndef TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_
+#define TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+namespace transformer_engine {
+namespace ep {
+
+/*! \brief EP backend singleton — owns the NCCL EP group; borrows the comm. */
+class EPBackend {
+ public:
+ /*! \brief Access the singleton. Aborts if not initialized. */
+ static EPBackend& get();
+
+ /*! \brief Bootstrap from an existing EP sub-communicator.
+ * ep_comm is borrowed; the caller keeps it alive until shutdown() returns
+ * and must span exactly config.ep_size ranks.
+ */
+ static void initialize(ncclComm_t ep_comm, NVTEEpGroupConfig config);
+
+ /*! \brief Tear down the backend. Idempotent. Does not destroy ep_comm_. */
+ static void shutdown();
+
+ // Host-only: reserve a fresh handle_id, cache the layer config, and report
+ // the handle_mem buffer size the caller must allocate.
+ uint64_t register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size);
+
+ void prepare(uint64_t handle_id, const NVTETensor topk_idx, NVTETensor token_counts,
+ void* handle_mem, size_t dispatch_output_per_expert_alignment, cudaStream_t stream);
+
+ void dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor topk_idx,
+ const NVTETensor tokens, const NVTECommWindow& tokens_win,
+ const NVTETensor topk_weights, const NVTECommWindow& topk_weights_win,
+ NVTETensor recv_tokens, const NVTECommWindow& recv_tokens_win,
+ NVTETensor recv_topk_weights, const NVTECommWindow& recv_topk_weights_win,
+ cudaStream_t stream);
+
+ void combine(uint64_t handle_id, void* handle_mem, const NVTETensor expert_out,
+ const NVTECommWindow& expert_out_win, NVTETensor result, cudaStream_t stream);
+
+ // g_recv_topk_weights: 1D [recv_capacity] f32; grad_topk_weights: 2D [T, top_k] f32.
+ void dispatch_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad,
+ const NVTECommWindow& grad_win, const NVTETensor g_recv_topk_weights,
+ const NVTECommWindow& g_recv_topk_weights_win, NVTETensor grad_tokens,
+ NVTETensor grad_topk_weights, cudaStream_t stream);
+
+ void combine_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad,
+ const NVTECommWindow& grad_win, NVTETensor grad_expert_out,
+ const NVTECommWindow& grad_expert_out_win, cudaStream_t stream);
+
+ private:
+ EPBackend() = default;
+ ~EPBackend();
+ EPBackend(const EPBackend&) = delete;
+ EPBackend& operator=(const EPBackend&) = delete;
+
+ // ep_comm is borrowed — caller retains ownership across the backend lifetime.
+ void init(ncclComm_t ep_comm, NVTEEpGroupConfig config);
+
+ static EPBackend& instance(); // Meyers singleton accessor
+ static void validate_config(const NVTEEpGroupConfig& config);
+
+ static ncclDataType_t nvte_dtype_to_nccl(NVTEDType dtype);
+ // Open a transient ncclEpHandle over handle_mem. num_topk=-1 for paths
+ // that don't carry per-token weights.
+ ncclEpHandle_t open_handle(void* handle_mem, size_t handle_mem_size, int num_topk,
+ size_t dispatch_output_per_expert_alignment);
+
+ ncclEpGroup_t ep_group_{nullptr};
+ ncclComm_t ep_comm_{nullptr};
+ NVTEEpGroupConfig group_config_{};
+ bool initialized_{false};
+ std::mutex mutex_;
+ struct HandleEntry {
+ size_t handle_mem_size;
+ size_t alignment;
+ int top_k;
+ // Persistent ncclEpHandle bound to cached_handle_mem. Lazily opened on first
+ // op; reused while handle_mem ptr is unchanged. Destroyed in shutdown().
+ ncclEpHandle_t cached_handle{nullptr};
+ void* cached_handle_mem{nullptr};
+ };
+ std::unordered_map handles_;
+ std::atomic next_handle_id_{1}; // 0 reserved as "no id"
+ size_t handle_cache_cap_{0}; // set lazily from NVTE_EP_HANDLE_CACHE_SIZE
+
+ // Caller must hold mutex_. Throws on cap overflow.
+ uint64_t insert_new_entry(size_t handle_mem_size, int top_k, size_t alignment);
+ HandleEntry& lookup_config(uint64_t handle_id);
+ // Caller must hold mutex_. Returns the cached handle if handle_mem matches.
+ // On mismatch: if group_config_.allow_handle_mem_reloc != 0, destroys the
+ // stale handle and opens a fresh one; otherwise throws.
+ ncclEpHandle_t get_or_open_handle(HandleEntry& cfg, void* handle_mem);
+};
+
+} // namespace ep
+} // namespace transformer_engine
+
+#endif // TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_
diff --git a/transformer_engine/common/include/transformer_engine/comm_window.h b/transformer_engine/common/include/transformer_engine/comm_window.h
new file mode 100644
index 0000000000..088ea7f0c3
--- /dev/null
+++ b/transformer_engine/common/include/transformer_engine/comm_window.h
@@ -0,0 +1,32 @@
+/*************************************************************************
+ * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+/*! \file comm_window.h
+ * \brief Borrowed symmetric-memory window + offset for zero-copy one-sided ops.
+ * Pass ``{NULL, 0}`` to use the raw-pointer path.
+ */
+
+#ifndef TRANSFORMER_ENGINE_COMM_WINDOW_H_
+#define TRANSFORMER_ENGINE_COMM_WINDOW_H_
+
+#include
+#include
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/*! \brief NCCL window + byte offset for a zero-copy payload tensor. */
+typedef struct {
+ ncclWindow_t window; /*!< NCCL window, or NULL to use the raw data pointer. */
+ uint64_t offset; /*!< Byte offset of the payload within ``window``. */
+} NVTECommWindow;
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // TRANSFORMER_ENGINE_COMM_WINDOW_H_
diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h
new file mode 100644
index 0000000000..22e7ec48ac
--- /dev/null
+++ b/transformer_engine/common/include/transformer_engine/ep.h
@@ -0,0 +1,177 @@
+/*************************************************************************
+ * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+/*! \file ep.h
+ * \brief Public C API for Expert Parallelism. Per-step ops are allocation-free
+ * and CUDA graph-capturable.
+ */
+
+#ifndef TRANSFORMER_ENGINE_EP_H_
+#define TRANSFORMER_ENGINE_EP_H_
+
+#include
+#include
+#include
+#include
+#include
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/* ── Config structs ─────────────────────────────────────────────────────── */
+/* TODO: add a struct_size/version field to these configs (and align with other
+ * TE public structs) once a TE-wide convention for ABI versioning lands. */
+
+/*! \brief Group-level EP configuration (fixed for the EP group lifetime). */
+typedef struct {
+ int ep_size; /*!< EP world size. */
+ int num_experts; /*!< Total experts across all ranks. */
+ int max_tokens_per_rank; /*!< Upper bound on tokens this rank sends per dispatch. */
+ /*! Upper bound on tokens received per dispatch (worst-case top_k fan-out; must be > 0). */
+ int max_recv_tokens_per_rank;
+ int hidden_dim; /*!< Token hidden dimension. */
+ int max_num_sms; /*!< Max SMs for EP kernels. 0 = auto. */
+ /*! 0 (default): throw on relocated handle_mem for a cached handle_id. 1: silently rebuild. */
+ int allow_handle_mem_reloc;
+ /*! Widest token dtype the group will dispatch. Sizes NCCL EP staging buffers
+ * at group create. Tensors passed to nvte_ep_dispatch may use any dtype whose
+ * element size is <= sizeof(max_token_dtype). */
+ NVTEDType max_token_dtype;
+} NVTEEpGroupConfig;
+
+/*! \brief Per-layer EP configuration. */
+typedef struct {
+ int num_local_experts; /*!< Reserved for ABI stability (derived from group config). */
+ int top_k; /*!< Per-token expert fan-out. Required. */
+ size_t dispatch_output_per_expert_alignment;
+ /*!< Per-expert zone alignment in tokens (pow2; 0/1 = no padding). Must match
+ * between nvte_ep_register_layer and nvte_ep_prepare. */
+} NVTEEpLayerConfig;
+
+/* ── Bootstrap ──────────────────────────────────────────────────────────── */
+
+/*! \brief Bootstrap from an existing NCCL EP sub-communicator. Requires SM>=90.
+ *
+ * ep_comm is borrowed and must span exactly group_config.ep_size ranks.
+ * The caller retains ownership and must keep ep_comm alive until
+ * nvte_ep_shutdown() returns; destroying it earlier is undefined behavior.
+ * Re-init after shutdown is allowed; double-init throws.
+ *
+ * One EP group per process, bound to the current CUDA device at initialize
+ * time. Multiple GPUs per process are not supported.
+ *
+ * \param[in] ep_comm Opaque ncclComm_t for the EP sub-group.
+ * \param[in] group_config Group-level EP configuration.
+ */
+void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config);
+
+/*! \brief Tear down the EP backend. Idempotent. Does not destroy ep_comm. */
+void nvte_ep_shutdown(void);
+
+/* ── Layer registration (host-only, eager) ───────────────────────────────── */
+
+/*! \brief Reserve a handle_id for a layer config and report the handle_mem buffer
+ * size the caller must allocate. Host-only.
+ *
+ * Registration is intended to be static (once per layer at model init). There is
+ * no per-layer unregister API; all registrations are released by nvte_ep_shutdown.
+ * Re-registering the same layer config each step is not supported and will
+ * eventually exhaust the handle cache (NVTE_EP_HANDLE_CACHE_SIZE, default 8192).
+ *
+ * \param[in] layer_config Per-layer EP configuration.
+ * \param[out] handle_mem_size Bytes the caller must allocate for handle_mem.
+ * \return uint64_t handle_id (non-zero).
+ */
+uint64_t nvte_ep_register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size);
+
+/*! \brief Per-step handle: the registered handle_id paired with its handle_mem buffer. */
+typedef struct {
+ uint64_t id; /*!< Handle id from nvte_ep_register_layer. */
+ NVTETensor mem; /*!< Caller-allocated handle_mem buffer (size from nvte_ep_register_layer). */
+} NVTEEpHandle;
+
+/* ── Per-step ops (all allocation-free, CUDA graph-capturable) ──────────── */
+
+/*! \brief AllGather the routing map; write per-expert counts and cache routing
+ * metadata in handle.mem for the subsequent dispatch/combine.
+ *
+ * \param[in] handle EP handle (id + mem buffer).
+ * \param[in] topk_idx [T, top_k] int64 routing indices.
+ * \param[out] token_counts [num_local_experts] int32 counts.
+ * \param[in] dispatch_output_per_expert_alignment Must match the handle_mem sizing.
+ * \param[in] stream CUDA stream.
+ */
+void nvte_ep_prepare(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor token_counts,
+ size_t dispatch_output_per_expert_alignment, cudaStream_t stream);
+
+/*! \brief Dispatch tokens (and routing weights) to expert ranks.
+ *
+ * \param[in] handle EP handle (id + mem buffer).
+ * \param[in] topk_idx [T, top_k] int64 sparse routing indices.
+ * \param[in] tokens [T, hidden_dim] input tokens.
+ * \param[in] tokens_win Optional symmem window for ``tokens``.
+ * \param[in] topk_weights [T, top_k] float32 weights, or null in backward.
+ * \param[in] topk_weights_win Optional symmem window for ``topk_weights``.
+ * \param[out] recv_tokens [recv_T, hidden_dim] received tokens.
+ * \param[in] recv_tokens_win Optional symmem window for ``recv_tokens``.
+ * \param[out] recv_topk_weights [recv_T] float32 per-slot weights, or null in backward.
+ * \param[in] recv_topk_weights_win Optional symmem window for ``recv_topk_weights``.
+ * \param[in] stream CUDA stream.
+ */
+void nvte_ep_dispatch(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor tokens,
+ NVTECommWindow tokens_win, NVTETensor topk_weights,
+ NVTECommWindow topk_weights_win, NVTETensor recv_tokens,
+ NVTECommWindow recv_tokens_win, NVTETensor recv_topk_weights,
+ NVTECommWindow recv_topk_weights_win, cudaStream_t stream);
+
+/*! \brief Scatter-sum expert outputs back to originating ranks. Unweighted —
+ * caller must pre-multiply expert_out by recv_topk_weights (and the
+ * valid-slot mask) before calling.
+ *
+ * \param[in] handle EP handle (id + mem buffer).
+ * \param[in] expert_out [recv_T, hidden_dim] pre-weighted expert outputs.
+ * \param[in] expert_out_win Optional symmem window for ``expert_out``.
+ * \param[out] result [T, hidden_dim] combined output.
+ * \param[in] stream CUDA stream.
+ */
+void nvte_ep_combine(NVTEEpHandle handle, NVTETensor expert_out, NVTECommWindow expert_out_win,
+ NVTETensor result, cudaStream_t stream);
+
+/*! \brief Backward of dispatch — routes token and weight grads back to source.
+ *
+ * \param[in] handle EP handle (id + mem buffer).
+ * \param[in] grad [recv_capacity, hidden_dim] grad w.r.t. recv_tokens.
+ * \param[in] grad_win Optional symmem window for ``grad``.
+ * \param[in] g_recv_topk_weights [recv_capacity] f32 grad w.r.t. recv_topk_weights.
+ * \param[in] g_recv_topk_weights_win Optional symmem window for ``g_recv_topk_weights``.
+ * \param[out] grad_tokens [T, hidden_dim] grad w.r.t. tokens.
+ * \param[out] grad_topk_weights [T, top_k] f32 grad w.r.t. topk_weights.
+ * \param[in] stream CUDA stream.
+ */
+void nvte_ep_dispatch_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win,
+ NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win,
+ NVTETensor grad_tokens, NVTETensor grad_topk_weights,
+ cudaStream_t stream);
+
+/*! \brief Backward of combine. Padded slots in grad_expert_out are zeroed.
+ *
+ * \param[in] handle EP handle (id + mem buffer).
+ * \param[in] grad [T, hidden_dim] grad w.r.t. result.
+ * \param[in] grad_win Optional symmem window for ``grad``.
+ * \param[out] grad_expert_out [recv_capacity, hidden_dim] grad w.r.t. expert_out.
+ * \param[in] grad_expert_out_win Optional symmem window for ``grad_expert_out``.
+ * \param[in] stream CUDA stream.
+ */
+void nvte_ep_combine_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win,
+ NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win,
+ cudaStream_t stream);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // TRANSFORMER_ENGINE_EP_H_
diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h
index da8b9b377d..3308bd22e4 100644
--- a/transformer_engine/common/util/logging.h
+++ b/transformer_engine/common/util/logging.h
@@ -98,6 +98,14 @@
} \
} while (false)
+#define NVTE_CHECK_NCCL(expr) \
+ do { \
+ const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \
+ if (status_NVTE_CHECK_NCCL != ncclSuccess) { \
+ NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \
+ } \
+ } while (false)
+
#ifdef NVTE_WITH_CUBLASMP
#define NVTE_CHECK_CUBLASMP(expr) \
diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py
index fe1f93dc7a..604da5e1b7 100644
--- a/transformer_engine/jax/cpp_extensions/__init__.py
+++ b/transformer_engine/jax/cpp_extensions/__init__.py
@@ -10,4 +10,5 @@
from .softmax import *
from .gemm import *
from .router import *
+from .ep import *
from .topk import *
diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py
index 6eb588c849..2cdef4bfe7 100644
--- a/transformer_engine/jax/cpp_extensions/base.py
+++ b/transformer_engine/jax/cpp_extensions/base.py
@@ -266,6 +266,17 @@ def _gspmd_wrapper(*args, **kwargs):
for _name, _value in transformer_engine_jax.registrations().items():
ffi.register_ffi_target(_name, _value, platform="CUDA")
+# Register EpInstanceState (no-op when TE is built without NCCL EP).
+if hasattr(transformer_engine_jax, "get_ep_instance_state_type_id"):
+ ffi.register_ffi_type(
+ "EpInstanceState",
+ {
+ "type_id": transformer_engine_jax.get_ep_instance_state_type_id(),
+ "type_info": transformer_engine_jax.get_ep_instance_state_type_info(),
+ },
+ platform="CUDA",
+ )
+
def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False):
"""
diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py
new file mode 100644
index 0000000000..7f8a05dbb8
--- /dev/null
+++ b/transformer_engine/jax/cpp_extensions/ep.py
@@ -0,0 +1,957 @@
+# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+"""JAX/TE custom ops for Expert Parallelism (EP).
+
+Sharding model:
+ - EpPrepare / EpDispatch outputs carry a single leading ``num_procs`` dim.
+ Sharded compound ``(dp_resource, ep_resource)`` when DP is set, else
+ ``ep_resource`` alone.
+ - EpDispatch inputs are 2D ``[T, H]`` or 3D ``[B, S, H]``; only the first
+ dim may be sharded, with axis in {ep, (dp, ep), dp, None}. Trailing dims
+ must be replicated. ``dp`` alone gets ``ep`` folded in locally.
+ - EpCombine output sharding comes from ``out_sharding`` or defaults to the
+ compound ``(dp, ep)`` axis on the leading dim.
+"""
+
+from dataclasses import dataclass
+
+import jax
+import jax.numpy as jnp
+from jax import dtypes, ffi
+from jax.experimental.compute_on import compute_on
+from jax.sharding import NamedSharding, PartitionSpec
+
+import transformer_engine_jax
+from .base import BasePrimitive, register_primitive
+from ..sharding import global_mesh_resource
+
+__all__ = [
+ "EpConfig",
+ "EpHandle",
+ "set_ep_config",
+ "get_ep_config",
+ "get_ep_num_local_experts",
+ "ep_allocate_handle_id",
+ "ep_make_handle",
+ "ep_prepare",
+ "ep_dispatch_fwd",
+ "ep_combine_fwd",
+ "ep_dispatch_bwd",
+ "ep_combine_bwd",
+]
+
+
+# ── Module-level EP config ──────────────────────────────────────────────────
+
+
+@dataclass(frozen=True)
+class EpConfig:
+ """Immutable Python view of the EP bootstrap config (see ep_bootstrap)."""
+
+ world_size: int
+ rank: int
+ ep_size: int
+ num_experts: int
+ num_local_experts: int
+ max_tokens_per_rank: int
+ recv_capacity_per_rank: int
+ hidden_dim: int
+
+
+_ep_config: EpConfig = None
+
+
+def set_ep_config(config: EpConfig) -> None:
+ """Cache the EP config for abstract-eval / sharding helpers. Call once."""
+ global _ep_config
+ _ep_config = config
+
+
+def get_ep_config() -> EpConfig:
+ if _ep_config is None:
+ raise RuntimeError("EpConfig has not been set. Did you call ep_bootstrap()?")
+ return _ep_config
+
+
+def get_ep_num_local_experts() -> int:
+ return get_ep_config().num_local_experts
+
+
+# handle_id -> handle_mem buffer size in bytes.
+_HANDLE_MEM_SIZE_BY_ID: dict = {}
+
+
+def ep_allocate_handle_id(top_k: int, dispatch_output_per_expert_alignment: int = 0) -> int:
+ """Low-level: reserve a fresh handle_id. Prefer ``ep_make_handle``."""
+ handle_id, handle_mem_size = transformer_engine_jax.ep_register_layer(
+ int(top_k), int(dispatch_output_per_expert_alignment)
+ )
+ handle_id = int(handle_id)
+ _HANDLE_MEM_SIZE_BY_ID[handle_id] = int(handle_mem_size)
+ return handle_id
+
+
+@dataclass(frozen=True)
+class EpHandle:
+ """Per-layer EP config + routing-slot identity.
+
+ Carries static layer config and a ``handle_id`` that pins the C++ routing
+ slot across re-traces. Allocate via ``ep_make_handle``; distinct layers
+ must hold distinct handles.
+ """
+
+ handle_id: int
+ top_k: int
+ dispatch_output_per_expert_alignment: int = 0
+
+
+def ep_make_handle(top_k: int, dispatch_output_per_expert_alignment: int = 0) -> EpHandle:
+ """Allocate a per-layer EP handle.
+
+ Call once per logical MoE layer at model init (outside ``jax.jit``), then
+ pass the same handle into every ``ep_dispatch`` / ``ep_combine`` for that
+ layer. The handle's ``handle_id`` survives re-traces, ``jax.checkpoint``
+ rematerialization, and separate inference/training compilations.
+ """
+ handle_id = ep_allocate_handle_id(top_k, dispatch_output_per_expert_alignment)
+ return EpHandle(
+ handle_id=handle_id,
+ top_k=int(top_k),
+ dispatch_output_per_expert_alignment=int(dispatch_output_per_expert_alignment),
+ )
+
+
+def _ep_handle_mem_size(handle_id: int) -> int:
+ """Return the handle_mem byte size for an id from ep_allocate_handle_id."""
+ try:
+ return _HANDLE_MEM_SIZE_BY_ID[int(handle_id)]
+ except KeyError as e:
+ raise RuntimeError(
+ f"handle_id={handle_id} not registered; call ep_allocate_handle_id first."
+ ) from e
+
+
+def _leading_axis_ok(spec, ep_axis, outer_axes=()):
+ # Only the first dim may carry sharding; remaining dims must be replicated.
+ # The first dim's axis must be one of:
+ # ``ep_axis`` alone,
+ # a tuple of dp/fsdp axes (no ep — ep gets sliced in locally),
+ # a tuple ending in ``ep_axis`` with dp/fsdp axes before it.
+ # Examples on a (dp, ep) mesh: 2D ``(ep, None)``, ``(("dp","ep"), None)``,
+ # ``("dp", None)``; 3D ``(ep, None, None)``, ``(("dp","ep"), None, None)``,
+ # ``("dp", None, None)``.
+ if len(spec) < 2 or ep_axis is None:
+ return False
+ if any(ax is not None for ax in spec[1:]):
+ return False # only first dim sharded
+ leading = spec[0]
+ allowed_outers = {a for a in outer_axes if a is not None}
+ allowed = allowed_outers | {ep_axis, None}
+ elts = leading if isinstance(leading, tuple) else (leading,)
+ return all(a in allowed for a in elts)
+
+
+def _canonical_input_spec(spec, ndim):
+ """Canonical input PartitionSpec the primitive demands JAX deliver.
+
+ Sharding lives entirely on the first dim. If ``spec[0]`` already includes
+ ``ep_resource``, returned unchanged. Otherwise ``ep_resource`` is folded
+ into the first-dim axis tuple, e.g. ``"dp"`` → ``("dp","ep")``. The added
+ ep axis is a local slice (the missing dim was replicated), no cross-device
+ comm.
+ """
+ gsr = global_mesh_resource()
+ ep = gsr.ep_resource
+ leading = spec[0]
+ present = leading if isinstance(leading, tuple) else (leading,) if leading is not None else ()
+ if ep in present:
+ return PartitionSpec(*spec)
+ if leading is None:
+ new_leading = ep
+ elif isinstance(leading, tuple):
+ new_leading = (*leading, ep)
+ else:
+ new_leading = (leading, ep)
+ return PartitionSpec(new_leading, *([None] * (ndim - 1)))
+
+
+def _dispatch_input_outer_axes():
+ """dp/fsdp axes allowed as outer companions to ep_resource on dispatch input."""
+ gsr = global_mesh_resource()
+ return tuple(a for a in (gsr.dp_resource, gsr.fsdp_resource) if a is not None)
+
+
+def _ep_outer_axis():
+ """The single dp/fsdp axis (if any) sitting outside ep on EP-output tensors.
+
+ When set, EP-output globals carry an extra leading ``dp_size`` dim so SPMD
+ sees each DP color's slab as distinct (rather than replicated across DP).
+ """
+ gsr = global_mesh_resource()
+ return gsr.dp_resource or gsr.fsdp_resource
+
+
+def _ep_leading_dims(is_outer):
+ """Single leading dim of an EP-output tensor: ``(dp*ep,)`` (or ``(ep,)`` when
+ DP is unset) globally; ``(1,)`` per shard."""
+ cfg = get_ep_config()
+ outer = _ep_outer_axis()
+ if not is_outer:
+ return (1,)
+ return (cfg.world_size,) if outer is not None else (cfg.ep_size,)
+
+
+def _ep_output_spec(*trailing):
+ """PartitionSpec for an EP-output tensor: ``(("dp","ep"), *trailing)`` when
+ DP is set (compound leading axis on a single dim), else ``("ep",*trailing)``."""
+ gsr = global_mesh_resource()
+ outer = _ep_outer_axis()
+ if outer is None:
+ return PartitionSpec(gsr.ep_resource, *trailing)
+ return PartitionSpec((outer, gsr.ep_resource), *trailing)
+
+
+def _ep_spec_ok(spec, trailing_count):
+ """Accept ``(ep, *[None])`` (no DP) or ``((dp,ep), *[None])`` /
+ ``(("dp",), *[None])`` / ``("dp", *[None])`` / ``(None, *[None])`` (with DP)
+ on an EP-output tensor's single leading dim. JAX may collapse a size-1
+ mesh axis to ``None`` (matters for dp_size=1 like 1x4)."""
+ gsr = global_mesh_resource()
+ ep_axis = gsr.ep_resource
+ outer = _ep_outer_axis()
+ expected_len = 1 + trailing_count
+ if len(spec) != expected_len:
+ return False
+ if any(ax is not None for ax in spec[1:]):
+ return False
+ leading = spec[0]
+ if outer is None:
+ return leading == ep_axis
+ allowed = {ep_axis, outer, None}
+ elts = leading if isinstance(leading, tuple) else (leading,)
+ return all(a in allowed for a in elts)
+
+
+# ── ep_prepare ──────────────────────────────────────────────────────────────
+
+
+class EpPreparePrimitive(BasePrimitive):
+ name = "te_ep_prepare_ffi"
+ multiple_results = True
+ impl_static_args = (1, 2, 3) # handle_id, dispatch_output_per_expert_alignment, is_outer
+ inner_primitive = None
+ outer_primitive = None
+
+ @staticmethod
+ def abstract(topk_idx_aval, *, handle_id, dispatch_output_per_expert_alignment, is_outer):
+ # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with
+ # no DP); False: per-shard = (1,).
+ del dispatch_output_per_expert_alignment
+ cfg = get_ep_config()
+ num_local_experts = cfg.num_local_experts
+ assert (
+ len(topk_idx_aval.shape) >= 2
+ ), f"topk_idx must be at least 2D [..., top_k], got shape {topk_idx_aval.shape}"
+ handle_mem_size = _ep_handle_mem_size(handle_id)
+ leading = _ep_leading_dims(is_outer)
+ token_counts_aval = jax.core.ShapedArray(leading + (num_local_experts,), jnp.int32)
+ handle_mem_aval = jax.core.ShapedArray(leading + (handle_mem_size,), jnp.uint8)
+ # FFI scratch for the int32 -> int64 topk_idx upcast. int32 with last
+ # dim doubled to keep the int64 byte count without JAX_ENABLE_X64.
+ # TODO(phuong): drop once NCCL EP supports int32 topk_idx.
+ workspace_shape = topk_idx_aval.shape[:-1] + (topk_idx_aval.shape[-1] * 2,)
+ workspace_aval = jax.core.ShapedArray(workspace_shape, jnp.int32)
+ return token_counts_aval, handle_mem_aval, workspace_aval
+
+ @staticmethod
+ def outer_abstract(topk_idx_aval, *, handle_id, dispatch_output_per_expert_alignment, is_outer):
+ del is_outer
+ avals = EpPreparePrimitive.abstract(
+ topk_idx_aval,
+ handle_id=handle_id,
+ dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment,
+ is_outer=True,
+ )
+ return avals[:2]
+
+ @staticmethod
+ def lowering(ctx, topk_idx, *, handle_id, dispatch_output_per_expert_alignment, is_outer):
+ del is_outer
+ return ffi.ffi_lowering(EpPreparePrimitive.name)(
+ ctx,
+ topk_idx,
+ handle_id=int(handle_id),
+ dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment,
+ )
+
+ @staticmethod
+ def impl(topk_idx, handle_id, dispatch_output_per_expert_alignment, is_outer):
+ assert EpPreparePrimitive.inner_primitive is not None
+ token_counts, handle_mem, _workspace = EpPreparePrimitive.inner_primitive.bind(
+ topk_idx,
+ handle_id=handle_id,
+ dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment,
+ is_outer=is_outer,
+ )
+ return token_counts, handle_mem
+
+ @staticmethod
+ def batcher(
+ batched_args, batch_dims, *, handle_id, dispatch_output_per_expert_alignment, is_outer
+ ):
+ raise NotImplementedError("EpPreparePrimitive does not support vmap")
+
+ @staticmethod
+ def partition(
+ handle_id, dispatch_output_per_expert_alignment, is_outer, mesh, arg_infos, result_infos
+ ):
+ del is_outer, result_infos
+ gsr = global_mesh_resource()
+ ep_axis = gsr.ep_resource
+ outer_axes = _dispatch_input_outer_axes()
+ idx_spec = arg_infos[0].sharding.spec
+ if not _leading_axis_ok(idx_spec, ep_axis, outer_axes):
+ raise NotImplementedError(
+ "EpPrepare: topk_idx leading dims must shard on ep_resource"
+ f" ('{ep_axis}') and/or {outer_axes}, with the topk dim replicated;"
+ f" got spec={idx_spec}."
+ )
+ idx_ndim = len(arg_infos[0].shape)
+ arg_shardings = (NamedSharding(mesh, _canonical_input_spec(idx_spec, idx_ndim)),)
+ tc_sharding = NamedSharding(mesh, _ep_output_spec(None))
+ hm_sharding = NamedSharding(mesh, _ep_output_spec(None))
+
+ def sharded_impl(topk_idx):
+ return EpPreparePrimitive.impl(
+ topk_idx, handle_id, dispatch_output_per_expert_alignment, False
+ )
+
+ return mesh, sharded_impl, (tc_sharding, hm_sharding), arg_shardings
+
+ @staticmethod
+ def shardy_sharding_rule(*args):
+ # Signature: (*static_args, mesh, value_types, result_types). Static args
+ # for this primitive are (handle_id, dispatch_alignment, is_outer).
+ value_types = args[-2]
+ topk_idx_rank = len(value_types[0].shape)
+ in_axes = " ".join(f"L{i}" for i in range(topk_idx_rank - 1)) + " topk"
+ return f"{in_axes} -> EPL nle, EPL hm"
+
+
+register_primitive(EpPreparePrimitive)
+
+
+# ── ep_dispatch ─────────────────────────────────────────────────────────────
+
+
+class EpDispatchPrimitive(BasePrimitive):
+ name = "te_ep_dispatch_ffi"
+ multiple_results = True
+ impl_static_args = (4, 5, 6, 7) # handle_id, recv_capacity_per_rank, top_k, is_outer
+ inner_primitive = None
+ outer_primitive = None
+
+ @staticmethod
+ def abstract(
+ handle_mem_aval,
+ topk_idx_aval,
+ tokens_aval,
+ topk_weights_aval,
+ *,
+ handle_id,
+ recv_capacity_per_rank,
+ top_k,
+ is_outer,
+ ):
+ # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with
+ # no DP); False: per-shard = (1,).
+ del handle_id, topk_weights_aval, top_k, handle_mem_aval
+ assert (
+ len(tokens_aval.shape) >= 2
+ ), f"tokens must be at least 2D [..., H], got shape {tokens_aval.shape}"
+ recv_pr = recv_capacity_per_rank
+ tok_dtype = dtypes.canonicalize_dtype(tokens_aval.dtype)
+ hidden_dim = tokens_aval.shape[-1]
+ leading = _ep_leading_dims(is_outer)
+ recv_tokens_aval = jax.core.ShapedArray(leading + (recv_pr, hidden_dim), tok_dtype)
+ recv_topk_weights_aval = jax.core.ShapedArray(leading + (recv_pr,), jnp.float32)
+ # int32 with last dim doubled to keep the int64 byte count without JAX_ENABLE_X64.
+ workspace_shape = topk_idx_aval.shape[:-1] + (topk_idx_aval.shape[-1] * 2,)
+ workspace_aval = jax.core.ShapedArray(workspace_shape, jnp.int32)
+ return (recv_tokens_aval, recv_topk_weights_aval, workspace_aval)
+
+ @staticmethod
+ def outer_abstract(*args, **kwargs):
+ kwargs = dict(kwargs)
+ kwargs["is_outer"] = True
+ avals = EpDispatchPrimitive.abstract(*args, **kwargs)
+ return avals[:2]
+
+ @staticmethod
+ def lowering(
+ ctx,
+ handle_mem,
+ topk_idx,
+ tokens,
+ topk_weights,
+ *,
+ handle_id,
+ recv_capacity_per_rank,
+ top_k,
+ is_outer,
+ ):
+ del recv_capacity_per_rank, is_outer
+ return ffi.ffi_lowering(EpDispatchPrimitive.name)(
+ ctx,
+ handle_mem,
+ topk_idx,
+ tokens,
+ topk_weights,
+ handle_id=int(handle_id),
+ top_k=top_k,
+ )
+
+ @staticmethod
+ def impl(
+ handle_mem,
+ topk_idx,
+ tokens,
+ topk_weights,
+ handle_id,
+ recv_capacity_per_rank,
+ top_k,
+ is_outer,
+ ):
+ assert EpDispatchPrimitive.inner_primitive is not None
+ recv_tokens, recv_topk_weights, _workspace = EpDispatchPrimitive.inner_primitive.bind(
+ handle_mem,
+ topk_idx,
+ tokens,
+ topk_weights,
+ handle_id=handle_id,
+ recv_capacity_per_rank=recv_capacity_per_rank,
+ top_k=top_k,
+ is_outer=is_outer,
+ )
+ return recv_tokens, recv_topk_weights
+
+ @staticmethod
+ def batcher(batched_args, batch_dims, *, handle_id, recv_capacity_per_rank, top_k, is_outer):
+ raise NotImplementedError("EpDispatchPrimitive does not support vmap")
+
+ @staticmethod
+ def partition(
+ handle_id, recv_capacity_per_rank, top_k, is_outer, mesh, arg_infos, result_infos
+ ):
+ del is_outer, result_infos
+ gsr = global_mesh_resource()
+ ep_axis = gsr.ep_resource
+ outer_axes = _dispatch_input_outer_axes()
+ tokens_spec = arg_infos[2].sharding.spec
+ if not _leading_axis_ok(tokens_spec, ep_axis, outer_axes):
+ raise NotImplementedError(
+ "EpDispatch: tokens leading dims must shard on ep_resource"
+ f" ('{ep_axis}') and/or {outer_axes}, hidden dim replicated;"
+ f" got spec={tokens_spec}."
+ )
+ idx_spec = arg_infos[1].sharding.spec
+ tw_spec = arg_infos[3].sharding.spec
+ arg_shardings = (
+ arg_infos[0].sharding,
+ NamedSharding(mesh, _canonical_input_spec(idx_spec, len(arg_infos[1].shape))),
+ NamedSharding(mesh, _canonical_input_spec(tokens_spec, len(arg_infos[2].shape))),
+ NamedSharding(mesh, _canonical_input_spec(tw_spec, len(arg_infos[3].shape))),
+ )
+ out_shardings = (
+ NamedSharding(mesh, _ep_output_spec(None, None)),
+ NamedSharding(mesh, _ep_output_spec(None)),
+ )
+
+ def sharded_impl(handle_mem, topk_idx, tokens, topk_weights):
+ return EpDispatchPrimitive.impl(
+ handle_mem,
+ topk_idx,
+ tokens,
+ topk_weights,
+ handle_id,
+ recv_capacity_per_rank,
+ top_k,
+ False,
+ )
+
+ return mesh, sharded_impl, out_shardings, arg_shardings
+
+ @staticmethod
+ def shardy_sharding_rule(*args):
+ # Signature: (*static_args, mesh, value_types, result_types). Static args
+ # for this primitive are (handle_id, recv_capacity_per_rank, top_k, is_outer).
+ value_types = args[-2]
+ # Inputs: handle_mem, topk_idx, tokens, topk_weights.
+ idx_rank = len(value_types[1].shape)
+ tok_rank = len(value_types[2].shape)
+ tw_rank = len(value_types[3].shape)
+ idx_axes = " ".join(f"I{i}" for i in range(idx_rank - 1)) + " topk_in"
+ tok_axes = " ".join(f"T{i}" for i in range(tok_rank - 1)) + " H"
+ tw_axes = " ".join(f"W{i}" for i in range(tw_rank - 1)) + " topk"
+ return f"EPL hm, {idx_axes}, {tok_axes}, {tw_axes} -> EPL recv_pr H, EPL recv_pr"
+
+
+register_primitive(EpDispatchPrimitive)
+
+
+# ── ep_combine ──────────────────────────────────────────────────────────────
+# `expert_out` here is the post-weight buffer; ep.ep_combine applies the
+# hadamard before calling.
+
+
+def _normalize_leading_shape(s):
+ return s if isinstance(s, tuple) else (int(s),)
+
+
+def _prod(seq):
+ p = 1
+ for x in seq:
+ p *= int(x)
+ return p
+
+
+def _resolve_out_partition_spec(out_partition_spec, num_leading):
+ """Pick the combine output PartitionSpec.
+
+ Defaults to a compound leading axis ``(dp_resource, ep_resource)`` when a
+ DP/FSDP axis is set on the active MeshResource, else just ``ep_resource``.
+ This matches the input sharding so XLA does not need collective-permutes
+ in the bwd path.
+ """
+ if out_partition_spec is not None:
+ assert len(out_partition_spec) == num_leading + 1, (
+ f"out_partition_spec length {len(out_partition_spec)} must equal num_leading"
+ f" + 1 ({num_leading + 1})"
+ )
+ return tuple(out_partition_spec)
+ gsr = global_mesh_resource()
+ if gsr.ep_resource is None:
+ raise ValueError(
+ "ep_combine: ep_resource is not set on the active MeshResource;"
+ " pass out_sharding=... explicitly."
+ )
+ outer = gsr.dp_resource or gsr.fsdp_resource
+ leading = (outer, gsr.ep_resource) if outer is not None else gsr.ep_resource
+ return (leading,) + (None,) * num_leading
+
+
+def _per_shard_leading(out_leading_shape, resolved_spec, mesh):
+ """Per-shard leading shape given resolved partition spec and mesh."""
+ per_shard = list(out_leading_shape)
+ for i, ax in enumerate(resolved_spec[: len(out_leading_shape)]):
+ if ax is None:
+ continue
+ axes = ax if isinstance(ax, tuple) else (ax,)
+ factor = 1
+ for a in axes:
+ factor *= mesh.shape[a]
+ assert (
+ per_shard[i] % factor == 0
+ ), f"leading dim {per_shard[i]} not divisible by shard factor {factor} on axes {axes}"
+ per_shard[i] //= factor
+ return tuple(per_shard)
+
+
+class EpCombinePrimitive(BasePrimitive):
+ name = "te_ep_combine_ffi"
+ multiple_results = False
+ impl_static_args = (2, 3, 4) # handle_id, out_leading_shape, out_partition_spec
+ inner_primitive = None
+ outer_primitive = None
+
+ @staticmethod
+ def abstract(
+ handle_mem_aval,
+ expert_out_aval,
+ *,
+ handle_id,
+ out_leading_shape,
+ out_partition_spec,
+ ):
+ del handle_id, out_partition_spec, handle_mem_aval
+ assert (
+ len(expert_out_aval.shape) == 3
+ ), f"expert_out must be 3D [num_procs, recv_pr, H], got shape {expert_out_aval.shape}"
+ eo_dtype = dtypes.canonicalize_dtype(expert_out_aval.dtype)
+ hidden_dim = expert_out_aval.shape[-1]
+ out_shape = tuple(out_leading_shape) + (hidden_dim,)
+ return jax.core.ShapedArray(out_shape, eo_dtype)
+
+ @staticmethod
+ def lowering(
+ ctx,
+ handle_mem,
+ expert_out,
+ *,
+ handle_id,
+ out_leading_shape,
+ out_partition_spec,
+ ):
+ del out_partition_spec
+ return ffi.ffi_lowering(EpCombinePrimitive.name)(
+ ctx,
+ handle_mem,
+ expert_out,
+ handle_id=int(handle_id),
+ num_local_tokens=_prod(out_leading_shape),
+ )
+
+ @staticmethod
+ def impl(handle_mem, expert_out, handle_id, out_leading_shape, out_partition_spec):
+ assert EpCombinePrimitive.inner_primitive is not None
+ return EpCombinePrimitive.inner_primitive.bind(
+ handle_mem,
+ expert_out,
+ handle_id=handle_id,
+ out_leading_shape=out_leading_shape,
+ out_partition_spec=out_partition_spec,
+ )
+
+ @staticmethod
+ def batcher(batched_args, batch_dims, *, handle_id, out_leading_shape, out_partition_spec):
+ raise NotImplementedError("EpCombinePrimitive does not support vmap")
+
+ @staticmethod
+ def partition(handle_id, out_leading_shape, out_partition_spec, mesh, arg_infos, result_infos):
+ del result_infos
+ eo_spec = arg_infos[1].sharding.spec
+ if not _ep_spec_ok(eo_spec, trailing_count=2):
+ raise NotImplementedError(
+ "EpCombine: expert_out must be sharded as PartitionSpec(ep_resource,"
+ " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)"
+ f" over [num_procs, recv_pr, H]; got spec={eo_spec}."
+ )
+ resolved = _resolve_out_partition_spec(out_partition_spec, len(out_leading_shape))
+ per_shard_leading = _per_shard_leading(out_leading_shape, resolved, mesh)
+ arg_shardings = tuple(a.sharding for a in arg_infos)
+ out_sharding = NamedSharding(mesh, PartitionSpec(*resolved))
+
+ def sharded_impl(handle_mem, expert_out):
+ return EpCombinePrimitive.impl(
+ handle_mem, expert_out, handle_id, per_shard_leading, out_partition_spec
+ )
+
+ return mesh, sharded_impl, out_sharding, arg_shardings
+
+ @staticmethod
+ def shardy_sharding_rule(*args):
+ # Signature: (*static_args, mesh, value_types, result_types). Static args:
+ # (handle_id, out_leading_shape, out_partition_spec).
+ result_types = args[-1]
+ out_rank = len(result_types[0].shape)
+ out_axes = " ".join(f"O{i}" for i in range(out_rank - 1)) + " H"
+ return f"EPL hm, EPL recv_pr H -> {out_axes}"
+
+
+register_primitive(EpCombinePrimitive)
+
+
+# ── ep_dispatch_bwd ─────────────────────────────────────────────────────────
+
+
+class EpDispatchBwdPrimitive(BasePrimitive):
+ name = "te_ep_dispatch_bwd_ffi"
+ multiple_results = True
+ impl_static_args = (3, 4, 5, 6) # handle_id, top_k, out_leading_shape, out_partition_spec
+ inner_primitive = None
+ outer_primitive = None
+
+ @staticmethod
+ def abstract(
+ handle_mem_aval,
+ grad_aval,
+ g_recv_topk_weights_aval,
+ *,
+ handle_id,
+ top_k,
+ out_leading_shape,
+ out_partition_spec,
+ ):
+ del handle_id, g_recv_topk_weights_aval, out_partition_spec, handle_mem_aval
+ assert (
+ len(grad_aval.shape) == 3
+ ), f"grad must be 3D [num_procs, recv_pr, H], got shape {grad_aval.shape}"
+ g_dtype = dtypes.canonicalize_dtype(grad_aval.dtype)
+ hidden_dim = grad_aval.shape[-1]
+ result_aval = jax.core.ShapedArray(tuple(out_leading_shape) + (hidden_dim,), g_dtype)
+ grad_topk_weights_aval = jax.core.ShapedArray(
+ tuple(out_leading_shape) + (top_k,), jnp.float32
+ )
+ return result_aval, grad_topk_weights_aval
+
+ @staticmethod
+ def lowering(
+ ctx,
+ handle_mem,
+ grad,
+ g_recv_topk_weights,
+ *,
+ handle_id,
+ top_k,
+ out_leading_shape,
+ out_partition_spec,
+ ):
+ del out_partition_spec
+ return ffi.ffi_lowering(EpDispatchBwdPrimitive.name)(
+ ctx,
+ handle_mem,
+ grad,
+ g_recv_topk_weights,
+ handle_id=int(handle_id),
+ num_local_tokens=_prod(out_leading_shape),
+ top_k=top_k,
+ )
+
+ @staticmethod
+ def impl(
+ handle_mem,
+ grad,
+ g_recv_topk_weights,
+ handle_id,
+ top_k,
+ out_leading_shape,
+ out_partition_spec,
+ ):
+ assert EpDispatchBwdPrimitive.inner_primitive is not None
+ return EpDispatchBwdPrimitive.inner_primitive.bind(
+ handle_mem,
+ grad,
+ g_recv_topk_weights,
+ handle_id=handle_id,
+ top_k=top_k,
+ out_leading_shape=out_leading_shape,
+ out_partition_spec=out_partition_spec,
+ )
+
+ @staticmethod
+ def batcher(
+ batched_args,
+ batch_dims,
+ *,
+ handle_id,
+ top_k,
+ out_leading_shape,
+ out_partition_spec,
+ ):
+ raise NotImplementedError("EpDispatchBwdPrimitive does not support vmap")
+
+ @staticmethod
+ def partition(
+ handle_id,
+ top_k,
+ out_leading_shape,
+ out_partition_spec,
+ mesh,
+ arg_infos,
+ result_infos,
+ ):
+ del result_infos
+ g_spec = arg_infos[1].sharding.spec
+ if not _ep_spec_ok(g_spec, trailing_count=2):
+ raise NotImplementedError(
+ "EpDispatchBwd: grad must be sharded as PartitionSpec(ep_resource,"
+ " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)"
+ f" over [num_procs, recv_pr, H]; got spec={g_spec}."
+ )
+ gw_spec = arg_infos[2].sharding.spec
+ if not _ep_spec_ok(gw_spec, trailing_count=1):
+ raise NotImplementedError(
+ "EpDispatchBwd: g_recv_topk_weights must be sharded as"
+ " PartitionSpec(ep_resource, None) (or ((dp, ep), None) when dp/fsdp is set)"
+ f" over [num_procs, recv_pr]; got spec={gw_spec}."
+ )
+ resolved = _resolve_out_partition_spec(out_partition_spec, len(out_leading_shape))
+ per_shard_leading = _per_shard_leading(out_leading_shape, resolved, mesh)
+ arg_shardings = tuple(a.sharding for a in arg_infos)
+ out_shardings = [
+ NamedSharding(mesh, PartitionSpec(*resolved)),
+ NamedSharding(mesh, PartitionSpec(*resolved, None)),
+ ]
+
+ def sharded_impl(handle_mem, grad, g_recv_topk_weights):
+ return EpDispatchBwdPrimitive.impl(
+ handle_mem,
+ grad,
+ g_recv_topk_weights,
+ handle_id,
+ top_k,
+ per_shard_leading,
+ out_partition_spec,
+ )
+
+ return mesh, sharded_impl, out_shardings, arg_shardings
+
+ @staticmethod
+ def shardy_sharding_rule(*args):
+ # Signature: (*static_args, mesh, value_types, result_types). Result rank
+ # follows out_leading_shape (static arg #2): rank = len(out_leading) + 1.
+ result_types = args[-1]
+ out_rank = len(result_types[0].shape)
+ out_axes = " ".join(f"O{i}" for i in range(out_rank - 1))
+ return f"EPL hm, EPL recv_pr H, EPL recv_pr -> {out_axes} H, {out_axes} k"
+
+
+register_primitive(EpDispatchBwdPrimitive)
+
+
+# ── ep_combine_bwd ──────────────────────────────────────────────────────────
+
+
+class EpCombineBwdPrimitive(BasePrimitive):
+ name = "te_ep_combine_bwd_ffi"
+ multiple_results = False
+ impl_static_args = (2, 3, 4) # handle_id, recv_capacity_per_rank, is_outer
+ inner_primitive = None
+ outer_primitive = None
+
+ @staticmethod
+ def abstract(handle_mem_aval, grad_aval, *, handle_id, recv_capacity_per_rank, is_outer):
+ # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with
+ # no DP); False: per-shard = (1,).
+ del handle_id, handle_mem_aval
+ assert (
+ len(grad_aval.shape) >= 2
+ ), f"grad must be at least 2D [..., H], got shape {grad_aval.shape}"
+ g_dtype = dtypes.canonicalize_dtype(grad_aval.dtype)
+ hidden_dim = grad_aval.shape[-1]
+ leading = _ep_leading_dims(is_outer)
+ return jax.core.ShapedArray(leading + (recv_capacity_per_rank, hidden_dim), g_dtype)
+
+ @staticmethod
+ def outer_abstract(*args, **kwargs):
+ kwargs = dict(kwargs)
+ kwargs["is_outer"] = True
+ return EpCombineBwdPrimitive.abstract(*args, **kwargs)
+
+ @staticmethod
+ def lowering(ctx, handle_mem, grad, *, handle_id, recv_capacity_per_rank, is_outer):
+ del recv_capacity_per_rank, is_outer
+ return ffi.ffi_lowering(EpCombineBwdPrimitive.name)(
+ ctx,
+ handle_mem,
+ grad,
+ handle_id=int(handle_id),
+ )
+
+ @staticmethod
+ def impl(handle_mem, grad, handle_id, recv_capacity_per_rank, is_outer):
+ assert EpCombineBwdPrimitive.inner_primitive is not None
+ return EpCombineBwdPrimitive.inner_primitive.bind(
+ handle_mem,
+ grad,
+ handle_id=handle_id,
+ recv_capacity_per_rank=recv_capacity_per_rank,
+ is_outer=is_outer,
+ )
+
+ @staticmethod
+ def batcher(batched_args, batch_dims, *, handle_id, recv_capacity_per_rank, is_outer):
+ raise NotImplementedError("EpCombineBwdPrimitive does not support vmap")
+
+ @staticmethod
+ def partition(handle_id, recv_capacity_per_rank, is_outer, mesh, arg_infos, result_infos):
+ del is_outer, result_infos
+ arg_shardings = tuple(a.sharding for a in arg_infos)
+ out_sharding = NamedSharding(mesh, _ep_output_spec(None, None))
+
+ def sharded_impl(handle_mem, grad):
+ return EpCombineBwdPrimitive.impl(
+ handle_mem, grad, handle_id, recv_capacity_per_rank, False
+ )
+
+ return mesh, sharded_impl, out_sharding, arg_shardings
+
+ @staticmethod
+ def shardy_sharding_rule(*args):
+ # T axes are dynamic-rank based on the actual cotangent shape.
+ value_types = args[-2]
+ g_rank = len(value_types[1].shape)
+ g_axes = " ".join(f"T{i}" for i in range(g_rank - 1)) + " H"
+ return f"EPL hm, {g_axes} -> EPL recv_pr H"
+
+
+register_primitive(EpCombineBwdPrimitive)
+
+
+# ── Public-ish helpers (used by jax/ep.py) ──────────────────────────────────
+
+
+@compute_on("gpu_stream:collective")
+def ep_prepare(topk_idx, handle):
+ """Exchange routing metadata for ``handle``; return ``(token_counts, handle_mem)``."""
+ return EpPreparePrimitive.outer_primitive.bind(
+ topk_idx,
+ handle_id=handle.handle_id,
+ dispatch_output_per_expert_alignment=handle.dispatch_output_per_expert_alignment,
+ is_outer=True,
+ )
+
+
+@compute_on("gpu_stream:collective")
+def ep_dispatch_fwd(handle, handle_mem, topk_idx, tokens, topk_weights, recv_capacity_per_rank):
+ """Scatter tokens and weights to expert ranks; returns (recv_tokens, recv_topk_weights)."""
+ top_k = int(topk_weights.shape[-1])
+ return EpDispatchPrimitive.outer_primitive.bind(
+ handle_mem,
+ topk_idx,
+ tokens,
+ topk_weights,
+ handle_id=handle.handle_id,
+ recv_capacity_per_rank=recv_capacity_per_rank,
+ top_k=top_k,
+ is_outer=True,
+ )
+
+
+@compute_on("gpu_stream:collective")
+def ep_combine_fwd(handle, handle_mem, expert_out, num_local_tokens, out_partition_spec=None):
+ """Gather expert outputs back to home ranks. expert_out is pre-weighted."""
+ out_leading = _normalize_leading_shape(num_local_tokens)
+ return EpCombinePrimitive.outer_primitive.bind(
+ handle_mem,
+ expert_out,
+ handle_id=handle.handle_id,
+ out_leading_shape=out_leading,
+ out_partition_spec=out_partition_spec,
+ )
+
+
+@compute_on("gpu_stream:collective")
+def ep_dispatch_bwd(
+ handle,
+ handle_mem,
+ grad,
+ g_recv_topk_weights,
+ top_k,
+ num_local_tokens,
+ out_partition_spec=None,
+):
+ """Backward of dispatch; returns (grad_tokens, grad_topk_weights)."""
+ out_leading = _normalize_leading_shape(num_local_tokens)
+ return EpDispatchBwdPrimitive.outer_primitive.bind(
+ handle_mem,
+ grad,
+ g_recv_topk_weights,
+ handle_id=handle.handle_id,
+ top_k=int(top_k),
+ out_leading_shape=out_leading,
+ out_partition_spec=out_partition_spec,
+ )
+
+
+@compute_on("gpu_stream:collective")
+def ep_combine_bwd(handle, handle_mem, grad, recv_capacity_per_rank):
+ """Backward of combine; returns grad_expert_out [num_procs, recv_capacity_per_rank, H]."""
+ return EpCombineBwdPrimitive.outer_primitive.bind(
+ handle_mem,
+ grad,
+ handle_id=handle.handle_id,
+ recv_capacity_per_rank=recv_capacity_per_rank,
+ is_outer=True,
+ )
diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py
index 3245439689..8cc94fcaaf 100644
--- a/transformer_engine/jax/cpp_extensions/router.py
+++ b/transformer_engine/jax/cpp_extensions/router.py
@@ -412,7 +412,7 @@ def partition(
arg_infos,
result_infos,
):
- del result_infos, routing_map_format
+ del result_infos
grad_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding(mesh, PartitionSpec(*grad_spec))
arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding, arg_infos[2].sharding)
diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h
index dcf2b193b1..a0c2dca94b 100644
--- a/transformer_engine/jax/csrc/extensions.h
+++ b/transformer_engine/jax/csrc/extensions.h
@@ -200,6 +200,27 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler);
+// Bootstrap EP (eager NCCL comm init); anchor released by ReleaseEpResources.
+void SetEpBootstrapParams(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group,
+ int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank,
+ int hidden_dim, int max_num_sms, int allow_handle_mem_reloc,
+ int max_token_dtype);
+void ReleaseEpResources();
+// Register an EP layer; returns (handle_id, handle_mem_size).
+pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment);
+
+// EpInstanceState type_id / type_info capsules for jax.ffi.register_ffi_type.
+pybind11::capsule GetEpInstanceStateTypeIdCapsule();
+pybind11::capsule GetEpInstanceStateTypeInfoCapsule();
+
+// EP FFI handlers
+XLA_FFI_DECLARE_HANDLER_SYMBOL(EpInstantiateHandler);
+XLA_FFI_DECLARE_HANDLER_SYMBOL(EpPrepareHandler);
+XLA_FFI_DECLARE_HANDLER_SYMBOL(EpDispatchHandler);
+XLA_FFI_DECLARE_HANDLER_SYMBOL(EpCombineHandler);
+XLA_FFI_DECLARE_HANDLER_SYMBOL(EpDispatchBwdHandler);
+XLA_FFI_DECLARE_HANDLER_SYMBOL(EpCombineBwdHandler);
+
// TopK
XLA_FFI_DECLARE_HANDLER_SYMBOL(TopkHandler);
pybind11::tuple GetTopkWorkspaceSizes(int batch_size, int seq_len, int k);
diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp
new file mode 100644
index 0000000000..84f24d75bf
--- /dev/null
+++ b/transformer_engine/jax/csrc/extensions/ep.cpp
@@ -0,0 +1,543 @@
+/*************************************************************************
+ * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#ifdef NVTE_WITH_NCCL_EP
+
+#include "transformer_engine/ep.h"
+
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+#include "../extensions.h"
+#include "common.h"
+#include "transformer_engine/gemm.h"
+
+namespace transformer_engine {
+namespace jax {
+
+// NCCL comm + EPBackend lifetime tracks live JAX executables via XLA stateful FFI.
+
+struct EpBootstrapParams {
+ std::array uid_bytes{};
+ int ep_size = 0;
+ int rank_within_group = 0;
+ int num_experts = 0;
+ int max_tokens_per_rank = 0;
+ int max_recv_tokens_per_rank = 0;
+ int hidden_dim = 0;
+ int max_num_sms = 0;
+ int allow_handle_mem_reloc = 0;
+ int max_token_dtype = 0;
+};
+
+class EpResources {
+ public:
+ explicit EpResources(const EpBootstrapParams& p) {
+ ncclUniqueId uid;
+ std::memcpy(&uid, p.uid_bytes.data(), sizeof(uid));
+ NVTE_CHECK_NCCL(ncclCommInitRank(&comm_, p.ep_size, uid, p.rank_within_group));
+ NVTEEpGroupConfig cfg{.ep_size = p.ep_size,
+ .num_experts = p.num_experts,
+ .max_tokens_per_rank = p.max_tokens_per_rank,
+ .max_recv_tokens_per_rank = p.max_recv_tokens_per_rank,
+ .hidden_dim = p.hidden_dim,
+ .max_num_sms = p.max_num_sms,
+ .allow_handle_mem_reloc = p.allow_handle_mem_reloc,
+ .max_token_dtype = static_cast(p.max_token_dtype)};
+ try {
+ nvte_ep_initialize(static_cast(comm_), cfg);
+ } catch (...) {
+ ncclCommDestroy(comm_);
+ comm_ = nullptr;
+ throw;
+ }
+ }
+
+ ~EpResources() {
+ if (comm_ == nullptr) return;
+ nvte_ep_shutdown();
+ ncclCommDestroy(comm_);
+ }
+
+ EpResources(const EpResources&) = delete;
+ EpResources& operator=(const EpResources&) = delete;
+
+ ncclComm_t comm() const { return comm_; }
+
+ private:
+ ncclComm_t comm_{nullptr};
+};
+
+struct EpInstanceState {
+ static ::xla::ffi::TypeId id;
+ static ::xla::ffi::TypeInfo info;
+ std::shared_ptr resources;
+};
+
+::xla::ffi::TypeId EpInstanceState::id = {};
+::xla::ffi::TypeInfo EpInstanceState::info = ::xla::ffi::MakeTypeInfo();
+
+namespace {
+
+std::mutex g_ep_mu;
+EpBootstrapParams g_ep_params;
+bool g_ep_params_set = false;
+std::weak_ptr g_ep_resources_weak;
+// Python-held anchor so trace-time ep_register_layer finds EPBackend ready.
+std::shared_ptr g_ep_resources_anchor;
+
+std::shared_ptr AcquireEpResources() {
+ std::lock_guard lock(g_ep_mu);
+ NVTE_CHECK(g_ep_params_set,
+ "EP bootstrap params not set; call transformer_engine_jax."
+ "set_ep_bootstrap_params() (typically via ep_bootstrap) first.");
+ auto sp = g_ep_resources_weak.lock();
+ if (sp) return sp;
+ sp = std::make_shared(g_ep_params);
+ g_ep_resources_weak = sp;
+ return sp;
+}
+
+} // namespace
+
+// handle_id is baked at jit trace time and carried as a static FFI attribute.
+
+struct EpPrepareConfig {
+ int64_t handle_id;
+ int64_t dispatch_output_per_expert_alignment;
+};
+
+struct EpDispatchConfig {
+ int64_t handle_id;
+ int64_t top_k;
+};
+
+struct EpCombineConfig {
+ int64_t handle_id;
+ int64_t num_local_tokens;
+};
+
+struct EpDispatchBwdConfig {
+ int64_t handle_id;
+ int64_t num_local_tokens;
+ int64_t top_k;
+};
+
+struct EpCombineBwdConfig {
+ int64_t handle_id;
+};
+
+// ── Bootstrap helpers ─────────────────────────────────────────────────────────
+
+// Caches uid + group config and eagerly creates the NCCL comm (ranks
+// synchronize via the UID broadcast).
+void SetEpBootstrapParams(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group,
+ int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank,
+ int hidden_dim, int max_num_sms, int allow_handle_mem_reloc,
+ int max_token_dtype) {
+ std::string uid_str = unique_id_bytes_obj;
+ NVTE_CHECK(static_cast(uid_str.size()) >= 128,
+ "unique_id_bytes must be at least 128 bytes (ncclUniqueId size).");
+ std::shared_ptr anchor;
+ {
+ std::lock_guard lock(g_ep_mu);
+ NVTE_CHECK(!g_ep_resources_anchor,
+ "EP bootstrap already initialized; call release_ep_resources() before re-init.");
+ std::memcpy(g_ep_params.uid_bytes.data(), uid_str.data(), 128);
+ g_ep_params.ep_size = ep_size;
+ g_ep_params.rank_within_group = rank_within_group;
+ g_ep_params.num_experts = num_experts;
+ g_ep_params.max_tokens_per_rank = max_tokens_per_rank;
+ g_ep_params.max_recv_tokens_per_rank = max_recv_tokens_per_rank;
+ g_ep_params.hidden_dim = hidden_dim;
+ g_ep_params.max_num_sms = max_num_sms;
+ g_ep_params.allow_handle_mem_reloc = allow_handle_mem_reloc;
+ g_ep_params.max_token_dtype = max_token_dtype;
+ g_ep_params_set = true;
+ }
+ // Acquire outside the lock: EpResources ctor runs ncclCommInitRank which is
+ // a collective and may block on peer ranks.
+ anchor = AcquireEpResources();
+ std::lock_guard lock(g_ep_mu);
+ g_ep_resources_anchor = std::move(anchor);
+}
+
+// Drops the anchor; comm tears down once the last executable also releases.
+void ReleaseEpResources() {
+ std::shared_ptr to_drop;
+ {
+ std::lock_guard lock(g_ep_mu);
+ to_drop = std::move(g_ep_resources_anchor);
+ }
+ // to_drop dtor runs outside the lock.
+}
+
+pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment) {
+ NVTEEpLayerConfig layer_cfg{0, top_k, dispatch_output_per_expert_alignment};
+ size_t handle_mem_size = 0;
+ uint64_t handle_id = nvte_ep_register_layer(layer_cfg, &handle_mem_size);
+ return pybind11::make_tuple(handle_id, handle_mem_size);
+}
+
+pybind11::capsule GetEpInstanceStateTypeIdCapsule() {
+ return pybind11::capsule(static_cast(&EpInstanceState::id), "xla.ffi.type_id");
+}
+
+pybind11::capsule GetEpInstanceStateTypeInfoCapsule() {
+ return pybind11::capsule(static_cast(&EpInstanceState::info), "xla.ffi.type_info");
+}
+
+// ── Instantiate handler ─────────────────────────────────────────────────────
+
+static ::xla::ffi::ErrorOr> EpInstantiateImpl() {
+ auto state = std::make_unique();
+ try {
+ state->resources = AcquireEpResources();
+ } catch (const std::exception& e) {
+ return ::xla::ffi::Unexpected(
+ ::xla::ffi::Error::Internal(std::string("EP instantiate failed: ") + e.what()));
+ }
+ return state;
+}
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(EpInstantiateHandler, EpInstantiateImpl, FFI::BindInstantiate());
+
+// ── ep_prepare ────────────────────────────────────────────────────────────────
+
+Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type topk_idx,
+ Result_Type token_counts, Result_Type handle_mem, Result_Type workspace,
+ EpPrepareConfig config) {
+ (void)ep_state; // lifetime only.
+ auto topk_dims = topk_idx.dimensions();
+ NVTE_CHECK(topk_dims.size() >= 2,
+ "topk_idx must be at least 2D [..., top_k], got ndim=", topk_dims.size());
+ auto idx_etype = topk_idx.element_type();
+ NVTE_CHECK(idx_etype == ::xla::ffi::DataType::S64 || idx_etype == ::xla::ffi::DataType::S32,
+ "topk_idx must be int32 or int64; got element_type=", static_cast(idx_etype));
+
+ std::vector topk_shape = {product(topk_dims, 0, topk_dims.size() - 1),
+ static_cast(topk_dims.back())};
+ // NCCL EP currently requires int64 topk_idx; upcast int32 on-stream.
+ // TODO(phuong): drop once NCCL EP accepts int32.
+ void* topk_idx_data = topk_idx.untyped_data();
+ if (idx_etype == ::xla::ffi::DataType::S32) {
+ const size_t n = topk_shape[0] * topk_shape[1];
+ NVTE_CHECK(static_cast(workspace->element_count()) >= n,
+ "workspace too small for int32 → int64 upcast: element_count=",
+ workspace->element_count(), " < required ", n);
+ int64_t* ws = reinterpret_cast(workspace->untyped_data());
+ nvte_convert_int32_to_int64(reinterpret_cast(topk_idx_data), ws, n, stream);
+ topk_idx_data = ws;
+ }
+ auto topk_idx_ = TensorWrapper(topk_idx_data, topk_shape, DType::kInt64);
+
+ std::vector tc_shape = {static_cast(token_counts->element_count())};
+ auto token_counts_ = TensorWrapper(token_counts->untyped_data(), tc_shape, DType::kInt32);
+
+ std::vector hm_shape = {static_cast(handle_mem->element_count())};
+ auto handle_mem_ = TensorWrapper(handle_mem->untyped_data(), hm_shape, DType::kByte);
+
+ NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()};
+ nvte_ep_prepare(handle, topk_idx_.data(), token_counts_.data(),
+ static_cast(config.dispatch_output_per_expert_alignment), stream);
+ return ffi_with_cuda_error_check();
+}
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(EpPrepareHandler, EpPrepareFFI,
+ FFI::Bind()
+ .Ctx() // stream
+ .Ctx<::xla::ffi::State>() // EP state
+ .Arg() // topk_idx
+ .Ret() // token_counts
+ .Ret() // handle_mem
+ .Ret() // workspace (FFI scratch)
+ .Attrs(),
+ FFI_CudaGraph_Traits);
+
+// ── ep_dispatch ───────────────────────────────────────────────────────────────
+
+Error_Type EpDispatchFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem,
+ Buffer_Type topk_idx, Buffer_Type tokens, Buffer_Type topk_weights,
+ Result_Type recv_tokens, Result_Type recv_topk_weights,
+ Result_Type workspace, EpDispatchConfig config) {
+ (void)ep_state;
+ auto token_dims = tokens.dimensions();
+ NVTE_CHECK(token_dims.size() >= 2,
+ "tokens must be at least 2D [..., H], got ndim=", token_dims.size());
+
+ std::vector hm_shape = {static_cast(handle_mem.element_count())};
+ auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte);
+
+ auto idx_dims = topk_idx.dimensions();
+ NVTE_CHECK(idx_dims.size() >= 2,
+ "topk_idx must be at least 2D [..., top_k], got ndim=", idx_dims.size());
+ auto idx_etype = topk_idx.element_type();
+ NVTE_CHECK(idx_etype == ::xla::ffi::DataType::S64 || idx_etype == ::xla::ffi::DataType::S32,
+ "topk_idx must be int32 or int64; got element_type=", static_cast(idx_etype));
+ NVTE_CHECK(static_cast(idx_dims.back()) == config.top_k, "top_k attr (", config.top_k,
+ ") must match topk_idx last dim (", idx_dims.back(), ")");
+ std::vector idx_shape = {product(idx_dims, 0, idx_dims.size() - 1),
+ static_cast(idx_dims.back())};
+ // NCCL EP currently requires int64 topk_idx; upcast int32 on-stream.
+ // TODO(phuong): drop once NCCL EP accepts int32.
+ void* topk_idx_data = topk_idx.untyped_data();
+ if (idx_etype == ::xla::ffi::DataType::S32) {
+ const size_t n = idx_shape[0] * idx_shape[1];
+ NVTE_CHECK(static_cast(workspace->element_count()) >= n,
+ "workspace too small for int32 → int64 upcast: element_count=",
+ workspace->element_count(), " < required ", n);
+ int64_t* ws = reinterpret_cast(workspace->untyped_data());
+ nvte_convert_int32_to_int64(reinterpret_cast(topk_idx_data), ws, n, stream);
+ topk_idx_data = ws;
+ }
+ auto topk_idx_ = TensorWrapper(topk_idx_data, idx_shape, DType::kInt64);
+
+ const size_t T_flat = product(token_dims, 0, token_dims.size() - 1);
+ const size_t H = static_cast(token_dims.back());
+ std::vector tok_shape = {T_flat, H};
+ auto token_dtype = convert_ffi_datatype_to_te_dtype(tokens.element_type());
+ auto tokens_ = TensorWrapper(tokens.untyped_data(), tok_shape, token_dtype);
+
+ auto tw_dims = topk_weights.dimensions();
+ NVTE_CHECK(tw_dims.size() >= 2,
+ "topk_weights must be at least 2D [..., top_k], got ndim=", tw_dims.size());
+ std::vector tw_shape = {product(tw_dims, 0, tw_dims.size() - 1),
+ static_cast(tw_dims.back())};
+ auto topk_weights_ = TensorWrapper(topk_weights.untyped_data(), tw_shape, DType::kFloat32);
+
+ // recv_tokens: flatten any leading dims into recv_capacity_per_rank.
+ auto recv_dims = recv_tokens->dimensions();
+ NVTE_CHECK(recv_dims.size() >= 2,
+ "recv_tokens must be at least 2D [..., recv_pr, H]; got ndim=", recv_dims.size());
+ const size_t recv_capacity_per_rank = product(recv_dims, 0, recv_dims.size() - 1);
+ std::vector recv_shape = {recv_capacity_per_rank, H};
+ auto recv_tokens_ = TensorWrapper(recv_tokens->untyped_data(), recv_shape, token_dtype);
+
+ auto recv_w_dims = recv_topk_weights->dimensions();
+ NVTE_CHECK(recv_w_dims.size() >= 1,
+ "recv_topk_weights must be at least 1D; got ndim=", recv_w_dims.size());
+ const size_t recv_w_total = product(recv_w_dims, 0, recv_w_dims.size());
+ NVTE_CHECK(recv_w_total == recv_capacity_per_rank, "recv_topk_weights total (", recv_w_total,
+ ") must match recv_tokens recv_pr (", recv_capacity_per_rank, ")");
+ std::vector recv_w_shape = {recv_capacity_per_rank};
+ auto recv_topk_weights_ =
+ TensorWrapper(recv_topk_weights->untyped_data(), recv_w_shape, DType::kFloat32);
+
+ NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()};
+ NVTECommWindow no_win{nullptr, 0};
+ nvte_ep_dispatch(handle, topk_idx_.data(), tokens_.data(), no_win, topk_weights_.data(), no_win,
+ recv_tokens_.data(), no_win, recv_topk_weights_.data(), no_win, stream);
+
+ return ffi_with_cuda_error_check();
+}
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchHandler, EpDispatchFFI,
+ FFI::Bind()
+ .Ctx() // stream
+ .Ctx<::xla::ffi::State>() // EP state
+ .Arg() // handle_mem
+ .Arg() // topk_idx
+ .Arg() // tokens
+ .Arg() // topk_weights
+ .Ret() // recv_tokens
+ .Ret() // recv_topk_weights
+ .Ret() // workspace (FFI scratch)
+ .Attrs(),
+ FFI_CudaGraph_Traits);
+
+// ── ep_combine ────────────────────────────────────────────────────────────────
+
+Error_Type EpCombineFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem,
+ Buffer_Type expert_out, Result_Type result, EpCombineConfig config) {
+ (void)ep_state;
+ auto eo_dims = expert_out.dimensions();
+ NVTE_CHECK(eo_dims.size() >= 2,
+ "expert_out must be at least 2D [..., recv_pr, H]; got ndim=", eo_dims.size());
+
+ std::vector hm_shape = {static_cast(handle_mem.element_count())};
+ auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte);
+
+ const size_t recv_capacity_per_rank = product(eo_dims, 0, eo_dims.size() - 1);
+ const size_t H = static_cast(eo_dims.back());
+ std::vector eo_shape = {recv_capacity_per_rank, H};
+ auto eo_dtype = convert_ffi_datatype_to_te_dtype(expert_out.element_type());
+ auto expert_out_ = TensorWrapper(expert_out.untyped_data(), eo_shape, eo_dtype);
+
+ auto res_dims = result->dimensions();
+ NVTE_CHECK(res_dims.size() >= 2,
+ "result must be at least 2D [..., H]; got ndim=", res_dims.size());
+ const size_t res_T_flat = product(res_dims, 0, res_dims.size() - 1);
+ NVTE_CHECK(static_cast(res_T_flat) == config.num_local_tokens,
+ "result leading-dim product (", res_T_flat, ") must equal num_local_tokens (",
+ config.num_local_tokens, ")");
+ std::vector