Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1ebffa3
Expert Parallelism: common C API + NCCL EP v0.1 backend
phu0ngng May 22, 2026
0b9bf7e
Expert Parallelism: persistent ncclEpHandle cache with allow_handle_m…
phu0ngng May 23, 2026
ed3d73c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2026
1923180
Build: NCCL_HOME discovery supports Debian/Ubuntu multiarch lib paths
phu0ngng May 27, 2026
3b8aafb
bump NCCL
phu0ngng May 27, 2026
9b225cb
Expert Parallelism: require token_dtype in NVTEEpGroupConfig and enfo…
phu0ngng May 28, 2026
03e56d2
Expert Parallelism: document ep_comm lifetime, v0.1 single-GPU scope,…
phu0ngng May 28, 2026
4cefdcb
Expert Parallelism: drop version label from initialize scope note
phu0ngng May 28, 2026
d101896
Expert Parallelism: JAX bindings (FFI, custom_vjp, multi-process test…
phu0ngng May 22, 2026
b43710e
JAX EP: tie NCCL comm lifetime to JAX executables via XLA stateful FFI
phu0ngng May 23, 2026
cb44374
JAX EP: expose allow_handle_mem_reloc as opt-in ep_bootstrap parameter
phu0ngng May 23, 2026
2012b0a
jax/ep: decorate EP ops with @compute_on("gpu_stream:collective")
phu0ngng May 28, 2026
c04bebb
ep_bootstrap: add XLA-collective fallback for UID allgather
phu0ngng May 28, 2026
1415580
jax/ep: introduce per-layer EpHandle, drop callsite-frame handle_id c…
phu0ngng May 29, 2026
0eee8b8
[JAX] EP: wire NVTEEpGroupConfig.max_token_dtype through bootstrap
tdophung Jun 3, 2026
10f4b1c
[JAX] MoE: enforce (outer_dp, ep) ordering for TE EP compatibility
tdophung Jun 2, 2026
9194fe3
integrate tex.* calls, remove all ragged-a2a + triton/pure jax step b…
tdophung Jun 3, 2026
776c5ef
[JAX] MoE: bootstrap TE EP eagerly outside jit; assert compatibility …
tdophung Jun 3, 2026
acb610f
[JAX] MoE: thread EpHandle + handle_mem through dispatch / combine
tdophung Jun 3, 2026
458d1c4
[JAX] MoE: pass bf16 as max_token_dtype to test fixture's ep_bootstrap
tdophung Jun 3, 2026
3d6825c
patching the sharding stripped by flattening logits input to topk, w…
tdophung Jun 4, 2026
0236467
MoEBlock tutorial
jberchtold-nvidia Jun 4, 2026
e748567
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 4, 2026
3e6f958
Update docs/examples/jax/moe.py
jberchtold-nvidia Jun 4, 2026
566c5fd
Update docs/examples/jax/moe.py
jberchtold-nvidia Jun 4, 2026
96831df
Update docs/examples/jax/test_moe.py
jberchtold-nvidia Jun 4, 2026
313bbf4
Diagrams
jberchtold-nvidia Jun 5, 2026
00145f7
Merge branch 'teddy/te_ep_integration' of github.com:tdophung/Transfo…
jberchtold-nvidia Jun 5, 2026
d8def79
Merge in TE EP and update tutorial accordingly
jberchtold-nvidia Jun 5, 2026
82826a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "3rdparty/nccl"]
path = 3rdparty/nccl
url = https://github.com/NVIDIA/nccl.git
branch = v2.30u1
1 change: 1 addition & 0 deletions 3rdparty/nccl
Submodule nccl added at 146496
42 changes: 40 additions & 2 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,54 @@ def setup_jax_extension(

setup_mpi_flags(include_dirs, cxx_flags)

libraries = []
submod_lib_dir = None
submod_nccl_inc = None

if bool(int(os.getenv("NVTE_WITH_CUBLASMP", 0))):
cxx_flags.append("-DNVTE_WITH_CUBLASMP")

# NCCL EP is on by default. Set NVTE_BUILD_WITH_NCCL_EP=0 to skip it.
build_with_nccl_ep = bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1")))
if build_with_nccl_ep:
cxx_flags.append("-DNVTE_WITH_NCCL_EP")
# Headers + libs come from the in-tree 3rdparty/nccl submodule build
# (auto-produced by setup.py).
libraries = ["nccl", "nccl_ep"]
# NCCL EP requires SM>=90 (Hopper+).
archs_env = os.getenv("NVTE_CUDA_ARCHS", "")
for a in archs_env.split(";"):
a_num = "".join(c for c in a if c.isdigit())
if a_num and int(a_num) < 90:
raise RuntimeError(
f"NCCL EP requires CUDA arch >= 90 (Hopper or newer); got '{a}' in"
" NVTE_CUDA_ARCHS."
)
submod_root = (common_header_files / ".." / "3rdparty" / "nccl").resolve()
submod_ep_inc = submod_root / "contrib" / "nccl_ep" / "include"
if not (submod_ep_inc / "nccl_ep.h").exists():
raise RuntimeError(
f"NCCL EP header not found at {submod_ep_inc}/nccl_ep.h. "
"Run `git submodule update --init --recursive` to checkout 3rdparty/nccl."
)
include_dirs.append(submod_ep_inc)
submod_lib_dir = submod_root / "build" / "lib"
submod_nccl_inc = submod_root / "build" / "include"

# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension

return Pybind11Extension(
ext = Pybind11Extension(
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args=cxx_flags,
libraries=["nccl"],
libraries=libraries,
)
if submod_lib_dir is not None:
ext.library_dirs.append(str(submod_lib_dir))
ext.runtime_library_dirs.append(str(submod_lib_dir))
# Prefer submodule's nccl.h when present (matches the C++ side).
if (submod_nccl_inc / "nccl.h").exists():
ext.include_dirs.insert(0, str(submod_nccl_inc))
return ext
43 changes: 43 additions & 0 deletions docs/examples/jax/media/jax_moe_native_vs_te_flow.drawio
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
<mxfile host="app.diagrams.net" modified="2026-06-04T00:00:00.000Z" agent="Codex" version="24.7.17" type="device">
<diagram id="jax-moe-native-vs-te-flow" name="JAX MoE native vs TE flow">
<mxGraphModel dx="1500" dy="760" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="1500" pageHeight="760" math="0" shadow="0">
<root>
<mxCell id="0"/>
<mxCell id="1" parent="0"/>
<mxCell id="title" value="JAX MoE forward data flow, native block vs TE block" style="text;html=1;strokeColor=none;fillColor=none;fontSize=28;fontStyle=1;fontColor=#111827;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="250" y="20" width="1000" height="40" as="geometry"/></mxCell>
<mxCell id="subtitle" value="Simplified view: router -&amp;gt; dispatch -&amp;gt; expert compute -&amp;gt; combine" style="text;html=1;strokeColor=none;fillColor=none;fontSize=15;fontColor=#475569;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="250" y="54" width="1000" height="28" as="geometry"/></mxCell>
<mxCell id="native_lane" value="" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f8fafc;strokeColor=#cbd5e1;" vertex="1" parent="1"><mxGeometry x="50" y="92" width="655" height="590" as="geometry"/></mxCell>
<mxCell id="te_lane" value="" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f8fafc;strokeColor=#cbd5e1;" vertex="1" parent="1"><mxGeometry x="795" y="92" width="655" height="590" as="geometry"/></mxCell>
<mxCell id="native_header" value="Native JAX BF16 EP MoE&lt;br/&gt;&lt;font style=&quot;font-size: 14px&quot;&gt;JAX router, ragged collectives, fused ragged_dot FFN&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;fontSize=23;fontStyle=1;fontColor=#0f172a;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="80" y="110" width="595" height="52" as="geometry"/></mxCell>
<mxCell id="te_header" value="TE _MoEBlock BF16&lt;br/&gt;&lt;font style=&quot;font-size: 14px&quot;&gt;TE router, NCCL EP dispatch/combine, grouped GEMM FFN&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;fontSize=23;fontStyle=1;fontColor=#0f172a;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="825" y="110" width="595" height="52" as="geometry"/></mxCell>
<mxCell id="te_vjp" value="" style="rounded=1;whiteSpace=wrap;html=1;fillColor=none;strokeColor=#0f766e;dashed=1;dashPattern=8 6;" vertex="1" parent="1"><mxGeometry x="825" y="214" width="595" height="444" as="geometry"/></mxCell>
<mxCell id="te_vjp_label" value="single TE MoE custom_vjp boundary" style="text;html=1;strokeColor=none;fillColor=none;fontSize=14;fontStyle=1;fontColor=#0f766e;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="970" y="660" width="250" height="24" as="geometry"/></mxCell>

<mxCell id="n0" value="Input shard&lt;br/&gt;&lt;font style=&quot;font-size: 14px&quot;&gt;x [B,S,H], expert weights sharded over ep&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#ffffff;strokeColor=#64748b;fontSize=18;fontStyle=1;fontColor=#0f172a;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="92" y="170" width="570" height="62" as="geometry"/></mxCell>
<mxCell id="n1" value="Router&lt;br/&gt;&lt;font style=&quot;font-size: 14px&quot;&gt;gate GEMM, softmax, top-k experts and routing weights&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e0f2fe;strokeColor=#0284c7;fontSize=18;fontStyle=1;fontColor=#0f172a;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="92" y="260" width="570" height="72" as="geometry"/></mxCell>
<mxCell id="n2" value="Dispatch&lt;br/&gt;&lt;font style=&quot;font-size: 14px&quot;&gt;sort routes, gather counts, ragged_all_to_all, local reorder&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#fef3c7;strokeColor=#d97706;fontSize=18;fontStyle=1;fontColor=#0f172a;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="92" y="360" width="570" height="76" as="geometry"/></mxCell>
<mxCell id="n3" value="Expert FFN&lt;br/&gt;&lt;font style=&quot;font-size: 14px&quot;&gt;ragged_dot(wi_0|wi_1), activation, ragged_dot(wo)&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#dcfce7;strokeColor=#16a34a;fontSize=18;fontStyle=1;fontColor=#0f172a;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="92" y="464" width="570" height="76" as="geometry"/></mxCell>
<mxCell id="n4" value="Combine&lt;br/&gt;&lt;font style=&quot;font-size: 14px&quot;&gt;reverse reorder, reverse ragged_all_to_all, unsort and weight&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#fce7f3;strokeColor=#db2777;fontSize=18;fontStyle=1;fontColor=#0f172a;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="92" y="568" width="570" height="76" as="geometry"/></mxCell>

<mxCell id="t0" value="Input shard&lt;br/&gt;&lt;font style=&quot;font-size: 14px&quot;&gt;same x [B,S,H] and same parameter names&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#ffffff;strokeColor=#64748b;fontSize=18;fontStyle=1;fontColor=#0f172a;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="837" y="170" width="570" height="62" as="geometry"/></mxCell>
<mxCell id="t1" value="Router&lt;br/&gt;&lt;font style=&quot;font-size: 14px&quot;&gt;gate GEMM, tex.fused_topk_with_score_function_fwd&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#ccfbf1;strokeColor=#0f766e;fontSize=18;fontStyle=1;fontColor=#0f172a;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="837" y="260" width="570" height="72" as="geometry"/></mxCell>
<mxCell id="t2" value="Dispatch&lt;br/&gt;&lt;font style=&quot;font-size: 14px&quot;&gt;tex.ep_dispatch via NCCL EP, TE handle state&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#fef3c7;strokeColor=#d97706;fontSize=18;fontStyle=1;fontColor=#0f172a;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="837" y="360" width="570" height="76" as="geometry"/></mxCell>
<mxCell id="t3" value="Expert FFN&lt;br/&gt;&lt;font style=&quot;font-size: 14px&quot;&gt;grouped_gemm(wi_0|wi_1), activation, grouped_gemm(wo)&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#dcfce7;strokeColor=#16a34a;fontSize=18;fontStyle=1;fontColor=#0f172a;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="837" y="464" width="570" height="76" as="geometry"/></mxCell>
<mxCell id="t4" value="Combine&lt;br/&gt;&lt;font style=&quot;font-size: 14px&quot;&gt;tex.ep_combine via NCCL EP, output reshard&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#fce7f3;strokeColor=#db2777;fontSize=18;fontStyle=1;fontColor=#0f172a;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="837" y="568" width="570" height="76" as="geometry"/></mxCell>

<mxCell id="e_n0_n1" value="" style="endArrow=block;html=1;rounded=0;strokeColor=#334155;strokeWidth=2;" edge="1" parent="1" source="n0" target="n1"><mxGeometry relative="1" as="geometry"/></mxCell>
<mxCell id="e_n1_n2" value="" style="endArrow=block;html=1;rounded=0;strokeColor=#334155;strokeWidth=2;" edge="1" parent="1" source="n1" target="n2"><mxGeometry relative="1" as="geometry"/></mxCell>
<mxCell id="e_n2_n3" value="" style="endArrow=block;html=1;rounded=0;strokeColor=#334155;strokeWidth=2;" edge="1" parent="1" source="n2" target="n3"><mxGeometry relative="1" as="geometry"/></mxCell>
<mxCell id="e_n3_n4" value="" style="endArrow=block;html=1;rounded=0;strokeColor=#334155;strokeWidth=2;" edge="1" parent="1" source="n3" target="n4"><mxGeometry relative="1" as="geometry"/></mxCell>
<mxCell id="e_t0_t1" value="" style="endArrow=block;html=1;rounded=0;strokeColor=#334155;strokeWidth=2;" edge="1" parent="1" source="t0" target="t1"><mxGeometry relative="1" as="geometry"/></mxCell>
<mxCell id="e_t1_t2" value="" style="endArrow=block;html=1;rounded=0;strokeColor=#334155;strokeWidth=2;" edge="1" parent="1" source="t1" target="t2"><mxGeometry relative="1" as="geometry"/></mxCell>
<mxCell id="e_t2_t3" value="" style="endArrow=block;html=1;rounded=0;strokeColor=#334155;strokeWidth=2;" edge="1" parent="1" source="t2" target="t3"><mxGeometry relative="1" as="geometry"/></mxCell>
<mxCell id="e_t3_t4" value="" style="endArrow=block;html=1;rounded=0;strokeColor=#334155;strokeWidth=2;" edge="1" parent="1" source="t3" target="t4"><mxGeometry relative="1" as="geometry"/></mxCell>
<mxCell id="diff_router_text" value="fused router" style="text;html=1;strokeColor=none;fillColor=none;fontSize=13;fontColor=#475569;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="690" y="284" width="130" height="24" as="geometry"/></mxCell>
<mxCell id="diff_router_edge" value="" style="endArrow=block;html=1;rounded=0;strokeColor=#0f766e;strokeWidth=2;dashed=1;dashPattern=6 5;" edge="1" parent="1" source="n1" target="t1"><mxGeometry relative="1" as="geometry"/></mxCell>
<mxCell id="diff_compute_text" value="ragged_dot FFN -&amp;gt; grouped GEMM FFN" style="text;html=1;strokeColor=none;fillColor=none;fontSize=13;fontColor=#475569;align=center;verticalAlign=middle;" vertex="1" parent="1"><mxGeometry x="630" y="490" width="240" height="24" as="geometry"/></mxCell>
<mxCell id="diff_compute_edge" value="" style="endArrow=block;html=1;rounded=0;strokeColor=#0f766e;strokeWidth=2;dashed=1;dashPattern=6 5;" edge="1" parent="1" source="n3" target="t3"><mxGeometry relative="1" as="geometry"/></mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>
Loading
Loading