From 1ebffa3ccb2bed9eb79a431d8a08ca902a207a05 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 22 May 2026 23:05:06 +0000 Subject: [PATCH 01/29] Expert Parallelism: common C API + NCCL EP v0.1 backend Signed-off-by: Phuong Nguyen --- .gitmodules | 4 + 3rdparty/nccl | 1 + qa/L1_cpp_distributed/test.sh | 3 + setup.py | 127 +++ tests/cpp_distributed/CMakeLists.txt | 91 +- tests/cpp_distributed/run_test_ep.sh | 137 +++ tests/cpp_distributed/test_ep_common.h | 308 ++++++ tests/cpp_distributed/test_ep_coverage.cu | 379 ++++++++ tests/cpp_distributed/test_ep_init.cu | 64 ++ tests/cpp_distributed/test_ep_pipeline.cu | 890 ++++++++++++++++++ transformer_engine/common/CMakeLists.txt | 90 ++ transformer_engine/common/ep/ep_api.cpp | 76 ++ transformer_engine/common/ep/ep_api_stub.cpp | 61 ++ transformer_engine/common/ep/ep_backend.cpp | 514 ++++++++++ transformer_engine/common/ep/ep_backend.h | 114 +++ .../include/transformer_engine/comm_window.h | 32 + .../common/include/transformer_engine/ep.h | 161 ++++ 17 files changed, 3050 insertions(+), 2 deletions(-) create mode 160000 3rdparty/nccl create mode 100755 tests/cpp_distributed/run_test_ep.sh create mode 100644 tests/cpp_distributed/test_ep_common.h create mode 100644 tests/cpp_distributed/test_ep_coverage.cu create mode 100644 tests/cpp_distributed/test_ep_init.cu create mode 100644 tests/cpp_distributed/test_ep_pipeline.cu create mode 100644 transformer_engine/common/ep/ep_api.cpp create mode 100644 transformer_engine/common/ep/ep_api_stub.cpp create mode 100644 transformer_engine/common/ep/ep_backend.cpp create mode 100644 transformer_engine/common/ep/ep_backend.h create mode 100644 transformer_engine/common/include/transformer_engine/comm_window.h create mode 100644 transformer_engine/common/include/transformer_engine/ep.h diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..e531c95507 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,7 @@ [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "3rdparty/nccl"] + path = 3rdparty/nccl + url = https://github.com/NVIDIA/nccl.git + branch = v2.30u1 diff --git a/3rdparty/nccl b/3rdparty/nccl new file mode 160000 index 0000000000..6a9bc953ac --- /dev/null +++ b/3rdparty/nccl @@ -0,0 +1 @@ +Subproject commit 6a9bc953ac1c4eef92d5adbe3092d4c2cb0a4c98 diff --git a/qa/L1_cpp_distributed/test.sh b/qa/L1_cpp_distributed/test.sh index 8d767a4efb..7e5ce2cf0d 100755 --- a/qa/L1_cpp_distributed/test.sh +++ b/qa/L1_cpp_distributed/test.sh @@ -14,4 +14,7 @@ if [[ $(nvidia-smi --list-gpus | wc -l) -ge 4 ]]; then cmake -GNinja -S. -Bbuild cmake --build build mpirun --allow-run-as-root --np 4 --oversubscribe ./build/test_comm_gemm + + # EP suites; runner self-skips on pre-Hopper GPUs. + bash ./run_test_ep.sh 4 ./build fi diff --git a/setup.py b/setup.py index ec277b6349..db360c8a29 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,34 @@ def setup_common_extension() -> CMakeExtension: cusolvermp_dir = os.getenv("CUSOLVERMP_HOME", "/usr") cmake_flags.append(f"-DCUSOLVERMP_DIR={cusolvermp_dir}") + # NCCL EP: on by default; auto-disabled if no arch >= 90. + # Set NVTE_BUILD_WITH_NCCL_EP=0/1 to force off/on. + nccl_ep_env = os.getenv("NVTE_BUILD_WITH_NCCL_EP") + explicit_nccl_ep = nccl_ep_env is not None + build_with_nccl_ep = bool(int(nccl_ep_env)) if explicit_nccl_ep else True + + if build_with_nccl_ep: + arch_tokens = [a.strip() for a in str(archs or "").split(";") if a.strip()] + has_hopper_or_newer = any(t.lower() == "native" for t in arch_tokens) or any( + int(t.rstrip("af")) >= 90 for t in arch_tokens if t.rstrip("af").isdigit() + ) + if not has_hopper_or_newer: + if explicit_nccl_ep: + raise RuntimeError( + "NVTE_BUILD_WITH_NCCL_EP=1 requires at least one CUDA arch >= 90 in " + f"NVTE_CUDA_ARCHS (got '{archs}'). Add '90' or unset NVTE_BUILD_WITH_NCCL_EP." + ) + print( + "[NCCL EP] No CUDA arch >= 90 in NVTE_CUDA_ARCHS" + f" ('{archs}'); auto-disabling NCCL EP (nvte_ep_* will throw at runtime)." + ) + build_with_nccl_ep = False + + if build_with_nccl_ep: + build_nccl_ep_submodule() + else: + cmake_flags.append("-DNVTE_WITH_NCCL_EP=OFF") + # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") if nvte_cmake_extra_args: @@ -128,6 +156,105 @@ def setup_requirements() -> Tuple[List[str], List[str]]: return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]] +def _discover_nccl_home() -> str: + """Resolve NCCL_HOME: honor env var, else probe well-known prefixes, else ldconfig.""" + env_home = os.environ.get("NCCL_HOME") + if env_home: + if (Path(env_home) / "include" / "nccl.h").exists(): + return env_home + print( + f"[NCCL EP] WARNING: NCCL_HOME='{env_home}' is set but " + f"'{env_home}/include/nccl.h' was not found; falling back to system probes." + ) + + for cand in ("/opt/nvidia/nccl", "/usr/local/nccl", "/usr"): + p = Path(cand) + if (p / "include" / "nccl.h").exists() and any( + (p / "lib" / name).exists() or (p / "lib64" / name).exists() + for name in ("libnccl.so", "libnccl.so.2") + ): + return str(p) + + try: + out = subprocess.check_output(["ldconfig", "-p"], stderr=subprocess.DEVNULL).decode() + for line in out.splitlines(): + if "libnccl.so" in line and "=>" in line: + lib_path = Path(line.split("=>")[-1].strip()) + root = lib_path.parent.parent + if (root / "include" / "nccl.h").exists(): + return str(root) + except (subprocess.CalledProcessError, FileNotFoundError): + pass + + raise RuntimeError( + "Could not locate NCCL core (nccl.h + libnccl.so). Set NCCL_HOME to the install prefix." + ) + + +def build_nccl_ep_submodule() -> str: + """Build libnccl_ep.so from the 3rdparty/nccl submodule. + + NCCL EP is on by default; the system NCCL core (libnccl.so) supplies the + headers and runtime symbols. Returns the submodule build directory. + """ + nccl_root = current_file_path / "3rdparty" / "nccl" + if not (nccl_root / "Makefile").exists(): + raise RuntimeError( + f"NCCL submodule not found at {nccl_root}. " + "Run `git submodule update --init --recursive`." + ) + + build_dir = nccl_root / "build" + nccl_ep_lib = build_dir / "lib" / "libnccl_ep.so" + + archs = cuda_archs() or "90" + arch_list = [] + for a in str(archs).split(";"): + a = a.strip().rstrip("af") + if a and a.isdigit() and int(a) >= 90: + arch_list.append(a) + if not arch_list: + arch_list = ["90"] + gencode = " ".join(f"-gencode=arch=compute_{a},code=sm_{a}" for a in arch_list) + + nproc = os.cpu_count() or 8 + env = os.environ.copy() + env["NVCC_GENCODE"] = gencode + # NCCL EP needs the core NCCL headers + libnccl.so; write NCCL EP build + # outputs to the submodule's local build/ tree. + nccl_home = _discover_nccl_home() + env["NCCL_HOME"] = nccl_home + env["NCCL_EP_BUILDDIR"] = str(build_dir) + + if not nccl_ep_lib.exists(): + print(f"[NCCL EP] Building libnccl_ep.so (gencode='{gencode}')") + subprocess.check_call( + ["make", "-j", str(nproc), "-C", "contrib/nccl_ep", "lib"], + cwd=str(nccl_root), + env=env, + ) + + # TE's CMake expects nccl.h under 3rdparty/nccl/build/include/ for its + # version check. Mirror the top-level host headers from the system NCCL + # install — DON'T mirror nccl_device/ because the submodule ships its own + # newer copy at src/include/nccl_device/ with device-side templates that + # conflict with older system versions, and the JIT include path picks the + # submodule's. + nccl_include = build_dir / "include" + nccl_include.mkdir(parents=True, exist_ok=True) + for cand in (Path(nccl_home) / "include", Path("/usr/include")): + p = Path(cand) + if (p / "nccl.h").exists(): + for name in ("nccl.h", "nccl_net.h", "nccl_tuner.h"): + src = p / name + dst = nccl_include / name + if src.exists() and not dst.exists(): + dst.symlink_to(src) + break + + return str(build_dir) + + def git_check_submodules() -> None: """ Attempt to checkout git submodules automatically during setup. diff --git a/tests/cpp_distributed/CMakeLists.txt b/tests/cpp_distributed/CMakeLists.txt index 0d7258a81d..3870f57911 100644 --- a/tests/cpp_distributed/CMakeLists.txt +++ b/tests/cpp_distributed/CMakeLists.txt @@ -30,7 +30,7 @@ if(NOT DEFINED TE_LIB_PATH) get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY) endif() -find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) +find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED NO_CMAKE_SYSTEM_PATH) message(STATUS "Found transformer_engine library: ${TE_LIB}") include_directories(../../transformer_engine/common/include) @@ -46,12 +46,99 @@ add_executable(test_comm_gemm find_package(OpenMP REQUIRED) find_package(MPI REQUIRED) + +# ── NCCL library ────────────────────────────────────────────────────────────── +# Search order: NCCL_HOME env → 3rdparty/nccl submodule build → system paths. +set(NCCL_SUBMODULE_BUILD "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/build") find_library(NCCL_LIB NAMES nccl libnccl - PATH_SUFFIXES lib + HINTS $ENV{NCCL_HOME}/lib ${NCCL_SUBMODULE_BUILD}/lib + PATH_SUFFIXES lib lib64 REQUIRED) + +# NCCL headers: prefer submodule build output (has the handle_init API), +# then submodule src, then system (CUDA toolkit). +set(NCCL_SUBMODULE_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/build/include") +set(NCCL_SUBMODULE_SRC_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/src/include") +if(EXISTS "${NCCL_SUBMODULE_INCLUDE}/nccl.h") + set(NCCL_INCLUDE_DIR "${NCCL_SUBMODULE_INCLUDE}") +elseif(EXISTS "${NCCL_SUBMODULE_SRC_INCLUDE}/nccl.h") + set(NCCL_INCLUDE_DIR "${NCCL_SUBMODULE_SRC_INCLUDE}") +elseif(DEFINED ENV{NCCL_HOME}) + set(NCCL_INCLUDE_DIR "$ENV{NCCL_HOME}/include") +endif() target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) include(GoogleTest) gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) + +# ── EP distributed tests (HT mode) ───────────────────────────────────────── +# No MPI dependency — processes are spawned by run_test_ep.sh with +# --rank / --nranks flags. ncclUniqueId exchange uses a +# shared temp file (see test_ep_common.h for details). +# Headers + libs come from the in-tree 3rdparty/nccl submodule build. +set(NCCL_EP_SUBMODULE_ROOT + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl") +find_library(NCCL_EP_LIB + NAMES nccl_ep libnccl_ep + HINTS ${NCCL_EP_SUBMODULE_ROOT}/build/lib + NO_DEFAULT_PATH + REQUIRED) + +set(NCCL_EP_INCLUDE_DIR "${NCCL_EP_SUBMODULE_ROOT}/contrib/nccl_ep/include") +if(NOT EXISTS "${NCCL_EP_INCLUDE_DIR}/nccl_ep.h") + message(FATAL_ERROR + "NCCL EP header not found at ${NCCL_EP_INCLUDE_DIR}/nccl_ep.h. " + "Run `git submodule update --init --recursive` to checkout 3rdparty/nccl.") +endif() +message(STATUS "EP test: NCCL EP headers: ${NCCL_EP_INCLUDE_DIR}") + +# Collect NCCL include dirs shared by all EP test targets (nccl_ep.h + nccl.h). +set(EP_TEST_NCCL_INCLUDES ${NCCL_EP_INCLUDE_DIR}) +if(DEFINED NCCL_INCLUDE_DIR) + list(APPEND EP_TEST_NCCL_INCLUDES ${NCCL_INCLUDE_DIR}) + message(STATUS "EP test: NCCL headers: ${NCCL_INCLUDE_DIR}") +endif() + +set(EP_TEST_COMMON_INCLUDES + ${EP_TEST_NCCL_INCLUDES} + ../../transformer_engine/common/include + ../../transformer_engine/common + ${CMAKE_CURRENT_SOURCE_DIR}) + +set(EP_TEST_COMMON_LIBS + CUDA::cuda_driver + CUDA::cudart + CUDA::nvrtc + GTest::gtest + ${TE_LIB} + ${NCCL_LIB} + ${NCCL_EP_LIB}) + +# nvrtc symbols are referenced from libtransformer_engine.so but not in its +# DT_NEEDED list (loaded via dlopen in Python). For cpp tests we link nvrtc +# explicitly with --no-as-needed so the linker keeps the dependency. +set(EP_TEST_LINK_OPTS "LINKER:--no-as-needed") + +# ── EP init tests (InitPath, HandleMemSizeQuery) ───────────────────────────── +add_executable(test_ep_init test_ep_init.cu) +target_include_directories(test_ep_init PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep_init PUBLIC ${EP_TEST_COMMON_LIBS}) +target_link_options(test_ep_init PUBLIC ${EP_TEST_LINK_OPTS}) + +# ── EP pipeline tests (dispatch, combine, bwd, integrated) ─────────────────── +add_executable(test_ep_pipeline test_ep_pipeline.cu) +target_include_directories(test_ep_pipeline PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep_pipeline PUBLIC ${EP_TEST_COMMON_LIBS}) +target_link_options(test_ep_pipeline PUBLIC ${EP_TEST_LINK_OPTS}) + +# ── EP coverage tests (multi-handle, top_k=1, empty experts, negatives, threading) ── +add_executable(test_ep_coverage test_ep_coverage.cu) +target_include_directories(test_ep_coverage PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep_coverage PUBLIC ${EP_TEST_COMMON_LIBS}) +target_link_options(test_ep_coverage PUBLIC ${EP_TEST_LINK_OPTS}) + +# Do NOT use gtest_discover_tests — these binaries require multi-process +# launch via run_test_ep.sh, not direct single-process execution. +message(STATUS "EP distributed tests enabled: ${NCCL_EP_LIB}") diff --git a/tests/cpp_distributed/run_test_ep.sh b/tests/cpp_distributed/run_test_ep.sh new file mode 100755 index 0000000000..017d3f807b --- /dev/null +++ b/tests/cpp_distributed/run_test_ep.sh @@ -0,0 +1,137 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Run TE EP distributed unit tests across multiple GPUs. +# +# Spawns one background bash process per GPU (no MPI dependency), matching the +# JAX multi-process launcher style. ncclUniqueId is exchanged via a shared +# temp file (see test_ep_common.h). Each rank builds its own ncclComm_t and +# passes it to nvte_ep_initialize. +# +# Usage: +# bash run_test_ep.sh [num_gpus] [build_dir] +# +# Defaults: +# num_gpus = number of GPUs visible to nvidia-smi +# build_dir = /build +# +# Environment variables: +# GTEST_FILTER — forwarded to all processes (e.g., "EPDispatchTest.*") +# TEST_TIMEOUT_S — per-process timeout in seconds (default: 180) + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BUILD_DIR="${2:-${SCRIPT_DIR}/build}" +NUM_GPUS="${1:-$(nvidia-smi -L 2>/dev/null | wc -l)}" +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-180}" + +# Skip cleanly on pre-Hopper: NCCL EP requires SM>=90. +MIN_SM=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \ + | awk -F. 'NR==1 || ($1*10+$2) 0 && MIN_SM < 90 )); then + echo "NCCL EP requires SM>=90 (lowest visible GPU is SM${MIN_SM}); SKIPPING." + exit 0 +fi + +GTEST_ARGS="${GTEST_FILTER:+--gtest_filter=${GTEST_FILTER}}" +OVERALL_FAIL=0 + +# --------------------------------------------------------------------------- +# run_suite BINARY SUITE_NAME MIN_GPUS +# --------------------------------------------------------------------------- +run_suite() { + local BINARY="$1" + local SUITE_NAME="$2" + local MIN_GPUS="${3:-2}" + + local TEST_BIN="${BUILD_DIR}/${BINARY}" + + if [[ ! -x "${TEST_BIN}" ]]; then + echo "ERROR: binary not found: ${TEST_BIN}" + echo "Build: cd ${SCRIPT_DIR} && mkdir -p build && cd build && cmake .. && make" + OVERALL_FAIL=1 + return + fi + + if (( NUM_GPUS < MIN_GPUS )); then + echo "${SUITE_NAME}: requires ${MIN_GPUS} GPUs, found ${NUM_GPUS}. Skipping." + return + fi + + local TMPDIR_L="${TMPDIR:-/tmp}" + local UID_FILE="${TMPDIR_L}/te_ep_uid_${BINARY}_$$" + rm -f "${UID_FILE}" + + local LOG_DIR + LOG_DIR=$(mktemp -d) + local FAIL=0 + + echo "=== ${SUITE_NAME} ===" + echo " GPUs: ${NUM_GPUS} Binary: ${TEST_BIN}" + echo + + # Spawn one background process per GPU. ncclUniqueId is exchanged via the + # shared UID_FILE. Each process is wrapped in `timeout` to detect hangs early. + local PIDS=() + for i in $(seq 0 $((NUM_GPUS - 1))); do + timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + "${TEST_BIN}" \ + --rank="${i}" \ + --nranks="${NUM_GPUS}" \ + --uid-file="${UID_FILE}" \ + ${GTEST_ARGS} \ + > "${LOG_DIR}/rank_${i}.log" 2>&1 & + PIDS+=($!) + done + for i in $(seq 0 $((NUM_GPUS - 1))); do + if ! wait "${PIDS[$i]}"; then + local rc=$? + FAIL=1 + if [[ $rc -eq 137 || $rc -eq 124 ]]; then + echo " rank ${i}: TIMEOUT after ${TEST_TIMEOUT_S}s (rc=${rc})" + fi + fi + done + + echo "--- Rank 0 output ---" + cat "${LOG_DIR}/rank_0.log" + + if (( FAIL )); then + for i in $(seq 1 $((NUM_GPUS - 1))); do + echo "--- Rank ${i} output ---" + cat "${LOG_DIR}/rank_${i}.log" + done + echo "=== ${SUITE_NAME}: FAILED ===" + OVERALL_FAIL=1 + else + echo "=== ${SUITE_NAME}: ALL PASSED ===" + fi + + rm -rf "${LOG_DIR}" + rm -f "${UID_FILE}" +} + +# --------------------------------------------------------------------------- +# Cleanup on abort +# --------------------------------------------------------------------------- +cleanup() { rm -f "${TMPDIR:-/tmp}"/te_ep_uid_*_"$$" 2>/dev/null || true; } +trap cleanup EXIT INT TERM + +# --------------------------------------------------------------------------- +# Run all suites +# --------------------------------------------------------------------------- +run_suite "test_ep_init" "EP Init Tests" 2 +run_suite "test_ep_pipeline" "EP Pipeline Tests" 2 +run_suite "test_ep_coverage" "EP Coverage Tests" 2 + +echo +if (( OVERALL_FAIL )); then + echo "=== SOME SUITES FAILED ===" +else + echo "=== ALL SUITES PASSED ===" +fi + +exit "${OVERALL_FAIL}" diff --git a/tests/cpp_distributed/test_ep_common.h b/tests/cpp_distributed/test_ep_common.h new file mode 100644 index 0000000000..77baa92b0c --- /dev/null +++ b/tests/cpp_distributed/test_ep_common.h @@ -0,0 +1,308 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * Shared TE EP test infrastructure. Include once per TU; ep_bootstrap() in + * each test binary's main() populates process-level globals. + * Defaults: 4 experts/rank, hidden_dim=256, max_tokens_per_rank=64. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +// ── Error-checking macros ───────────────────────────────────────────────────── + +#define CHECK_NCCL(expr) \ + do { \ + ncclResult_t _err = (expr); \ + if (_err != ncclSuccess) \ + FAIL() << "NCCL error " << _err << ": " << ncclGetErrorString(_err); \ + } while (false) + +#define CHECK_CUDA(expr) \ + do { \ + cudaError_t _err = (expr); \ + if (_err != cudaSuccess) \ + FAIL() << "CUDA error " << _err << ": " << cudaGetErrorString(_err); \ + } while (false) + +#define ASSERT_CUDA_OK(expr) \ + do { \ + cudaError_t _err = (expr); \ + if (_err != cudaSuccess) { \ + fprintf(stderr, "CUDA error %d: %s\n", _err, cudaGetErrorString(_err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (false) + +#define ASSERT_NCCL_OK(expr) \ + do { \ + ncclResult_t _err = (expr); \ + if (_err != ncclSuccess) { \ + fprintf(stderr, "NCCL error %d: %s\n", _err, ncclGetErrorString(_err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (false) + +// ── Process-level state ─────────────────────────────────────────────────────── + +static int g_process_id = -1; +static int g_num_processes = -1; +static std::string g_uid_file; + +static int g_sm_major = -1; // set by ep_bootstrap; -1 until then +static int g_ep_size = -1; +static int g_num_experts = -1; +static int g_hidden_dim = 256; +static int g_max_tokens_per_rank = 64; +static bool g_ep_initialized = false; +static ncclComm_t g_ep_comm = nullptr; // owned by harness, destroyed in ep_teardown + +// ── TensorHandle RAII wrapper ───────────────────────────────────────────────── + +// View over a caller-owned device buffer; owns NVTETensor metadata only. Move-only. +struct TensorHandle { + NVTETensor tensor = nullptr; + void* dev_ptr = nullptr; + + ~TensorHandle() { + if (tensor) nvte_destroy_tensor(tensor); + } + + TensorHandle() = default; + TensorHandle(const TensorHandle&) = delete; + TensorHandle& operator=(const TensorHandle&) = delete; + + TensorHandle(TensorHandle&& o) noexcept : tensor(o.tensor), dev_ptr(o.dev_ptr) { + o.tensor = nullptr; o.dev_ptr = nullptr; + } + TensorHandle& operator=(TensorHandle&& o) noexcept { + if (this != &o) { + if (tensor) nvte_destroy_tensor(tensor); + tensor = o.tensor; dev_ptr = o.dev_ptr; + o.tensor = nullptr; o.dev_ptr = nullptr; + } + return *this; + } +}; + +static TensorHandle make_nvte_tensor(void* dev_ptr, + const std::vector& shape, + NVTEDType dtype) { + TensorHandle h; + h.dev_ptr = dev_ptr; + h.tensor = nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING); + + NVTEShape s; + s.ndim = shape.size(); + for (size_t i = 0; i < shape.size(); ++i) s.data[i] = shape[i]; + + NVTEBasicTensor bt; + bt.data_ptr = dev_ptr; + bt.dtype = dtype; + bt.shape = s; + nvte_set_tensor_param_v2(h.tensor, kNVTERowwiseData, &bt, sizeof(bt)); + + return h; +} + +// RAII owner for a cudaMalloc'd device buffer; frees on destruction. +template +struct DevBuf { + T* ptr = nullptr; + size_t count = 0; + + DevBuf() = default; + explicit DevBuf(size_t n) { alloc(n); } + ~DevBuf() { reset(); } + + DevBuf(const DevBuf&) = delete; + DevBuf& operator=(const DevBuf&) = delete; + DevBuf(DevBuf&& o) noexcept : ptr(o.ptr), count(o.count) { o.ptr = nullptr; o.count = 0; } + DevBuf& operator=(DevBuf&& o) noexcept { + if (this != &o) { reset(); ptr = o.ptr; count = o.count; o.ptr = nullptr; o.count = 0; } + return *this; + } + + void alloc(size_t n) { + reset(); + count = n; + if (n > 0) { + cudaError_t e = cudaMalloc(&ptr, n * sizeof(T)); + if (e != cudaSuccess) { + fprintf(stderr, "DevBuf cudaMalloc(%zu) failed: %s\n", n * sizeof(T), + cudaGetErrorString(e)); + ptr = nullptr; + count = 0; + } + } + } + + void reset() { + if (ptr) { cudaFree(ptr); ptr = nullptr; } + count = 0; + } + + T* get() const { return ptr; } + size_t bytes() const { return count * sizeof(T); } +}; + +// ── Shared routing helper ───────────────────────────────────────────────────── + +// Balanced round-robin routing: token t on rank r maps top_k experts to +// (r * num_local_experts + t * top_k + k) % num_experts +static inline std::vector routing_balanced( + int rank, int num_tokens, int top_k, int num_experts, int num_local_experts) { + std::vector idx(num_tokens * top_k); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) + idx[t * top_k + k] = (rank * num_local_experts + t * top_k + k) % num_experts; + return idx; +} + +// ── File-based ncclUniqueId exchange ───────────────────────────────────────── + +static void exchange_unique_id(ncclUniqueId* uid) { + const size_t sz = sizeof(ncclUniqueId); + + if (g_process_id == 0) { + ASSERT_NCCL_OK(ncclGetUniqueId(uid)); + FILE* f = fopen(g_uid_file.c_str(), "wb"); + if (!f) { fprintf(stderr, "Cannot open uid file: %s\n", g_uid_file.c_str()); exit(EXIT_FAILURE); } + fwrite(uid, 1, sz, f); + fclose(f); + } else { + auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(60); + while (true) { + FILE* f = fopen(g_uid_file.c_str(), "rb"); + if (f) { + fseek(f, 0, SEEK_END); + if (static_cast(ftell(f)) >= sz) { + fseek(f, 0, SEEK_SET); + size_t n = fread(uid, 1, sz, f); + fclose(f); + if (n == sz) break; + } else { + fclose(f); + } + } + if (std::chrono::steady_clock::now() > deadline) { + fprintf(stderr, "Process %d: timed out waiting for uid file\n", g_process_id); + exit(EXIT_FAILURE); + } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + } +} + +// ── CLI parsing ─────────────────────────────────────────────────────────────── + +static void ep_parse_args(int argc, char* argv[]) { + for (int i = 1; i < argc; ++i) { + std::string a(argv[i]); + if (a.rfind("--process-id=", 0) == 0) g_process_id = std::stoi(a.substr(13)); + else if (a.rfind("--rank=", 0) == 0) g_process_id = std::stoi(a.substr(7)); + else if (a.rfind("--num-processes=",0)==0) g_num_processes = std::stoi(a.substr(16)); + else if (a.rfind("--nranks=", 0) == 0) g_num_processes = std::stoi(a.substr(9)); + else if (a.rfind("--uid-file=", 0) == 0) g_uid_file = a.substr(11); + } + + if (g_process_id < 0 || g_num_processes <= 0) { + fprintf(stderr, + "Usage: %s --rank=N --nranks=N [--uid-file=path] [gtest flags]\n" + " Aliases: --process-id=N, --num-processes=N\n", + argc > 0 ? argv[0] : "test_ep"); + exit(EXIT_FAILURE); + } + + if (g_uid_file.empty()) { + const char* t = getenv("TMPDIR"); if (!t) t = "/tmp"; + g_uid_file = std::string(t) + "/te_ep_uid_" + std::to_string(g_process_id); + } +} + +// ── Bootstrap / teardown ────────────────────────────────────────────────────── + +// Returns false if the binary should exit without running tests (wrong SM, etc.). +static bool ep_bootstrap(int argc, char* argv[]) { + ep_parse_args(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + + int device_count; + cudaGetDeviceCount(&device_count); + cudaSetDevice(g_process_id % device_count); + + int device, major; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); + g_sm_major = major; + if (major < 9) { + if (g_process_id == 0) + printf("SKIP: EP requires SM_90+ (device is SM_%d0)\n", major); + return false; + } + if (g_num_processes < 2) { + if (g_process_id == 0) + printf("SKIP: at least 2 processes required\n"); + return false; + } + + g_ep_size = g_num_processes; + g_num_experts = g_ep_size * 4; // 4 experts per rank + + ncclUniqueId uid{}; + exchange_unique_id(&uid); + + NVTEEpGroupConfig group_config{}; + group_config.ep_size = g_ep_size; + group_config.num_experts = g_num_experts; + group_config.max_tokens_per_rank = g_max_tokens_per_rank; + // Worst-case for top_k fan-out: ep_size * max_tokens_per_rank * 2. + group_config.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; + group_config.hidden_dim = g_hidden_dim; + + ASSERT_NCCL_OK(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id)); + nvte_ep_initialize(static_cast(g_ep_comm), group_config); + + if (g_process_id == 0) { + printf("EP initialized: ep_size=%d num_experts=%d " + "hidden_dim=%d max_tokens_per_rank=%d\n", + g_ep_size, g_num_experts, g_hidden_dim, g_max_tokens_per_rank); + } + + g_ep_initialized = true; + return true; +} + +// Tear down in dependency order: backend's ep_group reads from ep_comm, +// so destroy the group first, then the comm. +static void ep_teardown() { + if (g_ep_initialized) { + nvte_ep_shutdown(); + if (g_ep_comm != nullptr) { + ncclCommDestroy(g_ep_comm); + g_ep_comm = nullptr; + } + g_ep_initialized = false; + } + if (g_process_id == 0) remove(g_uid_file.c_str()); +} diff --git a/tests/cpp_distributed/test_ep_coverage.cu b/tests/cpp_distributed/test_ep_coverage.cu new file mode 100644 index 0000000000..ef7941905d --- /dev/null +++ b/tests/cpp_distributed/test_ep_coverage.cu @@ -0,0 +1,379 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * EP C-API coverage tests (paths not exercised by the pipeline suite). + * + * MultiHandleAllocTest — distinct handle ids; each works end-to-end. + * TopK1Test — top_k=1 dispatch/combine/bwd round-trip. + * EmptyExpertsTest — alignment ∈ {0, 2, 8, 16} with experts receiving 0 tokens. + * NegativeTests — alignment mismatch and null handle_mem must throw. + */ + +#include "test_ep_common.h" + +#include +#include + +// top1 -> expert 0, top2 -> expert 2; leaves local-expert 1 empty between two +// full experts. Requires top_k >= 2 and num_experts >= 3. +static std::vector routing_skip_middle(int num_tokens, int top_k) { + std::vector idx(num_tokens * top_k); + for (int t = 0; t < num_tokens; ++t) { + idx[t * top_k + 0] = 0; + if (top_k >= 2) idx[t * top_k + 1] = 2; + for (int k = 2; k < top_k; ++k) idx[t * top_k + k] = 2 + k; // distinct stragglers + } + return idx; +} + +static std::vector tokens_constant(int num_tokens, int hidden_dim, float val) { + std::vector v(num_tokens * hidden_dim); + nv_bfloat16 b = __float2bfloat16(val); + std::fill(v.begin(), v.end(), b); + return v; +} + +namespace { + +class EpCoverageBase : public ::testing::Test { + protected: + int ep_size_, num_experts_, num_local_experts_, hidden_dim_; + int max_tokens_per_rank_; + + void SetUp() override { + if (g_sm_major < 9) + GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; + ASSERT_GE(g_num_processes, 2); + ASSERT_TRUE(g_ep_initialized); + ep_size_ = g_ep_size; + num_experts_ = g_num_experts; + num_local_experts_ = num_experts_ / ep_size_; + hidden_dim_ = g_hidden_dim; + max_tokens_per_rank_ = g_max_tokens_per_rank; + } + + // Helper: allocate buffers + tensor views for a single dispatch+combine. + struct Bundle { + DevBuf topk_idx; + DevBuf topk_weights; + DevBuf tokens; + DevBuf token_counts; + DevBuf handle_mem; + DevBuf recv_tokens; + DevBuf recv_topk_weights; + DevBuf result; + uint64_t handle_id = 0; + size_t handle_mem_size = 0; + size_t recv_capacity = 0; + }; + + Bundle make_bundle(int num_tokens, int top_k, int num_local_experts, + size_t alignment) { + Bundle b; + b.recv_capacity = static_cast(ep_size_) * max_tokens_per_rank_ * 2; + b.topk_idx.alloc(num_tokens * top_k); + b.topk_weights.alloc(num_tokens * top_k); + b.tokens.alloc(num_tokens * hidden_dim_); + b.token_counts.alloc(num_local_experts); + b.recv_tokens.alloc(b.recv_capacity * hidden_dim_); + b.recv_topk_weights.alloc(b.recv_capacity); + b.result.alloc(num_tokens * hidden_dim_); + NVTEEpLayerConfig cfg{num_local_experts, top_k, alignment}; + b.handle_id = nvte_ep_register_layer(cfg, &b.handle_mem_size); + b.handle_mem.alloc(b.handle_mem_size); + return b; + } +}; + +} // namespace + +// ============================================================================= +// MultiHandleAllocTest: ids are distinct and each is independently usable. +// ============================================================================= + +class MultiHandleAllocTest : public EpCoverageBase {}; + +TEST_F(MultiHandleAllocTest, IdsAreDistinct) { + NVTEEpLayerConfig cfg{num_local_experts_, /*top_k=*/2, /*alignment=*/0}; + const int kN = 8; + std::vector ids(kN); + for (int i = 0; i < kN; ++i) { + size_t sz = 0; + ids[i] = nvte_ep_register_layer(cfg, &sz); + } + for (int i = 0; i < kN; ++i) { + EXPECT_NE(ids[i], 0u) << "handle_id 0 is reserved as \"no id\""; + for (int j = i + 1; j < kN; ++j) + EXPECT_NE(ids[i], ids[j]) << "duplicate id " << ids[i] << " at indices " << i << ", " << j; + } +} + +TEST_F(MultiHandleAllocTest, TwoHandlesCoexist) { + const int num_tokens = 16, top_k = 2; + Bundle a = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens * top_k, 1.0f / top_k); + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.5f); + for (Bundle* x : {&a, &b}) { + CHECK_CUDA(cudaMemcpy(x->topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(x->topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(x->tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + } + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NE(a.handle_id, b.handle_id); + + auto run_one = [&](Bundle& x) { + auto topk_idx = make_nvte_tensor(x.topk_idx.get(), {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto topk_weights = make_nvte_tensor(x.topk_weights.get(), {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + auto token_counts = make_nvte_tensor(x.token_counts.get(), {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem = make_nvte_tensor(x.handle_mem.get(), {x.handle_mem_size}, kNVTEByte); + auto tokens = make_nvte_tensor(x.tokens.get(), {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_tokens = make_nvte_tensor(x.recv_tokens.get(), {x.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_w = make_nvte_tensor(x.recv_topk_weights.get(), {x.recv_capacity}, kNVTEFloat32); + auto result = make_nvte_tensor(x.result.get(), {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + NVTEEpHandle h{x.handle_id, handle_mem.tensor}; + ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx.tensor, token_counts.tensor, + /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx.tensor, tokens.tensor, + NVTECommWindow{}, topk_weights.tensor, NVTECommWindow{}, + recv_tokens.tensor, NVTECommWindow{}, recv_w.tensor, + NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens.tensor, NVTECommWindow{}, + result.tensor, stream)); + }; + run_one(a); + run_one(b); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // Both round-trips must produce result == top_k * 0.5 = 1.0. + for (Bundle* x : {&a, &b}) { + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), x->result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), + static_cast(top_k) * 0.5f, 1e-2f); + } + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// TopK1Test: top_k=1 dispatch/combine round-trip, including dispatch_bwd. +// ============================================================================= + +class TopK1Test : public EpCoverageBase {}; + +TEST_F(TopK1Test, RoundTrip) { + const int num_tokens = 16, top_k = 1; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens * top_k, 1.0f); // top_k=1: weight is unity + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.25f); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + auto tokens_t = make_nvte_tensor(b.tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), + {b.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + auto result_t = make_nvte_tensor(b.result.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; + ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx_t.tensor, + tokens_t.tensor, NVTECommWindow{}, topk_weights_t.tensor, + NVTECommWindow{}, recv_tokens_t.tensor, NVTECommWindow{}, + recv_w_t.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens_t.tensor, + NVTECommWindow{}, result_t.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // top_k=1: combine is unweighted gather, so result[t] == tokens[t]. + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), 0.25f, 1e-2f) + << "tok " << t << " hidden " << p; + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EmptyExpertsTest: alignment ∈ {0, 2, 8, 16}, only local-expert 0 receives +// tokens. Round-trip must produce result == top_k * tokens regardless of the +// per-expert padding choice. +// ============================================================================= + +class EmptyExpertsTest : public EpCoverageBase, + public ::testing::WithParamInterface {}; + +TEST_P(EmptyExpertsTest, RoundTripCorrect) { + // routing_skip_middle needs experts {0, 2, ...}; smallest viable num_experts is 3. + ASSERT_GE(num_experts_, 3); + const size_t alignment = GetParam(); + const int num_tokens = 16, top_k = 2; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, alignment); + + // top1 -> expert 0, top2 -> expert 2; rank 0's local-expert 1 receives 0 + // tokens between two non-empty experts. + std::vector h_idx = routing_skip_middle(num_tokens, top_k); + std::vector h_w(num_tokens * top_k, 1.0f / top_k); + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.3f); + + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + auto tokens_t = make_nvte_tensor(b.tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), + {b.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + auto result_t = make_nvte_tensor(b.result.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; + ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + alignment, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx_t.tensor, + tokens_t.tensor, NVTECommWindow{}, topk_weights_t.tensor, + NVTECommWindow{}, recv_tokens_t.tensor, NVTECommWindow{}, + recv_w_t.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens_t.tensor, + NVTECommWindow{}, result_t.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // Identity expert + uniform weights: result[t] == top_k * tokens[t]. + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const float expected = static_cast(top_k) * 0.3f; + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), expected, 1e-2f) + << "alignment=" << alignment << " tok=" << t << " hidden=" << p; + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +INSTANTIATE_TEST_SUITE_P(Alignments, EmptyExpertsTest, + ::testing::Values(0, 2, 8, 16)); + +// ============================================================================= +// NegativeTests: prepare/dispatch must surface bad inputs as exceptions. +// ============================================================================= + +class NegativeTests : public EpCoverageBase {}; + +TEST_F(NegativeTests, AlignmentMismatchThrows) { + const int num_tokens = 8, top_k = 2; + // Allocate handle for alignment=0, then call prepare with alignment=16. + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; + EXPECT_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/16, stream), + std::exception); + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +TEST_F(NegativeTests, NullHandleMemThrows) { + const int num_tokens = 8, top_k = 2; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + // Construct a tensor view backed by a null device pointer. + auto null_hm_t = make_nvte_tensor(nullptr, {b.handle_mem_size}, kNVTEByte); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + NVTEEpHandle h{b.handle_id, null_hm_t.tensor}; + EXPECT_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/0, stream), + std::exception); + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ── main ────────────────────────────────────────────────────────────────────── + +int main(int argc, char* argv[]) { + if (!ep_bootstrap(argc, argv)) return 0; + int ret = RUN_ALL_TESTS(); + ep_teardown(); + return ret; +} diff --git a/tests/cpp_distributed/test_ep_init.cu b/tests/cpp_distributed/test_ep_init.cu new file mode 100644 index 0000000000..08744dfee5 --- /dev/null +++ b/tests/cpp_distributed/test_ep_init.cu @@ -0,0 +1,64 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * Unit tests for EP initialization paths. + * + * Tests: + * EPInitTest/InitPath — backend is live after init, handle_mem_size > 0 + * EPInitTest/NumLocalExperts — handle_mem_size is consistent across num_local_experts values + * + * Run via run_test_ep.sh (both uid and comm init paths are tested by the script). + */ + +#include "test_ep_common.h" + +// ── Fixture ─────────────────────────────────────────────────────────────────── + +class EPInitTest : public ::testing::Test { + protected: + void SetUp() override { + if (g_sm_major < 9) + GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; + ASSERT_GE(g_num_processes, 2) << "EP tests require at least 2 processes"; + ASSERT_TRUE(g_ep_initialized) << "EP not initialized"; + } +}; + +// ── Tests ───────────────────────────────────────────────────────────────────── + +TEST_F(EPInitTest, InitPath) { + int nle = g_num_experts / g_ep_size; + NVTEEpLayerConfig cfg{nle, /*top_k=*/2}; + size_t sz = 0; + (void)nvte_ep_register_layer(cfg, &sz); + ASSERT_GT(sz, 0u) << "handle_mem_size must be > 0 after init"; + + if (g_process_id == 0) { + printf(" handle_mem : %zu bytes\n", sz); + } +} + +TEST_F(EPInitTest, NumLocalExperts) { + // handle_mem_size should be > 0 for any valid num_local_experts value. + for (int nle : {1, g_num_experts / g_ep_size}) { + NVTEEpLayerConfig cfg{nle, /*top_k=*/2}; + size_t sz = 0; + (void)nvte_ep_register_layer(cfg, &sz); + ASSERT_GT(sz, 0u) << "num_local_experts=" << nle; + if (g_process_id == 0) + printf(" nle=%-3d handle_mem_size=%zu bytes\n", nle, sz); + } +} + +// ── main ────────────────────────────────────────────────────────────────────── + +int main(int argc, char* argv[]) { + if (!ep_bootstrap(argc, argv)) return 0; + int ret = RUN_ALL_TESTS(); + ep_teardown(); + return ret; +} diff --git a/tests/cpp_distributed/test_ep_pipeline.cu b/tests/cpp_distributed/test_ep_pipeline.cu new file mode 100644 index 0000000000..41f83a6d11 --- /dev/null +++ b/tests/cpp_distributed/test_ep_pipeline.cu @@ -0,0 +1,890 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * EP pipeline tests: smallest-scope first. + * + * EPDispatchTest/PrepareAndDispatch — exact recv values + per-expert counts + * EPCombineTest/Combine — round-trip: out == top_k * tokens + * EPCombineBwdTest/CombineBwdCheck — exact grad_expert values + * EPDispatchBwdTest/DispatchBwdCheck — exact grad_tokens + * EPDispatchBwdGradWeightsTest/RoundTrip — exact per-(t, k) grad_topk_weights + * EPPipelineTest/FullForwardBackward — fwd + bwd NaN/Inf check + * + * Routing: token t on rank r → expert (r * num_local_experts + t * top_k + k) % num_experts + * Token values: rank r, token t → all hidden dims = (r+1)*0.01 + t*0.001 + * + * Closed-form expected values: + * dispatch recv: multiset of source-token values routed to this rank's experts + * combine: result[t] == top_k * tokens[t] + * combine_bwd: grad_expert[slot] == d_result[t] (no weighting) + * dispatch_bwd: grad_tokens[t] == top_k * d_result[t] + */ + +#include "test_ep_common.h" + +#include +#include +#include +#include + +// ── Deterministic routing helpers ───────────────────────────────────────────── + +// Token value for (rank, t): (rank * num_tokens + t + 1) / 256. Step 1/256 is +// bf16-exact and unique across (rank, t) when rank * num_tokens + t < 256. +static inline float token_value(int rank, int t, int num_tokens) { + return static_cast(rank * num_tokens + t + 1) * (1.0f / 256.0f); +} + +static std::vector generate_tokens(int rank, int num_tokens, int hidden_dim) { + std::vector v(num_tokens * hidden_dim); + for (int t = 0; t < num_tokens; ++t) { + nv_bfloat16 val = __float2bfloat16(token_value(rank, t, num_tokens)); + for (int h = 0; h < hidden_dim; ++h) + v[t * hidden_dim + h] = val; + } + return v; +} + +static std::vector expected_token_counts( + int recv_rank, int num_processes, int num_tokens, int top_k, + int num_experts, int num_local_experts) { + int base = recv_rank * num_local_experts; + std::vector cnt(num_local_experts, 0); + for (int src = 0; src < num_processes; ++src) { + auto idx = routing_balanced(src, num_tokens, top_k, num_experts, num_local_experts); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) { + int64_t e = idx[t * top_k + k]; + if (e >= base && e < base + num_local_experts) ++cnt[e - base]; + } + } + return cnt; +} + +static std::vector expected_recv_values_sorted( + int recv_rank, int num_processes, int num_tokens, int top_k, + int num_experts, int num_local_experts) { + int base = recv_rank * num_local_experts; + std::vector vals; + for (int src = 0; src < num_processes; ++src) { + auto idx = routing_balanced(src, num_tokens, top_k, num_experts, num_local_experts); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) { + int64_t e = idx[t * top_k + k]; + if (e >= base && e < base + num_local_experts) { + float raw = token_value(src, t, num_tokens); + vals.push_back(__bfloat162float(__float2bfloat16(raw))); + } + } + } + std::sort(vals.begin(), vals.end()); + return vals; +} + +// BF16 has 7 mantissa bits; relative ULP ≈ 2^-7. Use 4× headroom for +// accumulation noise inside dispatch/combine. +static float bf16_tol(float magnitude) { + return 4.f * std::ldexp(std::fabs(magnitude) + 1e-3f, -7); +} + +static bool check_no_nan_inf(const nv_bfloat16* dev, int count, const char* name) { + std::vector h(count); + cudaMemcpy(h.data(), dev, count * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost); + for (int i = 0; i < count; ++i) { + float v = __bfloat162float(h[i]); + if (std::isnan(v) || std::isinf(v)) { + fprintf(stderr, "Rank %d: %s in %s[%d]\n", + g_process_id, std::isnan(v) ? "NaN" : "Inf", name, i); + return false; + } + } + return true; +} + +// ── Forward buffer set with RAII ────────────────────────────────────────────── + +struct EPBuffers { + // Forward + DevBuf topk_idx; + DevBuf topk_weights; + DevBuf tokens; + DevBuf token_counts; + DevBuf handle_mem; + DevBuf recv_tokens; + DevBuf recv_topk_weights; + DevBuf result; + // Backward + DevBuf grad_result; + DevBuf grad_expert; + DevBuf grad_tokens; + DevBuf g_recv_topk_weights; + DevBuf grad_topk_weights; + + uint64_t handle_id = 0; + size_t handle_mem_size = 0; + size_t recv_capacity = 0; + int top_k_ = 0; + + void alloc(int num_tokens, int top_k, int hidden_dim, int num_local_experts, + int ep_size, int max_tokens_per_rank, size_t alignment = 0) { + top_k_ = top_k; + recv_capacity = static_cast(ep_size) * max_tokens_per_rank * 2; + + topk_idx.alloc(num_tokens * top_k); + topk_weights.alloc(num_tokens * top_k); + tokens.alloc(num_tokens * hidden_dim); + token_counts.alloc(num_local_experts); + recv_tokens.alloc(recv_capacity * hidden_dim); + recv_topk_weights.alloc(recv_capacity); + result.alloc(num_tokens * hidden_dim); + + NVTEEpLayerConfig cfg{num_local_experts, top_k, alignment}; + handle_id = nvte_ep_register_layer(cfg, &handle_mem_size); + handle_mem.alloc(handle_mem_size); + + grad_result.alloc(num_tokens * hidden_dim); + grad_expert.alloc(recv_capacity * hidden_dim); + grad_tokens.alloc(num_tokens * hidden_dim); + g_recv_topk_weights.alloc(recv_capacity); + grad_topk_weights.alloc(num_tokens * top_k); + } +}; + +// Bundled NVTETensor views over an EPBuffers — one place to update the shape +// conventions when the C-API evolves. +struct EPTensors { + TensorHandle topk_idx, topk_weights, token_counts, handle_mem, tokens; + TensorHandle recv_tokens, recv_topk_weights, result; + TensorHandle grad_result, grad_expert, grad_tokens; + TensorHandle g_recv_topk_weights, grad_topk_weights; + + EPTensors(EPBuffers& b, int num_tokens, int top_k, int hidden_dim, + int num_local_experts) { + topk_idx = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + topk_weights = make_nvte_tensor(b.topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + token_counts = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts}, kNVTEInt32); + handle_mem = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + tokens = make_nvte_tensor(b.tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + recv_tokens = make_nvte_tensor(b.recv_tokens.get(), + {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); + recv_topk_weights = make_nvte_tensor(b.recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + result = make_nvte_tensor(b.result.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + grad_result = make_nvte_tensor(b.grad_result.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + grad_expert = make_nvte_tensor(b.grad_expert.get(), + {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); + grad_tokens = make_nvte_tensor(b.grad_tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + g_recv_topk_weights = make_nvte_tensor(b.g_recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + grad_topk_weights = make_nvte_tensor(b.grad_topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + } +}; + +// ── Shared fixture base ─────────────────────────────────────────────────────── + +class EpOpTestBase : public ::testing::Test { + protected: + int ep_size_, num_experts_, num_local_experts_, hidden_dim_; + int max_tokens_per_rank_, top_k_, num_tokens_; + + void SetUp() override { + if (g_sm_major < 9) + GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; + ASSERT_GE(g_num_processes, 2); + ASSERT_TRUE(g_ep_initialized); + + ep_size_ = g_ep_size; + num_experts_ = g_num_experts; + num_local_experts_ = num_experts_ / ep_size_; + hidden_dim_ = g_hidden_dim; + max_tokens_per_rank_ = g_max_tokens_per_rank; + top_k_ = 2; + num_tokens_ = 32; + } + + void upload_inputs(EPBuffers& buf, int rank = -1) { + if (rank < 0) rank = g_process_id; + auto h_idx = routing_balanced(rank, num_tokens_, top_k_, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens_ * top_k_, 1.0f / top_k_); + auto h_tok = generate_tokens(rank, num_tokens_, hidden_dim_); + + CHECK_CUDA(cudaMemcpy(buf.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(buf.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + } + + NVTEEpLayerConfig layer_config(size_t alignment = 0) const { + return NVTEEpLayerConfig{num_local_experts_, top_k_, alignment}; + } + + // ASSERT_CUDA_OK (fprintf+exit) so this non-void helper stays legal. + int read_total_recv(const EPBuffers& buf) const { + std::vector cnt(num_local_experts_); + ASSERT_CUDA_OK(cudaMemcpy(cnt.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + int total = 0; + for (int c : cnt) total += c; + return total; + } +}; + +// ============================================================================= +// EPDispatchTest: exact recv values and per-expert counts. +// ============================================================================= + +class EPDispatchTest : public EpOpTestBase {}; + +TEST_F(EPDispatchTest, PrepareAndDispatch) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + CHECK_CUDA(cudaMemset(buf.recv_tokens.get(), 0, buf.recv_tokens.bytes())); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // 1. Per-expert counts. + std::vector got_counts(num_local_experts_); + CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + auto exp_counts = expected_token_counts(g_process_id, g_num_processes, num_tokens_, top_k_, + num_experts_, num_local_experts_); + int total_recv = 0; + for (int i = 0; i < num_local_experts_; ++i) { + EXPECT_EQ(got_counts[i], exp_counts[i]) << "local expert " << i; + total_recv += exp_counts[i]; + } + ASSERT_LE(total_recv, static_cast(buf.recv_capacity)) + << "total_recv exceeded recv_capacity — overflow would corrupt downstream memory"; + + // 2. Recv values: read only the filled prefix per local-expert zone, not the + // whole recv buffer — avoids false positives from legitimate-zero token values. + std::vector h_recv(buf.recv_capacity * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_recv.data(), buf.recv_tokens.get(), + h_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + std::vector got_vals; + got_vals.reserve(total_recv); + size_t slot = 0; + for (int e = 0; e < num_local_experts_; ++e) { + for (int i = 0; i < got_counts[e]; ++i) { + got_vals.push_back(__bfloat162float(h_recv[slot * hidden_dim_])); + ++slot; + } + } + std::sort(got_vals.begin(), got_vals.end()); + + auto exp_vals = expected_recv_values_sorted(g_process_id, g_num_processes, num_tokens_, + top_k_, num_experts_, num_local_experts_); + + ASSERT_EQ(got_vals.size(), exp_vals.size()); + for (size_t i = 0; i < exp_vals.size(); ++i) + EXPECT_NEAR(got_vals[i], exp_vals[i], bf16_tol(exp_vals[i])) + << "recv value mismatch at sorted index " << i; + + // 3. recv_topk_weights: every filled slot must equal the per-token weight (1/top_k). + std::vector h_w(buf.recv_capacity); + CHECK_CUDA(cudaMemcpy(h_w.data(), buf.recv_topk_weights.get(), + h_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); + const float exp_w = 1.0f / static_cast(top_k_); + for (int i = 0; i < total_recv; ++i) + EXPECT_NEAR(h_w[i], exp_w, 1e-6f) << "recv_topk_weights[" << i << "]"; + + if (g_process_id == 0) + printf(" PrepareAndDispatch: passed (recv=%d, values + weights exact)\n", total_recv); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPCombineTest: round-trip identity expert → result == top_k * tokens. +// ============================================================================= + +class EPCombineTest : public EpOpTestBase {}; + +TEST_F(EPCombineTest, Combine) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_result(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_result.data(), buf.result.get(), + h_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + // Spot-check 3 hidden-dim positions per token to catch partial-row writes. + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int tok = 0; tok < num_tokens_; ++tok) { + float exp = __bfloat162float(h_tok[tok * hidden_dim_]) * static_cast(top_k_); + for (int p : probes) { + float got = __bfloat162float(h_result[tok * hidden_dim_ + p]); + EXPECT_NEAR(got, exp, bf16_tol(exp)) + << "token " << tok << " rank " << g_process_id << " hidden " << p; + } + } + + if (g_process_id == 0) + printf(" Combine: passed (result == top_k * tokens)\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPCombineBwdTest: filled slots in grad_expert == d_result (unweighted). +// ============================================================================= + +class EPCombineBwdTest : public EpOpTestBase {}; + +TEST_F(EPCombineBwdTest, CombineBwdCheck) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + + std::vector h_grad_r(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad_r.data(), + h_grad_r.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, + t.grad_expert.tensor, NVTECommWindow{}, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + int total_recv = read_total_recv(buf); + + std::vector cnt(num_local_experts_); + CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + std::vector h_ge(buf.recv_capacity * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_ge.data(), buf.grad_expert.get(), + h_ge.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + // Walk filled slots by per-expert zone (no v != 0 heuristic). + const float kExpGrad = 0.1f; + size_t slot = 0; + int filled = 0; + for (int e = 0; e < num_local_experts_; ++e) { + for (int i = 0; i < cnt[e]; ++i) { + float v = __bfloat162float(h_ge[slot * hidden_dim_]); + EXPECT_NEAR(v, kExpGrad, bf16_tol(kExpGrad)) + << "grad_expert expert " << e << " slot " << i << " (linear " << slot << ")"; + ++filled; ++slot; + } + } + EXPECT_EQ(filled, total_recv); + + if (g_process_id == 0) + printf(" CombineBwdCheck: passed (filled=%d)\n", filled); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPDispatchBwdTest: grad_tokens == top_k * d_result. +// ============================================================================= + +class EPDispatchBwdTest : public EpOpTestBase {}; + +TEST_F(EPDispatchBwdTest, DispatchBwdCheck) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + h_grad.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, + t.grad_expert.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, NVTECommWindow{}, + t.g_recv_topk_weights.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_gt(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), + h_gt.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const float kExpGrad = static_cast(top_k_) * 0.1f; + for (int tok = 0; tok < num_tokens_; ++tok) + EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_]), kExpGrad, bf16_tol(kExpGrad)) + << "grad_tokens token " << tok; + + if (g_process_id == 0) + printf(" DispatchBwdCheck: passed (grad_tokens == %.2f)\n", kExpGrad); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPDispatchBwdGradWeightsTest: round-trip per-(t, k) weights. +// ============================================================================= + +class EPDispatchBwdGradWeightsTest : public EpOpTestBase {}; + +TEST_F(EPDispatchBwdGradWeightsTest, RoundTrip) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + // Distinct per-(rank, t, k) weights so each slot carries a unique value. + std::vector h_w(num_tokens_ * top_k_); + for (int tok = 0; tok < num_tokens_; ++tok) + for (int k = 0; k < top_k_; ++k) + h_w[tok * top_k_ + k] = 0.1f + 0.01f * tok + 0.001f * k + + 0.0001f * (g_process_id + 1); + CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0, + buf.recv_topk_weights.bytes(), stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + + // Sentinel: NaN so any (t, k) the bwd kernel fails to write is immediately visible. + std::vector h_nan(num_tokens_ * top_k_, + std::numeric_limits::quiet_NaN()); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_topk_weights.get(), h_nan.data(), + h_nan.size() * sizeof(float), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + + // g_recv_topk_weights := recv_topk_weights (the round-trip input). + auto g_recv_t = make_nvte_tensor(buf.recv_topk_weights.get(), + {buf.recv_capacity}, kNVTEFloat32); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, + NVTECommWindow{}, g_recv_t.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_grad_w(num_tokens_ * top_k_); + CHECK_CUDA(cudaMemcpy(h_grad_w.data(), buf.grad_topk_weights.get(), + h_grad_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); + + const float kTol = 1e-5f; + int errs = 0, k0_eq_k1 = 0; + for (int tok = 0; tok < num_tokens_; ++tok) { + for (int k = 0; k < top_k_; ++k) { + float got = h_grad_w[tok * top_k_ + k]; + float exp = h_w[tok * top_k_ + k]; + if (std::isnan(got) || std::fabs(got - exp) > kTol) { + if (errs < 8) + fprintf(stderr, "Rank %d: grad_topk_weights[%d, %d]: got %.6f, expected %.6f\n", + g_process_id, tok, k, got, exp); + ++errs; + } + } + if (top_k_ >= 2 && + std::fabs(h_grad_w[tok * top_k_ + 0] - h_grad_w[tok * top_k_ + 1]) < 1e-7f) + ++k0_eq_k1; + } + EXPECT_EQ(errs, 0); + EXPECT_EQ(k0_eq_k1, 0) << "per-token-average regression: grad[t, 0] == grad[t, 1]"; + + if (g_process_id == 0 && errs == 0 && k0_eq_k1 == 0) + printf(" RoundTrip: passed (%d (t, k) gradients)\n", num_tokens_ * top_k_); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// Integrated FwdBwd: NaN/Inf check end-to-end. +// ============================================================================= + +class EPPipelineTest : public EpOpTestBase {}; + +TEST_F(EPPipelineTest, FullForwardBackward) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + h_grad.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, + t.grad_expert.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, NVTECommWindow{}, + t.g_recv_topk_weights.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); + ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); + + if (g_process_id == 0) printf(" FullForwardBackward: passed\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPZeroCopyTest: dispatch/combine with NCCL symmetric-memory windows attached +// to payload tensors (zero-copy fast path via ncclEpTensorCreateFromWindow). +// Symm-mem requirements per spec: input&output of Dispatch, input of Combine, +// input&output of Combine bwd, input of Dispatch bwd. +// ============================================================================= + +namespace { + +// Caller-owned ncclMemAlloc'd buffer with a registered symmetric window. +// Frees in destructor (deregister + ncclMemFree). Non-copyable, move-only. +struct SymmBuf { + void* ptr = nullptr; + size_t bytes = 0; + ncclWindow_t win = nullptr; + + SymmBuf() = default; + SymmBuf(const SymmBuf&) = delete; + SymmBuf& operator=(const SymmBuf&) = delete; + SymmBuf(SymmBuf&& o) noexcept : ptr(o.ptr), bytes(o.bytes), win(o.win) { + o.ptr = nullptr; o.win = nullptr; o.bytes = 0; + } + ~SymmBuf() { + if (win) ncclCommWindowDeregister(g_ep_comm, win); + if (ptr) ncclMemFree(ptr); + } + + void alloc(size_t n_bytes) { + bytes = n_bytes; + ASSERT_NCCL_OK(ncclMemAlloc(&ptr, bytes)); + CHECK_CUDA(cudaMemset(ptr, 0, bytes)); + ASSERT_NCCL_OK(ncclCommWindowRegister(g_ep_comm, ptr, bytes, &win, + NCCL_WIN_COLL_SYMMETRIC)); + } +}; + +// Build an NVTECommWindow descriptor pointing at a SymmBuf's window (offset 0). +static inline NVTECommWindow symm_window(const SymmBuf& b) { + return NVTECommWindow{b.win, /*offset=*/0}; +} + +} // namespace + +class EPZeroCopyTest : public EpOpTestBase {}; + +// Identity round-trip with symm-mem on dispatch i/o + combine input. Bit-exact +// vs HBM reference (same routing, same input). +TEST_F(EPZeroCopyTest, IdentityAllSymm) { + // HBM reference run. + EPBuffers ref_buf; + ref_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(ref_buf); + EPTensors ref_t(ref_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t ref_hid = ref_buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.topk_idx.tensor, ref_t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.topk_idx.tensor, + ref_t.tokens.tensor, NVTECommWindow{}, ref_t.topk_weights.tensor, + NVTECommWindow{}, ref_t.recv_tokens.tensor, NVTECommWindow{}, + ref_t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.recv_tokens.tensor, NVTECommWindow{}, + ref_t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector ref_recv(ref_buf.recv_capacity * hidden_dim_); + std::vector ref_result(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(ref_recv.data(), ref_buf.recv_tokens.get(), + ref_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(ref_result.data(), ref_buf.result.get(), + ref_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + // Symm-mem run: tokens, recv_tokens, combine_input (== recv_tokens) all symm. + EPBuffers sym_buf; // alloc all buffers except the symm ones. + sym_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(sym_buf); + + SymmBuf sym_tokens, sym_recv; + sym_tokens.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_recv .alloc(sym_buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + + // Stage same tokens into the symm-mem input. + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + EPTensors sym_t(sym_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + // Replace the tokens/recv_tokens views with ones pointing at the symm buffers. + sym_t.tokens = make_nvte_tensor(sym_tokens.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + sym_t.recv_tokens = make_nvte_tensor(sym_recv.ptr, + {sym_buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + + uint64_t sym_hid = sym_buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.topk_idx.tensor, sym_t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.topk_idx.tensor, + sym_t.tokens.tensor, symm_window(sym_tokens), + sym_t.topk_weights.tensor, NVTECommWindow{}, + sym_t.recv_tokens.tensor, symm_window(sym_recv), + sym_t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.recv_tokens.tensor, + symm_window(sym_recv), sym_t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector sym_recv_host(sym_buf.recv_capacity * hidden_dim_); + std::vector sym_result(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_recv_host.data(), sym_recv.ptr, + sym_recv_host.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(sym_result.data(), sym_buf.result.get(), + sym_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + // Compare per filled recv slot (HBM ref vs symm) and full result. + int total_recv = read_total_recv(sym_buf); + for (int i = 0; i < total_recv * hidden_dim_; ++i) + ASSERT_EQ(__bfloat162float(sym_recv_host[i]), __bfloat162float(ref_recv[i])) + << "recv mismatch at " << i; + for (size_t i = 0; i < sym_result.size(); ++i) + ASSERT_EQ(__bfloat162float(sym_result[i]), __bfloat162float(ref_result[i])) + << "result mismatch at " << i; + + if (g_process_id == 0) + printf(" IdentityAllSymm: passed (recv_slots=%d, bit-exact vs HBM)\n", total_recv); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// Same buffers, 2 iterations — catches window-lifecycle regressions where the +// symm-mem registration goes stale between calls. +TEST_F(EPZeroCopyTest, IdentityAllSymmRepeated) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + + SymmBuf sym_tokens, sym_recv; + sym_tokens.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_recv .alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + t.tokens = make_nvte_tensor(sym_tokens.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + t.recv_tokens = make_nvte_tensor(sym_recv.ptr, + {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + for (int iter = 0; iter < 2; ++iter) { + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, symm_window(sym_tokens), + t.topk_weights.tensor, NVTECommWindow{}, + t.recv_tokens.tensor, symm_window(sym_recv), + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, + symm_window(sym_recv), t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_res(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), buf.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + for (int tok = 0; tok < num_tokens_; ++tok) { + float exp = __bfloat162float(h_tok[tok * hidden_dim_]) * static_cast(top_k_); + float got = __bfloat162float(h_res[tok * hidden_dim_]); + ASSERT_NEAR(got, exp, bf16_tol(exp)) << "iter " << iter << " tok " << tok; + } + } + + if (g_process_id == 0) + printf(" IdentityAllSymmRepeated: passed (2 iters)\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// Full forward+backward with symm-mem on every spec-mandated buffer: +// dispatch i/o, combine input, combine_bwd i/o, dispatch_bwd input. +// TODO: flaky on rank 0 (grad_tokens partial-zero) when run after the prior +// EPZeroCopyTest cases in the same binary; passes in isolation. Re-enable once +// the root cause (likely NCCL EP NVLS write→read coherence on grad_expert) is +// understood. Tracked separately. +TEST_F(EPZeroCopyTest, DISABLED_FullPipelineSymm) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + + // Symm-mem: tokens (dispatch input), recv_tokens (dispatch output AND + // combine input), grad_result (combine_bwd input), grad_expert + // (combine_bwd output AND dispatch_bwd input). + SymmBuf sym_tokens, sym_recv, sym_grad_result, sym_grad_expert; + sym_tokens .alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_recv .alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + sym_grad_result.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_grad_expert.alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + t.tokens = make_nvte_tensor(sym_tokens.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + t.recv_tokens = make_nvte_tensor(sym_recv.ptr, + {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + t.grad_result = make_nvte_tensor(sym_grad_result.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + t.grad_expert = make_nvte_tensor(sym_grad_expert.ptr, + {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, symm_window(sym_tokens), + t.topk_weights.tensor, NVTECommWindow{}, + t.recv_tokens.tensor, symm_window(sym_recv), + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, + symm_window(sym_recv), t.result.tensor, stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(sym_grad_result.ptr, h_grad.data(), + h_grad.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(sym_grad_expert.ptr, 0, sym_grad_expert.bytes, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, + symm_window(sym_grad_result), t.grad_expert.tensor, + symm_window(sym_grad_expert), stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, + symm_window(sym_grad_expert), + t.g_recv_topk_weights.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); + ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); + + std::vector h_gt(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), + h_gt.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const float kExpGrad = static_cast(top_k_) * 0.1f; + for (int tok = 0; tok < num_tokens_; ++tok) + EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_]), kExpGrad, bf16_tol(kExpGrad)) + << "grad_tokens token " << tok; + + if (g_process_id == 0) printf(" FullPipelineSymm: passed\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ── main ────────────────────────────────────────────────────────────────────── + +int main(int argc, char* argv[]) { + if (!ep_bootstrap(argc, argv)) return 0; + int ret = RUN_ALL_TESTS(); + ep_teardown(); + return ret; +} diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 06d85b6d84..7c93f0e1da 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -391,6 +391,96 @@ if (NVTE_WITH_CUSOLVERMP) message(STATUS "Using cuSolverMp at: ${CUSOLVERMP_DIR}") endif() +# ── NCCL EP (on by default, HT mode only) ───────────────────────────────── +# Set -DNVTE_WITH_NCCL_EP=OFF (or NVTE_BUILD_WITH_NCCL_EP=0 in setup.py) to +# skip NCCL EP entirely — useful on older images whose system NCCL is below +# the 2.30.4 EP minimum. +option(NVTE_WITH_NCCL_EP "Build NCCL EP into libtransformer_engine.so" ON) +if(NVTE_WITH_NCCL_EP) +# SM>=90 and NCCL>=2.30.4 are gated at runtime in EPBackend::initialize. +# ── NCCL EP headers ──────────────────────────────────────────────────────── +# Headers + libs are produced by the in-tree 3rdparty/nccl submodule build +# (auto-built by setup.py via build_nccl_ep_submodule). +set(NCCL_EP_SUBMODULE_ROOT + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl") +set(NCCL_EP_INCLUDE_DIR "${NCCL_EP_SUBMODULE_ROOT}/contrib/nccl_ep/include") +if(NOT EXISTS "${NCCL_EP_INCLUDE_DIR}/nccl_ep.h") + message(FATAL_ERROR + "NCCL EP header not found at ${NCCL_EP_INCLUDE_DIR}/nccl_ep.h. " + "Run `git submodule update --init --recursive` to checkout 3rdparty/nccl.") +endif() +message(STATUS "NCCL EP headers: ${NCCL_EP_INCLUDE_DIR}") + +# ── libnccl_ep.so ────────────────────────────────────────────────────────── +set(NCCL_EP_LIB_DIR "${NCCL_EP_SUBMODULE_ROOT}/build/lib") +find_library(NCCL_EP_LIB + NAMES nccl_ep libnccl_ep + HINTS ${NCCL_EP_LIB_DIR} + NO_DEFAULT_PATH + REQUIRED) + +# ── NCCL + GIN headers ───────────────────────────────────────────────────── +# libnccl.so and all GIN headers (ncclGin.h, ncclWindow_t, ncclDevComm_t) +# ship with the base CUDA Toolkit OR the 3rdparty/nccl submodule build +# (preferred when present; auto-built by setup.py via build_nccl_ep_submodule). +if(NOT NCCL_LIB) + find_library(NCCL_LIB + NAMES nccl libnccl + HINTS ${NCCL_EP_LIB_DIR} ${CUDAToolkit_LIBRARY_DIR} + PATH_SUFFIXES lib lib64 + REQUIRED) +endif() + +set(NCCL_SUBMODULE_INCLUDE + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/build/include") +if(EXISTS "${NCCL_SUBMODULE_INCLUDE}/nccl.h") + set(NCCL_INCLUDE_DIRS_FOR_TE ${NCCL_SUBMODULE_INCLUDE}) +else() + set(NCCL_INCLUDE_DIRS_FOR_TE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +endif() + +# Diagnostic: log detected NCCL header version (minimum enforced at runtime). +find_file(_nvte_nccl_header_path nccl.h + PATHS ${NCCL_INCLUDE_DIRS_FOR_TE} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + NO_DEFAULT_PATH) +if(_nvte_nccl_header_path) + file(READ "${_nvte_nccl_header_path}" _nvte_nccl_h) + string(REGEX MATCH "#define[ \t]+NCCL_MAJOR[ \t]+([0-9]+)" _ "${_nvte_nccl_h}") + set(_nvte_nccl_major "${CMAKE_MATCH_1}") + string(REGEX MATCH "#define[ \t]+NCCL_MINOR[ \t]+([0-9]+)" _ "${_nvte_nccl_h}") + set(_nvte_nccl_minor "${CMAKE_MATCH_1}") + string(REGEX MATCH "#define[ \t]+NCCL_PATCH[ \t]+([0-9]+)" _ "${_nvte_nccl_h}") + set(_nvte_nccl_patch "${CMAKE_MATCH_1}") + if(_nvte_nccl_major AND _nvte_nccl_minor AND _nvte_nccl_patch) + message(STATUS "NCCL header: ${_nvte_nccl_header_path} (version ${_nvte_nccl_major}.${_nvte_nccl_minor}.${_nvte_nccl_patch})") + endif() +endif() + +target_include_directories(transformer_engine PRIVATE + ${NCCL_EP_INCLUDE_DIR} + ${NCCL_INCLUDE_DIRS_FOR_TE}) # covers nccl.h + nccl_device/ + +target_link_libraries(transformer_engine PUBLIC + ${NCCL_EP_LIB} + ${NCCL_LIB}) + +# Embed rpath so the installed wheel finds libnccl_ep.so at runtime. +# libnccl.so is already on the system via the Toolkit — no rpath needed for it. +set_target_properties(transformer_engine PROPERTIES + INSTALL_RPATH "$ORIGIN;${NCCL_EP_LIB_DIR}") + +target_sources(transformer_engine PRIVATE + ep/ep_backend.cpp + ep/ep_api.cpp) + +message(STATUS "NCCL EP enabled: ${NCCL_EP_LIB}") +message(STATUS "NCCL EP include: ${NCCL_EP_INCLUDE_DIR}") +else() + # NCCL EP off: export throwing nvte_ep_* stubs so framework bindings link. + target_sources(transformer_engine PRIVATE ep/ep_api_stub.cpp) + message(STATUS "NCCL EP disabled (NVTE_WITH_NCCL_EP=OFF) — using nvte_ep_* stubs") +endif() + # Number of philox4x32 rounds for stochastic rounding (build-time constant). set(NVTE_BUILD_NUM_PHILOX_ROUNDS_STR $ENV{NVTE_BUILD_NUM_PHILOX_ROUNDS}) if (NOT NVTE_BUILD_NUM_PHILOX_ROUNDS_STR) diff --git a/transformer_engine/common/ep/ep_api.cpp b/transformer_engine/common/ep/ep_api.cpp new file mode 100644 index 0000000000..89d8b38607 --- /dev/null +++ b/transformer_engine/common/ep/ep_api.cpp @@ -0,0 +1,76 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_api.cpp + * \brief nvte_ep_* C API: thin delegations to the EPBackend singleton. + */ + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "ep_backend.h" + +using transformer_engine::ep::EPBackend; + +void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config) { + NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); + EPBackend::initialize(static_cast(ep_comm), group_config); +} + +void nvte_ep_shutdown(void) { EPBackend::shutdown(); } + +uint64_t nvte_ep_register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size) { + NVTE_CHECK(handle_mem_size != nullptr, "handle_mem_size must not be null"); + return EPBackend::get().register_layer(layer_config, handle_mem_size); +} + +void nvte_ep_prepare(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor token_counts, + size_t dispatch_output_per_expert_alignment, cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().prepare(handle.id, topk_idx, token_counts, mem_ptr, + dispatch_output_per_expert_alignment, stream); +} + +void nvte_ep_dispatch(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor tokens, + NVTECommWindow tokens_win, NVTETensor topk_weights, + NVTECommWindow topk_weights_win, NVTETensor recv_tokens, + NVTECommWindow recv_tokens_win, NVTETensor recv_topk_weights, + NVTECommWindow recv_topk_weights_win, cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().dispatch(handle.id, mem_ptr, topk_idx, tokens, tokens_win, topk_weights, + topk_weights_win, recv_tokens, recv_tokens_win, recv_topk_weights, + recv_topk_weights_win, stream); +} + +void nvte_ep_combine(NVTEEpHandle handle, NVTETensor expert_out, NVTECommWindow expert_out_win, + NVTETensor result, cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().combine(handle.id, mem_ptr, expert_out, expert_out_win, result, stream); +} + +void nvte_ep_dispatch_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win, + NVTETensor grad_tokens, NVTETensor grad_topk_weights, + cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().dispatch_bwd(handle.id, mem_ptr, grad, grad_win, g_recv_topk_weights, + g_recv_topk_weights_win, grad_tokens, grad_topk_weights, stream); +} + +void nvte_ep_combine_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win, + cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().combine_bwd(handle.id, mem_ptr, grad, grad_win, grad_expert_out, + grad_expert_out_win, stream); +} diff --git a/transformer_engine/common/ep/ep_api_stub.cpp b/transformer_engine/common/ep/ep_api_stub.cpp new file mode 100644 index 0000000000..fe4127d87d --- /dev/null +++ b/transformer_engine/common/ep/ep_api_stub.cpp @@ -0,0 +1,61 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_api_stub.cpp + * \brief Throwing nvte_ep_* stubs compiled when NVTE_WITH_NCCL_EP=OFF. + */ + +#include + +#include "../util/logging.h" + +namespace { +[[noreturn]] void ep_not_built() { + NVTE_ERROR( + "NCCL EP is not built into this TransformerEngine. Rebuild TE with " + "NVTE_BUILD_WITH_NCCL_EP=1 and CUDA arch >= 90 (e.g. NVTE_CUDA_ARCHS=\"90\")."); +} +} // namespace + +void nvte_ep_initialize(void* /*ep_comm*/, NVTEEpGroupConfig /*group_config*/) { ep_not_built(); } + +void nvte_ep_shutdown(void) {} + +uint64_t nvte_ep_register_layer(NVTEEpLayerConfig /*layer_config*/, size_t* /*handle_mem_size*/) { + ep_not_built(); +} + +void nvte_ep_prepare(NVTEEpHandle /*handle*/, NVTETensor /*topk_idx*/, NVTETensor /*token_counts*/, + size_t /*dispatch_output_per_expert_alignment*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_dispatch(NVTEEpHandle /*handle*/, NVTETensor /*topk_idx*/, NVTETensor /*tokens*/, + NVTECommWindow /*tokens_win*/, NVTETensor /*topk_weights*/, + NVTECommWindow /*topk_weights_win*/, NVTETensor /*recv_tokens*/, + NVTECommWindow /*recv_tokens_win*/, NVTETensor /*recv_topk_weights*/, + NVTECommWindow /*recv_topk_weights_win*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_combine(NVTEEpHandle /*handle*/, NVTETensor /*expert_out*/, + NVTECommWindow /*expert_out_win*/, NVTETensor /*result*/, + cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_dispatch_bwd(NVTEEpHandle /*handle*/, NVTETensor /*grad*/, NVTECommWindow /*grad_win*/, + NVTETensor /*g_recv_topk_weights*/, + NVTECommWindow /*g_recv_topk_weights_win*/, NVTETensor /*grad_tokens*/, + NVTETensor /*grad_topk_weights*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_combine_bwd(NVTEEpHandle /*handle*/, NVTETensor /*grad*/, NVTECommWindow /*grad_win*/, + NVTETensor /*grad_expert_out*/, NVTECommWindow /*grad_expert_out_win*/, + cudaStream_t /*stream*/) { + ep_not_built(); +} diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp new file mode 100644 index 0000000000..ae0f3ab888 --- /dev/null +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -0,0 +1,514 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_backend.cpp + * \brief EPBackend implementation. See ep_backend.h for the op flow. + */ + +#include "ep_backend.h" + +#include +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "../util/logging.h" + +namespace transformer_engine { +namespace ep { + +namespace { + +// Build a by-value ncclEpTensor_t descriptor. `sizes` is caller-owned and must +// outlive any NCCL EP call that consumes the descriptor. +inline ncclEpTensor_t make_tensor(void* data, unsigned int ndim, ncclDataType_t datatype, + size_t* sizes) { + ncclEpTensor_t t = NCCL_EP_TENSOR_INIT; + t.ndim = ndim; + t.datatype = datatype; + t.data = data; + t.sizes = sizes; + return t; +} + +// Payload descriptor: prefer the symmem window when set, else fall back to the +// NVTETensor's raw device pointer. +inline ncclEpTensor_t make_payload_tensor(const NVTETensor t, const NVTECommWindow& win, + unsigned int ndim, ncclDataType_t datatype, + size_t* sizes) { + ncclEpTensor_t desc = NCCL_EP_TENSOR_INIT; + desc.ndim = ndim; + desc.datatype = datatype; + desc.sizes = sizes; + if (win.window != nullptr) { + desc.win_hdl = win.window; + desc.win_offset = win.offset; + } else { + desc.data = nvte_tensor_data(t); + NVTE_CHECK(desc.data != nullptr, "payload tensor data must not be null"); + } + return desc; +} + +// RAII guard for ncclEpHandle_t — destroys on scope exit, leak-free on throw. +class ScopedEpHandle { + public: + ScopedEpHandle() = default; + explicit ScopedEpHandle(ncclEpHandle_t h) : h_(h) {} + ~ScopedEpHandle() { + if (h_ != nullptr) ncclEpHandleDestroy(h_); + } + ScopedEpHandle(const ScopedEpHandle&) = delete; + ScopedEpHandle& operator=(const ScopedEpHandle&) = delete; + ScopedEpHandle(ScopedEpHandle&& other) noexcept : h_(other.h_) { other.h_ = nullptr; } + ScopedEpHandle& operator=(ScopedEpHandle&& other) noexcept { + if (this != &other) { + if (h_ != nullptr) ncclEpHandleDestroy(h_); + h_ = other.h_; + other.h_ = nullptr; + } + return *this; + } + operator ncclEpHandle_t() const { return h_; } + ncclEpHandle_t get() const { return h_; } + + private: + ncclEpHandle_t h_ = nullptr; +}; + +} // namespace + +// --------------------------------------------------------------------------- +// Singleton + bootstrap +// --------------------------------------------------------------------------- + +EPBackend& EPBackend::instance() { + static EPBackend inst; + return inst; +} + +EPBackend& EPBackend::get() { + EPBackend& inst = instance(); + NVTE_CHECK(inst.initialized_, "EPBackend not initialized. Call nvte_ep_initialize() first."); + return inst; +} + +void EPBackend::validate_config(const NVTEEpGroupConfig& config) { + NVTE_CHECK(config.ep_size > 0, "ep_size must be positive, got ", config.ep_size); + NVTE_CHECK(config.num_experts > 0, "num_experts must be positive, got ", config.num_experts); + NVTE_CHECK(config.max_tokens_per_rank > 0, "max_tokens_per_rank must be positive, got ", + config.max_tokens_per_rank); + NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ", + config.max_recv_tokens_per_rank); + NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim); + NVTE_CHECK(config.hidden_dim * sizeof(nv_bfloat16) >= 16, + "hidden_dim * 2 must be >= 16 (NCCL EP 16B row alignment); got hidden_dim=", + config.hidden_dim); + NVTE_CHECK(config.num_experts % config.ep_size == 0, "num_experts (", config.num_experts, + ") must be divisible by ep_size (", config.ep_size, ")"); + NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ", + config.max_num_sms); + + int device, major; + NVTE_CHECK_CUDA(cudaGetDevice(&device)); + NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + NVTE_CHECK(major >= 9, + "NCCL EP requires SM_90+ (Hopper or later), " + "but current device has compute capability ", + major, ".x"); + + // NCCL EP needs CUDA multicast (NVLS); init hangs without it. + NVTE_CHECK(cuda::supports_multicast(device), + "NCCL EP requires CUDA multicast (NVLS) support on device ", device, + " but CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED reports 0."); +} + +void EPBackend::initialize(ncclComm_t ep_comm, NVTEEpGroupConfig config) { + EPBackend& inst = instance(); + std::lock_guard lock(inst.mutex_); + NVTE_CHECK(!inst.initialized_, "EP already initialized. Call initialize only once per process."); + NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); + + // Runtime gate: NCCL >= 2.30.4 (matches the submodule pin). + constexpr int kMinNcclVersion = 23004; + int nccl_version = 0; + NVTE_CHECK_NCCL(ncclGetVersion(&nccl_version)); + NVTE_CHECK(nccl_version >= kMinNcclVersion, "NCCL EP requires NCCL >= 2.30.4, found ", + nccl_version / 10000, ".", (nccl_version / 100) % 100, ".", nccl_version % 100, + " at runtime."); + + validate_config(config); + + int comm_size = 0; + NVTE_CHECK_NCCL(ncclCommCount(ep_comm, &comm_size)); + NVTE_CHECK(comm_size == config.ep_size, "ep_comm size (", comm_size, ") must equal ep_size (", + config.ep_size, "). Pass the EP sub-communicator, not the world comm."); + + inst.init(ep_comm, config); +} + +void EPBackend::shutdown() { + EPBackend& inst = instance(); + std::lock_guard lock(inst.mutex_); + if (!inst.initialized_) return; + inst.handles_.clear(); + // ncclEpGroupDestroy reads from ep_comm_; destroy group while comm is still alive. + if (inst.ep_group_ != nullptr) { + ncclEpGroupDestroy(inst.ep_group_); + inst.ep_group_ = nullptr; + } + inst.ep_comm_ = nullptr; // borrowed — caller destroys + inst.initialized_ = false; +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +ncclDataType_t EPBackend::nvte_dtype_to_nccl(NVTEDType dtype) { + switch (dtype) { + case kNVTEFloat32: + return ncclFloat32; + case kNVTEFloat16: + return ncclFloat16; + case kNVTEBFloat16: + return ncclBfloat16; + case kNVTEInt32: + return ncclInt32; + case kNVTEInt64: + return ncclInt64; + case kNVTEByte: + return ncclUint8; + case kNVTEFloat8E4M3: + return ncclFloat8e4m3; + case kNVTEFloat8E5M2: + return ncclFloat8e5m2; + default: + NVTE_ERROR("Unsupported NVTEDType for NCCL EP conversion: ", static_cast(dtype)); + } + return ncclFloat32; // unreachable +} + +// Open a transient ncclEpHandle over handle_mem. Caller owns the result. +ncclEpHandle_t EPBackend::open_handle(void* handle_mem, size_t handle_mem_size, int num_topk, + size_t dispatch_output_per_expert_alignment) { + size_t hm_sizes[1] = {handle_mem_size}; + ncclEpTensor_t routing_desc = make_tensor(handle_mem, 1, ncclUint8, hm_sizes); + ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT; + hcfg.dispatch_output_per_expert_alignment = dispatch_output_per_expert_alignment; + ncclEpHandle_t handle; + NVTE_CHECK_NCCL(ncclEpInitHandle(&handle, ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, num_topk, + &routing_desc)); + return handle; +} + +// --------------------------------------------------------------------------- +// Lifecycle +// --------------------------------------------------------------------------- + +// Static-dtor teardown: skip NCCL calls (CUDA context / borrowed ep_comm_ may +// already be gone) and release in-memory state only. +EPBackend::~EPBackend() { + std::lock_guard lock(mutex_); + if (!initialized_) return; + handles_.clear(); + ep_group_ = nullptr; + ep_comm_ = nullptr; + initialized_ = false; +} + +void EPBackend::init(ncclComm_t ep_comm, NVTEEpGroupConfig group_config) { + NVTE_CHECK(!initialized_, "EPBackend already initialized"); + + group_config_ = group_config; + + ncclEpGroupConfig_t cfg = NCCL_EP_GROUP_CONFIG_INIT; + cfg.algorithm = NCCL_EP_ALGO_HIGH_THROUGHPUT; + cfg.num_experts = static_cast(group_config.num_experts); + cfg.max_dispatch_tokens_per_rank = static_cast(group_config.max_tokens_per_rank); + cfg.max_token_bytes = static_cast(group_config.hidden_dim * sizeof(nv_bfloat16)); + cfg.rdma_buffer_size = NCCL_EP_AUTO; + cfg.num_qp_per_rank = NCCL_EP_AUTO; + cfg.num_channels = NCCL_EP_AUTO; + cfg.max_num_sms = group_config.max_num_sms > 0 + ? static_cast(group_config.max_num_sms) + : NCCL_EP_AUTO; + // Must be > 0; NCCL EP errors out on 0. + cfg.max_recv_tokens_per_rank = static_cast(group_config.max_recv_tokens_per_rank); + + NVTE_CHECK_NCCL(ncclEpCreateGroup(&ep_group_, ep_comm, &cfg)); + + ep_comm_ = ep_comm; + + initialized_ = true; +} + +// --------------------------------------------------------------------------- +// Per-handle_id config cache +// --------------------------------------------------------------------------- + +uint64_t EPBackend::insert_new_entry(size_t handle_mem_size, int top_k, size_t alignment) { + if (handle_cache_cap_ == 0) { + const char* cap_env = std::getenv("NVTE_EP_HANDLE_CACHE_SIZE"); + handle_cache_cap_ = (cap_env != nullptr) ? std::max(1, std::atoi(cap_env)) : 8192; + } + NVTE_CHECK(handles_.size() < handle_cache_cap_, "EP handle cache full (", handle_cache_cap_, + " entries). Raise via NVTE_EP_HANDLE_CACHE_SIZE."); + uint64_t id = next_handle_id_.fetch_add(1, std::memory_order_relaxed); + handles_.emplace(id, HandleEntry{handle_mem_size, alignment, top_k}); + return id; +} + +EPBackend::HandleEntry& EPBackend::lookup_config(uint64_t handle_id) { + auto it = handles_.find(handle_id); + NVTE_CHECK(it != handles_.end(), "ep op on handle_id=", handle_id, + " with no cached config — call ep_prepare first."); + return it->second; +} + +// --------------------------------------------------------------------------- +// Per-step operations +// --------------------------------------------------------------------------- + +uint64_t EPBackend::register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(layer_config.top_k > 0, "NVTEEpLayerConfig.top_k must be > 0"); + NVTE_CHECK(handle_mem_size != nullptr, "handle_mem_size must not be null"); + ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT; + hcfg.dispatch_output_per_expert_alignment = layer_config.dispatch_output_per_expert_alignment; + size_t hm_size = 0; + NVTE_CHECK_NCCL(ncclEpHandleMemSize(ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, &hm_size, + layer_config.top_k)); + *handle_mem_size = hm_size; + std::lock_guard lock(mutex_); + return insert_new_entry(hm_size, layer_config.top_k, + layer_config.dispatch_output_per_expert_alignment); +} + +void EPBackend::prepare(uint64_t handle_id, const NVTETensor topk_idx, NVTETensor token_counts, + void* handle_mem, size_t dispatch_output_per_expert_alignment, + cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape idx_shape = nvte_tensor_shape(topk_idx); + void* idx_data = nvte_tensor_data(topk_idx); + NVTE_CHECK(idx_data != nullptr, "topk_idx data must not be null"); + + const size_t num_tokens = idx_shape.data[0]; + const size_t top_k = idx_shape.ndim > 1 ? idx_shape.data[1] : 1; + const size_t num_local_experts = + static_cast(group_config_.num_experts / group_config_.ep_size); + + size_t idx_sizes[2] = {num_tokens, top_k}; + ncclEpTensor_t nccl_topk_idx = make_tensor(idx_data, 2, ncclInt64, idx_sizes); + + // ncclEpUpdateHandle writes per-expert counts via expert_counters. + size_t cnt_sizes[1] = {num_local_experts}; + ncclEpTensor_t token_counts_desc; + void* token_counts_data = (token_counts != nullptr) ? nvte_tensor_data(token_counts) : nullptr; + if (token_counts_data != nullptr) { + token_counts_desc = make_tensor(token_counts_data, 1, ncclInt32, cnt_sizes); + } + ncclEpLayoutInfo_t layout_info = NCCL_EP_LAYOUT_INFO_INIT; + layout_info.expert_counters = (token_counts_data != nullptr) ? &token_counts_desc : nullptr; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + NVTE_CHECK(cfg.alignment == dispatch_output_per_expert_alignment, + "ep_prepare: alignment mismatch for handle_id=", handle_id, + " (cached=", cfg.alignment, ", got=", dispatch_output_per_expert_alignment, ")"); + transient = + ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); + } + NVTE_CHECK_NCCL(ncclEpUpdateHandle(transient, &nccl_topk_idx, &layout_info, stream)); +} + +void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor topk_idx, + const NVTETensor tokens, const NVTECommWindow& tokens_win, + const NVTETensor topk_weights, const NVTECommWindow& topk_weights_win, + NVTETensor recv_tokens, const NVTECommWindow& recv_tokens_win, + NVTETensor recv_topk_weights, const NVTECommWindow& recv_topk_weights_win, + cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape tok_shape = nvte_tensor_shape(tokens); + NVTEDType tok_dtype = nvte_tensor_type(tokens); + + const size_t num_tokens = tok_shape.data[0]; + const size_t hidden_dim = tok_shape.data[1]; + + size_t tok_sizes[2] = {num_tokens, hidden_dim}; + ncclEpTensor_t nccl_tokens_in = + make_payload_tensor(tokens, tokens_win, 2, nvte_dtype_to_nccl(tok_dtype), tok_sizes); + + const bool is_forward = (topk_weights != nullptr); + + // Routing is cached in handle_mem by ep_prepare; dispatch only needs + // topk_weights to reconstruct the sparse-to-dense prob map. + size_t weights_in_sizes[2] = {0, 0}; + ncclEpTensor_t nccl_topk_weights_in; + if (is_forward) { + NVTE_CHECK(topk_idx != nullptr, "topk_idx required in forward dispatch"); + NVTEShape idx_shape = nvte_tensor_shape(topk_idx); + const size_t top_k = idx_shape.ndim > 1 ? idx_shape.data[1] : 1; + weights_in_sizes[0] = num_tokens; + weights_in_sizes[1] = top_k; + nccl_topk_weights_in = + make_payload_tensor(topk_weights, topk_weights_win, 2, ncclFloat32, weights_in_sizes); + } + + NVTEShape recv_shape = nvte_tensor_shape(recv_tokens); + NVTEDType recv_dtype = nvte_tensor_type(recv_tokens); + + size_t recv_sizes[2] = {recv_shape.data[0], recv_shape.data[1]}; + ncclEpTensor_t nccl_tokens_out = make_payload_tensor(recv_tokens, recv_tokens_win, 2, + nvte_dtype_to_nccl(recv_dtype), recv_sizes); + + size_t weights_out_sizes[1] = {recv_shape.data[0]}; + ncclEpTensor_t nccl_topk_weights_out; + if (is_forward) { + NVTE_CHECK(recv_topk_weights != nullptr, + "recv_topk_weights must not be null in forward dispatch"); + NVTEShape recv_w_shape = nvte_tensor_shape(recv_topk_weights); + NVTE_CHECK(recv_w_shape.ndim == 1, "recv_topk_weights must be 1D [recv_capacity]"); + nccl_topk_weights_out = make_payload_tensor(recv_topk_weights, recv_topk_weights_win, 1, + ncclFloat32, weights_out_sizes); + } + + ncclEpDispatchInputs_t in_struct = NCCL_EP_DISPATCH_INPUTS_INIT; + in_struct.tokens = &nccl_tokens_in; + in_struct.topk_weights = is_forward ? &nccl_topk_weights_in : nullptr; + + ncclEpDispatchOutputs_t out_struct = NCCL_EP_DISPATCH_OUTPUTS_INIT; + out_struct.tokens = &nccl_tokens_out; + out_struct.topk_weights = is_forward ? &nccl_topk_weights_out : nullptr; + + ncclEpDispatchConfig_t dispatch_cfg = NCCL_EP_DISPATCH_CONFIG_INIT; + dispatch_cfg.pass_direction = is_forward ? NCCL_EP_FWD_PASS : NCCL_EP_BWD_PASS; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + transient = + ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); + } + NVTE_CHECK_NCCL(ncclEpDispatch(transient, &in_struct, &out_struct, + /*layout_info=*/nullptr, &dispatch_cfg, stream)); +} + +void EPBackend::combine(uint64_t handle_id, void* handle_mem, const NVTETensor expert_out, + const NVTECommWindow& expert_out_win, NVTETensor result, + cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape exp_shape = nvte_tensor_shape(expert_out); + NVTEDType exp_dtype = nvte_tensor_type(expert_out); + + size_t exp_sizes[2] = {exp_shape.data[0], exp_shape.data[1]}; + ncclEpTensor_t nccl_expert_in = + make_payload_tensor(expert_out, expert_out_win, 2, nvte_dtype_to_nccl(exp_dtype), exp_sizes); + + NVTEShape res_shape = nvte_tensor_shape(result); + void* res_data = nvte_tensor_data(result); + NVTEDType res_dtype = nvte_tensor_type(result); + NVTE_CHECK(res_data != nullptr, "result data must not be null"); + + size_t res_sizes[2] = {res_shape.data[0], res_shape.data[1]}; + ncclEpTensor_t nccl_result_out = + make_tensor(res_data, 2, nvte_dtype_to_nccl(res_dtype), res_sizes); + + ncclEpCombineInputs_t in_struct = NCCL_EP_COMBINE_INPUTS_INIT; + in_struct.tokens = &nccl_expert_in; + + ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT; + out_struct.tokens = &nccl_result_out; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + transient = + ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); + } + NVTE_CHECK_NCCL(ncclEpCombine(transient, &in_struct, &out_struct, /*config=*/nullptr, stream)); +} + +void EPBackend::dispatch_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, const NVTETensor g_recv_topk_weights, + const NVTECommWindow& g_recv_topk_weights_win, NVTETensor grad_tokens, + NVTETensor grad_topk_weights, cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape g_shape = nvte_tensor_shape(grad); + NVTEDType g_dtype = nvte_tensor_type(grad); + size_t g_sizes[2] = {g_shape.data[0], g_shape.data[1]}; + ncclEpTensor_t nccl_tok_in = + make_payload_tensor(grad, grad_win, 2, nvte_dtype_to_nccl(g_dtype), g_sizes); + + // g_recv_topk_weights must be 1D [recv_capacity] — caller flattens. + NVTEShape gw_shape = nvte_tensor_shape(g_recv_topk_weights); + NVTE_CHECK(gw_shape.ndim == 1, + "g_recv_topk_weights must be 1D [recv_capacity]; caller must flatten leading dims"); + size_t gw_sizes[1] = {gw_shape.data[0]}; + ncclEpTensor_t nccl_w_in = + make_payload_tensor(g_recv_topk_weights, g_recv_topk_weights_win, 1, ncclFloat32, gw_sizes); + + NVTEShape gt_shape = nvte_tensor_shape(grad_tokens); + void* gt_data = nvte_tensor_data(grad_tokens); + NVTE_CHECK(gt_data != nullptr, "grad_tokens data must not be null"); + size_t gt_sizes[2] = {gt_shape.data[0], gt_shape.data[1]}; + ncclEpTensor_t nccl_tok_out = make_tensor(gt_data, 2, nvte_dtype_to_nccl(g_dtype), gt_sizes); + + NVTEShape gtw_shape = nvte_tensor_shape(grad_topk_weights); + void* gtw_data = nvte_tensor_data(grad_topk_weights); + NVTE_CHECK(gtw_data != nullptr, "grad_topk_weights data must not be null"); + NVTE_CHECK(gtw_shape.ndim == 2, "grad_topk_weights must be 2D [T, top_k]"); + size_t gtw_sizes[2] = {gtw_shape.data[0], gtw_shape.data[1]}; + ncclEpTensor_t nccl_w_out = make_tensor(gtw_data, 2, ncclFloat32, gtw_sizes); + + ncclEpCombineInputs_t in_struct = NCCL_EP_COMBINE_INPUTS_INIT; + in_struct.tokens = &nccl_tok_in; + in_struct.topk_weights = &nccl_w_in; + + ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT; + out_struct.tokens = &nccl_tok_out; + out_struct.topk_weights = &nccl_w_out; + + ncclEpCombineConfig_t cfg = NCCL_EP_COMBINE_CONFIG_INIT; + cfg.pass_direction = NCCL_EP_BWD_PASS; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& entry = lookup_config(handle_id); + transient = ScopedEpHandle( + open_handle(handle_mem, entry.handle_mem_size, entry.top_k, entry.alignment)); + } + NVTE_CHECK_NCCL(ncclEpCombine(transient, &in_struct, &out_struct, &cfg, stream)); +} + +void EPBackend::combine_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, NVTETensor grad_expert_out, + const NVTECommWindow& grad_expert_out_win, cudaStream_t stream) { + // Backward of combine = reverse-direction dispatch. + dispatch(handle_id, handle_mem, /*topk_idx=*/nullptr, grad, grad_win, /*topk_weights=*/nullptr, + /*topk_weights_win=*/NVTECommWindow{}, grad_expert_out, grad_expert_out_win, + /*recv_topk_weights=*/nullptr, /*recv_topk_weights_win=*/NVTECommWindow{}, stream); +} + +} // namespace ep +} // namespace transformer_engine diff --git a/transformer_engine/common/ep/ep_backend.h b/transformer_engine/common/ep/ep_backend.h new file mode 100644 index 0000000000..18307ebb4f --- /dev/null +++ b/transformer_engine/common/ep/ep_backend.h @@ -0,0 +1,114 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_backend.h + * \brief Internal NCCL EP singleton; not part of the public API. + * + * Per handle_id the cache stores config only (no device pointers), so + * handle_mem may be relocated between ops. Cap: NVTE_EP_HANDLE_CACHE_SIZE + * (default 8192); overflow throws. + */ + +#ifndef TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ +#define TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace transformer_engine { +namespace ep { + +/*! \brief EP backend singleton — owns the NCCL EP group; borrows the comm. */ +class EPBackend { + public: + /*! \brief Access the singleton. Aborts if not initialized. */ + static EPBackend& get(); + + /*! \brief Bootstrap from an existing EP sub-communicator. + * ep_comm is borrowed; the caller keeps it alive until shutdown() returns + * and must span exactly config.ep_size ranks. + */ + static void initialize(ncclComm_t ep_comm, NVTEEpGroupConfig config); + + /*! \brief Tear down the backend. Idempotent. Does not destroy ep_comm_. */ + static void shutdown(); + + // Host-only: reserve a fresh handle_id, cache the layer config, and report + // the handle_mem buffer size the caller must allocate. + uint64_t register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size); + + void prepare(uint64_t handle_id, const NVTETensor topk_idx, NVTETensor token_counts, + void* handle_mem, size_t dispatch_output_per_expert_alignment, cudaStream_t stream); + + void dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor topk_idx, + const NVTETensor tokens, const NVTECommWindow& tokens_win, + const NVTETensor topk_weights, const NVTECommWindow& topk_weights_win, + NVTETensor recv_tokens, const NVTECommWindow& recv_tokens_win, + NVTETensor recv_topk_weights, const NVTECommWindow& recv_topk_weights_win, + cudaStream_t stream); + + void combine(uint64_t handle_id, void* handle_mem, const NVTETensor expert_out, + const NVTECommWindow& expert_out_win, NVTETensor result, cudaStream_t stream); + + // g_recv_topk_weights: 1D [recv_capacity] f32; grad_topk_weights: 2D [T, top_k] f32. + void dispatch_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, const NVTETensor g_recv_topk_weights, + const NVTECommWindow& g_recv_topk_weights_win, NVTETensor grad_tokens, + NVTETensor grad_topk_weights, cudaStream_t stream); + + void combine_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, NVTETensor grad_expert_out, + const NVTECommWindow& grad_expert_out_win, cudaStream_t stream); + + private: + EPBackend() = default; + ~EPBackend(); + EPBackend(const EPBackend&) = delete; + EPBackend& operator=(const EPBackend&) = delete; + + // ep_comm is borrowed — caller retains ownership across the backend lifetime. + void init(ncclComm_t ep_comm, NVTEEpGroupConfig config); + + static EPBackend& instance(); // Meyers singleton accessor + static void validate_config(const NVTEEpGroupConfig& config); + + static ncclDataType_t nvte_dtype_to_nccl(NVTEDType dtype); + // Open a transient ncclEpHandle over handle_mem. num_topk=-1 for paths + // that don't carry per-token weights. + ncclEpHandle_t open_handle(void* handle_mem, size_t handle_mem_size, int num_topk, + size_t dispatch_output_per_expert_alignment); + + ncclEpGroup_t ep_group_{nullptr}; + ncclComm_t ep_comm_{nullptr}; + NVTEEpGroupConfig group_config_{}; + bool initialized_{false}; + std::mutex mutex_; + struct HandleEntry { + size_t handle_mem_size; + size_t alignment; + int top_k; + }; + std::unordered_map handles_; + std::atomic next_handle_id_{1}; // 0 reserved as "no id" + size_t handle_cache_cap_{0}; // set lazily from NVTE_EP_HANDLE_CACHE_SIZE + + // Caller must hold mutex_. Throws on cap overflow. + uint64_t insert_new_entry(size_t handle_mem_size, int top_k, size_t alignment); + HandleEntry& lookup_config(uint64_t handle_id); +}; + +} // namespace ep +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ diff --git a/transformer_engine/common/include/transformer_engine/comm_window.h b/transformer_engine/common/include/transformer_engine/comm_window.h new file mode 100644 index 0000000000..088ea7f0c3 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_window.h @@ -0,0 +1,32 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file comm_window.h + * \brief Borrowed symmetric-memory window + offset for zero-copy one-sided ops. + * Pass ``{NULL, 0}`` to use the raw-pointer path. + */ + +#ifndef TRANSFORMER_ENGINE_COMM_WINDOW_H_ +#define TRANSFORMER_ENGINE_COMM_WINDOW_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief NCCL window + byte offset for a zero-copy payload tensor. */ +typedef struct { + ncclWindow_t window; /*!< NCCL window, or NULL to use the raw data pointer. */ + uint64_t offset; /*!< Byte offset of the payload within ``window``. */ +} NVTECommWindow; + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_COMM_WINDOW_H_ diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h new file mode 100644 index 0000000000..8c3a06b5f0 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -0,0 +1,161 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep.h + * \brief Public C API for Expert Parallelism. Per-step ops are allocation-free + * and CUDA graph-capturable. + */ + +#ifndef TRANSFORMER_ENGINE_EP_H_ +#define TRANSFORMER_ENGINE_EP_H_ + +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* ── Config structs ─────────────────────────────────────────────────────── */ + +/*! \brief Group-level EP configuration (fixed for the EP group lifetime). */ +typedef struct { + int ep_size; /*!< EP world size. */ + int num_experts; /*!< Total experts across all ranks. */ + int max_tokens_per_rank; /*!< Upper bound on tokens this rank sends per dispatch. */ + /*! Upper bound on tokens received per dispatch (worst-case top_k fan-out; must be > 0). */ + int max_recv_tokens_per_rank; + int hidden_dim; /*!< Token hidden dimension. */ + int max_num_sms; /*!< Max SMs for EP kernels. 0 = auto. */ + /*! 0 (default): throw on relocated handle_mem for a cached handle_id. 1: silently rebuild. */ + int allow_handle_mem_reloc; +} NVTEEpGroupConfig; + +/*! \brief Per-layer EP configuration. */ +typedef struct { + int num_local_experts; /*!< Reserved for ABI stability (derived from group config). */ + int top_k; /*!< Per-token expert fan-out. Required. */ + size_t dispatch_output_per_expert_alignment; + /*!< Per-expert zone alignment in tokens (pow2; 0/1 = no padding). Must match + * between nvte_ep_register_layer and nvte_ep_prepare. */ +} NVTEEpLayerConfig; + +/* ── Bootstrap ──────────────────────────────────────────────────────────── */ + +/*! \brief Bootstrap from an existing NCCL EP sub-communicator. Requires SM>=90. + * + * ep_comm is borrowed and must span exactly group_config.ep_size ranks. + * Re-init after shutdown is allowed; double-init throws. + * + * \param[in] ep_comm Opaque ncclComm_t for the EP sub-group. + * \param[in] group_config Group-level EP configuration. + */ +void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config); + +/*! \brief Tear down the EP backend. Idempotent. Does not destroy ep_comm. */ +void nvte_ep_shutdown(void); + +/* ── Layer registration (host-only, eager) ───────────────────────────────── */ + +/*! \brief Reserve a handle_id for a layer config and report the handle_mem buffer + * size the caller must allocate. Host-only. + * + * \param[in] layer_config Per-layer EP configuration. + * \param[out] handle_mem_size Bytes the caller must allocate for handle_mem. + * \return uint64_t handle_id (non-zero). + */ +uint64_t nvte_ep_register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size); + +/*! \brief Per-step handle: the registered handle_id paired with its handle_mem buffer. */ +typedef struct { + uint64_t id; /*!< Handle id from nvte_ep_register_layer. */ + NVTETensor mem; /*!< Caller-allocated handle_mem buffer (size from nvte_ep_register_layer). */ +} NVTEEpHandle; + +/* ── Per-step ops (all allocation-free, CUDA graph-capturable) ──────────── */ + +/*! \brief AllGather the routing map; write per-expert counts and cache routing + * metadata in handle.mem for the subsequent dispatch/combine. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] topk_idx [T, top_k] int64 routing indices. + * \param[out] token_counts [num_local_experts] int32 counts. + * \param[in] dispatch_output_per_expert_alignment Must match the handle_mem sizing. + * \param[in] stream CUDA stream. + */ +void nvte_ep_prepare(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor token_counts, + size_t dispatch_output_per_expert_alignment, cudaStream_t stream); + +/*! \brief Dispatch tokens (and routing weights) to expert ranks. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] topk_idx [T, top_k] int64 sparse routing indices. + * \param[in] tokens [T, hidden_dim] input tokens. + * \param[in] tokens_win Optional symmem window for ``tokens``. + * \param[in] topk_weights [T, top_k] float32 weights, or null in backward. + * \param[in] topk_weights_win Optional symmem window for ``topk_weights``. + * \param[out] recv_tokens [recv_T, hidden_dim] received tokens. + * \param[in] recv_tokens_win Optional symmem window for ``recv_tokens``. + * \param[out] recv_topk_weights [recv_T] float32 per-slot weights, or null in backward. + * \param[in] recv_topk_weights_win Optional symmem window for ``recv_topk_weights``. + * \param[in] stream CUDA stream. + */ +void nvte_ep_dispatch(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor tokens, + NVTECommWindow tokens_win, NVTETensor topk_weights, + NVTECommWindow topk_weights_win, NVTETensor recv_tokens, + NVTECommWindow recv_tokens_win, NVTETensor recv_topk_weights, + NVTECommWindow recv_topk_weights_win, cudaStream_t stream); + +/*! \brief Scatter-sum expert outputs back to originating ranks. Unweighted — + * caller must pre-multiply expert_out by recv_topk_weights (and the + * valid-slot mask) before calling. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] expert_out [recv_T, hidden_dim] pre-weighted expert outputs. + * \param[in] expert_out_win Optional symmem window for ``expert_out``. + * \param[out] result [T, hidden_dim] combined output. + * \param[in] stream CUDA stream. + */ +void nvte_ep_combine(NVTEEpHandle handle, NVTETensor expert_out, NVTECommWindow expert_out_win, + NVTETensor result, cudaStream_t stream); + +/*! \brief Backward of dispatch — routes token and weight grads back to source. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] grad [recv_capacity, hidden_dim] grad w.r.t. recv_tokens. + * \param[in] grad_win Optional symmem window for ``grad``. + * \param[in] g_recv_topk_weights [recv_capacity] f32 grad w.r.t. recv_topk_weights. + * \param[in] g_recv_topk_weights_win Optional symmem window for ``g_recv_topk_weights``. + * \param[out] grad_tokens [T, hidden_dim] grad w.r.t. tokens. + * \param[out] grad_topk_weights [T, top_k] f32 grad w.r.t. topk_weights. + * \param[in] stream CUDA stream. + */ +void nvte_ep_dispatch_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win, + NVTETensor grad_tokens, NVTETensor grad_topk_weights, + cudaStream_t stream); + +/*! \brief Backward of combine. Padded slots in grad_expert_out are zeroed. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] grad [T, hidden_dim] grad w.r.t. result. + * \param[in] grad_win Optional symmem window for ``grad``. + * \param[out] grad_expert_out [recv_capacity, hidden_dim] grad w.r.t. expert_out. + * \param[in] grad_expert_out_win Optional symmem window for ``grad_expert_out``. + * \param[in] stream CUDA stream. + */ +void nvte_ep_combine_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win, + cudaStream_t stream); + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_EP_H_ From 0b9bf7ec367d3642428469f7e94538d1784ec204 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Sat, 23 May 2026 19:36:55 +0000 Subject: [PATCH 02/29] Expert Parallelism: persistent ncclEpHandle cache with allow_handle_mem_reloc gating Signed-off-by: Phuong Nguyen --- tests/cpp_distributed/test_ep_coverage.cu | 183 ++++++++++++++++++++ transformer_engine/common/ep/ep_backend.cpp | 109 +++++------- transformer_engine/common/ep/ep_backend.h | 8 + 3 files changed, 238 insertions(+), 62 deletions(-) diff --git a/tests/cpp_distributed/test_ep_coverage.cu b/tests/cpp_distributed/test_ep_coverage.cu index ef7941905d..e9e532386c 100644 --- a/tests/cpp_distributed/test_ep_coverage.cu +++ b/tests/cpp_distributed/test_ep_coverage.cu @@ -369,6 +369,189 @@ TEST_F(NegativeTests, NullHandleMemThrows) { CHECK_CUDA(cudaStreamDestroy(stream)); } +// ============================================================================= +// HandleCacheTest: persistent ncclEpHandle is reused across ops on the same +// handle_mem ptr; relocation triggers throw by default and rebuild when +// NVTEEpGroupConfig.allow_handle_mem_reloc=1. +// ============================================================================= + +class HandleCacheTest : public EpCoverageBase {}; + +// Run prepare → dispatch → combine on bundle b. handle_mem_data overrides the +// device ptr used for handle_mem (must be the buffer owned by b unless +// reloc-allowed mode is active). Templated on Bundle because EpCoverageBase:: +// Bundle is declared in a protected section. +template +static void run_round_trip(B& b, void* handle_mem_data, + int num_tokens, int top_k, int num_local_experts, + int hidden_dim, size_t alignment, + cudaStream_t stream) { + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts}, kNVTEInt32); + auto handle_mem_t = make_nvte_tensor(handle_mem_data, + {b.handle_mem_size}, kNVTEByte); + auto tokens_t = make_nvte_tensor(b.tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), + {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); + auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + auto result_t = make_nvte_tensor(b.result.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + + NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; + nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, alignment, stream); + nvte_ep_dispatch(h, topk_idx_t.tensor, tokens_t.tensor, NVTECommWindow{}, + topk_weights_t.tensor, NVTECommWindow{}, + recv_tokens_t.tensor, NVTECommWindow{}, + recv_w_t.tensor, NVTECommWindow{}, stream); + nvte_ep_combine(h, recv_tokens_t.tensor, NVTECommWindow{}, result_t.tensor, stream); +} + +// Re-bootstrap EP backend with a different allow_handle_mem_reloc setting. +// Reuses the existing g_ep_comm; caller is responsible for restoring defaults. +static void reinit_ep_with_reloc(int allow_reloc) { + nvte_ep_shutdown(); + NVTEEpGroupConfig cfg{}; + cfg.ep_size = g_ep_size; + cfg.num_experts = g_num_experts; + cfg.max_tokens_per_rank = g_max_tokens_per_rank; + cfg.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; + cfg.hidden_dim = g_hidden_dim; + cfg.allow_handle_mem_reloc = allow_reloc; + nvte_ep_initialize(static_cast(g_ep_comm), cfg); +} + +TEST_F(HandleCacheTest, ReuseSameMemSucceeds) { + const int num_tokens = 16, top_k = 2; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens * top_k, 1.0f / top_k); + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.5f); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + // Two consecutive round-trips on the same handle_mem ptr: first opens the + // cached handle, second hits the cache. Both must succeed and be correct. + for (int iter = 0; iter < 2; ++iter) { + ASSERT_NO_THROW(run_round_trip(b, b.handle_mem.get(), num_tokens, top_k, + num_local_experts_, hidden_dim_, + /*alignment=*/0, stream)); + } + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), + static_cast(top_k) * 0.5f, 1e-2f); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +TEST_F(HandleCacheTest, RelocDefaultThrows) { + // Default bootstrap has allow_handle_mem_reloc=0: a second prepare call on + // the same handle_id with a different handle_mem ptr must throw. + const int num_tokens = 8, top_k = 2; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + DevBuf second_hm(b.handle_mem_size); // distinct device buffer + ASSERT_NE(b.handle_mem.get(), second_hm.get()); + + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + auto hm1_t = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + auto hm2_t = make_nvte_tensor(second_hm.get(), + {b.handle_mem_size}, kNVTEByte); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + // First prepare seeds the cache. + NVTEEpHandle h1{b.handle_id, hm1_t.tensor}; + ASSERT_NO_THROW(nvte_ep_prepare(h1, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/0, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + // Same handle_id with a different handle_mem ptr must throw. + NVTEEpHandle h2{b.handle_id, hm2_t.tensor}; + EXPECT_THROW(nvte_ep_prepare(h2, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/0, stream), + std::exception); + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +TEST_F(HandleCacheTest, RelocAllowedRebuilds) { + // Re-init EP backend with allow_handle_mem_reloc=1, run two round-trips with + // distinct handle_mem buffers, verify both succeed numerically, restore. + reinit_ep_with_reloc(/*allow_reloc=*/1); + + struct Restore { ~Restore() { reinit_ep_with_reloc(/*allow_reloc=*/0); } } restore; + + const int num_tokens = 16, top_k = 2; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + DevBuf alt_hm(b.handle_mem_size); + ASSERT_NE(b.handle_mem.get(), alt_hm.get()); + + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens * top_k, 1.0f / top_k); + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.5f); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + // First on the original handle_mem. + ASSERT_NO_THROW(run_round_trip(b, b.handle_mem.get(), num_tokens, top_k, + num_local_experts_, hidden_dim_, + /*alignment=*/0, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + // Then on the relocated handle_mem — must trigger silent rebuild, not throw. + ASSERT_NO_THROW(run_round_trip(b, alt_hm.get(), num_tokens, top_k, + num_local_experts_, hidden_dim_, + /*alignment=*/0, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), + static_cast(top_k) * 0.5f, 1e-2f); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + // ── main ────────────────────────────────────────────────────────────────────── int main(int argc, char* argv[]) { diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp index ae0f3ab888..6494a86817 100644 --- a/transformer_engine/common/ep/ep_backend.cpp +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -57,32 +57,6 @@ inline ncclEpTensor_t make_payload_tensor(const NVTETensor t, const NVTECommWind return desc; } -// RAII guard for ncclEpHandle_t — destroys on scope exit, leak-free on throw. -class ScopedEpHandle { - public: - ScopedEpHandle() = default; - explicit ScopedEpHandle(ncclEpHandle_t h) : h_(h) {} - ~ScopedEpHandle() { - if (h_ != nullptr) ncclEpHandleDestroy(h_); - } - ScopedEpHandle(const ScopedEpHandle&) = delete; - ScopedEpHandle& operator=(const ScopedEpHandle&) = delete; - ScopedEpHandle(ScopedEpHandle&& other) noexcept : h_(other.h_) { other.h_ = nullptr; } - ScopedEpHandle& operator=(ScopedEpHandle&& other) noexcept { - if (this != &other) { - if (h_ != nullptr) ncclEpHandleDestroy(h_); - h_ = other.h_; - other.h_ = nullptr; - } - return *this; - } - operator ncclEpHandle_t() const { return h_; } - ncclEpHandle_t get() const { return h_; } - - private: - ncclEpHandle_t h_ = nullptr; -}; - } // namespace // --------------------------------------------------------------------------- @@ -158,6 +132,13 @@ void EPBackend::shutdown() { EPBackend& inst = instance(); std::lock_guard lock(inst.mutex_); if (!inst.initialized_) return; + for (auto& kv : inst.handles_) { + if (kv.second.cached_handle != nullptr) { + ncclEpHandleDestroy(kv.second.cached_handle); + kv.second.cached_handle = nullptr; + kv.second.cached_handle_mem = nullptr; + } + } inst.handles_.clear(); // ncclEpGroupDestroy reads from ep_comm_; destroy group while comm is still alive. if (inst.ep_group_ != nullptr) { @@ -196,7 +177,7 @@ ncclDataType_t EPBackend::nvte_dtype_to_nccl(NVTEDType dtype) { return ncclFloat32; // unreachable } -// Open a transient ncclEpHandle over handle_mem. Caller owns the result. +// Open a fresh ncclEpHandle over handle_mem. Caller (or cache) owns the result. ncclEpHandle_t EPBackend::open_handle(void* handle_mem, size_t handle_mem_size, int num_topk, size_t dispatch_output_per_expert_alignment) { size_t hm_sizes[1] = {handle_mem_size}; @@ -273,6 +254,26 @@ EPBackend::HandleEntry& EPBackend::lookup_config(uint64_t handle_id) { return it->second; } +ncclEpHandle_t EPBackend::get_or_open_handle(HandleEntry& cfg, void* handle_mem) { + if (cfg.cached_handle != nullptr && cfg.cached_handle_mem == handle_mem) { + return cfg.cached_handle; + } + if (cfg.cached_handle != nullptr) { + NVTE_CHECK(group_config_.allow_handle_mem_reloc != 0, + "EP handle_mem relocated for cached handle (old=", + reinterpret_cast(cfg.cached_handle_mem), + ", new=", reinterpret_cast(handle_mem), + "). Set NVTEEpGroupConfig.allow_handle_mem_reloc=1 to allow rebuild."); + ncclEpHandleDestroy(cfg.cached_handle); + cfg.cached_handle = nullptr; + cfg.cached_handle_mem = nullptr; + } + ncclEpHandle_t h = open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment); + cfg.cached_handle = h; + cfg.cached_handle_mem = handle_mem; + return h; +} + // --------------------------------------------------------------------------- // Per-step operations // --------------------------------------------------------------------------- @@ -320,17 +321,13 @@ void EPBackend::prepare(uint64_t handle_id, const NVTETensor topk_idx, NVTETenso ncclEpLayoutInfo_t layout_info = NCCL_EP_LAYOUT_INFO_INIT; layout_info.expert_counters = (token_counts_data != nullptr) ? &token_counts_desc : nullptr; - ScopedEpHandle transient; - { - std::lock_guard lock(mutex_); - HandleEntry& cfg = lookup_config(handle_id); - NVTE_CHECK(cfg.alignment == dispatch_output_per_expert_alignment, - "ep_prepare: alignment mismatch for handle_id=", handle_id, - " (cached=", cfg.alignment, ", got=", dispatch_output_per_expert_alignment, ")"); - transient = - ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); - } - NVTE_CHECK_NCCL(ncclEpUpdateHandle(transient, &nccl_topk_idx, &layout_info, stream)); + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + NVTE_CHECK(cfg.alignment == dispatch_output_per_expert_alignment, + "ep_prepare: alignment mismatch for handle_id=", handle_id, + " (cached=", cfg.alignment, ", got=", dispatch_output_per_expert_alignment, ")"); + ncclEpHandle_t h = get_or_open_handle(cfg, handle_mem); + NVTE_CHECK_NCCL(ncclEpUpdateHandle(h, &nccl_topk_idx, &layout_info, stream)); } void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor topk_idx, @@ -397,14 +394,10 @@ void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor ncclEpDispatchConfig_t dispatch_cfg = NCCL_EP_DISPATCH_CONFIG_INIT; dispatch_cfg.pass_direction = is_forward ? NCCL_EP_FWD_PASS : NCCL_EP_BWD_PASS; - ScopedEpHandle transient; - { - std::lock_guard lock(mutex_); - HandleEntry& cfg = lookup_config(handle_id); - transient = - ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); - } - NVTE_CHECK_NCCL(ncclEpDispatch(transient, &in_struct, &out_struct, + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + ncclEpHandle_t h = get_or_open_handle(cfg, handle_mem); + NVTE_CHECK_NCCL(ncclEpDispatch(h, &in_struct, &out_struct, /*layout_info=*/nullptr, &dispatch_cfg, stream)); } @@ -436,14 +429,10 @@ void EPBackend::combine(uint64_t handle_id, void* handle_mem, const NVTETensor e ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT; out_struct.tokens = &nccl_result_out; - ScopedEpHandle transient; - { - std::lock_guard lock(mutex_); - HandleEntry& cfg = lookup_config(handle_id); - transient = - ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); - } - NVTE_CHECK_NCCL(ncclEpCombine(transient, &in_struct, &out_struct, /*config=*/nullptr, stream)); + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + ncclEpHandle_t h = get_or_open_handle(cfg, handle_mem); + NVTE_CHECK_NCCL(ncclEpCombine(h, &in_struct, &out_struct, /*config=*/nullptr, stream)); } void EPBackend::dispatch_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, @@ -491,14 +480,10 @@ void EPBackend::dispatch_bwd(uint64_t handle_id, void* handle_mem, const NVTETen ncclEpCombineConfig_t cfg = NCCL_EP_COMBINE_CONFIG_INIT; cfg.pass_direction = NCCL_EP_BWD_PASS; - ScopedEpHandle transient; - { - std::lock_guard lock(mutex_); - HandleEntry& entry = lookup_config(handle_id); - transient = ScopedEpHandle( - open_handle(handle_mem, entry.handle_mem_size, entry.top_k, entry.alignment)); - } - NVTE_CHECK_NCCL(ncclEpCombine(transient, &in_struct, &out_struct, &cfg, stream)); + std::lock_guard lock(mutex_); + HandleEntry& entry = lookup_config(handle_id); + ncclEpHandle_t h = get_or_open_handle(entry, handle_mem); + NVTE_CHECK_NCCL(ncclEpCombine(h, &in_struct, &out_struct, &cfg, stream)); } void EPBackend::combine_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, diff --git a/transformer_engine/common/ep/ep_backend.h b/transformer_engine/common/ep/ep_backend.h index 18307ebb4f..e82c974c3f 100644 --- a/transformer_engine/common/ep/ep_backend.h +++ b/transformer_engine/common/ep/ep_backend.h @@ -98,6 +98,10 @@ class EPBackend { size_t handle_mem_size; size_t alignment; int top_k; + // Persistent ncclEpHandle bound to cached_handle_mem. Lazily opened on first + // op; reused while handle_mem ptr is unchanged. Destroyed in shutdown(). + ncclEpHandle_t cached_handle{nullptr}; + void* cached_handle_mem{nullptr}; }; std::unordered_map handles_; std::atomic next_handle_id_{1}; // 0 reserved as "no id" @@ -106,6 +110,10 @@ class EPBackend { // Caller must hold mutex_. Throws on cap overflow. uint64_t insert_new_entry(size_t handle_mem_size, int top_k, size_t alignment); HandleEntry& lookup_config(uint64_t handle_id); + // Caller must hold mutex_. Returns the cached handle if handle_mem matches. + // On mismatch: if group_config_.allow_handle_mem_reloc != 0, destroys the + // stale handle and opens a fresh one; otherwise throws. + ncclEpHandle_t get_or_open_handle(HandleEntry& cfg, void* handle_mem); }; } // namespace ep From ed3d73cc84a215cd4d7c2c87db8eb6eae7e44b5e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 23 May 2026 23:09:15 +0000 Subject: [PATCH 03/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/ep/ep_backend.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp index 6494a86817..83657943a4 100644 --- a/transformer_engine/common/ep/ep_backend.cpp +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -324,8 +324,8 @@ void EPBackend::prepare(uint64_t handle_id, const NVTETensor topk_idx, NVTETenso std::lock_guard lock(mutex_); HandleEntry& cfg = lookup_config(handle_id); NVTE_CHECK(cfg.alignment == dispatch_output_per_expert_alignment, - "ep_prepare: alignment mismatch for handle_id=", handle_id, - " (cached=", cfg.alignment, ", got=", dispatch_output_per_expert_alignment, ")"); + "ep_prepare: alignment mismatch for handle_id=", handle_id, " (cached=", cfg.alignment, + ", got=", dispatch_output_per_expert_alignment, ")"); ncclEpHandle_t h = get_or_open_handle(cfg, handle_mem); NVTE_CHECK_NCCL(ncclEpUpdateHandle(h, &nccl_topk_idx, &layout_info, stream)); } From 1923180bff3c02f0c95a91d997ce3a1301d414ec Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 27 May 2026 14:12:53 -0700 Subject: [PATCH 04/29] Build: NCCL_HOME discovery supports Debian/Ubuntu multiarch lib paths Signed-off-by: Phuong Nguyen --- setup.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index db360c8a29..34a3abfd99 100644 --- a/setup.py +++ b/setup.py @@ -167,11 +167,13 @@ def _discover_nccl_home() -> str: f"'{env_home}/include/nccl.h' was not found; falling back to system probes." ) + lib_names = ("libnccl.so", "libnccl.so.2") + # Include Debian/Ubuntu multiarch subdirs (e.g. lib/aarch64-linux-gnu). + lib_subdirs = ("lib", "lib64", "lib/aarch64-linux-gnu", "lib/x86_64-linux-gnu") for cand in ("/opt/nvidia/nccl", "/usr/local/nccl", "/usr"): p = Path(cand) if (p / "include" / "nccl.h").exists() and any( - (p / "lib" / name).exists() or (p / "lib64" / name).exists() - for name in ("libnccl.so", "libnccl.so.2") + (p / sub / name).exists() for sub in lib_subdirs for name in lib_names ): return str(p) @@ -180,9 +182,11 @@ def _discover_nccl_home() -> str: for line in out.splitlines(): if "libnccl.so" in line and "=>" in line: lib_path = Path(line.split("=>")[-1].strip()) - root = lib_path.parent.parent - if (root / "include" / "nccl.h").exists(): - return str(root) + # Walk upward so multiarch layouts (.../lib//libnccl.so) + # resolve to the prefix that contains include/nccl.h. + for root in (lib_path.parent.parent, lib_path.parent.parent.parent): + if (root / "include" / "nccl.h").exists(): + return str(root) except (subprocess.CalledProcessError, FileNotFoundError): pass From 3b8aafb0bd81f5d0d18bd633ac679905c7b47673 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 27 May 2026 14:26:39 -0700 Subject: [PATCH 05/29] bump NCCL Signed-off-by: Phuong Nguyen --- 3rdparty/nccl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/nccl b/3rdparty/nccl index 6a9bc953ac..146496ac88 160000 --- a/3rdparty/nccl +++ b/3rdparty/nccl @@ -1 +1 @@ -Subproject commit 6a9bc953ac1c4eef92d5adbe3092d4c2cb0a4c98 +Subproject commit 146496ac881bc504ed1a52be0ae7b707ce41e706 From 9b225cbed1834d235234d9850ee6ee20f1f64c15 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 28 May 2026 15:25:16 -0700 Subject: [PATCH 06/29] Expert Parallelism: require token_dtype in NVTEEpGroupConfig and enforce at dispatch Signed-off-by: Phuong Nguyen --- tests/cpp_distributed/test_ep_common.h | 4 ++++ transformer_engine/common/ep/ep_backend.cpp | 21 +++++++++++++++---- .../common/include/transformer_engine/ep.h | 3 +++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/tests/cpp_distributed/test_ep_common.h b/tests/cpp_distributed/test_ep_common.h index 77baa92b0c..ccb20ee3a0 100644 --- a/tests/cpp_distributed/test_ep_common.h +++ b/tests/cpp_distributed/test_ep_common.h @@ -74,6 +74,7 @@ static int g_ep_size = -1; static int g_num_experts = -1; static int g_hidden_dim = 256; static int g_max_tokens_per_rank = 64; +static NVTEDType g_token_dtype = kNVTEBFloat16; static bool g_ep_initialized = false; static ncclComm_t g_ep_comm = nullptr; // owned by harness, destroyed in ep_teardown @@ -224,6 +225,8 @@ static void ep_parse_args(int argc, char* argv[]) { else if (a.rfind("--num-processes=",0)==0) g_num_processes = std::stoi(a.substr(16)); else if (a.rfind("--nranks=", 0) == 0) g_num_processes = std::stoi(a.substr(9)); else if (a.rfind("--uid-file=", 0) == 0) g_uid_file = a.substr(11); + else if (a.rfind("--token-dtype=", 0) == 0) + g_token_dtype = static_cast(std::stoi(a.substr(14))); } if (g_process_id < 0 || g_num_processes <= 0) { @@ -279,6 +282,7 @@ static bool ep_bootstrap(int argc, char* argv[]) { // Worst-case for top_k fan-out: ep_size * max_tokens_per_rank * 2. group_config.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; group_config.hidden_dim = g_hidden_dim; + group_config.token_dtype = g_token_dtype; ASSERT_NCCL_OK(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id)); nvte_ep_initialize(static_cast(g_ep_comm), group_config); diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp index 83657943a4..1e08cb55df 100644 --- a/transformer_engine/common/ep/ep_backend.cpp +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -82,9 +82,13 @@ void EPBackend::validate_config(const NVTEEpGroupConfig& config) { NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ", config.max_recv_tokens_per_rank); NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim); - NVTE_CHECK(config.hidden_dim * sizeof(nv_bfloat16) >= 16, - "hidden_dim * 2 must be >= 16 (NCCL EP 16B row alignment); got hidden_dim=", - config.hidden_dim); + NVTE_CHECK(config.token_dtype >= 0 && config.token_dtype < kNVTENumTypes, + "token_dtype out of range, got ", static_cast(config.token_dtype)); + const size_t elem_bytes = typeToSize(static_cast(config.token_dtype)); + NVTE_CHECK(config.hidden_dim * elem_bytes >= 16, + "hidden_dim * sizeof(token_dtype) must be >= 16 (NCCL EP 16B row alignment); " + "got hidden_dim=", + config.hidden_dim, ", element_bytes=", elem_bytes); NVTE_CHECK(config.num_experts % config.ep_size == 0, "num_experts (", config.num_experts, ") must be divisible by ep_size (", config.ep_size, ")"); NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ", @@ -214,7 +218,8 @@ void EPBackend::init(ncclComm_t ep_comm, NVTEEpGroupConfig group_config) { cfg.algorithm = NCCL_EP_ALGO_HIGH_THROUGHPUT; cfg.num_experts = static_cast(group_config.num_experts); cfg.max_dispatch_tokens_per_rank = static_cast(group_config.max_tokens_per_rank); - cfg.max_token_bytes = static_cast(group_config.hidden_dim * sizeof(nv_bfloat16)); + const size_t elem_bytes = typeToSize(static_cast(group_config.token_dtype)); + cfg.max_token_bytes = static_cast(group_config.hidden_dim * elem_bytes); cfg.rdma_buffer_size = NCCL_EP_AUTO; cfg.num_qp_per_rank = NCCL_EP_AUTO; cfg.num_channels = NCCL_EP_AUTO; @@ -341,6 +346,10 @@ void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor NVTEShape tok_shape = nvte_tensor_shape(tokens); NVTEDType tok_dtype = nvte_tensor_type(tokens); + NVTE_CHECK(tok_dtype == group_config_.token_dtype, + "tokens dtype (", static_cast(tok_dtype), + ") does not match group token_dtype (", + static_cast(group_config_.token_dtype), ")"); const size_t num_tokens = tok_shape.data[0]; const size_t hidden_dim = tok_shape.data[1]; @@ -367,6 +376,10 @@ void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor NVTEShape recv_shape = nvte_tensor_shape(recv_tokens); NVTEDType recv_dtype = nvte_tensor_type(recv_tokens); + NVTE_CHECK(recv_dtype == group_config_.token_dtype, + "recv_tokens dtype (", static_cast(recv_dtype), + ") does not match group token_dtype (", + static_cast(group_config_.token_dtype), ")"); size_t recv_sizes[2] = {recv_shape.data[0], recv_shape.data[1]}; ncclEpTensor_t nccl_tokens_out = make_payload_tensor(recv_tokens, recv_tokens_win, 2, diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h index 8c3a06b5f0..ac7f1dbf07 100644 --- a/transformer_engine/common/include/transformer_engine/ep.h +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -35,6 +35,9 @@ typedef struct { int max_num_sms; /*!< Max SMs for EP kernels. 0 = auto. */ /*! 0 (default): throw on relocated handle_mem for a cached handle_id. 1: silently rebuild. */ int allow_handle_mem_reloc; + /*! Token dtype for this EP group. Sizes NCCL EP staging buffers at group + * create and is enforced against tensors passed to nvte_ep_dispatch. */ + NVTEDType token_dtype; } NVTEEpGroupConfig; /*! \brief Per-layer EP configuration. */ From 03e56d221d28fe129eeede510168723ee2d26d68 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 28 May 2026 15:31:47 -0700 Subject: [PATCH 07/29] Expert Parallelism: document ep_comm lifetime, v0.1 single-GPU scope, static layer registration Signed-off-by: Phuong Nguyen --- .../common/include/transformer_engine/ep.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h index ac7f1dbf07..a1c9305e9b 100644 --- a/transformer_engine/common/include/transformer_engine/ep.h +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -54,8 +54,13 @@ typedef struct { /*! \brief Bootstrap from an existing NCCL EP sub-communicator. Requires SM>=90. * * ep_comm is borrowed and must span exactly group_config.ep_size ranks. + * The caller retains ownership and must keep ep_comm alive until + * nvte_ep_shutdown() returns; destroying it earlier is undefined behavior. * Re-init after shutdown is allowed; double-init throws. * + * v0.1 scope: one EP group per process, bound to the current CUDA device at + * initialize time. Multiple GPUs per process are not supported. + * * \param[in] ep_comm Opaque ncclComm_t for the EP sub-group. * \param[in] group_config Group-level EP configuration. */ @@ -69,6 +74,11 @@ void nvte_ep_shutdown(void); /*! \brief Reserve a handle_id for a layer config and report the handle_mem buffer * size the caller must allocate. Host-only. * + * Registration is intended to be static (once per layer at model init). There is + * no per-layer unregister API; all registrations are released by nvte_ep_shutdown. + * Re-registering the same layer config each step is not supported and will + * eventually exhaust the handle cache (NVTE_EP_HANDLE_CACHE_SIZE, default 8192). + * * \param[in] layer_config Per-layer EP configuration. * \param[out] handle_mem_size Bytes the caller must allocate for handle_mem. * \return uint64_t handle_id (non-zero). From 4cefdcb2ad71f95be154516a4234a44c50eef641 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 28 May 2026 15:32:48 -0700 Subject: [PATCH 08/29] Expert Parallelism: drop version label from initialize scope note Signed-off-by: Phuong Nguyen --- tests/cpp_distributed/CMakeLists.txt | 45 +- tests/cpp_distributed/run_test_ep.sh | 123 +--- .../{test_ep_pipeline.cu => test_ep.cu} | 643 ++++++++---------- tests/cpp_distributed/test_ep_common.h | 194 +----- tests/cpp_distributed/test_ep_coverage.cu | 562 --------------- tests/cpp_distributed/test_ep_init.cu | 64 -- transformer_engine/common/ep/ep_backend.cpp | 25 +- .../common/include/transformer_engine/ep.h | 13 +- transformer_engine/common/util/logging.h | 8 + 9 files changed, 376 insertions(+), 1301 deletions(-) rename tests/cpp_distributed/{test_ep_pipeline.cu => test_ep.cu} (51%) delete mode 100644 tests/cpp_distributed/test_ep_coverage.cu delete mode 100644 tests/cpp_distributed/test_ep_init.cu diff --git a/tests/cpp_distributed/CMakeLists.txt b/tests/cpp_distributed/CMakeLists.txt index 3870f57911..7dd8ea33e7 100644 --- a/tests/cpp_distributed/CMakeLists.txt +++ b/tests/cpp_distributed/CMakeLists.txt @@ -30,7 +30,7 @@ if(NOT DEFINED TE_LIB_PATH) get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY) endif() -find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED NO_CMAKE_SYSTEM_PATH) +find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) message(STATUS "Found transformer_engine library: ${TE_LIB}") include_directories(../../transformer_engine/common/include) @@ -73,10 +73,8 @@ target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest include(GoogleTest) gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) -# ── EP distributed tests (HT mode) ───────────────────────────────────────── -# No MPI dependency — processes are spawned by run_test_ep.sh with -# --rank / --nranks flags. ncclUniqueId exchange uses a -# shared temp file (see test_ep_common.h for details). +# ── EP distributed tests ────────────────────────────────────────────────────── +# Launched via mpirun; ncclUniqueId exchange uses MPI_Bcast (see test_ep_common.h). # Headers + libs come from the in-tree 3rdparty/nccl submodule build. set(NCCL_EP_SUBMODULE_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl") @@ -103,41 +101,28 @@ endif() set(EP_TEST_COMMON_INCLUDES ${EP_TEST_NCCL_INCLUDES} + ${MPI_CXX_INCLUDE_PATH} ../../transformer_engine/common/include ../../transformer_engine/common ${CMAKE_CURRENT_SOURCE_DIR}) +# nvrtc must follow TE_LIB so symbols referenced from libtransformer_engine.so +# (loaded via dlopen in Python; not in its DT_NEEDED) resolve through nvrtc. set(EP_TEST_COMMON_LIBS CUDA::cuda_driver CUDA::cudart - CUDA::nvrtc GTest::gtest ${TE_LIB} + CUDA::nvrtc ${NCCL_LIB} - ${NCCL_EP_LIB}) - -# nvrtc symbols are referenced from libtransformer_engine.so but not in its -# DT_NEEDED list (loaded via dlopen in Python). For cpp tests we link nvrtc -# explicitly with --no-as-needed so the linker keeps the dependency. -set(EP_TEST_LINK_OPTS "LINKER:--no-as-needed") - -# ── EP init tests (InitPath, HandleMemSizeQuery) ───────────────────────────── -add_executable(test_ep_init test_ep_init.cu) -target_include_directories(test_ep_init PRIVATE ${EP_TEST_COMMON_INCLUDES}) -target_link_libraries(test_ep_init PUBLIC ${EP_TEST_COMMON_LIBS}) -target_link_options(test_ep_init PUBLIC ${EP_TEST_LINK_OPTS}) - -# ── EP pipeline tests (dispatch, combine, bwd, integrated) ─────────────────── -add_executable(test_ep_pipeline test_ep_pipeline.cu) -target_include_directories(test_ep_pipeline PRIVATE ${EP_TEST_COMMON_INCLUDES}) -target_link_libraries(test_ep_pipeline PUBLIC ${EP_TEST_COMMON_LIBS}) -target_link_options(test_ep_pipeline PUBLIC ${EP_TEST_LINK_OPTS}) - -# ── EP coverage tests (multi-handle, top_k=1, empty experts, negatives, threading) ── -add_executable(test_ep_coverage test_ep_coverage.cu) -target_include_directories(test_ep_coverage PRIVATE ${EP_TEST_COMMON_INCLUDES}) -target_link_libraries(test_ep_coverage PUBLIC ${EP_TEST_COMMON_LIBS}) -target_link_options(test_ep_coverage PUBLIC ${EP_TEST_LINK_OPTS}) + ${NCCL_EP_LIB} + MPI::MPI_CXX + OpenMP::OpenMP_CXX) + +# ── EP distributed tests (per-op + full pipeline + zero-copy symm) ─────────── +add_executable(test_ep test_ep.cu ../cpp/test_common.cu) +target_include_directories(test_ep PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep PUBLIC ${EP_TEST_COMMON_LIBS}) # Do NOT use gtest_discover_tests — these binaries require multi-process # launch via run_test_ep.sh, not direct single-process execution. diff --git a/tests/cpp_distributed/run_test_ep.sh b/tests/cpp_distributed/run_test_ep.sh index 017d3f807b..13e86fa02d 100755 --- a/tests/cpp_distributed/run_test_ep.sh +++ b/tests/cpp_distributed/run_test_ep.sh @@ -3,12 +3,8 @@ # # See LICENSE for license information. # -# Run TE EP distributed unit tests across multiple GPUs. -# -# Spawns one background bash process per GPU (no MPI dependency), matching the -# JAX multi-process launcher style. ncclUniqueId is exchanged via a shared -# temp file (see test_ep_common.h). Each rank builds its own ncclComm_t and -# passes it to nvte_ep_initialize. +# Run TE EP distributed unit tests via mpirun. Each MPI rank pins to one GPU +# (rank % device_count) and exchanges ncclUniqueId through MPI_Bcast. # # Usage: # bash run_test_ep.sh [num_gpus] [build_dir] @@ -18,15 +14,16 @@ # build_dir = /build # # Environment variables: -# GTEST_FILTER — forwarded to all processes (e.g., "EPDispatchTest.*") -# TEST_TIMEOUT_S — per-process timeout in seconds (default: 180) +# GTEST_FILTER — forwarded to all processes (e.g., "EPPipelineTest.*") +# MPIRUN — override the mpirun binary (default: mpirun) +# MPIRUN_EXTRA — extra flags forwarded to mpirun set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BUILD_DIR="${2:-${SCRIPT_DIR}/build}" NUM_GPUS="${1:-$(nvidia-smi -L 2>/dev/null | wc -l)}" -TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-180}" +MPIRUN="${MPIRUN:-mpirun}" # Skip cleanly on pre-Hopper: NCCL EP requires SM>=90. MIN_SM=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \ @@ -36,102 +33,22 @@ if (( MIN_SM > 0 && MIN_SM < 90 )); then exit 0 fi -GTEST_ARGS="${GTEST_FILTER:+--gtest_filter=${GTEST_FILTER}}" -OVERALL_FAIL=0 - -# --------------------------------------------------------------------------- -# run_suite BINARY SUITE_NAME MIN_GPUS -# --------------------------------------------------------------------------- -run_suite() { - local BINARY="$1" - local SUITE_NAME="$2" - local MIN_GPUS="${3:-2}" - - local TEST_BIN="${BUILD_DIR}/${BINARY}" - - if [[ ! -x "${TEST_BIN}" ]]; then - echo "ERROR: binary not found: ${TEST_BIN}" - echo "Build: cd ${SCRIPT_DIR} && mkdir -p build && cd build && cmake .. && make" - OVERALL_FAIL=1 - return - fi - - if (( NUM_GPUS < MIN_GPUS )); then - echo "${SUITE_NAME}: requires ${MIN_GPUS} GPUs, found ${NUM_GPUS}. Skipping." - return - fi - - local TMPDIR_L="${TMPDIR:-/tmp}" - local UID_FILE="${TMPDIR_L}/te_ep_uid_${BINARY}_$$" - rm -f "${UID_FILE}" - - local LOG_DIR - LOG_DIR=$(mktemp -d) - local FAIL=0 - - echo "=== ${SUITE_NAME} ===" - echo " GPUs: ${NUM_GPUS} Binary: ${TEST_BIN}" - echo - - # Spawn one background process per GPU. ncclUniqueId is exchanged via the - # shared UID_FILE. Each process is wrapped in `timeout` to detect hangs early. - local PIDS=() - for i in $(seq 0 $((NUM_GPUS - 1))); do - timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ - "${TEST_BIN}" \ - --rank="${i}" \ - --nranks="${NUM_GPUS}" \ - --uid-file="${UID_FILE}" \ - ${GTEST_ARGS} \ - > "${LOG_DIR}/rank_${i}.log" 2>&1 & - PIDS+=($!) - done - for i in $(seq 0 $((NUM_GPUS - 1))); do - if ! wait "${PIDS[$i]}"; then - local rc=$? - FAIL=1 - if [[ $rc -eq 137 || $rc -eq 124 ]]; then - echo " rank ${i}: TIMEOUT after ${TEST_TIMEOUT_S}s (rc=${rc})" - fi - fi - done - - echo "--- Rank 0 output ---" - cat "${LOG_DIR}/rank_0.log" - - if (( FAIL )); then - for i in $(seq 1 $((NUM_GPUS - 1))); do - echo "--- Rank ${i} output ---" - cat "${LOG_DIR}/rank_${i}.log" - done - echo "=== ${SUITE_NAME}: FAILED ===" - OVERALL_FAIL=1 - else - echo "=== ${SUITE_NAME}: ALL PASSED ===" - fi - - rm -rf "${LOG_DIR}" - rm -f "${UID_FILE}" -} +TEST_BIN="${BUILD_DIR}/test_ep" +if [[ ! -x "${TEST_BIN}" ]]; then + echo "ERROR: binary not found: ${TEST_BIN}" + echo "Build: cd ${SCRIPT_DIR} && mkdir -p build && cd build && cmake .. && make" + exit 1 +fi -# --------------------------------------------------------------------------- -# Cleanup on abort -# --------------------------------------------------------------------------- -cleanup() { rm -f "${TMPDIR:-/tmp}"/te_ep_uid_*_"$$" 2>/dev/null || true; } -trap cleanup EXIT INT TERM +if (( NUM_GPUS < 2 )); then + echo "EP Tests: requires at least 2 GPUs, found ${NUM_GPUS}. Skipping." + exit 0 +fi -# --------------------------------------------------------------------------- -# Run all suites -# --------------------------------------------------------------------------- -run_suite "test_ep_init" "EP Init Tests" 2 -run_suite "test_ep_pipeline" "EP Pipeline Tests" 2 -run_suite "test_ep_coverage" "EP Coverage Tests" 2 +GTEST_ARGS="${GTEST_FILTER:+--gtest_filter=${GTEST_FILTER}}" +echo "=== EP Tests ===" +echo " GPUs: ${NUM_GPUS} Binary: ${TEST_BIN}" echo -if (( OVERALL_FAIL )); then - echo "=== SOME SUITES FAILED ===" -else - echo "=== ALL SUITES PASSED ===" -fi -exit "${OVERALL_FAIL}" +"${MPIRUN}" -n "${NUM_GPUS}" ${MPIRUN_EXTRA:-} "${TEST_BIN}" ${GTEST_ARGS} diff --git a/tests/cpp_distributed/test_ep_pipeline.cu b/tests/cpp_distributed/test_ep.cu similarity index 51% rename from tests/cpp_distributed/test_ep_pipeline.cu rename to tests/cpp_distributed/test_ep.cu index 41f83a6d11..bcf4ca3c98 100644 --- a/tests/cpp_distributed/test_ep_pipeline.cu +++ b/tests/cpp_distributed/test_ep.cu @@ -39,10 +39,21 @@ static inline float token_value(int rank, int t, int num_tokens) { return static_cast(rank * num_tokens + t + 1) * (1.0f / 256.0f); } -static std::vector generate_tokens(int rank, int num_tokens, int hidden_dim) { - std::vector v(num_tokens * hidden_dim); +// Per-element host-side conversion helpers used by templated test code. +inline float tok_to_float(nv_bfloat16 v) { return __bfloat162float(v); } +inline float tok_to_float(__half v) { return __half2float(v); } +inline float tok_to_float(float v) { return v; } + +template T tok_from_float(float v); +template <> inline nv_bfloat16 tok_from_float(float v) { return __float2bfloat16(v); } +template <> inline __half tok_from_float<__half> (float v) { return __float2half(v); } +template <> inline float tok_from_float (float v) { return v; } + +template +static std::vector generate_tokens(int rank, int num_tokens, int hidden_dim) { + std::vector v(num_tokens * hidden_dim); for (int t = 0; t < num_tokens; ++t) { - nv_bfloat16 val = __float2bfloat16(token_value(rank, t, num_tokens)); + T val = tok_from_float(token_value(rank, t, num_tokens)); for (int h = 0; h < hidden_dim; ++h) v[t * hidden_dim + h] = val; } @@ -85,17 +96,20 @@ static std::vector expected_recv_values_sorted( return vals; } -// BF16 has 7 mantissa bits; relative ULP ≈ 2^-7. Use 4× headroom for -// accumulation noise inside dispatch/combine. +// 2^-5 relative tolerance for BF16 (matches mantissa precision with margin), +// plus a small atol floor for near-zero expected values. +static constexpr float kBf16Rtol = 1.0f / 32.0f; +static constexpr float kBf16Atol = 1e-3f; static float bf16_tol(float magnitude) { - return 4.f * std::ldexp(std::fabs(magnitude) + 1e-3f, -7); + return kBf16Atol + kBf16Rtol * std::fabs(magnitude); } -static bool check_no_nan_inf(const nv_bfloat16* dev, int count, const char* name) { - std::vector h(count); - cudaMemcpy(h.data(), dev, count * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost); +template +static bool check_no_nan_inf(const T* dev, int count, const char* name) { + std::vector h(count); + cudaMemcpy(h.data(), dev, count * sizeof(T), cudaMemcpyDeviceToHost); for (int i = 0; i < count; ++i) { - float v = __bfloat162float(h[i]); + float v = tok_to_float(h[i]); if (std::isnan(v) || std::isinf(v)) { fprintf(stderr, "Rank %d: %s in %s[%d]\n", g_process_id, std::isnan(v) ? "NaN" : "Inf", name, i); @@ -107,20 +121,21 @@ static bool check_no_nan_inf(const nv_bfloat16* dev, int count, const char* name // ── Forward buffer set with RAII ────────────────────────────────────────────── +template struct EPBuffers { // Forward DevBuf topk_idx; DevBuf topk_weights; - DevBuf tokens; + DevBuf tokens; DevBuf token_counts; DevBuf handle_mem; - DevBuf recv_tokens; + DevBuf recv_tokens; DevBuf recv_topk_weights; - DevBuf result; + DevBuf result; // Backward - DevBuf grad_result; - DevBuf grad_expert; - DevBuf grad_tokens; + DevBuf grad_result; + DevBuf grad_expert; + DevBuf grad_tokens; DevBuf g_recv_topk_weights; DevBuf grad_topk_weights; @@ -154,42 +169,45 @@ struct EPBuffers { } }; -// Bundled NVTETensor views over an EPBuffers — one place to update the shape -// conventions when the C-API evolves. +// Bundled NVTETensor views over an EPBuffers, with the shapes the EP C API +// expects. +template struct EPTensors { - TensorHandle topk_idx, topk_weights, token_counts, handle_mem, tokens; - TensorHandle recv_tokens, recv_topk_weights, result; - TensorHandle grad_result, grad_expert, grad_tokens; - TensorHandle g_recv_topk_weights, grad_topk_weights; + TensorWrapper topk_idx, topk_weights, token_counts, handle_mem, tokens; + TensorWrapper recv_tokens, recv_topk_weights, result; + TensorWrapper grad_result, grad_expert, grad_tokens; + TensorWrapper g_recv_topk_weights, grad_topk_weights; - EPTensors(EPBuffers& b, int num_tokens, int top_k, int hidden_dim, + EPTensors(EPBuffers& b, int num_tokens, int top_k, int hidden_dim, int num_local_experts) { - topk_idx = make_nvte_tensor(b.topk_idx.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - topk_weights = make_nvte_tensor(b.topk_weights.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); - token_counts = make_nvte_tensor(b.token_counts.get(), - {(size_t)num_local_experts}, kNVTEInt32); - handle_mem = make_nvte_tensor(b.handle_mem.get(), - {b.handle_mem_size}, kNVTEByte); - tokens = make_nvte_tensor(b.tokens.get(), - {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); - recv_tokens = make_nvte_tensor(b.recv_tokens.get(), - {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); - recv_topk_weights = make_nvte_tensor(b.recv_topk_weights.get(), - {b.recv_capacity}, kNVTEFloat32); - result = make_nvte_tensor(b.result.get(), - {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); - grad_result = make_nvte_tensor(b.grad_result.get(), - {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); - grad_expert = make_nvte_tensor(b.grad_expert.get(), - {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); - grad_tokens = make_nvte_tensor(b.grad_tokens.get(), - {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); - g_recv_topk_weights = make_nvte_tensor(b.g_recv_topk_weights.get(), - {b.recv_capacity}, kNVTEFloat32); - grad_topk_weights = make_nvte_tensor(b.grad_topk_weights.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + constexpr DType kTokDType = test::TypeInfo::dtype; + using Shape = std::vector; + topk_idx = TensorWrapper(b.topk_idx.get(), + Shape{(size_t)num_tokens, (size_t)top_k}, DType::kInt64); + topk_weights = TensorWrapper(b.topk_weights.get(), + Shape{(size_t)num_tokens, (size_t)top_k}, DType::kFloat32); + token_counts = TensorWrapper(b.token_counts.get(), + Shape{(size_t)num_local_experts}, DType::kInt32); + handle_mem = TensorWrapper(b.handle_mem.get(), + Shape{b.handle_mem_size}, DType::kByte); + tokens = TensorWrapper(b.tokens.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + recv_tokens = TensorWrapper(b.recv_tokens.get(), + Shape{b.recv_capacity, (size_t)hidden_dim}, kTokDType); + recv_topk_weights = TensorWrapper(b.recv_topk_weights.get(), + Shape{b.recv_capacity}, DType::kFloat32); + result = TensorWrapper(b.result.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + grad_result = TensorWrapper(b.grad_result.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + grad_expert = TensorWrapper(b.grad_expert.get(), + Shape{b.recv_capacity, (size_t)hidden_dim}, kTokDType); + grad_tokens = TensorWrapper(b.grad_tokens.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + g_recv_topk_weights = TensorWrapper(b.g_recv_topk_weights.get(), + Shape{b.recv_capacity}, DType::kFloat32); + grad_topk_weights = TensorWrapper(b.grad_topk_weights.get(), + Shape{(size_t)num_tokens, (size_t)top_k}, DType::kFloat32); } }; @@ -215,29 +233,31 @@ class EpOpTestBase : public ::testing::Test { num_tokens_ = 32; } - void upload_inputs(EPBuffers& buf, int rank = -1) { + template + void upload_inputs(EPBuffers& buf, int rank = -1) { if (rank < 0) rank = g_process_id; auto h_idx = routing_balanced(rank, num_tokens_, top_k_, num_experts_, num_local_experts_); std::vector h_w(num_tokens_ * top_k_, 1.0f / top_k_); - auto h_tok = generate_tokens(rank, num_tokens_, hidden_dim_); - - CHECK_CUDA(cudaMemcpy(buf.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), - h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(buf.tokens.get(), h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + auto h_tok = generate_tokens(rank, num_tokens_, hidden_dim_); + + NVTE_CHECK_CUDA(cudaMemcpy(buf.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(buf.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(T), cudaMemcpyHostToDevice)); } NVTEEpLayerConfig layer_config(size_t alignment = 0) const { return NVTEEpLayerConfig{num_local_experts_, top_k_, alignment}; } - // ASSERT_CUDA_OK (fprintf+exit) so this non-void helper stays legal. - int read_total_recv(const EPBuffers& buf) const { + // NVTE_CHECK_CUDA (fprintf+exit) so this non-void helper stays legal. + template + int read_total_recv(const EPBuffers& buf) const { std::vector cnt(num_local_experts_); - ASSERT_CUDA_OK(cudaMemcpy(cnt.data(), buf.token_counts.get(), + NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); int total = 0; for (int c : cnt) total += c; @@ -252,28 +272,28 @@ class EpOpTestBase : public ::testing::Test { class EPDispatchTest : public EpOpTestBase {}; TEST_F(EPDispatchTest, PrepareAndDispatch) { - EPBuffers buf; + EPBuffers<> buf; buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, ep_size_, max_tokens_per_rank_); upload_inputs(buf); - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); - CHECK_CUDA(cudaMemset(buf.recv_tokens.get(), 0, buf.recv_tokens.bytes())); + NVTE_CHECK_CUDA(cudaMemset(buf.recv_tokens.get(), 0, buf.recv_tokens.bytes())); cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); uint64_t handle_id = buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, - NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); // 1. Per-expert counts. std::vector got_counts(num_local_experts_); - CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.token_counts.get(), + NVTE_CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.token_counts.get(), num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); auto exp_counts = expected_token_counts(g_process_id, g_num_processes, num_tokens_, top_k_, num_experts_, num_local_experts_); @@ -288,7 +308,7 @@ TEST_F(EPDispatchTest, PrepareAndDispatch) { // 2. Recv values: read only the filled prefix per local-expert zone, not the // whole recv buffer — avoids false positives from legitimate-zero token values. std::vector h_recv(buf.recv_capacity * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_recv.data(), buf.recv_tokens.get(), + NVTE_CHECK_CUDA(cudaMemcpy(h_recv.data(), buf.recv_tokens.get(), h_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); std::vector got_vals; @@ -312,7 +332,7 @@ TEST_F(EPDispatchTest, PrepareAndDispatch) { // 3. recv_topk_weights: every filled slot must equal the per-token weight (1/top_k). std::vector h_w(buf.recv_capacity); - CHECK_CUDA(cudaMemcpy(h_w.data(), buf.recv_topk_weights.get(), + NVTE_CHECK_CUDA(cudaMemcpy(h_w.data(), buf.recv_topk_weights.get(), h_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); const float exp_w = 1.0f / static_cast(top_k_); for (int i = 0; i < total_recv; ++i) @@ -321,7 +341,7 @@ TEST_F(EPDispatchTest, PrepareAndDispatch) { if (g_process_id == 0) printf(" PrepareAndDispatch: passed (recv=%d, values + weights exact)\n", total_recv); - CHECK_CUDA(cudaStreamDestroy(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); } // ============================================================================= @@ -331,34 +351,32 @@ TEST_F(EPDispatchTest, PrepareAndDispatch) { class EPCombineTest : public EpOpTestBase {}; TEST_F(EPCombineTest, Combine) { - EPBuffers buf; + EPBuffers<> buf; buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, ep_size_, max_tokens_per_rank_); upload_inputs(buf); - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); uint64_t handle_id = buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, - NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, - t.result.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); std::vector h_result(num_tokens_ * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_result.data(), buf.result.get(), + NVTE_CHECK_CUDA(cudaMemcpy(h_result.data(), buf.result.get(), h_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); - // Spot-check 3 hidden-dim positions per token to catch partial-row writes. - const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; for (int tok = 0; tok < num_tokens_; ++tok) { float exp = __bfloat162float(h_tok[tok * hidden_dim_]) * static_cast(top_k_); - for (int p : probes) { + for (int p = 0; p < hidden_dim_; ++p) { float got = __bfloat162float(h_result[tok * hidden_dim_ + p]); EXPECT_NEAR(got, exp, bf16_tol(exp)) << "token " << tok << " rank " << g_process_id << " hidden " << p; @@ -368,7 +386,7 @@ TEST_F(EPCombineTest, Combine) { if (g_process_id == 0) printf(" Combine: passed (result == top_k * tokens)\n"); - CHECK_CUDA(cudaStreamDestroy(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); } // ============================================================================= @@ -378,41 +396,41 @@ TEST_F(EPCombineTest, Combine) { class EPCombineBwdTest : public EpOpTestBase {}; TEST_F(EPCombineBwdTest, CombineBwdCheck) { - EPBuffers buf; + EPBuffers<> buf; buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, ep_size_, max_tokens_per_rank_); upload_inputs(buf); - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); uint64_t handle_id = buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, - NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, - t.result.tensor, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); std::vector h_grad_r(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); - CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad_r.data(), + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad_r.data(), h_grad_r.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice, stream)); - CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); - ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, - t.grad_expert.tensor, NVTECommWindow{}, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_result.data(), NVTECommWindow{}, + t.grad_expert.data(), NVTECommWindow{}, stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); int total_recv = read_total_recv(buf); std::vector cnt(num_local_experts_); - CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), + NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); std::vector h_ge(buf.recv_capacity * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_ge.data(), buf.grad_expert.get(), + NVTE_CHECK_CUDA(cudaMemcpy(h_ge.data(), buf.grad_expert.get(), h_ge.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); // Walk filled slots by per-expert zone (no v != 0 heuristic). @@ -421,9 +439,12 @@ TEST_F(EPCombineBwdTest, CombineBwdCheck) { int filled = 0; for (int e = 0; e < num_local_experts_; ++e) { for (int i = 0; i < cnt[e]; ++i) { - float v = __bfloat162float(h_ge[slot * hidden_dim_]); - EXPECT_NEAR(v, kExpGrad, bf16_tol(kExpGrad)) - << "grad_expert expert " << e << " slot " << i << " (linear " << slot << ")"; + for (int p = 0; p < hidden_dim_; ++p) { + float v = __bfloat162float(h_ge[slot * hidden_dim_ + p]); + EXPECT_NEAR(v, kExpGrad, bf16_tol(kExpGrad)) + << "grad_expert expert " << e << " slot " << i + << " (linear " << slot << ") hidden " << p; + } ++filled; ++slot; } } @@ -432,7 +453,7 @@ TEST_F(EPCombineBwdTest, CombineBwdCheck) { if (g_process_id == 0) printf(" CombineBwdCheck: passed (filled=%d)\n", filled); - CHECK_CUDA(cudaStreamDestroy(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); } // ============================================================================= @@ -442,51 +463,53 @@ TEST_F(EPCombineBwdTest, CombineBwdCheck) { class EPDispatchBwdTest : public EpOpTestBase {}; TEST_F(EPDispatchBwdTest, DispatchBwdCheck) { - EPBuffers buf; + EPBuffers<> buf; buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, ep_size_, max_tokens_per_rank_); upload_inputs(buf); - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); uint64_t handle_id = buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, - NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, - t.result.tensor, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); - CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), h_grad.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice, stream)); - CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); - CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); - CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); - ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, - t.grad_expert.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, NVTECommWindow{}, - t.g_recv_topk_weights.tensor, NVTECommWindow{}, - t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_result.data(), NVTECommWindow{}, + t.grad_expert.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_expert.data(), NVTECommWindow{}, + t.g_recv_topk_weights.data(), NVTECommWindow{}, + t.grad_tokens.data(), t.grad_topk_weights.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); std::vector h_gt(num_tokens_ * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), + NVTE_CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), h_gt.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); const float kExpGrad = static_cast(top_k_) * 0.1f; for (int tok = 0; tok < num_tokens_; ++tok) - EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_]), kExpGrad, bf16_tol(kExpGrad)) - << "grad_tokens token " << tok; + for (int p = 0; p < hidden_dim_; ++p) + EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_ + p]), kExpGrad, + bf16_tol(kExpGrad)) + << "grad_tokens token " << tok << " hidden " << p; if (g_process_id == 0) printf(" DispatchBwdCheck: passed (grad_tokens == %.2f)\n", kExpGrad); - CHECK_CUDA(cudaStreamDestroy(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); } // ============================================================================= @@ -496,11 +519,11 @@ TEST_F(EPDispatchBwdTest, DispatchBwdCheck) { class EPDispatchBwdGradWeightsTest : public EpOpTestBase {}; TEST_F(EPDispatchBwdGradWeightsTest, RoundTrip) { - EPBuffers buf; + EPBuffers<> buf; buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, ep_size_, max_tokens_per_rank_); upload_inputs(buf); - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); // Distinct per-(rank, t, k) weights so each slot carries a unique value. std::vector h_w(num_tokens_ * top_k_); @@ -508,39 +531,39 @@ TEST_F(EPDispatchBwdGradWeightsTest, RoundTrip) { for (int k = 0; k < top_k_; ++k) h_w[tok * top_k_ + k] = 0.1f + 0.01f * tok + 0.001f * k + 0.0001f * (g_process_id + 1); - CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + NVTE_CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); uint64_t handle_id = buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0, + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0, buf.recv_topk_weights.bytes(), stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, - NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); // Sentinel: NaN so any (t, k) the bwd kernel fails to write is immediately visible. std::vector h_nan(num_tokens_ * top_k_, std::numeric_limits::quiet_NaN()); - CHECK_CUDA(cudaMemcpyAsync(buf.grad_topk_weights.get(), h_nan.data(), + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_topk_weights.get(), h_nan.data(), h_nan.size() * sizeof(float), cudaMemcpyHostToDevice, stream)); - CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); // g_recv_topk_weights := recv_topk_weights (the round-trip input). - auto g_recv_t = make_nvte_tensor(buf.recv_topk_weights.get(), - {buf.recv_capacity}, kNVTEFloat32); - ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, - NVTECommWindow{}, g_recv_t.tensor, NVTECommWindow{}, - t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); + auto g_recv_t = TensorWrapper(buf.recv_topk_weights.get(), + std::vector{buf.recv_capacity}, DType::kFloat32); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_expert.data(), + NVTECommWindow{}, g_recv_t.data(), NVTECommWindow{}, + t.grad_tokens.data(), t.grad_topk_weights.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); std::vector h_grad_w(num_tokens_ * top_k_); - CHECK_CUDA(cudaMemcpy(h_grad_w.data(), buf.grad_topk_weights.get(), + NVTE_CHECK_CUDA(cudaMemcpy(h_grad_w.data(), buf.grad_topk_weights.get(), h_grad_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); const float kTol = 1e-5f; @@ -566,57 +589,81 @@ TEST_F(EPDispatchBwdGradWeightsTest, RoundTrip) { if (g_process_id == 0 && errs == 0 && k0_eq_k1 == 0) printf(" RoundTrip: passed (%d (t, k) gradients)\n", num_tokens_ * top_k_); - CHECK_CUDA(cudaStreamDestroy(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); } // ============================================================================= // Integrated FwdBwd: NaN/Inf check end-to-end. // ============================================================================= -class EPPipelineTest : public EpOpTestBase {}; - -TEST_F(EPPipelineTest, FullForwardBackward) { - EPBuffers buf; - buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, - ep_size_, max_tokens_per_rank_); - upload_inputs(buf); - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - uint64_t handle_id = buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, - NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, - t.result.tensor, stream)); - - std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); - CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), - h_grad.size() * sizeof(nv_bfloat16), - cudaMemcpyHostToDevice, stream)); - CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); - CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); - CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); - - ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, - t.grad_expert.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, NVTECommWindow{}, - t.g_recv_topk_weights.tensor, NVTECommWindow{}, - t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - - ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); - ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); - - if (g_process_id == 0) printf(" FullForwardBackward: passed\n"); +class EPPipelineTest : public EpOpTestBase, public ::testing::WithParamInterface { + protected: + template + void run_full_forward_backward() { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, tok_from_float(0.1f)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + h_grad.size() * sizeof(Tok), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_result.data(), NVTECommWindow{}, + t.grad_expert.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_expert.data(), NVTECommWindow{}, + t.g_recv_topk_weights.data(), NVTECommWindow{}, + t.grad_tokens.data(), t.grad_topk_weights.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); + ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); + } +}; - CHECK_CUDA(cudaStreamDestroy(stream)); +TEST_P(EPPipelineTest, FullForwardBackward) { + const DType dtype = GetParam(); + // NCCL EP backend currently asserts ncclBfloat16 in ncclEpDispatch + // (contrib/nccl_ep/nccl_ep.cc); skip FP16/FP32 until the backend supports them. + if (dtype != DType::kBFloat16) { + GTEST_SKIP() << test::typeName(dtype) << " not yet supported by NCCL EP backend"; + } + switch (dtype) { + case DType::kBFloat16: run_full_forward_backward(); break; + case DType::kFloat16: run_full_forward_backward<__half> (); break; + case DType::kFloat32: run_full_forward_backward (); break; + default: FAIL() << "unsupported token dtype " << static_cast(dtype); + } + if (g_process_id == 0) + printf(" FullForwardBackward[%s]: passed\n", test::typeName(dtype).c_str()); } +INSTANTIATE_TEST_SUITE_P( + Dtypes, EPPipelineTest, + ::testing::Values(DType::kBFloat16, DType::kFloat16, DType::kFloat32), + [](const ::testing::TestParamInfo& info) { + return test::typeName(info.param); + }); + // ============================================================================= // EPZeroCopyTest: dispatch/combine with NCCL symmetric-memory windows attached // to payload tensors (zero-copy fast path via ncclEpTensorCreateFromWindow). @@ -646,9 +693,9 @@ struct SymmBuf { void alloc(size_t n_bytes) { bytes = n_bytes; - ASSERT_NCCL_OK(ncclMemAlloc(&ptr, bytes)); - CHECK_CUDA(cudaMemset(ptr, 0, bytes)); - ASSERT_NCCL_OK(ncclCommWindowRegister(g_ep_comm, ptr, bytes, &win, + NVTE_CHECK_NCCL(ncclMemAlloc(&ptr, bytes)); + NVTE_CHECK_CUDA(cudaMemset(ptr, 0, bytes)); + NVTE_CHECK_NCCL(ncclCommWindowRegister(g_ep_comm, ptr, bytes, &win, NCCL_WIN_COLL_SYMMETRIC)); } }; @@ -666,34 +713,34 @@ class EPZeroCopyTest : public EpOpTestBase {}; // vs HBM reference (same routing, same input). TEST_F(EPZeroCopyTest, IdentityAllSymm) { // HBM reference run. - EPBuffers ref_buf; + EPBuffers<> ref_buf; ref_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, ep_size_, max_tokens_per_rank_); upload_inputs(ref_buf); - EPTensors ref_t(ref_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + EPTensors<> ref_t(ref_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); uint64_t ref_hid = ref_buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.topk_idx.tensor, ref_t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.topk_idx.tensor, - ref_t.tokens.tensor, NVTECommWindow{}, ref_t.topk_weights.tensor, - NVTECommWindow{}, ref_t.recv_tokens.tensor, NVTECommWindow{}, - ref_t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.recv_tokens.tensor, NVTECommWindow{}, - ref_t.result.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{ref_hid, ref_t.handle_mem.data()}, ref_t.topk_idx.data(), ref_t.token_counts.data(), /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{ref_hid, ref_t.handle_mem.data()}, ref_t.topk_idx.data(), + ref_t.tokens.data(), NVTECommWindow{}, ref_t.topk_weights.data(), + NVTECommWindow{}, ref_t.recv_tokens.data(), NVTECommWindow{}, + ref_t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{ref_hid, ref_t.handle_mem.data()}, ref_t.recv_tokens.data(), NVTECommWindow{}, + ref_t.result.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); std::vector ref_recv(ref_buf.recv_capacity * hidden_dim_); std::vector ref_result(num_tokens_ * hidden_dim_); - CHECK_CUDA(cudaMemcpy(ref_recv.data(), ref_buf.recv_tokens.get(), + NVTE_CHECK_CUDA(cudaMemcpy(ref_recv.data(), ref_buf.recv_tokens.get(), ref_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - CHECK_CUDA(cudaMemcpy(ref_result.data(), ref_buf.result.get(), + NVTE_CHECK_CUDA(cudaMemcpy(ref_result.data(), ref_buf.result.get(), ref_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); // Symm-mem run: tokens, recv_tokens, combine_input (== recv_tokens) all symm. - EPBuffers sym_buf; // alloc all buffers except the symm ones. + EPBuffers<> sym_buf; // alloc all buffers except the symm ones. sym_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, ep_size_, max_tokens_per_rank_); upload_inputs(sym_buf); @@ -704,32 +751,32 @@ TEST_F(EPZeroCopyTest, IdentityAllSymm) { // Stage same tokens into the symm-mem input. auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); - CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + NVTE_CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - EPTensors sym_t(sym_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + EPTensors<> sym_t(sym_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); // Replace the tokens/recv_tokens views with ones pointing at the symm buffers. - sym_t.tokens = make_nvte_tensor(sym_tokens.ptr, - {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); - sym_t.recv_tokens = make_nvte_tensor(sym_recv.ptr, - {sym_buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + sym_t.tokens = TensorWrapper(sym_tokens.ptr, + std::vector{(size_t)num_tokens_, (size_t)hidden_dim_}, DType::kBFloat16); + sym_t.recv_tokens = TensorWrapper(sym_recv.ptr, + std::vector{sym_buf.recv_capacity, (size_t)hidden_dim_}, DType::kBFloat16); uint64_t sym_hid = sym_buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.topk_idx.tensor, sym_t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.topk_idx.tensor, - sym_t.tokens.tensor, symm_window(sym_tokens), - sym_t.topk_weights.tensor, NVTECommWindow{}, - sym_t.recv_tokens.tensor, symm_window(sym_recv), - sym_t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.recv_tokens.tensor, - symm_window(sym_recv), sym_t.result.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{sym_hid, sym_t.handle_mem.data()}, sym_t.topk_idx.data(), sym_t.token_counts.data(), /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{sym_hid, sym_t.handle_mem.data()}, sym_t.topk_idx.data(), + sym_t.tokens.data(), symm_window(sym_tokens), + sym_t.topk_weights.data(), NVTECommWindow{}, + sym_t.recv_tokens.data(), symm_window(sym_recv), + sym_t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{sym_hid, sym_t.handle_mem.data()}, sym_t.recv_tokens.data(), + symm_window(sym_recv), sym_t.result.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); std::vector sym_recv_host(sym_buf.recv_capacity * hidden_dim_); std::vector sym_result(num_tokens_ * hidden_dim_); - CHECK_CUDA(cudaMemcpy(sym_recv_host.data(), sym_recv.ptr, + NVTE_CHECK_CUDA(cudaMemcpy(sym_recv_host.data(), sym_recv.ptr, sym_recv_host.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - CHECK_CUDA(cudaMemcpy(sym_result.data(), sym_buf.result.get(), + NVTE_CHECK_CUDA(cudaMemcpy(sym_result.data(), sym_buf.result.get(), sym_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); // Compare per filled recv slot (HBM ref vs symm) and full result. @@ -744,141 +791,9 @@ TEST_F(EPZeroCopyTest, IdentityAllSymm) { if (g_process_id == 0) printf(" IdentityAllSymm: passed (recv_slots=%d, bit-exact vs HBM)\n", total_recv); - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -// Same buffers, 2 iterations — catches window-lifecycle regressions where the -// symm-mem registration goes stale between calls. -TEST_F(EPZeroCopyTest, IdentityAllSymmRepeated) { - EPBuffers buf; - buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, - ep_size_, max_tokens_per_rank_); - upload_inputs(buf); - - SymmBuf sym_tokens, sym_recv; - sym_tokens.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); - sym_recv .alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); - auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); - CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); - t.tokens = make_nvte_tensor(sym_tokens.ptr, - {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); - t.recv_tokens = make_nvte_tensor(sym_recv.ptr, - {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - uint64_t handle_id = buf.handle_id; - for (int iter = 0; iter < 2; ++iter) { - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, symm_window(sym_tokens), - t.topk_weights.tensor, NVTECommWindow{}, - t.recv_tokens.tensor, symm_window(sym_recv), - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, - symm_window(sym_recv), t.result.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - - std::vector h_res(num_tokens_ * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_res.data(), buf.result.get(), - h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - for (int tok = 0; tok < num_tokens_; ++tok) { - float exp = __bfloat162float(h_tok[tok * hidden_dim_]) * static_cast(top_k_); - float got = __bfloat162float(h_res[tok * hidden_dim_]); - ASSERT_NEAR(got, exp, bf16_tol(exp)) << "iter " << iter << " tok " << tok; - } - } - - if (g_process_id == 0) - printf(" IdentityAllSymmRepeated: passed (2 iters)\n"); - - CHECK_CUDA(cudaStreamDestroy(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); } -// Full forward+backward with symm-mem on every spec-mandated buffer: -// dispatch i/o, combine input, combine_bwd i/o, dispatch_bwd input. -// TODO: flaky on rank 0 (grad_tokens partial-zero) when run after the prior -// EPZeroCopyTest cases in the same binary; passes in isolation. Re-enable once -// the root cause (likely NCCL EP NVLS write→read coherence on grad_expert) is -// understood. Tracked separately. -TEST_F(EPZeroCopyTest, DISABLED_FullPipelineSymm) { - EPBuffers buf; - buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, - ep_size_, max_tokens_per_rank_); - upload_inputs(buf); - - // Symm-mem: tokens (dispatch input), recv_tokens (dispatch output AND - // combine input), grad_result (combine_bwd input), grad_expert - // (combine_bwd output AND dispatch_bwd input). - SymmBuf sym_tokens, sym_recv, sym_grad_result, sym_grad_expert; - sym_tokens .alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); - sym_recv .alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); - sym_grad_result.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); - sym_grad_expert.alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); - - auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); - CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); - t.tokens = make_nvte_tensor(sym_tokens.ptr, - {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); - t.recv_tokens = make_nvte_tensor(sym_recv.ptr, - {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); - t.grad_result = make_nvte_tensor(sym_grad_result.ptr, - {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); - t.grad_expert = make_nvte_tensor(sym_grad_expert.ptr, - {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - uint64_t handle_id = buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, symm_window(sym_tokens), - t.topk_weights.tensor, NVTECommWindow{}, - t.recv_tokens.tensor, symm_window(sym_recv), - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, - symm_window(sym_recv), t.result.tensor, stream)); - - std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); - CHECK_CUDA(cudaMemcpyAsync(sym_grad_result.ptr, h_grad.data(), - h_grad.size() * sizeof(nv_bfloat16), - cudaMemcpyHostToDevice, stream)); - CHECK_CUDA(cudaMemsetAsync(sym_grad_expert.ptr, 0, sym_grad_expert.bytes, stream)); - CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); - CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); - - ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, - symm_window(sym_grad_result), t.grad_expert.tensor, - symm_window(sym_grad_expert), stream)); - ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, - symm_window(sym_grad_expert), - t.g_recv_topk_weights.tensor, NVTECommWindow{}, - t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - - ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); - ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); - - std::vector h_gt(num_tokens_ * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), - h_gt.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - const float kExpGrad = static_cast(top_k_) * 0.1f; - for (int tok = 0; tok < num_tokens_; ++tok) - EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_]), kExpGrad, bf16_tol(kExpGrad)) - << "grad_tokens token " << tok; - - if (g_process_id == 0) printf(" FullPipelineSymm: passed\n"); - - CHECK_CUDA(cudaStreamDestroy(stream)); -} // ── main ────────────────────────────────────────────────────────────────────── diff --git a/tests/cpp_distributed/test_ep_common.h b/tests/cpp_distributed/test_ep_common.h index ccb20ee3a0..135a39416e 100644 --- a/tests/cpp_distributed/test_ep_common.h +++ b/tests/cpp_distributed/test_ep_common.h @@ -13,157 +13,67 @@ #include #include +#include #include #include +#include #include -#include #include #include #include #include -#include #include #include #include #include +#include "../cpp/test_common.h" +#include "util/logging.h" -// ── Error-checking macros ───────────────────────────────────────────────────── +using transformer_engine::DType; +using transformer_engine::TensorWrapper; -#define CHECK_NCCL(expr) \ - do { \ - ncclResult_t _err = (expr); \ - if (_err != ncclSuccess) \ - FAIL() << "NCCL error " << _err << ": " << ncclGetErrorString(_err); \ - } while (false) - -#define CHECK_CUDA(expr) \ - do { \ - cudaError_t _err = (expr); \ - if (_err != cudaSuccess) \ - FAIL() << "CUDA error " << _err << ": " << cudaGetErrorString(_err); \ - } while (false) - -#define ASSERT_CUDA_OK(expr) \ - do { \ - cudaError_t _err = (expr); \ - if (_err != cudaSuccess) { \ - fprintf(stderr, "CUDA error %d: %s\n", _err, cudaGetErrorString(_err)); \ - exit(EXIT_FAILURE); \ - } \ - } while (false) - -#define ASSERT_NCCL_OK(expr) \ - do { \ - ncclResult_t _err = (expr); \ - if (_err != ncclSuccess) { \ - fprintf(stderr, "NCCL error %d: %s\n", _err, ncclGetErrorString(_err)); \ - exit(EXIT_FAILURE); \ - } \ +#define CHECK_MPI(expr) \ + do { \ + int _err_mpi = (expr); \ + NVTE_CHECK(_err_mpi == MPI_SUCCESS, "MPI error: ", _err_mpi); \ } while (false) // ── Process-level state ─────────────────────────────────────────────────────── static int g_process_id = -1; static int g_num_processes = -1; -static std::string g_uid_file; static int g_sm_major = -1; // set by ep_bootstrap; -1 until then static int g_ep_size = -1; static int g_num_experts = -1; static int g_hidden_dim = 256; static int g_max_tokens_per_rank = 64; -static NVTEDType g_token_dtype = kNVTEBFloat16; +static NVTEDType g_max_token_dtype = kNVTEFloat32; // staging-buffer sizing static bool g_ep_initialized = false; static ncclComm_t g_ep_comm = nullptr; // owned by harness, destroyed in ep_teardown -// ── TensorHandle RAII wrapper ───────────────────────────────────────────────── - -// View over a caller-owned device buffer; owns NVTETensor metadata only. Move-only. -struct TensorHandle { - NVTETensor tensor = nullptr; - void* dev_ptr = nullptr; - - ~TensorHandle() { - if (tensor) nvte_destroy_tensor(tensor); - } - - TensorHandle() = default; - TensorHandle(const TensorHandle&) = delete; - TensorHandle& operator=(const TensorHandle&) = delete; - - TensorHandle(TensorHandle&& o) noexcept : tensor(o.tensor), dev_ptr(o.dev_ptr) { - o.tensor = nullptr; o.dev_ptr = nullptr; - } - TensorHandle& operator=(TensorHandle&& o) noexcept { - if (this != &o) { - if (tensor) nvte_destroy_tensor(tensor); - tensor = o.tensor; dev_ptr = o.dev_ptr; - o.tensor = nullptr; o.dev_ptr = nullptr; - } - return *this; - } -}; - -static TensorHandle make_nvte_tensor(void* dev_ptr, - const std::vector& shape, - NVTEDType dtype) { - TensorHandle h; - h.dev_ptr = dev_ptr; - h.tensor = nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING); - - NVTEShape s; - s.ndim = shape.size(); - for (size_t i = 0; i < shape.size(); ++i) s.data[i] = shape[i]; - - NVTEBasicTensor bt; - bt.data_ptr = dev_ptr; - bt.dtype = dtype; - bt.shape = s; - nvte_set_tensor_param_v2(h.tensor, kNVTERowwiseData, &bt, sizeof(bt)); - - return h; -} - -// RAII owner for a cudaMalloc'd device buffer; frees on destruction. +// RAII owner for a cudaMalloc'd device buffer; element-count API on top of +// test::CudaPtr. template struct DevBuf { - T* ptr = nullptr; + test::CudaPtr ptr; size_t count = 0; DevBuf() = default; explicit DevBuf(size_t n) { alloc(n); } - ~DevBuf() { reset(); } - - DevBuf(const DevBuf&) = delete; - DevBuf& operator=(const DevBuf&) = delete; - DevBuf(DevBuf&& o) noexcept : ptr(o.ptr), count(o.count) { o.ptr = nullptr; o.count = 0; } - DevBuf& operator=(DevBuf&& o) noexcept { - if (this != &o) { reset(); ptr = o.ptr; count = o.count; o.ptr = nullptr; o.count = 0; } - return *this; - } void alloc(size_t n) { - reset(); count = n; - if (n > 0) { - cudaError_t e = cudaMalloc(&ptr, n * sizeof(T)); - if (e != cudaSuccess) { - fprintf(stderr, "DevBuf cudaMalloc(%zu) failed: %s\n", n * sizeof(T), - cudaGetErrorString(e)); - ptr = nullptr; - count = 0; - } - } + ptr = (n > 0) ? test::cuda_alloc(n * sizeof(T)) : test::CudaPtr{}; } - void reset() { - if (ptr) { cudaFree(ptr); ptr = nullptr; } + ptr.reset(); count = 0; } - T* get() const { return ptr; } + T* get() const { return ptr.get(); } size_t bytes() const { return count * sizeof(T); } }; @@ -180,39 +90,11 @@ static inline std::vector routing_balanced( return idx; } -// ── File-based ncclUniqueId exchange ───────────────────────────────────────── +// ── ncclUniqueId exchange via MPI ───────────────────────────────────────────── static void exchange_unique_id(ncclUniqueId* uid) { - const size_t sz = sizeof(ncclUniqueId); - - if (g_process_id == 0) { - ASSERT_NCCL_OK(ncclGetUniqueId(uid)); - FILE* f = fopen(g_uid_file.c_str(), "wb"); - if (!f) { fprintf(stderr, "Cannot open uid file: %s\n", g_uid_file.c_str()); exit(EXIT_FAILURE); } - fwrite(uid, 1, sz, f); - fclose(f); - } else { - auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(60); - while (true) { - FILE* f = fopen(g_uid_file.c_str(), "rb"); - if (f) { - fseek(f, 0, SEEK_END); - if (static_cast(ftell(f)) >= sz) { - fseek(f, 0, SEEK_SET); - size_t n = fread(uid, 1, sz, f); - fclose(f); - if (n == sz) break; - } else { - fclose(f); - } - } - if (std::chrono::steady_clock::now() > deadline) { - fprintf(stderr, "Process %d: timed out waiting for uid file\n", g_process_id); - exit(EXIT_FAILURE); - } - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - } - } + if (g_process_id == 0) NVTE_CHECK_NCCL(ncclGetUniqueId(uid)); + CHECK_MPI(MPI_Bcast(uid, sizeof(*uid), MPI_BYTE, 0, MPI_COMM_WORLD)); } // ── CLI parsing ─────────────────────────────────────────────────────────────── @@ -220,26 +102,8 @@ static void exchange_unique_id(ncclUniqueId* uid) { static void ep_parse_args(int argc, char* argv[]) { for (int i = 1; i < argc; ++i) { std::string a(argv[i]); - if (a.rfind("--process-id=", 0) == 0) g_process_id = std::stoi(a.substr(13)); - else if (a.rfind("--rank=", 0) == 0) g_process_id = std::stoi(a.substr(7)); - else if (a.rfind("--num-processes=",0)==0) g_num_processes = std::stoi(a.substr(16)); - else if (a.rfind("--nranks=", 0) == 0) g_num_processes = std::stoi(a.substr(9)); - else if (a.rfind("--uid-file=", 0) == 0) g_uid_file = a.substr(11); - else if (a.rfind("--token-dtype=", 0) == 0) - g_token_dtype = static_cast(std::stoi(a.substr(14))); - } - - if (g_process_id < 0 || g_num_processes <= 0) { - fprintf(stderr, - "Usage: %s --rank=N --nranks=N [--uid-file=path] [gtest flags]\n" - " Aliases: --process-id=N, --num-processes=N\n", - argc > 0 ? argv[0] : "test_ep"); - exit(EXIT_FAILURE); - } - - if (g_uid_file.empty()) { - const char* t = getenv("TMPDIR"); if (!t) t = "/tmp"; - g_uid_file = std::string(t) + "/te_ep_uid_" + std::to_string(g_process_id); + if (a.rfind("--max-token-dtype=", 0) == 0) + g_max_token_dtype = static_cast(std::stoi(a.substr(18))); } } @@ -247,6 +111,12 @@ static void ep_parse_args(int argc, char* argv[]) { // Returns false if the binary should exit without running tests (wrong SM, etc.). static bool ep_bootstrap(int argc, char* argv[]) { + int mpi_initialized = 0; + MPI_Initialized(&mpi_initialized); + if (!mpi_initialized) CHECK_MPI(MPI_Init(&argc, &argv)); + CHECK_MPI(MPI_Comm_rank(MPI_COMM_WORLD, &g_process_id)); + CHECK_MPI(MPI_Comm_size(MPI_COMM_WORLD, &g_num_processes)); + ep_parse_args(argc, argv); ::testing::InitGoogleTest(&argc, argv); @@ -282,9 +152,9 @@ static bool ep_bootstrap(int argc, char* argv[]) { // Worst-case for top_k fan-out: ep_size * max_tokens_per_rank * 2. group_config.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; group_config.hidden_dim = g_hidden_dim; - group_config.token_dtype = g_token_dtype; + group_config.max_token_dtype = g_max_token_dtype; - ASSERT_NCCL_OK(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id)); + NVTE_CHECK_NCCL(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id)); nvte_ep_initialize(static_cast(g_ep_comm), group_config); if (g_process_id == 0) { @@ -308,5 +178,7 @@ static void ep_teardown() { } g_ep_initialized = false; } - if (g_process_id == 0) remove(g_uid_file.c_str()); + int finalized = 0; + MPI_Finalized(&finalized); + if (!finalized) MPI_Finalize(); } diff --git a/tests/cpp_distributed/test_ep_coverage.cu b/tests/cpp_distributed/test_ep_coverage.cu deleted file mode 100644 index e9e532386c..0000000000 --- a/tests/cpp_distributed/test_ep_coverage.cu +++ /dev/null @@ -1,562 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/* - * EP C-API coverage tests (paths not exercised by the pipeline suite). - * - * MultiHandleAllocTest — distinct handle ids; each works end-to-end. - * TopK1Test — top_k=1 dispatch/combine/bwd round-trip. - * EmptyExpertsTest — alignment ∈ {0, 2, 8, 16} with experts receiving 0 tokens. - * NegativeTests — alignment mismatch and null handle_mem must throw. - */ - -#include "test_ep_common.h" - -#include -#include - -// top1 -> expert 0, top2 -> expert 2; leaves local-expert 1 empty between two -// full experts. Requires top_k >= 2 and num_experts >= 3. -static std::vector routing_skip_middle(int num_tokens, int top_k) { - std::vector idx(num_tokens * top_k); - for (int t = 0; t < num_tokens; ++t) { - idx[t * top_k + 0] = 0; - if (top_k >= 2) idx[t * top_k + 1] = 2; - for (int k = 2; k < top_k; ++k) idx[t * top_k + k] = 2 + k; // distinct stragglers - } - return idx; -} - -static std::vector tokens_constant(int num_tokens, int hidden_dim, float val) { - std::vector v(num_tokens * hidden_dim); - nv_bfloat16 b = __float2bfloat16(val); - std::fill(v.begin(), v.end(), b); - return v; -} - -namespace { - -class EpCoverageBase : public ::testing::Test { - protected: - int ep_size_, num_experts_, num_local_experts_, hidden_dim_; - int max_tokens_per_rank_; - - void SetUp() override { - if (g_sm_major < 9) - GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; - ASSERT_GE(g_num_processes, 2); - ASSERT_TRUE(g_ep_initialized); - ep_size_ = g_ep_size; - num_experts_ = g_num_experts; - num_local_experts_ = num_experts_ / ep_size_; - hidden_dim_ = g_hidden_dim; - max_tokens_per_rank_ = g_max_tokens_per_rank; - } - - // Helper: allocate buffers + tensor views for a single dispatch+combine. - struct Bundle { - DevBuf topk_idx; - DevBuf topk_weights; - DevBuf tokens; - DevBuf token_counts; - DevBuf handle_mem; - DevBuf recv_tokens; - DevBuf recv_topk_weights; - DevBuf result; - uint64_t handle_id = 0; - size_t handle_mem_size = 0; - size_t recv_capacity = 0; - }; - - Bundle make_bundle(int num_tokens, int top_k, int num_local_experts, - size_t alignment) { - Bundle b; - b.recv_capacity = static_cast(ep_size_) * max_tokens_per_rank_ * 2; - b.topk_idx.alloc(num_tokens * top_k); - b.topk_weights.alloc(num_tokens * top_k); - b.tokens.alloc(num_tokens * hidden_dim_); - b.token_counts.alloc(num_local_experts); - b.recv_tokens.alloc(b.recv_capacity * hidden_dim_); - b.recv_topk_weights.alloc(b.recv_capacity); - b.result.alloc(num_tokens * hidden_dim_); - NVTEEpLayerConfig cfg{num_local_experts, top_k, alignment}; - b.handle_id = nvte_ep_register_layer(cfg, &b.handle_mem_size); - b.handle_mem.alloc(b.handle_mem_size); - return b; - } -}; - -} // namespace - -// ============================================================================= -// MultiHandleAllocTest: ids are distinct and each is independently usable. -// ============================================================================= - -class MultiHandleAllocTest : public EpCoverageBase {}; - -TEST_F(MultiHandleAllocTest, IdsAreDistinct) { - NVTEEpLayerConfig cfg{num_local_experts_, /*top_k=*/2, /*alignment=*/0}; - const int kN = 8; - std::vector ids(kN); - for (int i = 0; i < kN; ++i) { - size_t sz = 0; - ids[i] = nvte_ep_register_layer(cfg, &sz); - } - for (int i = 0; i < kN; ++i) { - EXPECT_NE(ids[i], 0u) << "handle_id 0 is reserved as \"no id\""; - for (int j = i + 1; j < kN; ++j) - EXPECT_NE(ids[i], ids[j]) << "duplicate id " << ids[i] << " at indices " << i << ", " << j; - } -} - -TEST_F(MultiHandleAllocTest, TwoHandlesCoexist) { - const int num_tokens = 16, top_k = 2; - Bundle a = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - - auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, - num_experts_, num_local_experts_); - std::vector h_w(num_tokens * top_k, 1.0f / top_k); - auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.5f); - for (Bundle* x : {&a, &b}) { - CHECK_CUDA(cudaMemcpy(x->topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(x->topk_weights.get(), h_w.data(), - h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(x->tokens.get(), h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - } - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - ASSERT_NE(a.handle_id, b.handle_id); - - auto run_one = [&](Bundle& x) { - auto topk_idx = make_nvte_tensor(x.topk_idx.get(), {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - auto topk_weights = make_nvte_tensor(x.topk_weights.get(), {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); - auto token_counts = make_nvte_tensor(x.token_counts.get(), {(size_t)num_local_experts_}, kNVTEInt32); - auto handle_mem = make_nvte_tensor(x.handle_mem.get(), {x.handle_mem_size}, kNVTEByte); - auto tokens = make_nvte_tensor(x.tokens.get(), {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); - auto recv_tokens = make_nvte_tensor(x.recv_tokens.get(), {x.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); - auto recv_w = make_nvte_tensor(x.recv_topk_weights.get(), {x.recv_capacity}, kNVTEFloat32); - auto result = make_nvte_tensor(x.result.get(), {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); - NVTEEpHandle h{x.handle_id, handle_mem.tensor}; - ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx.tensor, token_counts.tensor, - /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx.tensor, tokens.tensor, - NVTECommWindow{}, topk_weights.tensor, NVTECommWindow{}, - recv_tokens.tensor, NVTECommWindow{}, recv_w.tensor, - NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens.tensor, NVTECommWindow{}, - result.tensor, stream)); - }; - run_one(a); - run_one(b); - CHECK_CUDA(cudaStreamSynchronize(stream)); - - // Both round-trips must produce result == top_k * 0.5 = 1.0. - for (Bundle* x : {&a, &b}) { - std::vector h_res(num_tokens * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_res.data(), x->result.get(), - h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; - for (int t = 0; t < num_tokens; ++t) - for (int p : probes) - EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), - static_cast(top_k) * 0.5f, 1e-2f); - } - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -// ============================================================================= -// TopK1Test: top_k=1 dispatch/combine round-trip, including dispatch_bwd. -// ============================================================================= - -class TopK1Test : public EpCoverageBase {}; - -TEST_F(TopK1Test, RoundTrip) { - const int num_tokens = 16, top_k = 1; - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - - auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, - num_experts_, num_local_experts_); - std::vector h_w(num_tokens * top_k, 1.0f); // top_k=1: weight is unity - auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.25f); - CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), - h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - - auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); - auto token_counts_t = make_nvte_tensor(b.token_counts.get(), - {(size_t)num_local_experts_}, kNVTEInt32); - auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), - {b.handle_mem_size}, kNVTEByte); - auto tokens_t = make_nvte_tensor(b.tokens.get(), - {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); - auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), - {b.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); - auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), - {b.recv_capacity}, kNVTEFloat32); - auto result_t = make_nvte_tensor(b.result.get(), - {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; - ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, - /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx_t.tensor, - tokens_t.tensor, NVTECommWindow{}, topk_weights_t.tensor, - NVTECommWindow{}, recv_tokens_t.tensor, NVTECommWindow{}, - recv_w_t.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens_t.tensor, - NVTECommWindow{}, result_t.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - - // top_k=1: combine is unweighted gather, so result[t] == tokens[t]. - std::vector h_res(num_tokens * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), - h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; - for (int t = 0; t < num_tokens; ++t) - for (int p : probes) - EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), 0.25f, 1e-2f) - << "tok " << t << " hidden " << p; - - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -// ============================================================================= -// EmptyExpertsTest: alignment ∈ {0, 2, 8, 16}, only local-expert 0 receives -// tokens. Round-trip must produce result == top_k * tokens regardless of the -// per-expert padding choice. -// ============================================================================= - -class EmptyExpertsTest : public EpCoverageBase, - public ::testing::WithParamInterface {}; - -TEST_P(EmptyExpertsTest, RoundTripCorrect) { - // routing_skip_middle needs experts {0, 2, ...}; smallest viable num_experts is 3. - ASSERT_GE(num_experts_, 3); - const size_t alignment = GetParam(); - const int num_tokens = 16, top_k = 2; - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, alignment); - - // top1 -> expert 0, top2 -> expert 2; rank 0's local-expert 1 receives 0 - // tokens between two non-empty experts. - std::vector h_idx = routing_skip_middle(num_tokens, top_k); - std::vector h_w(num_tokens * top_k, 1.0f / top_k); - auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.3f); - - CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), - h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - - auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); - auto token_counts_t = make_nvte_tensor(b.token_counts.get(), - {(size_t)num_local_experts_}, kNVTEInt32); - auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), - {b.handle_mem_size}, kNVTEByte); - auto tokens_t = make_nvte_tensor(b.tokens.get(), - {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); - auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), - {b.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); - auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), - {b.recv_capacity}, kNVTEFloat32); - auto result_t = make_nvte_tensor(b.result.get(), - {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; - ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, - alignment, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx_t.tensor, - tokens_t.tensor, NVTECommWindow{}, topk_weights_t.tensor, - NVTECommWindow{}, recv_tokens_t.tensor, NVTECommWindow{}, - recv_w_t.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens_t.tensor, - NVTECommWindow{}, result_t.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - - // Identity expert + uniform weights: result[t] == top_k * tokens[t]. - std::vector h_res(num_tokens * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), - h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - const float expected = static_cast(top_k) * 0.3f; - const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; - for (int t = 0; t < num_tokens; ++t) - for (int p : probes) - EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), expected, 1e-2f) - << "alignment=" << alignment << " tok=" << t << " hidden=" << p; - - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -INSTANTIATE_TEST_SUITE_P(Alignments, EmptyExpertsTest, - ::testing::Values(0, 2, 8, 16)); - -// ============================================================================= -// NegativeTests: prepare/dispatch must surface bad inputs as exceptions. -// ============================================================================= - -class NegativeTests : public EpCoverageBase {}; - -TEST_F(NegativeTests, AlignmentMismatchThrows) { - const int num_tokens = 8, top_k = 2; - // Allocate handle for alignment=0, then call prepare with alignment=16. - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, - num_experts_, num_local_experts_); - CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - - auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - auto token_counts_t = make_nvte_tensor(b.token_counts.get(), - {(size_t)num_local_experts_}, kNVTEInt32); - auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), - {b.handle_mem_size}, kNVTEByte); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; - EXPECT_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, - /*alignment=*/16, stream), - std::exception); - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -TEST_F(NegativeTests, NullHandleMemThrows) { - const int num_tokens = 8, top_k = 2; - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, - num_experts_, num_local_experts_); - CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - - auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - auto token_counts_t = make_nvte_tensor(b.token_counts.get(), - {(size_t)num_local_experts_}, kNVTEInt32); - // Construct a tensor view backed by a null device pointer. - auto null_hm_t = make_nvte_tensor(nullptr, {b.handle_mem_size}, kNVTEByte); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - NVTEEpHandle h{b.handle_id, null_hm_t.tensor}; - EXPECT_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, - /*alignment=*/0, stream), - std::exception); - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -// ============================================================================= -// HandleCacheTest: persistent ncclEpHandle is reused across ops on the same -// handle_mem ptr; relocation triggers throw by default and rebuild when -// NVTEEpGroupConfig.allow_handle_mem_reloc=1. -// ============================================================================= - -class HandleCacheTest : public EpCoverageBase {}; - -// Run prepare → dispatch → combine on bundle b. handle_mem_data overrides the -// device ptr used for handle_mem (must be the buffer owned by b unless -// reloc-allowed mode is active). Templated on Bundle because EpCoverageBase:: -// Bundle is declared in a protected section. -template -static void run_round_trip(B& b, void* handle_mem_data, - int num_tokens, int top_k, int num_local_experts, - int hidden_dim, size_t alignment, - cudaStream_t stream) { - auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); - auto token_counts_t = make_nvte_tensor(b.token_counts.get(), - {(size_t)num_local_experts}, kNVTEInt32); - auto handle_mem_t = make_nvte_tensor(handle_mem_data, - {b.handle_mem_size}, kNVTEByte); - auto tokens_t = make_nvte_tensor(b.tokens.get(), - {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); - auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), - {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); - auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), - {b.recv_capacity}, kNVTEFloat32); - auto result_t = make_nvte_tensor(b.result.get(), - {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); - - NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; - nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, alignment, stream); - nvte_ep_dispatch(h, topk_idx_t.tensor, tokens_t.tensor, NVTECommWindow{}, - topk_weights_t.tensor, NVTECommWindow{}, - recv_tokens_t.tensor, NVTECommWindow{}, - recv_w_t.tensor, NVTECommWindow{}, stream); - nvte_ep_combine(h, recv_tokens_t.tensor, NVTECommWindow{}, result_t.tensor, stream); -} - -// Re-bootstrap EP backend with a different allow_handle_mem_reloc setting. -// Reuses the existing g_ep_comm; caller is responsible for restoring defaults. -static void reinit_ep_with_reloc(int allow_reloc) { - nvte_ep_shutdown(); - NVTEEpGroupConfig cfg{}; - cfg.ep_size = g_ep_size; - cfg.num_experts = g_num_experts; - cfg.max_tokens_per_rank = g_max_tokens_per_rank; - cfg.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; - cfg.hidden_dim = g_hidden_dim; - cfg.allow_handle_mem_reloc = allow_reloc; - nvte_ep_initialize(static_cast(g_ep_comm), cfg); -} - -TEST_F(HandleCacheTest, ReuseSameMemSucceeds) { - const int num_tokens = 16, top_k = 2; - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - - auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, - num_experts_, num_local_experts_); - std::vector h_w(num_tokens * top_k, 1.0f / top_k); - auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.5f); - CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), - h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - // Two consecutive round-trips on the same handle_mem ptr: first opens the - // cached handle, second hits the cache. Both must succeed and be correct. - for (int iter = 0; iter < 2; ++iter) { - ASSERT_NO_THROW(run_round_trip(b, b.handle_mem.get(), num_tokens, top_k, - num_local_experts_, hidden_dim_, - /*alignment=*/0, stream)); - } - CHECK_CUDA(cudaStreamSynchronize(stream)); - - std::vector h_res(num_tokens * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), - h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; - for (int t = 0; t < num_tokens; ++t) - for (int p : probes) - EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), - static_cast(top_k) * 0.5f, 1e-2f); - - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -TEST_F(HandleCacheTest, RelocDefaultThrows) { - // Default bootstrap has allow_handle_mem_reloc=0: a second prepare call on - // the same handle_id with a different handle_mem ptr must throw. - const int num_tokens = 8, top_k = 2; - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - DevBuf second_hm(b.handle_mem_size); // distinct device buffer - ASSERT_NE(b.handle_mem.get(), second_hm.get()); - - auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, - num_experts_, num_local_experts_); - CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - - auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - auto token_counts_t = make_nvte_tensor(b.token_counts.get(), - {(size_t)num_local_experts_}, kNVTEInt32); - auto hm1_t = make_nvte_tensor(b.handle_mem.get(), - {b.handle_mem_size}, kNVTEByte); - auto hm2_t = make_nvte_tensor(second_hm.get(), - {b.handle_mem_size}, kNVTEByte); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - // First prepare seeds the cache. - NVTEEpHandle h1{b.handle_id, hm1_t.tensor}; - ASSERT_NO_THROW(nvte_ep_prepare(h1, topk_idx_t.tensor, token_counts_t.tensor, - /*alignment=*/0, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - // Same handle_id with a different handle_mem ptr must throw. - NVTEEpHandle h2{b.handle_id, hm2_t.tensor}; - EXPECT_THROW(nvte_ep_prepare(h2, topk_idx_t.tensor, token_counts_t.tensor, - /*alignment=*/0, stream), - std::exception); - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -TEST_F(HandleCacheTest, RelocAllowedRebuilds) { - // Re-init EP backend with allow_handle_mem_reloc=1, run two round-trips with - // distinct handle_mem buffers, verify both succeed numerically, restore. - reinit_ep_with_reloc(/*allow_reloc=*/1); - - struct Restore { ~Restore() { reinit_ep_with_reloc(/*allow_reloc=*/0); } } restore; - - const int num_tokens = 16, top_k = 2; - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - DevBuf alt_hm(b.handle_mem_size); - ASSERT_NE(b.handle_mem.get(), alt_hm.get()); - - auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, - num_experts_, num_local_experts_); - std::vector h_w(num_tokens * top_k, 1.0f / top_k); - auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.5f); - CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), - h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - // First on the original handle_mem. - ASSERT_NO_THROW(run_round_trip(b, b.handle_mem.get(), num_tokens, top_k, - num_local_experts_, hidden_dim_, - /*alignment=*/0, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - // Then on the relocated handle_mem — must trigger silent rebuild, not throw. - ASSERT_NO_THROW(run_round_trip(b, alt_hm.get(), num_tokens, top_k, - num_local_experts_, hidden_dim_, - /*alignment=*/0, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - - std::vector h_res(num_tokens * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), - h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; - for (int t = 0; t < num_tokens; ++t) - for (int p : probes) - EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), - static_cast(top_k) * 0.5f, 1e-2f); - - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -// ── main ────────────────────────────────────────────────────────────────────── - -int main(int argc, char* argv[]) { - if (!ep_bootstrap(argc, argv)) return 0; - int ret = RUN_ALL_TESTS(); - ep_teardown(); - return ret; -} diff --git a/tests/cpp_distributed/test_ep_init.cu b/tests/cpp_distributed/test_ep_init.cu deleted file mode 100644 index 08744dfee5..0000000000 --- a/tests/cpp_distributed/test_ep_init.cu +++ /dev/null @@ -1,64 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/* - * Unit tests for EP initialization paths. - * - * Tests: - * EPInitTest/InitPath — backend is live after init, handle_mem_size > 0 - * EPInitTest/NumLocalExperts — handle_mem_size is consistent across num_local_experts values - * - * Run via run_test_ep.sh (both uid and comm init paths are tested by the script). - */ - -#include "test_ep_common.h" - -// ── Fixture ─────────────────────────────────────────────────────────────────── - -class EPInitTest : public ::testing::Test { - protected: - void SetUp() override { - if (g_sm_major < 9) - GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; - ASSERT_GE(g_num_processes, 2) << "EP tests require at least 2 processes"; - ASSERT_TRUE(g_ep_initialized) << "EP not initialized"; - } -}; - -// ── Tests ───────────────────────────────────────────────────────────────────── - -TEST_F(EPInitTest, InitPath) { - int nle = g_num_experts / g_ep_size; - NVTEEpLayerConfig cfg{nle, /*top_k=*/2}; - size_t sz = 0; - (void)nvte_ep_register_layer(cfg, &sz); - ASSERT_GT(sz, 0u) << "handle_mem_size must be > 0 after init"; - - if (g_process_id == 0) { - printf(" handle_mem : %zu bytes\n", sz); - } -} - -TEST_F(EPInitTest, NumLocalExperts) { - // handle_mem_size should be > 0 for any valid num_local_experts value. - for (int nle : {1, g_num_experts / g_ep_size}) { - NVTEEpLayerConfig cfg{nle, /*top_k=*/2}; - size_t sz = 0; - (void)nvte_ep_register_layer(cfg, &sz); - ASSERT_GT(sz, 0u) << "num_local_experts=" << nle; - if (g_process_id == 0) - printf(" nle=%-3d handle_mem_size=%zu bytes\n", nle, sz); - } -} - -// ── main ────────────────────────────────────────────────────────────────────── - -int main(int argc, char* argv[]) { - if (!ep_bootstrap(argc, argv)) return 0; - int ret = RUN_ALL_TESTS(); - ep_teardown(); - return ret; -} diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp index 1e08cb55df..a5ae99b089 100644 --- a/transformer_engine/common/ep/ep_backend.cpp +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -82,11 +82,11 @@ void EPBackend::validate_config(const NVTEEpGroupConfig& config) { NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ", config.max_recv_tokens_per_rank); NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim); - NVTE_CHECK(config.token_dtype >= 0 && config.token_dtype < kNVTENumTypes, - "token_dtype out of range, got ", static_cast(config.token_dtype)); - const size_t elem_bytes = typeToSize(static_cast(config.token_dtype)); + NVTE_CHECK(config.max_token_dtype >= 0 && config.max_token_dtype < kNVTENumTypes, + "max_token_dtype out of range, got ", static_cast(config.max_token_dtype)); + const size_t elem_bytes = typeToSize(static_cast(config.max_token_dtype)); NVTE_CHECK(config.hidden_dim * elem_bytes >= 16, - "hidden_dim * sizeof(token_dtype) must be >= 16 (NCCL EP 16B row alignment); " + "hidden_dim * sizeof(max_token_dtype) must be >= 16 (NCCL EP 16B row alignment); " "got hidden_dim=", config.hidden_dim, ", element_bytes=", elem_bytes); NVTE_CHECK(config.num_experts % config.ep_size == 0, "num_experts (", config.num_experts, @@ -218,7 +218,7 @@ void EPBackend::init(ncclComm_t ep_comm, NVTEEpGroupConfig group_config) { cfg.algorithm = NCCL_EP_ALGO_HIGH_THROUGHPUT; cfg.num_experts = static_cast(group_config.num_experts); cfg.max_dispatch_tokens_per_rank = static_cast(group_config.max_tokens_per_rank); - const size_t elem_bytes = typeToSize(static_cast(group_config.token_dtype)); + const size_t elem_bytes = typeToSize(static_cast(group_config.max_token_dtype)); cfg.max_token_bytes = static_cast(group_config.hidden_dim * elem_bytes); cfg.rdma_buffer_size = NCCL_EP_AUTO; cfg.num_qp_per_rank = NCCL_EP_AUTO; @@ -346,10 +346,10 @@ void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor NVTEShape tok_shape = nvte_tensor_shape(tokens); NVTEDType tok_dtype = nvte_tensor_type(tokens); - NVTE_CHECK(tok_dtype == group_config_.token_dtype, - "tokens dtype (", static_cast(tok_dtype), - ") does not match group token_dtype (", - static_cast(group_config_.token_dtype), ")"); + NVTE_CHECK(typeToSize(static_cast(tok_dtype)) <= + typeToSize(static_cast(group_config_.max_token_dtype)), + "tokens dtype (", static_cast(tok_dtype), ") wider than group max_token_dtype (", + static_cast(group_config_.max_token_dtype), ")"); const size_t num_tokens = tok_shape.data[0]; const size_t hidden_dim = tok_shape.data[1]; @@ -376,10 +376,11 @@ void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor NVTEShape recv_shape = nvte_tensor_shape(recv_tokens); NVTEDType recv_dtype = nvte_tensor_type(recv_tokens); - NVTE_CHECK(recv_dtype == group_config_.token_dtype, + NVTE_CHECK(typeToSize(static_cast(recv_dtype)) <= + typeToSize(static_cast(group_config_.max_token_dtype)), "recv_tokens dtype (", static_cast(recv_dtype), - ") does not match group token_dtype (", - static_cast(group_config_.token_dtype), ")"); + ") wider than group max_token_dtype (", + static_cast(group_config_.max_token_dtype), ")"); size_t recv_sizes[2] = {recv_shape.data[0], recv_shape.data[1]}; ncclEpTensor_t nccl_tokens_out = make_payload_tensor(recv_tokens, recv_tokens_win, 2, diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h index a1c9305e9b..22e7ec48ac 100644 --- a/transformer_engine/common/include/transformer_engine/ep.h +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -23,6 +23,8 @@ extern "C" { #endif /* ── Config structs ─────────────────────────────────────────────────────── */ +/* TODO: add a struct_size/version field to these configs (and align with other + * TE public structs) once a TE-wide convention for ABI versioning lands. */ /*! \brief Group-level EP configuration (fixed for the EP group lifetime). */ typedef struct { @@ -35,9 +37,10 @@ typedef struct { int max_num_sms; /*!< Max SMs for EP kernels. 0 = auto. */ /*! 0 (default): throw on relocated handle_mem for a cached handle_id. 1: silently rebuild. */ int allow_handle_mem_reloc; - /*! Token dtype for this EP group. Sizes NCCL EP staging buffers at group - * create and is enforced against tensors passed to nvte_ep_dispatch. */ - NVTEDType token_dtype; + /*! Widest token dtype the group will dispatch. Sizes NCCL EP staging buffers + * at group create. Tensors passed to nvte_ep_dispatch may use any dtype whose + * element size is <= sizeof(max_token_dtype). */ + NVTEDType max_token_dtype; } NVTEEpGroupConfig; /*! \brief Per-layer EP configuration. */ @@ -58,8 +61,8 @@ typedef struct { * nvte_ep_shutdown() returns; destroying it earlier is undefined behavior. * Re-init after shutdown is allowed; double-init throws. * - * v0.1 scope: one EP group per process, bound to the current CUDA device at - * initialize time. Multiple GPUs per process are not supported. + * One EP group per process, bound to the current CUDA device at initialize + * time. Multiple GPUs per process are not supported. * * \param[in] ep_comm Opaque ncclComm_t for the EP sub-group. * \param[in] group_config Group-level EP configuration. diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index da8b9b377d..3308bd22e4 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -98,6 +98,14 @@ } \ } while (false) +#define NVTE_CHECK_NCCL(expr) \ + do { \ + const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \ + if (status_NVTE_CHECK_NCCL != ncclSuccess) { \ + NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \ + } \ + } while (false) + #ifdef NVTE_WITH_CUBLASMP #define NVTE_CHECK_CUBLASMP(expr) \ From d10189603939c54304429649ae54969633835eb4 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 22 May 2026 23:05:43 +0000 Subject: [PATCH 09/29] Expert Parallelism: JAX bindings (FFI, custom_vjp, multi-process tests, MoE example) Signed-off-by: Phuong Nguyen --- build_tools/jax.py | 41 +- examples/jax/ep/ep_moe.py | 396 ++++++++ examples/jax/ep/run_test_ep.sh | 85 ++ tests/jax/multi_process_launch_ep.sh | 67 ++ tests/jax/test_multi_process_ep.py | 690 +++++++++++++ .../jax/cpp_extensions/__init__.py | 1 + transformer_engine/jax/cpp_extensions/ep.py | 955 ++++++++++++++++++ transformer_engine/jax/csrc/extensions.h | 19 + transformer_engine/jax/csrc/extensions/ep.cpp | 457 +++++++++ .../jax/csrc/extensions/pybind.cpp | 18 + transformer_engine/jax/ep.py | 303 ++++++ transformer_engine/jax/sharding.py | 12 +- 12 files changed, 3041 insertions(+), 3 deletions(-) create mode 100644 examples/jax/ep/ep_moe.py create mode 100755 examples/jax/ep/run_test_ep.sh create mode 100755 tests/jax/multi_process_launch_ep.sh create mode 100644 tests/jax/test_multi_process_ep.py create mode 100644 transformer_engine/jax/cpp_extensions/ep.py create mode 100644 transformer_engine/jax/csrc/extensions/ep.cpp create mode 100644 transformer_engine/jax/ep.py diff --git a/build_tools/jax.py b/build_tools/jax.py index a7b200f915..49c5001d18 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -103,13 +103,50 @@ def setup_jax_extension( setup_mpi_flags(include_dirs, cxx_flags) + # NCCL EP is on by default. Set NVTE_BUILD_WITH_NCCL_EP=0 to skip it. + build_with_nccl_ep = bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1"))) + libraries = [] + submod_lib_dir = None + submod_nccl_inc = None + if build_with_nccl_ep: + cxx_flags.append("-DNVTE_WITH_NCCL_EP") + # Headers + libs come from the in-tree 3rdparty/nccl submodule build + # (auto-produced by setup.py). + libraries = ["nccl", "nccl_ep"] + # NCCL EP requires SM>=90 (Hopper+). + archs_env = os.getenv("NVTE_CUDA_ARCHS", "") + for a in archs_env.split(";"): + a_num = "".join(c for c in a if c.isdigit()) + if a_num and int(a_num) < 90: + raise RuntimeError( + f"NCCL EP requires CUDA arch >= 90 (Hopper or newer); got '{a}' in" + " NVTE_CUDA_ARCHS." + ) + submod_root = (common_header_files / ".." / "3rdparty" / "nccl").resolve() + submod_ep_inc = submod_root / "contrib" / "nccl_ep" / "include" + if not (submod_ep_inc / "nccl_ep.h").exists(): + raise RuntimeError( + f"NCCL EP header not found at {submod_ep_inc}/nccl_ep.h. " + "Run `git submodule update --init --recursive` to checkout 3rdparty/nccl." + ) + include_dirs.append(submod_ep_inc) + submod_lib_dir = submod_root / "build" / "lib" + submod_nccl_inc = submod_root / "build" / "include" + # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension - return Pybind11Extension( + ext = Pybind11Extension( "transformer_engine_jax", sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], extra_compile_args=cxx_flags, - libraries=["nccl"], + libraries=libraries, ) + if submod_lib_dir is not None: + ext.library_dirs.append(str(submod_lib_dir)) + ext.runtime_library_dirs.append(str(submod_lib_dir)) + # Prefer submodule's nccl.h when present (matches the C++ side). + if (submod_nccl_inc / "nccl.h").exists(): + ext.include_dirs.insert(0, str(submod_nccl_inc)) + return ext diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py new file mode 100644 index 0000000000..8dcac02a04 --- /dev/null +++ b/examples/jax/ep/ep_moe.py @@ -0,0 +1,396 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""End-to-end MoE example: dispatch -> batched expert linear -> combine, fwd + bwd. + +One process per GPU. Run via run_test_ep.sh. +""" + +import argparse +import sys + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from transformer_engine.jax.ep import ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +# ── Setup ─────────────────────────────────────────────────────────────────── + + +def _parse_args(): + p = argparse.ArgumentParser(description="TE-JAX EP MoE example (fwd + bwd)") + p.add_argument("--coordinator-address", required=True) + p.add_argument("--process-id", type=int, required=True) + p.add_argument("--num-processes", type=int, required=True) + p.add_argument("--num-tokens", type=int, default=8, help="Per-rank token count.") + p.add_argument("--top-k", type=int, default=2) + p.add_argument("--hidden", type=int, default=32) + p.add_argument("--hidden-out", type=int, default=32) + p.add_argument( + "--num-experts", + type=int, + default=None, + help="Total experts across the EP group. Default: num_processes.", + ) + p.add_argument("--dp-size", type=int, default=None, help="Default: num_procs // ep_size.") + p.add_argument( + "--check", + action="store_true", + default=True, + help="Verify fwd+bwd against a single-rank numpy reference.", + ) + return p.parse_args() + + +def _distributed_init(args): + jax.distributed.initialize( + coordinator_address=args.coordinator_address, + num_processes=args.num_processes, + process_id=args.process_id, + local_device_ids=[args.process_id], + ) + assert ( + jax.local_device_count() == 1 + ), f"EP example requires 1 GPU per process; got {jax.local_device_count()}" + + +def _build_mesh_and_resource(args): + """Pick a (2, 2) mesh by default. Override via --dp-size.""" + n = args.num_processes + if n < 4: + raise ValueError(f"num_processes ({n}) must be >= 4 for NCCL EP") + if args.dp_size is None: + if n != 4: + raise ValueError( + f"default mesh expects exactly 4 ranks (got {n}); pass --dp-size to override" + ) + args.dp_size = 2 + assert n % args.dp_size == 0, f"num_processes={n} not divisible by dp_size={args.dp_size}" + args.ep_size = n // args.dp_size + if args.num_experts is None: + args.num_experts = args.num_processes + assert args.num_experts % args.ep_size == 0 + args.num_local_experts = args.num_experts // args.ep_size + args.recv_capacity_per_rank = args.ep_size * args.num_tokens * args.top_k + + devs = np.asarray(jax.devices()).reshape(args.dp_size, args.ep_size) + mesh = Mesh(devs, ("dp", "ep")) + mr = MeshResource(dp_resource="dp", ep_resource="ep") + return mesh, mr + + +def _make_routing(dp_color, num_tokens, top_k, num_experts, num_local_experts): + """Deterministic routing: topk_idx[t, k] = (dp_color*NLE + t*K + k) % E.""" + topk_idx = np.empty((num_tokens, top_k), dtype=np.int32) + for t in range(num_tokens): + for k in range(top_k): + topk_idx[t, k] = (dp_color * num_local_experts + t * top_k + k) % num_experts + return topk_idx + + +def _make_inputs(args): + """Build 3D ``[B, S, H]`` arrays sharded ``(("dp","ep"), None, None)``. + + B = num_processes (sharded across the compound (dp,ep) axis so each rank + holds one slot); S = args.num_tokens. Global numpy views (rank-0 + reference) are kept 2D for the legacy reference implementation. + """ + T, K, H, H_out = args.num_tokens, args.top_k, args.hidden, args.hidden_out + E = args.num_experts + dp_size = args.dp_size + ep_size = args.ep_size + num_procs = args.num_processes + dp_color = args.process_id // ep_size + + rng_dp = np.random.default_rng(seed=42 + dp_color) + tokens_np = (rng_dp.standard_normal((T, H), dtype=np.float32) * 0.5).astype(np.float32) + topk_idx_np = _make_routing(dp_color, T, K, E, args.num_local_experts) + w_np = np.full((T, K), 1.0 / K, dtype=np.float32) + + tokens_global_np = np.concatenate( + [ + ( + np.random.default_rng(seed=42 + c).standard_normal((T, H), dtype=np.float32) * 0.5 + ).astype(np.float32) + for c in range(dp_size) + ], + axis=0, + ) + topk_idx_global_np = np.concatenate( + [_make_routing(c, T, K, E, args.num_local_experts) for c in range(dp_size)], axis=0 + ) + w_global_np = np.full((dp_size * T, K), 1.0 / K, dtype=np.float32) + + # Same seed on every rank → identical kernel array everywhere. + rng = np.random.default_rng(seed=42) + kernels_np = (rng.standard_normal((E, H, H_out), dtype=np.float32) * (1.0 / np.sqrt(H))).astype( + np.float32 + ) + + # Each rank contributes one [1, T, ...] slab; the global shape is + # [num_procs, T, ...] sharded on the first dim across (dp, ep). + mesh = args.mesh + dpep_spec = NamedSharding(mesh, PartitionSpec(("dp", "ep"), None, None)) + tokens = jax.make_array_from_process_local_data( + dpep_spec, tokens_np[None, :, :].astype(np.float32), (num_procs, T, H) + ).astype(jnp.bfloat16) + topk_idx = jax.make_array_from_process_local_data( + dpep_spec, topk_idx_np[None, :, :], (num_procs, T, K) + ) + topk_w = jax.make_array_from_process_local_data(dpep_spec, w_np[None, :, :], (num_procs, T, K)) + kernels = jnp.asarray(kernels_np, dtype=jnp.bfloat16) + return ( + tokens_global_np, + topk_idx_global_np, + w_global_np, + kernels_np, + tokens, + topk_idx, + topk_w, + kernels, + ) + + +# ── MoE step ──────────────────────────────────────────────────────────────── + + +def _batched_expert_linear(recv_tokens, kernels, num_local_experts, dp_size, ep_size): + """Per-expert linear. ``recv_tokens`` is 3D ``[num_procs, recv_pr, H]`` + (compound (dp,ep) leading); ``kernels`` is 4D ``[ep_size, NLE, H, H_out]``, + broadcast over the dp axis. Output matches ``recv_tokens``' 3D layout + with ``H_out`` in place of ``H``.""" + num_procs, recv_pr, H = recv_tokens.shape + H_out = kernels.shape[-1] + slots_per_expert = recv_pr // num_local_experts + # [num_procs, recv_pr, H] -> [dp, ep, NLE, slots, H] + grouped = recv_tokens.reshape(dp_size, ep_size, num_local_experts, slots_per_expert, H) + # Contract H; batch over (ep, NLE) which are present on both sides. + out = jax.lax.dot_general( + grouped, + kernels.astype(grouped.dtype), + dimension_numbers=(((4,), (2,)), ((1, 2), (0, 1))), + ) + # Output dim order from dot_general: batch dims first, then remaining lhs, rhs. + # batch=(ep,NLE), lhs_remaining=(dp,slots), rhs_remaining=(H_out,) + # → shape [ep, NLE, dp, slots, H_out]. Permute to [dp, ep, NLE, slots, H_out]. + out = jnp.transpose(out, (2, 0, 1, 3, 4)) + return out.reshape(num_procs, recv_pr, H_out) + + +def _moe_step(args, topk_idx, tokens, topk_w, kernels): + """Jit'd MoE step: dispatch -> batched per-expert linear -> combine. + + Inputs are 3D ``[B, S, H]`` with the first dim compound-sharded across + ``("dp","ep")``. Combine returns the same 3D shape. + """ + B = args.num_processes + S = args.num_tokens + NLE = args.num_local_experts + dp_size, ep_size = args.dp_size, args.ep_size + mesh = args.mesh + in_spec = PartitionSpec(("dp", "ep"), None, None) # [B, S, ...] + ep3 = PartitionSpec(("dp", "ep"), None, None) # [num_procs, recv_pr, H] + ep2 = PartitionSpec(("dp", "ep"), None) # [num_procs, recv_pr] + # Kernels are EP-replicated across dp colors; shard only the ep-rank axis. + kernel_spec = PartitionSpec("ep", None, None, None) + + kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:]) + + @jax.jit + def step(topk_idx, tokens, topk_w, local_kernels): + topk_idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(mesh, in_spec)) + tokens = jax.lax.with_sharding_constraint(tokens, NamedSharding(mesh, in_spec)) + topk_w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(mesh, in_spec)) + local_kernels = jax.lax.with_sharding_constraint( + local_kernels, NamedSharding(mesh, kernel_spec) + ) + slots_per_expert = args.recv_capacity_per_rank // NLE + recv_tokens, recv_topk_w, handle, _tc = ep_dispatch( + topk_idx, + tokens, + topk_w, + args.recv_capacity_per_rank, + dispatch_output_per_expert_alignment=slots_per_expert, + ) + recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3)) + recv_topk_w = jax.lax.with_sharding_constraint(recv_topk_w, NamedSharding(mesh, ep2)) + expert_out = _batched_expert_linear(recv_tokens, local_kernels, NLE, dp_size, ep_size) + expert_out = jax.lax.with_sharding_constraint(expert_out, NamedSharding(mesh, ep3)) + return ep_combine( + handle, + _tc, + expert_out, + recv_topk_w, + num_local_tokens=(B, S), + out_sharding=(("dp", "ep"), None, None), + ) + + return step(topk_idx, tokens, topk_w, kernels) + + +# ── Reference (numerical check) ───────────────────────────────────────────── + + +def _reference_moe(tokens, topk_idx, topk_w, kernels): + """Single-rank dense MoE reference. tokens [T, H], output [T, H_out].""" + T, K = topk_idx.shape + H_out = kernels.shape[-1] + out = np.zeros((T, H_out), dtype=np.float32) + for t in range(T): + tok = tokens[t].astype(np.float32) + for k in range(K): + e = int(topk_idx[t, k]) + out[t] += float(topk_w[t, k]) * (tok @ kernels[e].astype(np.float32)) + return out + + +def _reference_grad(tokens, topk_idx, topk_w, kernels): + """d/dtokens of 0.5 * sum(ref_out**2) — used by --check to validate bwd.""" + T, K = topk_idx.shape + H = tokens.shape[-1] + ref_out = _reference_moe(tokens, topk_idx, topk_w, kernels) + grad = np.zeros((T, H), dtype=np.float32) + for t in range(T): + mixed = np.zeros_like(kernels[0]) + for k in range(K): + mixed = mixed + float(topk_w[t, k]) * kernels[int(topk_idx[t, k])] + grad[t] = ref_out[t] @ mixed.T + return ref_out, grad + + +# ── Main ──────────────────────────────────────────────────────────────────── + + +def main(): + args = _parse_args() + _distributed_init(args) + + dev = jax.local_devices()[0] + cap = getattr(dev, "compute_capability", None) + if cap is not None: + major, minor = (int(x) for x in str(cap).split(".")) + if major * 10 + minor < 90: + print(f"[ep_moe] SKIPPED: NCCL EP requires SM>=90 (got SM{major}{minor})") + return + + args.mesh, args.mr = _build_mesh_and_resource(args) + + with args.mesh, global_shard_guard(args.mr): + ep_bootstrap( + world_size=args.num_processes, + rank=args.process_id, + ep_size=args.ep_size, + num_experts=args.num_experts, + max_tokens_per_rank=args.num_tokens, + recv_capacity_per_rank=args.recv_capacity_per_rank, + hidden_dim=args.hidden, + ) + + ( + tokens_global_np, + topk_idx_global_np, + w_global_np, + kernels_np, + tokens, + topk_idx, + topk_w, + kernels, + ) = _make_inputs(args) + + def loss_fn(toks, idx, w, kern): + out = _moe_step(args, idx, toks, w, kern) + return 0.5 * (out.astype(jnp.float32) ** 2).sum(), out + + (loss, out_fwd), grad_tokens = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))( + tokens, topk_idx, topk_w, kernels + ) + grad_tokens.block_until_ready() + out_fwd.block_until_ready() + + if args.process_id == 0: + print( + f"[ep_moe] loss={float(loss):.4f} grad_tokens.shape={grad_tokens.shape} " + f"dp={args.dp_size} ep={args.ep_size} " + f"num_experts={args.num_experts} recv_pr={args.recv_capacity_per_rank}" + ) + + if args.check: + + def _norm(spec, ndim): + return tuple(spec) + (None,) * (ndim - len(spec)) + + # JAX may collapse a size-1 mesh axis: when dp_size==1 the spec can + # appear as ``(("dp","ep"),...)`` or ``("ep",...)``. Accept both. + if args.dp_size > 1: + acceptable_specs = ((("dp", "ep"), None, None),) + else: + acceptable_specs = ((("dp", "ep"), None, None), ("ep", None, None)) + assert ( + _norm(out_fwd.sharding.spec, out_fwd.ndim) in acceptable_specs + ), f"out_fwd.sharding.spec={out_fwd.sharding.spec} (expected one of {acceptable_specs})" + assert _norm(grad_tokens.sharding.spec, grad_tokens.ndim) in acceptable_specs, ( + f"grad_tokens.sharding.spec={grad_tokens.sharding.spec}" + f" (expected one of {acceptable_specs})" + ) + + replicated = NamedSharding(args.mesh, jax.sharding.PartitionSpec()) + out_global = jax.jit(lambda x: jax.lax.with_sharding_constraint(x, replicated))(out_fwd) + grad_global = jax.jit(lambda x: jax.lax.with_sharding_constraint(x, replicated))( + grad_tokens + ) + out_global.block_until_ready() + grad_global.block_until_ready() + + ref_out, ref_grad = _reference_grad( + tokens_global_np, topk_idx_global_np, w_global_np, kernels_np + ) + ref_loss = 0.5 * float((ref_out.astype(np.float32) ** 2).sum()) + # 3D global ``[num_procs, S, H]`` with num_procs = dp * ep. Each EP + # column in a DP color sees identical inputs (and produces identical + # outputs), so collapse the ep dim to one replica before flattening + # to 2D against the dp-only reference. + dp_size, ep_size = args.dp_size, args.ep_size + global_out = ( + np.asarray(out_global.addressable_shards[0].data.astype(jnp.float32)) + .reshape(dp_size, ep_size, -1, ref_out.shape[-1])[:, 0] + .reshape(-1, ref_out.shape[-1]) + ) + global_grad = ( + np.asarray(grad_global.addressable_shards[0].data.astype(jnp.float32)) + .reshape(dp_size, ep_size, -1, ref_grad.shape[-1])[:, 0] + .reshape(-1, ref_grad.shape[-1]) + ) + if args.process_id == 0: + fwd_diff = np.abs(global_out - ref_out) + grad_diff = np.abs(global_grad - ref_grad) + print( + f"[ep_moe] DEBUG loss={float(loss):.4f} ref_loss(global)={ref_loss:.4f} " + f"ratio={float(loss) / max(ref_loss, 1e-9):.4f} (expected ~1.0)" + ) + print(f"[ep_moe] DEBUG fwd max-abs-diff per row: {fwd_diff.max(axis=1)}") + print(f"[ep_moe] DEBUG grad max-abs-diff per row: {grad_diff.max(axis=1)}") + np.testing.assert_allclose( + global_out, + ref_out, + rtol=5e-2, + atol=5e-2, + err_msg=f"rank {args.process_id}: fwd mismatch", + ) + np.testing.assert_allclose( + global_grad, + ref_grad, + rtol=5e-2, + atol=5e-2, + err_msg=f"rank {args.process_id}: bwd mismatch", + ) + if args.process_id == 0: + print(f"[ep_moe] --check PASSED (ref_out.sum()={float(ref_out.sum()):.4f})") + + +if __name__ == "__main__": + main() + sys.exit(0) diff --git a/examples/jax/ep/run_test_ep.sh b/examples/jax/ep/run_test_ep.sh new file mode 100755 index 0000000000..55b958f146 --- /dev/null +++ b/examples/jax/ep/run_test_ep.sh @@ -0,0 +1,85 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +#!/bin/bash + +NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} + +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "NCCL EP requires at least 4 GPUs (found ${NUM_GPUS}); SKIPPING." + exit 0 +fi +# Default mesh is (2, 2); use exactly 4 ranks even on larger boxes. +NUM_GPUS="${NVTE_EP_NUM_RANKS:-4}" + +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + +# NCCL EP requires NVLink P2P among ranks on the node. +echo "*** Checking NVLINK support ***" +NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1) +NVLINK_EXIT_CODE=$? +if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] \ + || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then + echo "NVLINK is not supported on this platform — EP example requires NVLINK; SKIPPING" + exit 0 +fi +echo "NVLINK support detected" + +SCRIPT="$TE_PATH/examples/jax/ep/ep_moe.py" +export PYTHONPATH="${TE_PATH}${PYTHONPATH:+:${PYTHONPATH}}" +COORD="${COORD:-127.0.0.1:12345}" +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-300}" + +XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_graph_min_graph_size=1" +export XLA_FLAGS="${XLA_BASE_FLAGS}" + +# Stage NCCL EP JIT cubins on tmpfs to keep build/iteration fast. +: ${NCCL_EP_JIT_CACHE_DIR:="${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)"} +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "$NCCL_EP_JIT_CACHE_DIR" + +echo +echo "*** Executing ep_moe.py across $NUM_GPUS GPUs ***" + +PIDS=() +cleanup() { + for pid in "${PIDS[@]}"; do + kill -0 "$pid" 2>/dev/null && kill -KILL "$pid" 2>/dev/null || true + done +} +trap cleanup EXIT INT TERM + +EXTRA_ARGS=${EXTRA_ARGS:-"--check"} + +for ((i=1; i "stdout_rank_${i}.txt" 2>&1 & + PIDS+=($!) +done +timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + python -u "$SCRIPT" \ + --coordinator-address "$COORD" --process-id "0" --num-processes "$NUM_GPUS" \ + $EXTRA_ARGS 2>&1 | tee stdout_rank_0.txt +wait + +HAS_FAILURE=0 +if grep -qE "FAILED|Traceback|ERROR" stdout_rank_0.txt; then + echo "... ep_moe FAILED" + HAS_FAILURE=1 +elif ! grep -qE "\[ep_moe\]" stdout_rank_0.txt; then + echo "... ep_moe INVALID (rank 0 produced no summary line)" + for ((i=1; i/dev/null + done + HAS_FAILURE=1 +else + echo "... ep_moe PASSED" +fi +rm -f stdout_rank_*.txt +exit $HAS_FAILURE diff --git a/tests/jax/multi_process_launch_ep.sh b/tests/jax/multi_process_launch_ep.sh new file mode 100755 index 0000000000..a37ffc2952 --- /dev/null +++ b/tests/jax/multi_process_launch_ep.sh @@ -0,0 +1,67 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +#!/bin/bash + +SCRIPT_NAMES="${SCRIPT_NAMES:-test_multi_process_ep.py}" +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-180}" + + +XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_graph_min_graph_size=1" + +export XLA_FLAGS="${XLA_BASE_FLAGS}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +NUM_RUNS=$(nvidia-smi -L | wc -l) + +if [ "${NUM_RUNS}" -lt 4 ]; then + echo "NCCL EP requires at least 4 GPUs (found ${NUM_RUNS}); SKIPPING." + exit 0 +fi +# Default test mesh is (2, 2); use exactly 4 ranks even on larger boxes. +NUM_RUNS="${NVTE_TEST_EP_NUM_RANKS:-4}" + +OVERALL_RET=0 + +for SCRIPT_NAME in $SCRIPT_NAMES; do + echo "=== Running ${SCRIPT_NAME} ===" + for ((i=1; i stdout_rank_${i}.txt 2>&1 & + done + + timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS 2>&1 | tee stdout_multi_process.txt + + wait + + RET=0 + if grep -q "FAILED" stdout_multi_process.txt; then + RET=1 + fi + # Treat missing test summary on rank 0 as hang/crash rather than silent success. + if ! grep -qE "Ran [0-9]+ test|^OK$|PASSED" stdout_multi_process.txt; then + echo "ERROR: rank 0 produced no test summary for ${SCRIPT_NAME} — likely a hang or early crash." + echo " NCCL EP requires NVLS multicast; check NCCL_DEBUG=INFO output." + RET=1 + fi + if [ "$RET" -ne 0 ]; then + for ((i=1; i/dev/null || echo "(no log)" + done + fi + + rm -f stdout_multi_process.txt stdout_rank_*.txt + if [ "$RET" -ne 0 ]; then + OVERALL_RET=1 + fi +done + +exit "$OVERALL_RET" diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py new file mode 100644 index 0000000000..0658ad9750 --- /dev/null +++ b/tests/jax/test_multi_process_ep.py @@ -0,0 +1,690 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Multi-process unit tests for the TE-JAX Expert Parallelism (EP) primitives. + +Default mesh is (dp=2, ep=2); override via ``NVTE_TEST_EP_MESH=DPxEP``. +Coverage: + + - ``ep_bootstrap`` rejects when ``ep_resource`` is unset. + - Individual primitives (``ep_prepare``, ``ep_dispatch_fwd``, ``ep_combine_fwd``) + round-trip an identity expert → output ≈ tokens. + - ``ep_dispatch`` custom_vjp: ``grad_tokens ≈ TOP_K · tokens`` (closed form). + - ``ep_combine`` custom_vjp: ``max|grad_eo| ≈ eo_const / TOP_K`` (closed form). + - ``ep_dispatch`` custom_vjp: exact per-(t, k) ``grad_topk_weights`` under + skewed upstream gradients (no k-axis averaging). + - HLO reshard guard: compile-only, no XLA collectives outside the EP FFI. + +Launch via tests/jax/multi_process_launch_ep.sh (one process per GPU). +""" + +import os +import sys +import unittest + +import jax +import jax.experimental.multihost_utils as jmu +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.jax.ep import ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.cpp_extensions.ep import ( + ep_prepare, + ep_dispatch_fwd, + ep_combine_fwd, +) + + +# ── Test config ───────────────────────────────────────────────────────────── +# NCCL EP requires NUM_LOCAL_EXPERTS*ep % 4 == 0 (TMA alignment in +# device/hybridep_adapter.cu:511). With NUM_LOCAL_EXPERTS=2, ep must be even. + +NUM_LOCAL_EXPERTS = 2 # per-rank → num_experts = NLE * EP +HIDDEN_DIM = 32 +TOP_K = 2 +TOKENS_PER_DP_SHARD = 4 # per device along dp + + +def _factor_dp_ep(num_procs): + """Default to a (2, 2) mesh. Override via ``NVTE_TEST_EP_MESH=DPxEP``. + + NUM_LOCAL_EXPERTS*ep must be a multiple of 4 for NCCL EP's TMA alignment. + """ + override = os.environ.get("NVTE_TEST_EP_MESH") + if override: + dp_str, ep_str = override.lower().split("x") + dp, ep = int(dp_str), int(ep_str) + if dp * ep != num_procs: + raise ValueError( + f"NVTE_TEST_EP_MESH={override!r} does not multiply to num_procs={num_procs}" + ) + if (NUM_LOCAL_EXPERTS * ep) % 4 != 0: + raise ValueError( + f"NUM_LOCAL_EXPERTS*ep ({NUM_LOCAL_EXPERTS}*{ep}) must be a multiple of 4 " + "for NCCL EP TMA alignment" + ) + return dp, ep + if num_procs != 4: + raise ValueError( + f"default mesh expects exactly 4 ranks (got {num_procs}); set " + "NVTE_TEST_EP_MESH=DPxEP to override" + ) + return 2, 2 + + +def _build_mesh(dp, ep): + devs = np.asarray(jax.devices()).reshape(dp, ep) + return Mesh(devs, ("dp", "ep")) + + +def _local_device_sm(): + """Return SM major*10+minor of the first local CUDA device, or None.""" + try: + dev = jax.local_devices()[0] + cap = getattr(dev, "compute_capability", None) + if cap is None: + return None + major, minor = (int(x) for x in str(cap).split(".")) + return major * 10 + minor + except Exception: + return None + + +class TestEP(unittest.TestCase): + @classmethod + def setUpClass(cls): + sm = _local_device_sm() + if sm is not None and sm < 90: + raise unittest.SkipTest(f"NCCL EP requires SM>=90 (got SM{sm})") + cls.num_procs = jax.process_count() + cls.rank = jax.process_index() + cls.dp, cls.ep = _factor_dp_ep(cls.num_procs) + cls.num_experts = NUM_LOCAL_EXPERTS * cls.ep + # recv_capacity is per-DP-group (NCCL EP comms isolated per DP color). + # Under PartitionSpec(("dp","ep"), None) each EP group sees + # T_global/dp = TOKENS_PER_DP_SHARD tokens total; pad for routing skew. + T_per_ep_group = TOKENS_PER_DP_SHARD + active_experts = min(cls.num_experts, T_per_ep_group * TOP_K) + overconc = cls.num_experts // active_experts + cls.recv_capacity_per_rank = ( + NUM_LOCAL_EXPERTS * max(T_per_ep_group * TOP_K, 16) * overconc * 2 + ) + cls.mesh = _build_mesh(cls.dp, cls.ep) + cls.mr = MeshResource(dp_resource="dp", ep_resource="ep") + with cls.mesh, global_shard_guard(cls.mr): + ep_bootstrap( + world_size=cls.num_procs, + rank=cls.rank, + ep_size=cls.ep, + num_experts=cls.num_experts, + max_tokens_per_rank=TOKENS_PER_DP_SHARD, + recv_capacity_per_rank=cls.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + ) + + # ── Bootstrap precondition ──────────────────────────────────────────── + + def test_bootstrap_rejects_missing_ep_axis(self): + """ep_bootstrap raises when MeshResource has no ep_resource.""" + with self.mesh, global_shard_guard(MeshResource()): + with self.assertRaisesRegex(ValueError, "ep_resource"): + ep_bootstrap( + world_size=self.num_procs, + rank=self.rank, + ep_size=self.ep, + num_experts=self.num_experts, + max_tokens_per_rank=TOKENS_PER_DP_SHARD, + recv_capacity_per_rank=self.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + ) + + # ── Helpers ─────────────────────────────────────────────────────────── + + def _make_identity_inputs(self, nonuniform=False): + """Identity routing + uniform weights — combined output ≈ tokens. + + ``nonuniform=False``: ``(t*TOP_K+k) % E`` (round-robin, near-balanced). + ``nonuniform=True``: ``top1=0`` for every token, ``top2=1+(t%(E-1))`` — + expert 0 absorbs the entire batch while the others split the second + slot evenly. Exercises a skewed per-expert load. + """ + T_global = TOKENS_PER_DP_SHARD * self.dp + E = self.num_experts + topk_idx = np.empty((T_global, TOP_K), dtype=np.int32) + if nonuniform: + assert TOP_K == 2, "non-uniform pattern assumes top_k=2" + for t in range(T_global): + topk_idx[t, 0] = 0 + topk_idx[t, 1] = 1 + (t % (E - 1)) + else: + for t in range(T_global): + for k in range(TOP_K): + topk_idx[t, k] = (t * TOP_K + k) % E + topk_idx = jnp.asarray(topk_idx) + topk_weights = jnp.full((T_global, TOP_K), 1.0 / TOP_K, dtype=jnp.float32) + tokens = jnp.asarray( + np.linspace(0.1, 0.9, T_global * HIDDEN_DIM, dtype=np.float32).reshape( + T_global, HIDDEN_DIM + ), + dtype=jnp.bfloat16, + ) + return T_global, topk_idx, tokens, topk_weights + + def _make_random_inputs(self, seed=42, nonuniform=True): + """Random tokens + skewed top-2 routing (top1=0 always; top2 varies). + + Non-uniform load by default — guarantees expert 0 receives every token + while the rest of the experts split the second slot. Use + ``nonuniform=False`` for a balanced (t%E, (t+1)%E) pattern. + """ + T_dp = TOKENS_PER_DP_SHARD * self.dp + E = self.num_experts + rng = np.random.default_rng(seed=seed) + tokens = jnp.asarray( + rng.standard_normal((T_dp, HIDDEN_DIM), dtype=np.float32) * 0.5, + dtype=jnp.bfloat16, + ) + topk_idx_np = np.empty((T_dp, TOP_K), dtype=np.int32) + if nonuniform: + assert TOP_K == 2, "non-uniform pattern assumes top_k=2" + for t in range(T_dp): + topk_idx_np[t, 0] = 0 + topk_idx_np[t, 1] = 1 + (t % (E - 1)) + else: + for t in range(T_dp): + a, b = t % E, (t + 1) % E + topk_idx_np[t, 0], topk_idx_np[t, 1] = (a, b) if a < b else (b, a) + topk_idx = jnp.asarray(topk_idx_np) + topk_weights = jnp.asarray(np.full((T_dp, TOP_K), 1.0 / TOP_K, dtype=np.float32)) + return T_dp, tokens, topk_idx, topk_weights + + # ── Individual primitives (cpp_extensions level) ────────────────────── + + def test_two_prepares_distinct_handle_ids(self): + """Two ep_prepare sites with matching (top_k, alignment) must produce + distinct handle_ids — distinct logical layers cannot share a + HandleEntry. Verified by tracing through jit so the primitive's + outer_primitive.bind path is exercised.""" + _T, topk_idx, _tokens, _w = self._make_identity_inputs() + captured: list = [] + dp_spec = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + + @jax.jit + def run(idx): + _tc_a, ha = ep_prepare(idx) + _tc_b, hb = ep_prepare(idx) + captured.append((ha.handle_id, hb.handle_id)) + return ha.handle_mem, hb.handle_mem + + hm_a, hm_b = run(idx_s) + hm_a.block_until_ready() + hm_b.block_until_ready() + id_a, id_b = captured[0] + self.assertNotEqual(id_a, id_b, "two ep_prepare calls returned the same handle_id") + + def test_primitive_prepare(self): + """ep_prepare returns the expected shapes and a valid handle id.""" + T_global, topk_idx, _tokens, _w = self._make_identity_inputs() + del T_global + dp_spec = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + + @jax.jit + def run(idx): + tc, handle = ep_prepare(idx) + return tc, handle.handle_mem + + tc, hm = run(idx_s) + tc.block_until_ready() + self.assertEqual(tc.shape, (self.dp * self.ep, NUM_LOCAL_EXPERTS)) + self.assertEqual(hm.shape[0], self.dp * self.ep) + self.assertGreater(hm.shape[1], 0) + + def _run_identity_round_trip(self, nonuniform): + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=nonuniform) + dp_spec = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + tok_s = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) + w_s = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + @jax.jit + def run(idx, toks, w): + _tc, handle = ep_prepare(idx) + recv_t, recv_w, handle = ep_dispatch_fwd( + handle, idx, toks, w, self.recv_capacity_per_rank + ) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + # Apply the weighted hadamard inline (combine FFI is unweighted). + mask = (recv_w != 0).astype(jnp.float32)[..., None] + weighted = (recv_t.astype(jnp.float32) * recv_w[..., None] * mask).astype( + recv_t.dtype + ) + weighted = jax.lax.with_sharding_constraint( + weighted, NamedSharding(self.mesh, ep_spec_3d) + ) + out = ep_combine_fwd( + handle, weighted, T_global, out_partition_spec=(("dp", "ep"), None) + ) + return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) + + out = run(idx_s, tok_s, w_s) + out.block_until_ready() + # Allgather so the rank-0 numpy comparison sees the full global tensor. + out_global = jmu.process_allgather(out, tiled=True) + + # Identity expert + uniform weights → out ≈ tokens (rank-0 check). + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(out_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + def test_primitive_dispatch_combine_identity_uniform(self): + """Round-robin routing → identity round-trip via the primitive layer.""" + self._run_identity_round_trip(nonuniform=False) + + def test_primitive_dispatch_combine_identity_nonuniform(self): + """Skewed routing (top1=0 always) → identity round-trip via the primitive layer.""" + self._run_identity_round_trip(nonuniform=True) + + def test_primitive_dispatch_combine_identity_bwd_uniform(self): + """Bwd through identity round-trip: ∇(0.5 ||out||²) w.r.t. tokens ≈ tokens. + + Identity routing + uniform top-k weights ⇒ dispatch∘combine is the + identity, so loss = 0.5||tokens||² and ∇_tokens loss = tokens. + """ + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(toks): + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + recv_t, recv_w, handle, tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + out = ep_combine( + handle, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None) + ) + return 0.5 * (out.astype(jnp.float32) ** 2).sum() + + grad = jax.jit(jax.grad(loss_fn))(tokens) + grad.block_until_ready() + grad_global = jmu.process_allgather(grad, tiled=True) + + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(grad_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + def test_dispatch_combine_3d_input_output(self): + """3D input ``[B, S, H]`` sharded on the first dim only — + ``(("dp","ep"), None, None)`` here — dispatch accepts the rank-3 shape + and combine returns a matching 3D ``[B, S, H]`` output. End-to-end + round trip recovers the original tokens under identity routing + + uniform top-k weights.""" + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + # B is sharded across all (dp*ep) ranks; S held in one piece per rank. + B, S, H = T_global, 1, tokens.shape[-1] + tokens_3d = tokens.reshape(B, S, H) + topk_idx_3d = topk_idx.reshape(B, S, -1) + topk_w_3d = topk_w.reshape(B, S, -1) + spec_3d = PartitionSpec(("dp", "ep"), None, None) + out_spec_3d = (("dp", "ep"), None, None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx_3d, NamedSharding(self.mesh, spec_3d)) + tok_s = jax.lax.with_sharding_constraint(tokens_3d, NamedSharding(self.mesh, spec_3d)) + w_s = jax.lax.with_sharding_constraint(topk_w_3d, NamedSharding(self.mesh, spec_3d)) + + ep_t = PartitionSpec(("dp", "ep"), None, None) + ep_w = PartitionSpec(("dp", "ep"), None) + + @jax.jit + def run(idx, toks, w): + recv_t, recv_w, handle, _tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) + out = ep_combine( + handle, + _tc, + recv_t, + recv_w, + num_local_tokens=(B, S), + out_sharding=out_spec_3d, + ) + return out + + out = run(idx_s, tok_s, w_s) + out.block_until_ready() + out_global = jmu.process_allgather(out, tiled=True) + + if self.rank == 0: + self.assertEqual(out_global.shape, (B, S, H)) + np.testing.assert_allclose( + np.asarray(out_global.astype(jnp.float32)), + np.asarray(tokens_3d.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + def test_dispatch_combine_dp_only_first_dim(self): + """Input sharded ``("dp", None)`` (no ep on leading) — dispatch must + accept it. JAX SPMD slices the missing ep axis locally so the kernel + still sees ``T/(dp*ep)`` tokens per rank.""" + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + dp_only = PartitionSpec("dp", None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_only)) + tok_s = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_only)) + w_s = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_only)) + + ep_t = PartitionSpec(("dp", "ep"), None, None) + ep_w = PartitionSpec(("dp", "ep"), None) + + @jax.jit + def run(idx, toks, w): + recv_t, recv_w, handle, _tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) + out = ep_combine( + handle, + _tc, + recv_t, + recv_w, + num_local_tokens=T_global, + out_sharding=(("dp", "ep"), None), + ) + return out + + out = run(idx_s, tok_s, w_s) + out.block_until_ready() + out_global = jmu.process_allgather(out, tiled=True) + + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(out_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + # ── Custom-VJP tests ───────────────────────────────────────────────── + + def test_dispatch_vjp_fwd_bwd(self): + """ep_dispatch fwd + jax.grad w.r.t. tokens. + + Identity routing + loss = 0.5||recv_tokens||² ⇒ each token appears + TOP_K times in recv_tokens (all routes fit recv_capacity), so + grad_tokens = TOP_K * tokens (closed form). + """ + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs() + del T_global + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(toks): + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + recv_tokens, _recv_w, _handle, _tc = ep_dispatch( + idx, toks, w, self.recv_capacity_per_rank + ) + recv_tokens = jax.lax.with_sharding_constraint( + recv_tokens, NamedSharding(self.mesh, ep_spec_3d) + ) + return 0.5 * (recv_tokens.astype(jnp.float32) ** 2).sum() + + loss, grad_tokens = jax.jit(jax.value_and_grad(loss_fn))(tokens) + grad_tokens.block_until_ready() + grad_global = jmu.process_allgather(grad_tokens, tiled=True) + + self.assertTrue(np.isfinite(float(loss))) + self.assertEqual(grad_tokens.shape, tokens.shape) + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(grad_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)) * float(TOP_K), + atol=5e-2, + rtol=5e-2, + ) + + def test_combine_vjp_fwd_bwd(self): + """ep_combine fwd + jax.grad w.r.t. expert_out. + + Identity routing + constant eo=c + uniform topk_w ⇒ combined[t] = c + (sum_k topk_w = 1) and grad_eo[e, s, h] = recv_w[e, s] * c at filled + slots — so max|grad_eo| ≈ c / TOP_K. + """ + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs() + eo_const = 0.5 + expert_out = jnp.full( + (self.dp * self.ep, self.recv_capacity_per_rank, HIDDEN_DIM), + eo_const, + dtype=jnp.bfloat16, + ) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(eo): + eo = jax.lax.with_sharding_constraint(eo, NamedSharding(self.mesh, ep_spec_3d)) + toks = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + _recv_tokens, recv_w, handle, tc = ep_dispatch( + idx, toks, w, self.recv_capacity_per_rank + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, PartitionSpec(("dp", "ep"), None)) + ) + combined = ep_combine(handle, tc, eo, recv_w, T_global) + # Pin combined to dp-sharded so autodiff transpose feeds + # ep_combine_bwd a per-shard cotangent. + combined = jax.lax.with_sharding_constraint( + combined, NamedSharding(self.mesh, dp_spec) + ) + return 0.5 * (combined.astype(jnp.float32) ** 2).sum() + + loss, grad_eo = jax.jit(jax.value_and_grad(loss_fn))(expert_out) + grad_eo.block_until_ready() + + self.assertTrue(np.isfinite(float(loss))) + self.assertEqual(grad_eo.shape, expert_out.shape) + for shard in grad_eo.addressable_shards: + arr = np.asarray(shard.data.astype(jnp.float32)) + self.assertTrue(np.all(np.isfinite(arr))) + self.assertGreater(arr.max(), 0.0, "grad_eo has no positive entry on filled slots") + np.testing.assert_allclose( + arr.max(), + eo_const / float(TOP_K), + atol=5e-2, + rtol=5e-2, + ) + + def test_dispatch_bwd_exact_per_k_topk_weights(self): + """Distinct per-(t, k) upstream grads ⇒ grad[t, 0] != grad[t, 1] for all t. + + Guards against a regression where the bwd would average across the k + axis (per-token mean instead of per-slot exact recovery). + """ + T_dp, tokens, topk_idx, topk_w = self._make_random_inputs() + dp_spec = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(idx_in, tok_in, w_in): + idx_in = jax.lax.with_sharding_constraint(idx_in, NamedSharding(self.mesh, dp_spec)) + tok_in = jax.lax.with_sharding_constraint(tok_in, NamedSharding(self.mesh, dp_spec)) + w_in = jax.lax.with_sharding_constraint(w_in, NamedSharding(self.mesh, dp_spec)) + _recv_t, recv_w, _h, _tc = ep_dispatch( + idx_in, tok_in, w_in, self.recv_capacity_per_rank + ) + # Per-slot index scale ⇒ each slot's contribution differs. + scale = jnp.asarray( + np.arange(recv_w.size, dtype=np.float32).reshape(recv_w.shape) + 1.0 + ) + return jnp.sum(recv_w * scale) + + grad_topk_w = jax.jit(jax.grad(loss_fn, argnums=2))(topk_idx, tokens, topk_w) + grad_topk_w.block_until_ready() + grad_global = jmu.process_allgather(grad_topk_w, tiled=True) + + if self.rank == 0: + grad_np = np.asarray(grad_global).astype(np.float32) + mismatch = sum(int(abs(grad_np[t, 0] - grad_np[t, 1]) < 1e-6) for t in range(T_dp)) + self.assertEqual( + mismatch, + 0, + f"Expected grad[t, 0] != grad[t, 1] for all {T_dp} tokens under skewed " + f"upstream scaling; got {mismatch} tokens with grad[t, 0] == grad[t, 1].", + ) + + # ── HLO reshard guard ──────────────────────────────────────────────── + # Compile-only: assert XLA inserts no cross-device collectives outside + # the EP FFI. EP-axis flux is carried by the FFI itself. + + def test_z_no_unexpected_reshard_in_hlo_fwd(self): + """Compiled fwd HLO must not insert XLA collectives outside the EP FFI.""" + T_dp, tokens, topk_idx, topk_w = self._make_random_inputs() + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + @jax.jit + def run(idx, toks, w): + idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) + recv_t, recv_w, handle, tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + out = ep_combine( + handle, tc, recv_t, recv_w, T_dp, out_sharding=(("dp", "ep"), None) + ) + return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) + + compiled = run.lower(topk_idx, tokens, topk_w).compile() + hlo = compiled.as_text() + # Match instruction names; "all-gather-start" and "all-gather-done" + # bracket a single async all-gather. + for op in ("all-gather-start", "all-to-all", "collective-permute"): + self.assertEqual(hlo.count(op), 0, f"unexpected XLA {op} in fwd HLO:\n{hlo}") + # XLA drops trailing-None entries from the spec; compare as a tuple. + # JAX collapses size-1 mesh axes, so dp=1 reduces ("dp","ep") to "ep". + expected = (("dp", "ep"),) if self.dp > 1 else ("ep",) + self.assertEqual(tuple(compiled.output_shardings.spec), expected) + + def test_z_no_unexpected_reshard_in_hlo_bwd(self): + """Compiled bwd HLO must not insert XLA collectives outside the EP FFI.""" + T_dp, tokens, topk_idx, topk_w = self._make_random_inputs() + rng = np.random.default_rng(seed=44) + expert_out = jnp.asarray( + rng.standard_normal( + (self.dp * self.ep, self.recv_capacity_per_rank, HIDDEN_DIM), dtype=np.float32 + ) + * 0.5, + dtype=jnp.bfloat16, + ) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + def fwd(eo, toks, idx, w): + eo = jax.lax.with_sharding_constraint(eo, NamedSharding(self.mesh, ep_spec_3d)) + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) + _rt, rw, handle, tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + rw = jax.lax.with_sharding_constraint(rw, NamedSharding(self.mesh, ep_spec_2d)) + combined = ep_combine(handle, tc, eo, rw, T_dp, out_sharding=(("dp", "ep"), None)) + return jax.lax.with_sharding_constraint(combined, NamedSharding(self.mesh, dp_spec)) + + # jax.vjp + pinned cotangent feeds ep_combine_bwd/ep_dispatch_bwd + # the expected sharding without relying on XLA-transpose propagation. + def bwd_only(eo, toks, idx, w, g): + _y, vjp_fn = jax.vjp(fwd, eo, toks, idx, w) + g = jax.lax.with_sharding_constraint(g, NamedSharding(self.mesh, dp_spec)) + grads = vjp_fn(g) + return ( + jax.lax.with_sharding_constraint( + grads[0], NamedSharding(self.mesh, ep_spec_3d) + ), + jax.lax.with_sharding_constraint(grads[1], NamedSharding(self.mesh, dp_spec)), + ) + + g_seed = jnp.ones((T_dp, HIDDEN_DIM), dtype=jnp.bfloat16) + compiled = ( + jax.jit(bwd_only).lower(expert_out, tokens, topk_idx, topk_w, g_seed).compile() + ) + hlo = compiled.as_text() + for op in ("all-gather-start", "all-to-all", "collective-permute"): + self.assertEqual(hlo.count(op), 0, f"unexpected XLA {op} in bwd HLO:\n{hlo}") + + +# ── Entry point ────────────────────────────────────────────────────────────── + + +if __name__ == "__main__": + if len(sys.argv) < 4: + print("Usage: python test_multi_process_ep.py ") + sys.exit(1) + + coord_addr = sys.argv[1] + proc_id = int(sys.argv[2]) + num_procs = int(sys.argv[3]) + + jax.distributed.initialize( + coordinator_address=coord_addr, + num_processes=num_procs, + process_id=proc_id, + local_device_ids=[proc_id], + ) + + loader = unittest.TestLoader() + target = os.environ.get("TARGET_TEST") + if target: + name = target.split(".")[-1] + suite = loader.loadTestsFromName(name, TestEP) + else: + suite = loader.loadTestsFromTestCase(TestEP) + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + sys.exit(0 if result.wasSuccessful() else 1) diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index fe1f93dc7a..604da5e1b7 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -10,4 +10,5 @@ from .softmax import * from .gemm import * from .router import * +from .ep import * from .topk import * diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py new file mode 100644 index 0000000000..7d112ad5f4 --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -0,0 +1,955 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX/TE custom ops for Expert Parallelism (EP). + +Sharding model: + - EpPrepare / EpDispatch outputs carry a single leading ``num_procs`` dim. + Sharded compound ``(dp_resource, ep_resource)`` when DP is set, else + ``ep_resource`` alone. + - EpDispatch inputs are 2D ``[T, H]`` or 3D ``[B, S, H]``; only the first + dim may be sharded, with axis ∈ {ep, (dp, ep), dp, None}. Trailing dims + must be replicated. ``dp`` alone gets ``ep`` folded in locally. + - EpCombine output sharding comes from ``out_sharding`` or defaults to the + compound ``(dp, ep)`` axis on the leading dim. +""" + +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +from jax import dtypes, ffi +from jax.sharding import NamedSharding, PartitionSpec +import jax.tree_util as jtu + +import transformer_engine_jax +from .base import BasePrimitive, register_primitive +from ..sharding import global_mesh_resource + +__all__ = [ + "EpConfig", + "EpHandle", + "set_ep_config", + "get_ep_config", + "get_ep_num_local_experts", + "ep_allocate_handle_id", + "ep_prepare", + "ep_dispatch_fwd", + "ep_combine_fwd", + "ep_dispatch_bwd", + "ep_combine_bwd", +] + + +# Routing-state container threaded through dispatch/combine/*_bwd. +@jtu.register_pytree_node_class +class EpHandle: + def __init__(self, handle_mem, handle_id): + self.handle_mem = handle_mem + self.handle_id = int(handle_id) + + def tree_flatten(self): + return (self.handle_mem,), (self.handle_id,) + + @classmethod + def tree_unflatten(cls, aux, children): + return cls(children[0], aux[0]) + + def __repr__(self): + return f"EpHandle(handle_id={self.handle_id})" + + +# ── Module-level EP config ────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class EpConfig: + """Immutable Python view of the EP bootstrap config (see ep_bootstrap).""" + + world_size: int + rank: int + ep_size: int + num_experts: int + num_local_experts: int + max_tokens_per_rank: int + recv_capacity_per_rank: int + hidden_dim: int + + +_ep_config: EpConfig = None + + +def set_ep_config(config: EpConfig) -> None: + """Cache the EP config for abstract-eval / sharding helpers. Call once.""" + global _ep_config + _ep_config = config + + +def get_ep_config() -> EpConfig: + if _ep_config is None: + raise RuntimeError("EpConfig has not been set. Did you call ep_bootstrap()?") + return _ep_config + + +def get_ep_num_local_experts() -> int: + return get_ep_config().num_local_experts + + +# handle_id -> handle_mem buffer size in bytes. +_HANDLE_MEM_SIZE_BY_ID: dict = {} + + +def ep_allocate_handle_id(top_k: int, dispatch_output_per_expert_alignment: int = 0) -> int: + """Reserve a fresh handle_id for an EP layer. + + Distinct logical layers must each call this — sharing a handle_id across + layers corrupts the routing state, even when (top_k, alignment) match. + """ + handle_id, handle_mem_size = transformer_engine_jax.ep_register_layer( + int(top_k), int(dispatch_output_per_expert_alignment) + ) + handle_id = int(handle_id) + _HANDLE_MEM_SIZE_BY_ID[handle_id] = int(handle_mem_size) + return handle_id + + +def _ep_handle_mem_size(handle_id: int) -> int: + """Return the handle_mem byte size for an id from ep_allocate_handle_id.""" + try: + return _HANDLE_MEM_SIZE_BY_ID[int(handle_id)] + except KeyError as e: + raise RuntimeError( + f"handle_id={handle_id} not registered; call ep_allocate_handle_id first." + ) from e + + +def _leading_axis_ok(spec, ep_axis, outer_axes=()): + # Only the first dim may carry sharding; remaining dims must be replicated. + # The first dim's axis must be one of: + # ``ep_axis`` alone, + # a tuple of dp/fsdp axes (no ep — ep gets sliced in locally), + # a tuple ending in ``ep_axis`` with dp/fsdp axes before it. + # Examples on a (dp, ep) mesh: 2D ``(ep, None)``, ``(("dp","ep"), None)``, + # ``("dp", None)``; 3D ``(ep, None, None)``, ``(("dp","ep"), None, None)``, + # ``("dp", None, None)``. + if len(spec) < 2 or ep_axis is None: + return False + if any(ax is not None for ax in spec[1:]): + return False # only first dim sharded + leading = spec[0] + allowed_outers = {a for a in outer_axes if a is not None} + allowed = allowed_outers | {ep_axis, None} + elts = leading if isinstance(leading, tuple) else (leading,) + return all(a in allowed for a in elts) + + +def _canonical_input_spec(spec, ndim): + """Canonical input PartitionSpec the primitive demands JAX deliver. + + Sharding lives entirely on the first dim. If ``spec[0]`` already includes + ``ep_resource``, returned unchanged. Otherwise ``ep_resource`` is folded + into the first-dim axis tuple, e.g. ``"dp"`` → ``("dp","ep")``. The added + ep axis is a local slice (the missing dim was replicated), no cross-device + comm. + """ + gsr = global_mesh_resource() + ep = gsr.ep_resource + leading = spec[0] + present = leading if isinstance(leading, tuple) else (leading,) if leading is not None else () + if ep in present: + return PartitionSpec(*spec) + if leading is None: + new_leading = ep + elif isinstance(leading, tuple): + new_leading = (*leading, ep) + else: + new_leading = (leading, ep) + return PartitionSpec(new_leading, *([None] * (ndim - 1))) + + +def _dispatch_input_outer_axes(): + """dp/fsdp axes allowed as outer companions to ep_resource on dispatch input.""" + gsr = global_mesh_resource() + return tuple(a for a in (gsr.dp_resource, gsr.fsdp_resource) if a is not None) + + +def _ep_outer_axis(): + """The single dp/fsdp axis (if any) sitting outside ep on EP-output tensors. + + When set, EP-output globals carry an extra leading ``dp_size`` dim so SPMD + sees each DP color's slab as distinct (rather than replicated across DP). + """ + gsr = global_mesh_resource() + return gsr.dp_resource or gsr.fsdp_resource + + +def _ep_leading_dims(is_outer): + """Single leading dim of an EP-output tensor: ``(dp*ep,)`` (or ``(ep,)`` when + DP is unset) globally; ``(1,)`` per shard.""" + cfg = get_ep_config() + outer = _ep_outer_axis() + if not is_outer: + return (1,) + return (cfg.world_size,) if outer is not None else (cfg.ep_size,) + + +def _ep_output_spec(*trailing): + """PartitionSpec for an EP-output tensor: ``(("dp","ep"), *trailing)`` when + DP is set (compound leading axis on a single dim), else ``("ep",*trailing)``.""" + gsr = global_mesh_resource() + outer = _ep_outer_axis() + if outer is None: + return PartitionSpec(gsr.ep_resource, *trailing) + return PartitionSpec((outer, gsr.ep_resource), *trailing) + + +def _ep_spec_ok(spec, trailing_count): + """Accept ``(ep, *[None])`` (no DP) or ``((dp,ep), *[None])`` / + ``(("dp",), *[None])`` / ``("dp", *[None])`` / ``(None, *[None])`` (with DP) + on an EP-output tensor's single leading dim. JAX may collapse a size-1 + mesh axis to ``None`` (matters for dp_size=1 like 1x4).""" + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer = _ep_outer_axis() + expected_len = 1 + trailing_count + if len(spec) != expected_len: + return False + if any(ax is not None for ax in spec[1:]): + return False + leading = spec[0] + if outer is None: + return leading == ep_axis + allowed = {ep_axis, outer, None} + elts = leading if isinstance(leading, tuple) else (leading,) + return all(a in allowed for a in elts) + + +# ── ep_prepare ────────────────────────────────────────────────────────────── + + +class EpPreparePrimitive(BasePrimitive): + name = "te_ep_prepare_ffi" + multiple_results = True + impl_static_args = (1, 2, 3) # handle_id, dispatch_output_per_expert_alignment, is_outer + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(topk_idx_aval, *, handle_id, dispatch_output_per_expert_alignment, is_outer): + # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with + # no DP); False: per-shard = (1,). + del dispatch_output_per_expert_alignment + cfg = get_ep_config() + num_local_experts = cfg.num_local_experts + assert ( + len(topk_idx_aval.shape) >= 2 + ), f"topk_idx must be at least 2D [..., top_k], got shape {topk_idx_aval.shape}" + handle_mem_size = _ep_handle_mem_size(handle_id) + leading = _ep_leading_dims(is_outer) + token_counts_aval = jax.core.ShapedArray(leading + (num_local_experts,), jnp.int32) + handle_mem_aval = jax.core.ShapedArray(leading + (handle_mem_size,), jnp.uint8) + # FFI scratch for the int32 -> int64 topk_idx upcast. int32 with last + # dim doubled to keep the int64 byte count without JAX_ENABLE_X64. + # TODO(phuong): drop once NCCL EP supports int32 topk_idx. + workspace_shape = topk_idx_aval.shape[:-1] + (topk_idx_aval.shape[-1] * 2,) + workspace_aval = jax.core.ShapedArray(workspace_shape, jnp.int32) + return token_counts_aval, handle_mem_aval, workspace_aval + + @staticmethod + def outer_abstract(topk_idx_aval, *, handle_id, dispatch_output_per_expert_alignment, is_outer): + del is_outer + avals = EpPreparePrimitive.abstract( + topk_idx_aval, + handle_id=handle_id, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + is_outer=True, + ) + return avals[:2] + + @staticmethod + def lowering(ctx, topk_idx, *, handle_id, dispatch_output_per_expert_alignment, is_outer): + del is_outer + return ffi.ffi_lowering(EpPreparePrimitive.name)( + ctx, + topk_idx, + handle_id=int(handle_id), + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + ) + + @staticmethod + def impl(topk_idx, handle_id, dispatch_output_per_expert_alignment, is_outer): + assert EpPreparePrimitive.inner_primitive is not None + token_counts, handle_mem, _workspace = EpPreparePrimitive.inner_primitive.bind( + topk_idx, + handle_id=handle_id, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + is_outer=is_outer, + ) + return token_counts, handle_mem + + @staticmethod + def batcher( + batched_args, batch_dims, *, handle_id, dispatch_output_per_expert_alignment, is_outer + ): + raise NotImplementedError("EpPreparePrimitive does not support vmap") + + @staticmethod + def partition( + handle_id, dispatch_output_per_expert_alignment, is_outer, mesh, arg_infos, result_infos + ): + del is_outer, result_infos + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer_axes = _dispatch_input_outer_axes() + idx_spec = arg_infos[0].sharding.spec + if not _leading_axis_ok(idx_spec, ep_axis, outer_axes): + raise NotImplementedError( + "EpPrepare: topk_idx leading dims must shard on ep_resource" + f" ('{ep_axis}') and/or {outer_axes}, with the topk dim replicated;" + f" got spec={idx_spec}." + ) + idx_ndim = len(arg_infos[0].shape) + arg_shardings = (NamedSharding(mesh, _canonical_input_spec(idx_spec, idx_ndim)),) + tc_sharding = NamedSharding(mesh, _ep_output_spec(None)) + hm_sharding = NamedSharding(mesh, _ep_output_spec(None)) + + def sharded_impl(topk_idx): + return EpPreparePrimitive.impl( + topk_idx, handle_id, dispatch_output_per_expert_alignment, False + ) + + return mesh, sharded_impl, (tc_sharding, hm_sharding), arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Static args + # for this primitive are (handle_id, dispatch_alignment, is_outer). + value_types = args[-2] + topk_idx_rank = len(value_types[0].shape) + in_axes = " ".join(f"L{i}" for i in range(topk_idx_rank - 1)) + " topk" + return f"{in_axes} -> EPL nle, EPL hm" + + +register_primitive(EpPreparePrimitive) + + +# ── ep_dispatch ───────────────────────────────────────────────────────────── + + +class EpDispatchPrimitive(BasePrimitive): + name = "te_ep_dispatch_ffi" + multiple_results = True + impl_static_args = (4, 5, 6, 7) # handle_id, recv_capacity_per_rank, top_k, is_outer + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + topk_idx_aval, + tokens_aval, + topk_weights_aval, + *, + handle_id, + recv_capacity_per_rank, + top_k, + is_outer, + ): + # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with + # no DP); False: per-shard = (1,). + del handle_id, topk_weights_aval, top_k, handle_mem_aval + assert ( + len(tokens_aval.shape) >= 2 + ), f"tokens must be at least 2D [..., H], got shape {tokens_aval.shape}" + recv_pr = recv_capacity_per_rank + tok_dtype = dtypes.canonicalize_dtype(tokens_aval.dtype) + hidden_dim = tokens_aval.shape[-1] + leading = _ep_leading_dims(is_outer) + recv_tokens_aval = jax.core.ShapedArray(leading + (recv_pr, hidden_dim), tok_dtype) + recv_topk_weights_aval = jax.core.ShapedArray(leading + (recv_pr,), jnp.float32) + # int32 with last dim doubled to keep the int64 byte count without JAX_ENABLE_X64. + workspace_shape = topk_idx_aval.shape[:-1] + (topk_idx_aval.shape[-1] * 2,) + workspace_aval = jax.core.ShapedArray(workspace_shape, jnp.int32) + return (recv_tokens_aval, recv_topk_weights_aval, workspace_aval) + + @staticmethod + def outer_abstract(*args, **kwargs): + kwargs = dict(kwargs) + kwargs["is_outer"] = True + avals = EpDispatchPrimitive.abstract(*args, **kwargs) + return avals[:2] + + @staticmethod + def lowering( + ctx, + handle_mem, + topk_idx, + tokens, + topk_weights, + *, + handle_id, + recv_capacity_per_rank, + top_k, + is_outer, + ): + del recv_capacity_per_rank, is_outer + return ffi.ffi_lowering(EpDispatchPrimitive.name)( + ctx, + handle_mem, + topk_idx, + tokens, + topk_weights, + handle_id=int(handle_id), + top_k=top_k, + ) + + @staticmethod + def impl( + handle_mem, + topk_idx, + tokens, + topk_weights, + handle_id, + recv_capacity_per_rank, + top_k, + is_outer, + ): + assert EpDispatchPrimitive.inner_primitive is not None + recv_tokens, recv_topk_weights, _workspace = EpDispatchPrimitive.inner_primitive.bind( + handle_mem, + topk_idx, + tokens, + topk_weights, + handle_id=handle_id, + recv_capacity_per_rank=recv_capacity_per_rank, + top_k=top_k, + is_outer=is_outer, + ) + return recv_tokens, recv_topk_weights + + @staticmethod + def batcher(batched_args, batch_dims, *, handle_id, recv_capacity_per_rank, top_k, is_outer): + raise NotImplementedError("EpDispatchPrimitive does not support vmap") + + @staticmethod + def partition( + handle_id, recv_capacity_per_rank, top_k, is_outer, mesh, arg_infos, result_infos + ): + del is_outer, result_infos + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer_axes = _dispatch_input_outer_axes() + tokens_spec = arg_infos[2].sharding.spec + if not _leading_axis_ok(tokens_spec, ep_axis, outer_axes): + raise NotImplementedError( + "EpDispatch: tokens leading dims must shard on ep_resource" + f" ('{ep_axis}') and/or {outer_axes}, hidden dim replicated;" + f" got spec={tokens_spec}." + ) + idx_spec = arg_infos[1].sharding.spec + tw_spec = arg_infos[3].sharding.spec + arg_shardings = ( + arg_infos[0].sharding, + NamedSharding(mesh, _canonical_input_spec(idx_spec, len(arg_infos[1].shape))), + NamedSharding(mesh, _canonical_input_spec(tokens_spec, len(arg_infos[2].shape))), + NamedSharding(mesh, _canonical_input_spec(tw_spec, len(arg_infos[3].shape))), + ) + out_shardings = ( + NamedSharding(mesh, _ep_output_spec(None, None)), + NamedSharding(mesh, _ep_output_spec(None)), + ) + + def sharded_impl(handle_mem, topk_idx, tokens, topk_weights): + return EpDispatchPrimitive.impl( + handle_mem, + topk_idx, + tokens, + topk_weights, + handle_id, + recv_capacity_per_rank, + top_k, + False, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Static args + # for this primitive are (handle_id, recv_capacity_per_rank, top_k, is_outer). + value_types = args[-2] + # Inputs: handle_mem, topk_idx, tokens, topk_weights. + idx_rank = len(value_types[1].shape) + tok_rank = len(value_types[2].shape) + tw_rank = len(value_types[3].shape) + idx_axes = " ".join(f"I{i}" for i in range(idx_rank - 1)) + " topk_in" + tok_axes = " ".join(f"T{i}" for i in range(tok_rank - 1)) + " H" + tw_axes = " ".join(f"W{i}" for i in range(tw_rank - 1)) + " topk" + return f"EPL hm, {idx_axes}, {tok_axes}, {tw_axes} -> EPL recv_pr H, EPL recv_pr" + + +register_primitive(EpDispatchPrimitive) + + +# ── ep_combine ────────────────────────────────────────────────────────────── +# `expert_out` here is the post-weight buffer; ep.ep_combine applies the +# hadamard before calling. + + +def _normalize_leading_shape(s): + return s if isinstance(s, tuple) else (int(s),) + + +def _prod(seq): + p = 1 + for x in seq: + p *= int(x) + return p + + +def _resolve_out_partition_spec(out_partition_spec, num_leading): + """Pick the combine output PartitionSpec. + + Defaults to a compound leading axis ``(dp_resource, ep_resource)`` when a + DP/FSDP axis is set on the active MeshResource, else just ``ep_resource``. + This matches the input sharding so XLA does not need collective-permutes + in the bwd path. + """ + if out_partition_spec is not None: + assert len(out_partition_spec) == num_leading + 1, ( + f"out_partition_spec length {len(out_partition_spec)} must equal num_leading" + f" + 1 ({num_leading + 1})" + ) + return tuple(out_partition_spec) + gsr = global_mesh_resource() + if gsr.ep_resource is None: + raise ValueError( + "ep_combine: ep_resource is not set on the active MeshResource;" + " pass out_sharding=... explicitly." + ) + outer = gsr.dp_resource or gsr.fsdp_resource + leading = (outer, gsr.ep_resource) if outer is not None else gsr.ep_resource + return (leading,) + (None,) * num_leading + + +def _per_shard_leading(out_leading_shape, resolved_spec, mesh): + """Per-shard leading shape given resolved partition spec and mesh.""" + per_shard = list(out_leading_shape) + for i, ax in enumerate(resolved_spec[: len(out_leading_shape)]): + if ax is None: + continue + axes = ax if isinstance(ax, tuple) else (ax,) + factor = 1 + for a in axes: + factor *= mesh.shape[a] + assert ( + per_shard[i] % factor == 0 + ), f"leading dim {per_shard[i]} not divisible by shard factor {factor} on axes {axes}" + per_shard[i] //= factor + return tuple(per_shard) + + +class EpCombinePrimitive(BasePrimitive): + name = "te_ep_combine_ffi" + multiple_results = False + impl_static_args = (2, 3, 4) # handle_id, out_leading_shape, out_partition_spec + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + expert_out_aval, + *, + handle_id, + out_leading_shape, + out_partition_spec, + ): + del handle_id, out_partition_spec, handle_mem_aval + assert ( + len(expert_out_aval.shape) == 3 + ), f"expert_out must be 3D [num_procs, recv_pr, H], got shape {expert_out_aval.shape}" + eo_dtype = dtypes.canonicalize_dtype(expert_out_aval.dtype) + hidden_dim = expert_out_aval.shape[-1] + out_shape = tuple(out_leading_shape) + (hidden_dim,) + return jax.core.ShapedArray(out_shape, eo_dtype) + + @staticmethod + def lowering( + ctx, + handle_mem, + expert_out, + *, + handle_id, + out_leading_shape, + out_partition_spec, + ): + del out_partition_spec + return ffi.ffi_lowering(EpCombinePrimitive.name)( + ctx, + handle_mem, + expert_out, + handle_id=int(handle_id), + num_local_tokens=_prod(out_leading_shape), + ) + + @staticmethod + def impl(handle_mem, expert_out, handle_id, out_leading_shape, out_partition_spec): + assert EpCombinePrimitive.inner_primitive is not None + return EpCombinePrimitive.inner_primitive.bind( + handle_mem, + expert_out, + handle_id=handle_id, + out_leading_shape=out_leading_shape, + out_partition_spec=out_partition_spec, + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, handle_id, out_leading_shape, out_partition_spec): + raise NotImplementedError("EpCombinePrimitive does not support vmap") + + @staticmethod + def partition(handle_id, out_leading_shape, out_partition_spec, mesh, arg_infos, result_infos): + del result_infos + eo_spec = arg_infos[1].sharding.spec + if not _ep_spec_ok(eo_spec, trailing_count=2): + raise NotImplementedError( + "EpCombine: expert_out must be sharded as PartitionSpec(ep_resource," + " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)" + f" over [num_procs, recv_pr, H]; got spec={eo_spec}." + ) + resolved = _resolve_out_partition_spec(out_partition_spec, len(out_leading_shape)) + per_shard_leading = _per_shard_leading(out_leading_shape, resolved, mesh) + arg_shardings = tuple(a.sharding for a in arg_infos) + out_sharding = NamedSharding(mesh, PartitionSpec(*resolved)) + + def sharded_impl(handle_mem, expert_out): + return EpCombinePrimitive.impl( + handle_mem, expert_out, handle_id, per_shard_leading, out_partition_spec + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Static args: + # (handle_id, out_leading_shape, out_partition_spec). + result_types = args[-1] + out_rank = len(result_types[0].shape) + out_axes = " ".join(f"O{i}" for i in range(out_rank - 1)) + " H" + return f"EPL hm, EPL recv_pr H -> {out_axes}" + + +register_primitive(EpCombinePrimitive) + + +# ── ep_dispatch_bwd ───────────────────────────────────────────────────────── + + +class EpDispatchBwdPrimitive(BasePrimitive): + name = "te_ep_dispatch_bwd_ffi" + multiple_results = True + impl_static_args = (3, 4, 5, 6) # handle_id, top_k, out_leading_shape, out_partition_spec + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + grad_aval, + g_recv_topk_weights_aval, + *, + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + ): + del handle_id, g_recv_topk_weights_aval, out_partition_spec, handle_mem_aval + assert ( + len(grad_aval.shape) == 3 + ), f"grad must be 3D [num_procs, recv_pr, H], got shape {grad_aval.shape}" + g_dtype = dtypes.canonicalize_dtype(grad_aval.dtype) + hidden_dim = grad_aval.shape[-1] + result_aval = jax.core.ShapedArray(tuple(out_leading_shape) + (hidden_dim,), g_dtype) + grad_topk_weights_aval = jax.core.ShapedArray( + tuple(out_leading_shape) + (top_k,), jnp.float32 + ) + return result_aval, grad_topk_weights_aval + + @staticmethod + def lowering( + ctx, + handle_mem, + grad, + g_recv_topk_weights, + *, + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + ): + del out_partition_spec + return ffi.ffi_lowering(EpDispatchBwdPrimitive.name)( + ctx, + handle_mem, + grad, + g_recv_topk_weights, + handle_id=int(handle_id), + num_local_tokens=_prod(out_leading_shape), + top_k=top_k, + ) + + @staticmethod + def impl( + handle_mem, + grad, + g_recv_topk_weights, + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + ): + assert EpDispatchBwdPrimitive.inner_primitive is not None + return EpDispatchBwdPrimitive.inner_primitive.bind( + handle_mem, + grad, + g_recv_topk_weights, + handle_id=handle_id, + top_k=top_k, + out_leading_shape=out_leading_shape, + out_partition_spec=out_partition_spec, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + ): + raise NotImplementedError("EpDispatchBwdPrimitive does not support vmap") + + @staticmethod + def partition( + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + mesh, + arg_infos, + result_infos, + ): + del result_infos + g_spec = arg_infos[1].sharding.spec + if not _ep_spec_ok(g_spec, trailing_count=2): + raise NotImplementedError( + "EpDispatchBwd: grad must be sharded as PartitionSpec(ep_resource," + " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)" + f" over [num_procs, recv_pr, H]; got spec={g_spec}." + ) + gw_spec = arg_infos[2].sharding.spec + if not _ep_spec_ok(gw_spec, trailing_count=1): + raise NotImplementedError( + "EpDispatchBwd: g_recv_topk_weights must be sharded as" + " PartitionSpec(ep_resource, None) (or ((dp, ep), None) when dp/fsdp is set)" + f" over [num_procs, recv_pr]; got spec={gw_spec}." + ) + resolved = _resolve_out_partition_spec(out_partition_spec, len(out_leading_shape)) + per_shard_leading = _per_shard_leading(out_leading_shape, resolved, mesh) + arg_shardings = tuple(a.sharding for a in arg_infos) + out_shardings = [ + NamedSharding(mesh, PartitionSpec(*resolved)), + NamedSharding(mesh, PartitionSpec(*resolved, None)), + ] + + def sharded_impl(handle_mem, grad, g_recv_topk_weights): + return EpDispatchBwdPrimitive.impl( + handle_mem, + grad, + g_recv_topk_weights, + handle_id, + top_k, + per_shard_leading, + out_partition_spec, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Result rank + # follows out_leading_shape (static arg #2): rank = len(out_leading) + 1. + result_types = args[-1] + out_rank = len(result_types[0].shape) + out_axes = " ".join(f"O{i}" for i in range(out_rank - 1)) + return f"EPL hm, EPL recv_pr H, EPL recv_pr -> {out_axes} H, {out_axes} k" + + +register_primitive(EpDispatchBwdPrimitive) + + +# ── ep_combine_bwd ────────────────────────────────────────────────────────── + + +class EpCombineBwdPrimitive(BasePrimitive): + name = "te_ep_combine_bwd_ffi" + multiple_results = False + impl_static_args = (2, 3, 4) # handle_id, recv_capacity_per_rank, is_outer + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(handle_mem_aval, grad_aval, *, handle_id, recv_capacity_per_rank, is_outer): + # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with + # no DP); False: per-shard = (1,). + del handle_id, handle_mem_aval + assert ( + len(grad_aval.shape) >= 2 + ), f"grad must be at least 2D [..., H], got shape {grad_aval.shape}" + g_dtype = dtypes.canonicalize_dtype(grad_aval.dtype) + hidden_dim = grad_aval.shape[-1] + leading = _ep_leading_dims(is_outer) + return jax.core.ShapedArray(leading + (recv_capacity_per_rank, hidden_dim), g_dtype) + + @staticmethod + def outer_abstract(*args, **kwargs): + kwargs = dict(kwargs) + kwargs["is_outer"] = True + return EpCombineBwdPrimitive.abstract(*args, **kwargs) + + @staticmethod + def lowering(ctx, handle_mem, grad, *, handle_id, recv_capacity_per_rank, is_outer): + del recv_capacity_per_rank, is_outer + return ffi.ffi_lowering(EpCombineBwdPrimitive.name)( + ctx, + handle_mem, + grad, + handle_id=int(handle_id), + ) + + @staticmethod + def impl(handle_mem, grad, handle_id, recv_capacity_per_rank, is_outer): + assert EpCombineBwdPrimitive.inner_primitive is not None + return EpCombineBwdPrimitive.inner_primitive.bind( + handle_mem, + grad, + handle_id=handle_id, + recv_capacity_per_rank=recv_capacity_per_rank, + is_outer=is_outer, + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, handle_id, recv_capacity_per_rank, is_outer): + raise NotImplementedError("EpCombineBwdPrimitive does not support vmap") + + @staticmethod + def partition(handle_id, recv_capacity_per_rank, is_outer, mesh, arg_infos, result_infos): + del is_outer, result_infos + arg_shardings = tuple(a.sharding for a in arg_infos) + out_sharding = NamedSharding(mesh, _ep_output_spec(None, None)) + + def sharded_impl(handle_mem, grad): + return EpCombineBwdPrimitive.impl( + handle_mem, grad, handle_id, recv_capacity_per_rank, False + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # T axes are dynamic-rank based on the actual cotangent shape. + value_types = args[-2] + g_rank = len(value_types[1].shape) + g_axes = " ".join(f"T{i}" for i in range(g_rank - 1)) + " H" + return f"EPL hm, {g_axes} -> EPL recv_pr H" + + +register_primitive(EpCombineBwdPrimitive) + + +# ── Public-ish helpers (used by jax/ep.py) ────────────────────────────────── + + +_HANDLE_ID_CALLSITE_CACHE = {} + + +def ep_prepare(topk_idx, dispatch_output_per_expert_alignment=0): + """Exchange routing metadata; return ``(token_counts, EpHandle)``.""" + import sys as _sys + + top_k = int(topk_idx.shape[-1]) + alignment = int(dispatch_output_per_expert_alignment) + # Cache handle_id by caller (file:lineno, top_k, alignment): JAX re-traces + # the same call site (e.g. custom_vjp fwd vs primal) and the resulting + # EpHandles must share the same id to compare equal in pytree aux. + f = _sys._getframe(1) + cache_key = (f.f_code.co_filename, f.f_lineno, top_k, alignment) + handle_id = _HANDLE_ID_CALLSITE_CACHE.get(cache_key) + if handle_id is None: + handle_id = ep_allocate_handle_id(top_k, alignment) + _HANDLE_ID_CALLSITE_CACHE[cache_key] = handle_id + token_counts, handle_mem = EpPreparePrimitive.outer_primitive.bind( + topk_idx, + handle_id=handle_id, + dispatch_output_per_expert_alignment=alignment, + is_outer=True, + ) + return token_counts, EpHandle(handle_mem, handle_id) + + +def ep_dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank): + """Scatter tokens and weights to expert ranks; returns (recv_tokens, recv_topk_weights, handle).""" + top_k = int(topk_weights.shape[-1]) + recv_tokens, recv_topk_weights = EpDispatchPrimitive.outer_primitive.bind( + handle.handle_mem, + topk_idx, + tokens, + topk_weights, + handle_id=handle.handle_id, + recv_capacity_per_rank=recv_capacity_per_rank, + top_k=top_k, + is_outer=True, + ) + return recv_tokens, recv_topk_weights, handle + + +def ep_combine_fwd(handle, expert_out, num_local_tokens, out_partition_spec=None): + """Gather expert outputs back to home ranks. expert_out is pre-weighted.""" + out_leading = _normalize_leading_shape(num_local_tokens) + return EpCombinePrimitive.outer_primitive.bind( + handle.handle_mem, + expert_out, + handle_id=handle.handle_id, + out_leading_shape=out_leading, + out_partition_spec=out_partition_spec, + ) + + +def ep_dispatch_bwd( + handle, grad, g_recv_topk_weights, top_k, num_local_tokens, out_partition_spec=None +): + """Backward of dispatch; returns (grad_tokens, grad_topk_weights).""" + out_leading = _normalize_leading_shape(num_local_tokens) + return EpDispatchBwdPrimitive.outer_primitive.bind( + handle.handle_mem, + grad, + g_recv_topk_weights, + handle_id=handle.handle_id, + top_k=int(top_k), + out_leading_shape=out_leading, + out_partition_spec=out_partition_spec, + ) + + +def ep_combine_bwd(handle, grad, recv_capacity_per_rank): + """Backward of combine; returns grad_expert_out [num_procs, recv_capacity_per_rank, H].""" + return EpCombineBwdPrimitive.outer_primitive.bind( + handle.handle_mem, + grad, + handle_id=handle.handle_id, + recv_capacity_per_rank=recv_capacity_per_rank, + is_outer=True, + ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 416b18ada0..4d8b097f27 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -200,6 +200,25 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); +// EP bootstrap (called once per process) +void EpInitialize(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms); +// EP shutdown — registered as a Python atexit hook so it runs before +// C++ static destructors of the JAX extension and libtransformer_engine.so. +void EpShutdown(); +// Host-only: register an EP layer. Returns (handle_id, handle_mem_size) where +// handle_id is baked into each FFI op as a static int64 attribute (no D2H sync +// per op) and handle_mem_size sizes the caller's handle_mem buffer. +pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment); + +// EP FFI handlers +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpPrepareHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpDispatchHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpCombineHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpDispatchBwdHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpCombineBwdHandler); + // TopK XLA_FFI_DECLARE_HANDLER_SYMBOL(TopkHandler); pybind11::tuple GetTopkWorkspaceSizes(int batch_size, int seq_len, int k); diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp new file mode 100644 index 0000000000..e2c50135aa --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -0,0 +1,457 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifdef NVTE_WITH_NCCL_EP + +#include "transformer_engine/ep.h" + +#include + +#include +#include +#include + +#include "../extensions.h" +#include "common.h" +#include "transformer_engine/gemm.h" + +namespace transformer_engine { +namespace jax { + +namespace { + +// Process-lifetime owner of the EP ncclComm_t. Created from a broadcast +// ncclUniqueId during EpInitialize; destroyed by EpShutdown (registered as a +// Python atexit hook from ep.py so it runs before C++ static destructors). +class EpCommManager { + public: + static EpCommManager& get() { + static EpCommManager inst; + return inst; + } + + void init_from_uid(const uint8_t* uid_bytes, int ep_size, int rank_within_group) { + std::lock_guard lock(mutex_); + NVTE_CHECK(comm_ == nullptr, "EP comm already initialized for this process"); + ncclUniqueId uid; + std::memcpy(&uid, uid_bytes, sizeof(uid)); + NVTE_CHECK_NCCL(ncclCommInitRank(&comm_, ep_size, uid, rank_within_group)); + } + + ncclComm_t comm() const { return comm_; } + + void shutdown() { + std::lock_guard lock(mutex_); + if (comm_ == nullptr) return; + ncclCommDestroy(comm_); + comm_ = nullptr; + } + + private: + EpCommManager() = default; + // Intentionally no NCCL teardown in the destructor: this runs at static-dtor + // time, after Python has finalized and possibly after the CUDA driver + // detaches the context. Calling ncclCommDestroy there has been observed to + // hang or report cudartUnloading. Normal teardown goes through the Python + // atexit hook (shutdown_ep_communicator) registered from ep.py; any path + // that skips that (os._exit, fatal signal) leaks the comm, which the OS + // reaps on process exit. + ~EpCommManager() = default; + EpCommManager(const EpCommManager&) = delete; + EpCommManager& operator=(const EpCommManager&) = delete; + + std::mutex mutex_; + ncclComm_t comm_{nullptr}; +}; + +} // namespace + +// handle_id is baked at jit trace time and carried as a static FFI attribute. + +struct EpPrepareConfig { + int64_t handle_id; + int64_t dispatch_output_per_expert_alignment; +}; + +struct EpDispatchConfig { + int64_t handle_id; + int64_t top_k; +}; + +struct EpCombineConfig { + int64_t handle_id; + int64_t num_local_tokens; +}; + +struct EpDispatchBwdConfig { + int64_t handle_id; + int64_t num_local_tokens; + int64_t top_k; +}; + +struct EpCombineBwdConfig { + int64_t handle_id; +}; + +// ── Bootstrap helpers ───────────────────────────────────────────────────────── + +void EpInitialize(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms) { + std::string uid_str = unique_id_bytes_obj; + NVTE_CHECK(static_cast(uid_str.size()) >= 128, + "unique_id_bytes must be at least 128 bytes (ncclUniqueId size)."); + EpCommManager::get().init_from_uid(reinterpret_cast(uid_str.data()), ep_size, + rank_within_group); + NVTEEpGroupConfig cfg{.ep_size = ep_size, + .num_experts = num_experts, + .max_tokens_per_rank = max_tokens_per_rank, + .max_recv_tokens_per_rank = max_recv_tokens_per_rank, + .hidden_dim = hidden_dim, + .max_num_sms = max_num_sms}; + // If common rejects the config (validate_config / ncclEpCreateGroup), roll + // the comm back so the two singletons don't end up in inconsistent states + // and the comm doesn't strand until process exit. + try { + nvte_ep_initialize(static_cast(EpCommManager::get().comm()), cfg); + } catch (...) { + EpCommManager::get().shutdown(); + throw; + } +} + +void EpShutdown() { + // Order matters: ep_group_ in common reads from the comm, so tear it down + // first, then destroy the comm. + nvte_ep_shutdown(); + EpCommManager::get().shutdown(); +} + +pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment) { + NVTEEpLayerConfig layer_cfg{0, top_k, dispatch_output_per_expert_alignment}; + size_t handle_mem_size = 0; + uint64_t handle_id = nvte_ep_register_layer(layer_cfg, &handle_mem_size); + return pybind11::make_tuple(handle_id, handle_mem_size); +} + +// ── ep_prepare ──────────────────────────────────────────────────────────────── + +Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts, + Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) { + auto topk_dims = topk_idx.dimensions(); + NVTE_CHECK(topk_dims.size() >= 2, + "topk_idx must be at least 2D [..., top_k], got ndim=", topk_dims.size()); + auto idx_etype = topk_idx.element_type(); + NVTE_CHECK(idx_etype == ::xla::ffi::DataType::S64 || idx_etype == ::xla::ffi::DataType::S32, + "topk_idx must be int32 or int64; got element_type=", static_cast(idx_etype)); + + std::vector topk_shape = {product(topk_dims, 0, topk_dims.size() - 1), + static_cast(topk_dims.back())}; + // NCCL EP currently requires int64 topk_idx; upcast int32 on-stream. + // TODO(phuong): drop once NCCL EP accepts int32. + void* topk_idx_data = topk_idx.untyped_data(); + if (idx_etype == ::xla::ffi::DataType::S32) { + const size_t n = topk_shape[0] * topk_shape[1]; + NVTE_CHECK(static_cast(workspace->element_count()) >= n, + "workspace too small for int32 → int64 upcast: element_count=", + workspace->element_count(), " < required ", n); + int64_t* ws = reinterpret_cast(workspace->untyped_data()); + nvte_convert_int32_to_int64(reinterpret_cast(topk_idx_data), ws, n, stream); + topk_idx_data = ws; + } + auto topk_idx_ = TensorWrapper(topk_idx_data, topk_shape, DType::kInt64); + + std::vector tc_shape = {static_cast(token_counts->element_count())}; + auto token_counts_ = TensorWrapper(token_counts->untyped_data(), tc_shape, DType::kInt32); + + std::vector hm_shape = {static_cast(handle_mem->element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem->untyped_data(), hm_shape, DType::kByte); + + NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()}; + nvte_ep_prepare(handle, topk_idx_.data(), token_counts_.data(), + static_cast(config.dispatch_output_per_expert_alignment), stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpPrepareHandler, EpPrepareFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // topk_idx + .Ret() // token_counts + .Ret() // handle_mem + .Ret() // workspace (FFI scratch) + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_dispatch ─────────────────────────────────────────────────────────────── + +Error_Type EpDispatchFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type topk_idx, + Buffer_Type tokens, Buffer_Type topk_weights, Result_Type recv_tokens, + Result_Type recv_topk_weights, Result_Type workspace, + EpDispatchConfig config) { + auto token_dims = tokens.dimensions(); + NVTE_CHECK(token_dims.size() >= 2, + "tokens must be at least 2D [..., H], got ndim=", token_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + auto idx_dims = topk_idx.dimensions(); + NVTE_CHECK(idx_dims.size() >= 2, + "topk_idx must be at least 2D [..., top_k], got ndim=", idx_dims.size()); + auto idx_etype = topk_idx.element_type(); + NVTE_CHECK(idx_etype == ::xla::ffi::DataType::S64 || idx_etype == ::xla::ffi::DataType::S32, + "topk_idx must be int32 or int64; got element_type=", static_cast(idx_etype)); + NVTE_CHECK(static_cast(idx_dims.back()) == config.top_k, "top_k attr (", config.top_k, + ") must match topk_idx last dim (", idx_dims.back(), ")"); + std::vector idx_shape = {product(idx_dims, 0, idx_dims.size() - 1), + static_cast(idx_dims.back())}; + // NCCL EP currently requires int64 topk_idx; upcast int32 on-stream. + // TODO(phuong): drop once NCCL EP accepts int32. + void* topk_idx_data = topk_idx.untyped_data(); + if (idx_etype == ::xla::ffi::DataType::S32) { + const size_t n = idx_shape[0] * idx_shape[1]; + NVTE_CHECK(static_cast(workspace->element_count()) >= n, + "workspace too small for int32 → int64 upcast: element_count=", + workspace->element_count(), " < required ", n); + int64_t* ws = reinterpret_cast(workspace->untyped_data()); + nvte_convert_int32_to_int64(reinterpret_cast(topk_idx_data), ws, n, stream); + topk_idx_data = ws; + } + auto topk_idx_ = TensorWrapper(topk_idx_data, idx_shape, DType::kInt64); + + const size_t T_flat = product(token_dims, 0, token_dims.size() - 1); + const size_t H = static_cast(token_dims.back()); + std::vector tok_shape = {T_flat, H}; + auto token_dtype = convert_ffi_datatype_to_te_dtype(tokens.element_type()); + auto tokens_ = TensorWrapper(tokens.untyped_data(), tok_shape, token_dtype); + + auto tw_dims = topk_weights.dimensions(); + NVTE_CHECK(tw_dims.size() >= 2, + "topk_weights must be at least 2D [..., top_k], got ndim=", tw_dims.size()); + std::vector tw_shape = {product(tw_dims, 0, tw_dims.size() - 1), + static_cast(tw_dims.back())}; + auto topk_weights_ = TensorWrapper(topk_weights.untyped_data(), tw_shape, DType::kFloat32); + + // recv_tokens: flatten any leading dims into recv_capacity_per_rank. + auto recv_dims = recv_tokens->dimensions(); + NVTE_CHECK(recv_dims.size() >= 2, + "recv_tokens must be at least 2D [..., recv_pr, H]; got ndim=", recv_dims.size()); + const size_t recv_capacity_per_rank = product(recv_dims, 0, recv_dims.size() - 1); + std::vector recv_shape = {recv_capacity_per_rank, H}; + auto recv_tokens_ = TensorWrapper(recv_tokens->untyped_data(), recv_shape, token_dtype); + + auto recv_w_dims = recv_topk_weights->dimensions(); + NVTE_CHECK(recv_w_dims.size() >= 1, + "recv_topk_weights must be at least 1D; got ndim=", recv_w_dims.size()); + const size_t recv_w_total = product(recv_w_dims, 0, recv_w_dims.size()); + NVTE_CHECK(recv_w_total == recv_capacity_per_rank, "recv_topk_weights total (", recv_w_total, + ") must match recv_tokens recv_pr (", recv_capacity_per_rank, ")"); + std::vector recv_w_shape = {recv_capacity_per_rank}; + auto recv_topk_weights_ = + TensorWrapper(recv_topk_weights->untyped_data(), recv_w_shape, DType::kFloat32); + + NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()}; + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_dispatch(handle, topk_idx_.data(), tokens_.data(), no_win, topk_weights_.data(), no_win, + recv_tokens_.data(), no_win, recv_topk_weights_.data(), no_win, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchHandler, EpDispatchFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // handle_mem + .Arg() // topk_idx + .Arg() // tokens + .Arg() // topk_weights + .Ret() // recv_tokens + .Ret() // recv_topk_weights + .Ret() // workspace (FFI scratch) + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_combine ──────────────────────────────────────────────────────────────── + +Error_Type EpCombineFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type expert_out, + Result_Type result, EpCombineConfig config) { + auto eo_dims = expert_out.dimensions(); + NVTE_CHECK(eo_dims.size() >= 2, + "expert_out must be at least 2D [..., recv_pr, H]; got ndim=", eo_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + const size_t recv_capacity_per_rank = product(eo_dims, 0, eo_dims.size() - 1); + const size_t H = static_cast(eo_dims.back()); + std::vector eo_shape = {recv_capacity_per_rank, H}; + auto eo_dtype = convert_ffi_datatype_to_te_dtype(expert_out.element_type()); + auto expert_out_ = TensorWrapper(expert_out.untyped_data(), eo_shape, eo_dtype); + + auto res_dims = result->dimensions(); + NVTE_CHECK(res_dims.size() >= 2, + "result must be at least 2D [..., H]; got ndim=", res_dims.size()); + const size_t res_T_flat = product(res_dims, 0, res_dims.size() - 1); + NVTE_CHECK(static_cast(res_T_flat) == config.num_local_tokens, + "result leading-dim product (", res_T_flat, ") must equal num_local_tokens (", + config.num_local_tokens, ")"); + std::vector res_shape = {res_T_flat, H}; + auto result_ = TensorWrapper(result->untyped_data(), res_shape, eo_dtype); + + NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()}; + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_combine(handle, expert_out_.data(), no_win, result_.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineHandler, EpCombineFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // handle_mem + .Arg() // expert_out + .Ret() // result + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_dispatch_bwd ─────────────────────────────────────────────────────────── + +Error_Type EpDispatchBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type grad, + Buffer_Type g_recv_topk_weights, Result_Type grad_tokens, + Result_Type grad_topk_weights, EpDispatchBwdConfig config) { + auto grad_dims = grad.dimensions(); + NVTE_CHECK(grad_dims.size() >= 2, + "grad must be at least 2D [..., recv_pr, H]; got ndim=", grad_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + const size_t recv_capacity_per_rank = product(grad_dims, 0, grad_dims.size() - 1); + const size_t H = static_cast(grad_dims.back()); + std::vector g_shape = {recv_capacity_per_rank, H}; + auto g_dtype = convert_ffi_datatype_to_te_dtype(grad.element_type()); + auto grad_ = TensorWrapper(grad.untyped_data(), g_shape, g_dtype); + + auto gw_dims = g_recv_topk_weights.dimensions(); + NVTE_CHECK( + gw_dims.size() >= 1, + "g_recv_topk_weights rank must flatten to recv_capacity_per_rank; got ndim=", gw_dims.size()); + const size_t gw_total = product(gw_dims, 0, gw_dims.size()); + NVTE_CHECK(gw_total == recv_capacity_per_rank, "g_recv_topk_weights total (", gw_total, + ") must match grad recv_pr (", recv_capacity_per_rank, ")"); + std::vector gw_shape = {recv_capacity_per_rank}; + auto g_recv_topk_weights_ = + TensorWrapper(g_recv_topk_weights.untyped_data(), gw_shape, DType::kFloat32); + + auto out_dims = grad_tokens->dimensions(); + NVTE_CHECK(out_dims.size() >= 2, + "grad_tokens must be at least 2D [..., H], got ndim=", out_dims.size()); + const size_t T_flat = product(out_dims, 0, out_dims.size() - 1); + NVTE_CHECK(static_cast(T_flat) == config.num_local_tokens, + "grad_tokens leading-dim product (", T_flat, ") must equal num_local_tokens (", + config.num_local_tokens, ")"); + std::vector out_shape = {T_flat, H}; + auto grad_tokens_ = TensorWrapper(grad_tokens->untyped_data(), out_shape, g_dtype); + + auto gtw_dims = grad_topk_weights->dimensions(); + NVTE_CHECK(gtw_dims.size() >= 2, + "grad_topk_weights must be at least 2D [..., top_k]; got ndim=", gtw_dims.size()); + const size_t gtw_T_flat = product(gtw_dims, 0, gtw_dims.size() - 1); + NVTE_CHECK(gtw_T_flat == T_flat, "grad_topk_weights leading-dim product (", gtw_T_flat, + ") must equal grad_tokens leading-dim product (", T_flat, ")"); + const size_t top_k = static_cast(gtw_dims.back()); + NVTE_CHECK(static_cast(top_k) == config.top_k, "top_k attr (", config.top_k, + ") must match grad_topk_weights last dim (", top_k, ")"); + std::vector gtw_shape = {T_flat, top_k}; + auto grad_topk_weights_ = + TensorWrapper(grad_topk_weights->untyped_data(), gtw_shape, DType::kFloat32); + + NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()}; + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_dispatch_bwd(handle, grad_.data(), no_win, g_recv_topk_weights_.data(), no_win, + grad_tokens_.data(), grad_topk_weights_.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchBwdHandler, EpDispatchBwdFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // handle_mem + .Arg() // grad (w.r.t. recv_tokens) + .Arg() // g_recv_topk_weights + .Ret() // grad_tokens + .Ret() // grad_topk_weights + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_combine_bwd ──────────────────────────────────────────────────────────── + +Error_Type EpCombineBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type grad, + Result_Type grad_expert_out, EpCombineBwdConfig config) { + auto grad_dims = grad.dimensions(); + NVTE_CHECK(grad_dims.size() >= 2, + "grad must be at least 2D [..., H], got ndim=", grad_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + const size_t T_flat = product(grad_dims, 0, grad_dims.size() - 1); + const size_t H = static_cast(grad_dims.back()); + std::vector g_shape = {T_flat, H}; + auto g_dtype = convert_ffi_datatype_to_te_dtype(grad.element_type()); + auto grad_ = TensorWrapper(grad.untyped_data(), g_shape, g_dtype); + + auto out_dims = grad_expert_out->dimensions(); + NVTE_CHECK(out_dims.size() >= 2, + "grad_expert_out must be at least 2D [..., recv_pr, H]; got ndim=", out_dims.size()); + const size_t recv_capacity_per_rank = product(out_dims, 0, out_dims.size() - 1); + const size_t out_H = static_cast(out_dims.back()); + NVTE_CHECK(out_H == H, "grad_expert_out hidden dim (", out_H, ") must match grad H (", H, ")"); + std::vector out_shape = {recv_capacity_per_rank, H}; + auto grad_expert_out_ = TensorWrapper(grad_expert_out->untyped_data(), out_shape, g_dtype); + + NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()}; + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_combine_bwd(handle, grad_.data(), no_win, grad_expert_out_.data(), no_win, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineBwdHandler, EpCombineBwdFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // handle_mem + .Arg() // grad (w.r.t. result) + .Ret() // grad_expert_out + .Attrs(), + FFI_CudaGraph_Traits); + +} // namespace jax +} // namespace transformer_engine + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::EpPrepareConfig, ::xla::ffi::StructMember("handle_id"), + ::xla::ffi::StructMember("dispatch_output_per_expert_alignment")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::EpDispatchConfig, + ::xla::ffi::StructMember("handle_id"), + ::xla::ffi::StructMember("top_k")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::EpCombineConfig, + ::xla::ffi::StructMember("handle_id"), + ::xla::ffi::StructMember("num_local_tokens")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::EpDispatchBwdConfig, + ::xla::ffi::StructMember("handle_id"), + ::xla::ffi::StructMember("num_local_tokens"), + ::xla::ffi::StructMember("top_k")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::EpCombineBwdConfig, + ::xla::ffi::StructMember("handle_id")); + +#endif // NVTE_WITH_NCCL_EP diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 70d0403b3e..b34f8739ee 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -101,6 +101,15 @@ pybind11::dict Registrations() { dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler); dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); +#ifdef NVTE_WITH_NCCL_EP + // Expert Parallelism + dict["te_ep_prepare_ffi"] = EncapsulateFFI(EpPrepareHandler); + dict["te_ep_dispatch_ffi"] = EncapsulateFFI(EpDispatchHandler); + dict["te_ep_combine_ffi"] = EncapsulateFFI(EpCombineHandler); + dict["te_ep_dispatch_bwd_ffi"] = EncapsulateFFI(EpDispatchBwdHandler); + dict["te_ep_combine_bwd_ffi"] = EncapsulateFFI(EpCombineBwdHandler); +#endif // NVTE_WITH_NCCL_EP + // TopK dict["te_topk_ffi"] = EncapsulateFFI(TopkHandler); @@ -127,6 +136,15 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); m.def("get_grouped_gemm_setup_workspace_size", &nvte_get_grouped_gemm_setup_workspace_size); +#ifdef NVTE_WITH_NCCL_EP + m.def("initialize_ep_communicator", &EpInitialize, pybind11::arg("unique_id_bytes"), + pybind11::arg("ep_size"), pybind11::arg("rank_within_group"), pybind11::arg("num_experts"), + pybind11::arg("max_tokens_per_rank"), pybind11::arg("max_recv_tokens_per_rank"), + pybind11::arg("hidden_dim"), pybind11::arg("max_num_sms") = 0); + m.def("shutdown_ep_communicator", &EpShutdown); + m.def("ep_register_layer", &EpRegisterLayer, pybind11::arg("top_k"), + pybind11::arg("dispatch_output_per_expert_alignment") = 0); +#endif // NVTE_WITH_NCCL_EP pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py new file mode 100644 index 0000000000..40d07bc3d4 --- /dev/null +++ b/transformer_engine/jax/ep.py @@ -0,0 +1,303 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX Expert Parallelism (EP) API.""" + +import atexit +import ctypes +from functools import partial + +import jax +import jax.numpy as jnp +import jax.experimental.multihost_utils as jmu +import numpy as np + +import transformer_engine_jax +import transformer_engine.jax.cpp_extensions as tex +from transformer_engine.jax.cpp_extensions.ep import EpHandle +from transformer_engine.jax.sharding import global_mesh_resource, get_mesh_axis_size + +ep_prepare = tex.ep_prepare + +__all__ = [ + "EpHandle", + "ep_bootstrap", + "ep_prepare", + "ep_dispatch", + "ep_combine", +] + +_atexit_registered = False + + +# ── Bootstrap ──────────────────────────────────────────────────────────────── + + +def ep_bootstrap( + world_size, + rank, + ep_size, + num_experts, + max_tokens_per_rank, + recv_capacity_per_rank, + hidden_dim, + max_num_sms=0, +): + """Initialize the EP communicator. Call once per process before any EP op. + + max_num_sms caps the SMs allotted to EP kernels (0 = auto). + """ + if world_size < 2: + raise ValueError( + f"ep_bootstrap requires world_size >= 2 (got {world_size}); NCCL EP needs" + " at least 2 ranks to form a group." + ) + if world_size % ep_size != 0: + raise ValueError( + f"world_size ({world_size}) must be divisible by ep_size ({ep_size}); otherwise" + " some EP groups would have fewer than ep_size ranks and ncclCommInitRank would hang." + ) + if num_experts % ep_size != 0: + raise ValueError(f"num_experts ({num_experts}) must be divisible by ep_size ({ep_size}).") + if jax.local_device_count() != 1: + raise ValueError( + "ep_bootstrap requires one local device per process (got" + f" jax.local_device_count() = {jax.local_device_count()}); NCCL EP does not" + " support single-process multi-device setups." + ) + UID_SIZE = 128 + dp_color = rank // ep_size + rank_within_group = rank % ep_size + is_color_root = rank_within_group == 0 + if is_color_root: + try: + from nccl import get_unique_id + + uid_bytes = bytes(get_unique_id())[:UID_SIZE] + except ImportError: + libnccl = ctypes.CDLL("libnccl.so.2", use_errno=True) + uid_arr = (ctypes.c_uint8 * UID_SIZE)() + ret = libnccl.ncclGetUniqueId(ctypes.cast(uid_arr, ctypes.c_void_p)) + assert ret == 0, f"ncclGetUniqueId failed with code {ret}" + uid_bytes = bytes(uid_arr) + else: + uid_bytes = bytes(UID_SIZE) + + uid_arr = jnp.frombuffer(uid_bytes, dtype=jnp.uint8) + all_uids = jmu.process_allgather(uid_arr).reshape(world_size, UID_SIZE) + uid_bytes = bytes(np.asarray(all_uids[dp_color * ep_size]).tolist()) + + ep_resource = global_mesh_resource().ep_resource + if ep_resource is None: + raise ValueError( + "ep_bootstrap requires MeshResource.ep_resource to be set; enter a" + " global_shard_guard(MeshResource(..., ep_resource=)) before bootstrap." + ) + mesh_ep_size = get_mesh_axis_size(ep_resource) + if mesh_ep_size != ep_size: + raise ValueError( + f"ep_bootstrap: EpConfig.ep_size ({ep_size}) does not match mesh axis" + f" '{ep_resource}' size ({mesh_ep_size})." + ) + + transformer_engine_jax.initialize_ep_communicator( + uid_bytes, + ep_size, + rank_within_group, + num_experts, + max_tokens_per_rank, + recv_capacity_per_rank, + hidden_dim, + max_num_sms=int(max_num_sms), + ) + + # Shutdown ordering: + # - Python atexit is LIFO. ep_bootstrap runs jmu.process_allgather first, + # which assumes jax.distributed.initialize() ran earlier, so JAX's + # distributed atexit hooks are already registered before this one. Ours + # therefore fires first at exit — fine, because EpShutdown only touches + # NCCL (ncclEpGroupDestroy + ncclCommDestroy) and does not depend on + # JAX's coordination service. Do not add JAX calls to EpShutdown. + # - Running before C++ static destructors avoids the cudartUnloading + # hazard; the C++ destructors are intentionally no-ops. + global _atexit_registered + if not _atexit_registered: + atexit.register(transformer_engine_jax.shutdown_ep_communicator) + _atexit_registered = True + + tex.ep.set_ep_config( + tex.ep.EpConfig( + world_size=world_size, + rank=rank, + ep_size=ep_size, + num_experts=num_experts, + num_local_experts=num_experts // ep_size, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=hidden_dim, + ) + ) + + +# ── ep_dispatch (custom_vjp) ───────────────────────────────────────────────── + + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4)) +def ep_dispatch( + topk_idx, + tokens, + topk_weights, + recv_capacity_per_rank, + dispatch_output_per_expert_alignment=0, +): + """Scatter tokens and weights to expert ranks. + + Inputs are 2D ``[T, H]`` or 3D ``[B, S, H]``. Only the leading dim may + be sharded — axis ∈ {ep, (dp, ep), dp, None}; trailing dims replicated. + + Args: + topk_idx: ``[..., top_k]`` int32/int64 routing indices. + tokens: ``[..., H]`` activations (matching leading dims). + topk_weights: ``[..., top_k]`` float32 routing weights. + recv_capacity_per_rank: STATIC int. Per-rank recv slot count. + dispatch_output_per_expert_alignment: STATIC int. Per-expert slot + alignment; 0 disables. + + Returns: + ``(recv_tokens, recv_topk_weights, handle, token_counts)`` where + ``recv_tokens`` is 3D ``[num_procs, recv_capacity_per_rank, H]`` + sharded ``(("dp","ep"), None, None)`` (or ``("ep", None, None)`` if + DP is unset), and ``recv_topk_weights`` is 2D + ``[num_procs, recv_capacity_per_rank]`` similarly sharded. Pass + ``handle`` to the matching ``ep_combine``. + """ + return _dispatch_fwd( + topk_idx, + tokens, + topk_weights, + recv_capacity_per_rank, + dispatch_output_per_expert_alignment, + )[0] + + +def _dispatch_fwd( + topk_idx, + tokens, + topk_weights, + recv_capacity_per_rank, + dispatch_output_per_expert_alignment, +): + top_k = int(topk_weights.shape[-1]) + token_counts, handle = tex.ep_prepare(topk_idx, dispatch_output_per_expert_alignment) + recv_tokens, recv_topk_weights, handle = tex.ep_dispatch_fwd( + handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank + ) + out_leading = tuple(tokens.shape[:-1]) + primal = (recv_tokens, recv_topk_weights, handle, token_counts) + return primal, (handle, out_leading, top_k) + + +def _dispatch_bwd(recv_capacity_per_rank, dispatch_output_per_expert_alignment, res, g_outputs): + del recv_capacity_per_rank, dispatch_output_per_expert_alignment + handle, out_leading, top_k = res + # Re-pin cotangent sharding: XLA transpose can drop the EP axis on a + # single-fwd-output cotangent, landing a global tensor in the FFI. + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer = gsr.dp_resource or gsr.fsdp_resource + leading = (outer, ep_axis) if outer is not None else ep_axis + g_recv_tokens = jax.lax.with_sharding_constraint( + g_outputs[0], jax.sharding.PartitionSpec(leading, None, None) + ) + g_recv_topk_weights = jax.lax.with_sharding_constraint( + g_outputs[1], jax.sharding.PartitionSpec(leading, None) + ) + grad_tokens, grad_topk_weights = tex.ep_dispatch_bwd( + handle, g_recv_tokens, g_recv_topk_weights, top_k, out_leading + ) + return (None, grad_tokens, grad_topk_weights) + + +ep_dispatch.defvjp(_dispatch_fwd, _dispatch_bwd) + + +# ── ep_combine (custom_vjp) ────────────────────────────────────────────────── + + +@partial(jax.custom_vjp, nondiff_argnums=(4, 5)) +def ep_combine( + handle, token_counts, expert_out, recv_topk_weights, num_local_tokens, out_sharding=None +): + """Reduce weighted expert outputs back to source ranks. + + Args: + handle: ``EpHandle`` from a matching ``ep_dispatch`` call. + token_counts: ``[num_procs, num_local_experts]`` int32 (passed through). + expert_out: ``[num_procs, recv_capacity_per_rank, H]`` post-FFN activations. + recv_topk_weights: ``[num_procs, recv_capacity_per_rank]`` float32 weights + returned by ``ep_dispatch``. + num_local_tokens: STATIC int or tuple. int → 2D output ``[T, H]``; + tuple → N-D output ``[*tuple, H]``. + out_sharding: STATIC optional ``PartitionSpec`` tuple for the + output. Defaults to ``(("dp","ep"), *None)`` when + DP is set, else ``("ep", *None)``. Pass a custom + spec to override; only the leading dim may be + sharded. + + Returns: + ``[..., H]`` combined output shaped per ``num_local_tokens``. + """ + return _combine_fwd( + handle, token_counts, expert_out, recv_topk_weights, num_local_tokens, out_sharding + )[0] + + +def _make_valid_mask(recv_topk_weights, dtype): + # recv_topk_weights == 0 marks a padded slot. + return (recv_topk_weights != 0).astype(dtype)[..., None] + + +def _combine_fwd( + handle, token_counts, expert_out, recv_topk_weights, num_local_tokens, out_sharding +): + del token_counts + w = recv_topk_weights[..., None] + mask = _make_valid_mask(recv_topk_weights, jnp.float32) + weighted = (expert_out.astype(jnp.float32) * w * mask).astype(expert_out.dtype) + result = tex.ep_combine_fwd(handle, weighted, num_local_tokens, out_partition_spec=out_sharding) + return result, (handle, recv_topk_weights, expert_out) + + +def _combine_bwd(_num_local_tokens, _out_sharding, res, g_result): + handle, recv_topk_weights, expert_out = res + # expert_out is [..., recv_pr, H]; pull recv_pr from the second-to-last dim. + recv_capacity_per_rank = expert_out.shape[-2] + # Re-pin cotangent sharding: same XLA-transpose workaround as _dispatch_bwd. + gsr = global_mesh_resource() + if _out_sharding is not None: + spec = jax.sharding.PartitionSpec(*_out_sharding) + else: + ep_axis = gsr.ep_resource + outer = gsr.dp_resource or gsr.fsdp_resource + leading = (outer, ep_axis) if outer is not None and ep_axis is not None else ep_axis + spec = ( + jax.sharding.PartitionSpec(leading, *([None] * (g_result.ndim - 1))) + if leading is not None + else None + ) + if spec is not None: + g_result = jax.lax.with_sharding_constraint(g_result, spec) + grad_weighted = tex.ep_combine_bwd(handle, g_result, recv_capacity_per_rank) + w = recv_topk_weights[..., None] + mask = _make_valid_mask(recv_topk_weights, jnp.float32) + grad_weighted_f32 = grad_weighted.astype(jnp.float32) + grad_expert_out = (grad_weighted_f32 * w * mask).astype(grad_weighted.dtype) + grad_recv_topk_weights = ( + (grad_weighted_f32 * expert_out.astype(jnp.float32) * mask) + .sum(axis=-1) + .astype(recv_topk_weights.dtype) + ) + return (None, None, grad_expert_out, grad_recv_topk_weights) + + +ep_combine.defvjp(_combine_fwd, _combine_bwd) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 182a4a2e00..1dbdfbc533 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -332,7 +332,12 @@ class MeshResource: fsdp_resource: Axis name for full-sharded data parallelism, default is None pp_resource: Axis name for pipeline parallelism (layer sharding), default is None cp_resource: Axis name for context parallelism (sequence sharding), default is None - ep_resource: Axis name for expert parallelism (MoE expert sharding), default is None + ep_resource: Axis name for expert parallelism. Dispatch input tokens + must be sharded on their leading dim by ``ep_resource`` (alone or + compound with ``dp_resource`` / ``fsdp_resource`` as outer, e.g. + ``PartitionSpec(("dp", "ep"), None, None)``). Dispatch output + ``[ep_size, recv_capacity, H]`` is always sharded by ``ep_resource`` + on the leading ``ep_size`` dim. """ dp_resource: str = None @@ -475,3 +480,8 @@ def dp_or_fsdp_axis_size(): dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource) fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource) return dp_size if dp_size > 1 else fsdp_size + + +def ep_axis_size(): + """Get the size of the dispatch/EP axis (ep_resource). Returns 1 if unset.""" + return get_mesh_axis_size(global_mesh_resource().ep_resource) From b43710e538f6626ad91b0dee4a9b735bfe2a5fe9 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Sat, 23 May 2026 00:31:54 +0000 Subject: [PATCH 10/29] JAX EP: tie NCCL comm lifetime to JAX executables via XLA stateful FFI Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/base.py | 11 + transformer_engine/jax/csrc/extensions.h | 21 +- transformer_engine/jax/csrc/extensions/ep.cpp | 273 +++++++++++------- .../jax/csrc/extensions/pybind.cpp | 28 +- transformer_engine/jax/ep.py | 15 +- 5 files changed, 222 insertions(+), 126 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 6eb588c849..2cdef4bfe7 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -266,6 +266,17 @@ def _gspmd_wrapper(*args, **kwargs): for _name, _value in transformer_engine_jax.registrations().items(): ffi.register_ffi_target(_name, _value, platform="CUDA") +# Register EpInstanceState (no-op when TE is built without NCCL EP). +if hasattr(transformer_engine_jax, "get_ep_instance_state_type_id"): + ffi.register_ffi_type( + "EpInstanceState", + { + "type_id": transformer_engine_jax.get_ep_instance_state_type_id(), + "type_info": transformer_engine_jax.get_ep_instance_state_type_info(), + }, + platform="CUDA", + ) + def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False): """ diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 4d8b097f27..62e762a5bb 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -200,19 +200,20 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); -// EP bootstrap (called once per process) -void EpInitialize(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group, - int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, - int hidden_dim, int max_num_sms); -// EP shutdown — registered as a Python atexit hook so it runs before -// C++ static destructors of the JAX extension and libtransformer_engine.so. -void EpShutdown(); -// Host-only: register an EP layer. Returns (handle_id, handle_mem_size) where -// handle_id is baked into each FFI op as a static int64 attribute (no D2H sync -// per op) and handle_mem_size sizes the caller's handle_mem buffer. +// Bootstrap EP (eager NCCL comm init); anchor released by ReleaseEpResources. +void SetEpBootstrapParams(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms); +void ReleaseEpResources(); +// Register an EP layer; returns (handle_id, handle_mem_size). pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment); +// EpInstanceState type_id / type_info capsules for jax.ffi.register_ffi_type. +pybind11::capsule GetEpInstanceStateTypeIdCapsule(); +pybind11::capsule GetEpInstanceStateTypeInfoCapsule(); + // EP FFI handlers +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpInstantiateHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(EpPrepareHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(EpDispatchHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(EpCombineHandler); diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index e2c50135aa..5dc05de0ae 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -10,8 +10,10 @@ #include +#include #include #include +#include #include #include "../extensions.h" @@ -21,52 +23,85 @@ namespace transformer_engine { namespace jax { -namespace { +// NCCL comm + EPBackend lifetime tracks live JAX executables via XLA stateful FFI. + +struct EpBootstrapParams { + std::array uid_bytes{}; + int ep_size = 0; + int rank_within_group = 0; + int num_experts = 0; + int max_tokens_per_rank = 0; + int max_recv_tokens_per_rank = 0; + int hidden_dim = 0; + int max_num_sms = 0; +}; -// Process-lifetime owner of the EP ncclComm_t. Created from a broadcast -// ncclUniqueId during EpInitialize; destroyed by EpShutdown (registered as a -// Python atexit hook from ep.py so it runs before C++ static destructors). -class EpCommManager { +class EpResources { public: - static EpCommManager& get() { - static EpCommManager inst; - return inst; - } - - void init_from_uid(const uint8_t* uid_bytes, int ep_size, int rank_within_group) { - std::lock_guard lock(mutex_); - NVTE_CHECK(comm_ == nullptr, "EP comm already initialized for this process"); + explicit EpResources(const EpBootstrapParams& p) { ncclUniqueId uid; - std::memcpy(&uid, uid_bytes, sizeof(uid)); - NVTE_CHECK_NCCL(ncclCommInitRank(&comm_, ep_size, uid, rank_within_group)); + std::memcpy(&uid, p.uid_bytes.data(), sizeof(uid)); + NVTE_CHECK_NCCL(ncclCommInitRank(&comm_, p.ep_size, uid, p.rank_within_group)); + NVTEEpGroupConfig cfg{.ep_size = p.ep_size, + .num_experts = p.num_experts, + .max_tokens_per_rank = p.max_tokens_per_rank, + .max_recv_tokens_per_rank = p.max_recv_tokens_per_rank, + .hidden_dim = p.hidden_dim, + .max_num_sms = p.max_num_sms}; + try { + nvte_ep_initialize(static_cast(comm_), cfg); + } catch (...) { + ncclCommDestroy(comm_); + comm_ = nullptr; + throw; + } } - ncclComm_t comm() const { return comm_; } - - void shutdown() { - std::lock_guard lock(mutex_); + ~EpResources() { if (comm_ == nullptr) return; + nvte_ep_shutdown(); ncclCommDestroy(comm_); - comm_ = nullptr; } + EpResources(const EpResources&) = delete; + EpResources& operator=(const EpResources&) = delete; + + ncclComm_t comm() const { return comm_; } + private: - EpCommManager() = default; - // Intentionally no NCCL teardown in the destructor: this runs at static-dtor - // time, after Python has finalized and possibly after the CUDA driver - // detaches the context. Calling ncclCommDestroy there has been observed to - // hang or report cudartUnloading. Normal teardown goes through the Python - // atexit hook (shutdown_ep_communicator) registered from ep.py; any path - // that skips that (os._exit, fatal signal) leaks the comm, which the OS - // reaps on process exit. - ~EpCommManager() = default; - EpCommManager(const EpCommManager&) = delete; - EpCommManager& operator=(const EpCommManager&) = delete; - - std::mutex mutex_; ncclComm_t comm_{nullptr}; }; +struct EpInstanceState { + static ::xla::ffi::TypeId id; + static ::xla::ffi::TypeInfo info; + std::shared_ptr resources; +}; + +::xla::ffi::TypeId EpInstanceState::id = {}; +::xla::ffi::TypeInfo EpInstanceState::info = ::xla::ffi::MakeTypeInfo(); + +namespace { + +std::mutex g_ep_mu; +EpBootstrapParams g_ep_params; +bool g_ep_params_set = false; +std::weak_ptr g_ep_resources_weak; +// Python-held anchor so trace-time ep_register_layer finds EPBackend ready. +std::shared_ptr g_ep_resources_anchor; + +std::shared_ptr AcquireEpResources() { + std::lock_guard lock(g_ep_mu); + NVTE_CHECK(g_ep_params_set, + "EP bootstrap params not set; call transformer_engine_jax." + "set_ep_bootstrap_params() (typically via ep_bootstrap) first."); + auto sp = g_ep_resources_weak.lock(); + if (sp) return sp; + sp = std::make_shared(g_ep_params); + g_ep_resources_weak = sp; + return sp; +} + } // namespace // handle_id is baked at jit trace time and carried as a static FFI attribute. @@ -98,36 +133,44 @@ struct EpCombineBwdConfig { // ── Bootstrap helpers ───────────────────────────────────────────────────────── -void EpInitialize(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group, - int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, - int hidden_dim, int max_num_sms) { +// Caches uid + group config and eagerly creates the NCCL comm (ranks +// synchronize via the UID broadcast). +void SetEpBootstrapParams(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms) { std::string uid_str = unique_id_bytes_obj; NVTE_CHECK(static_cast(uid_str.size()) >= 128, "unique_id_bytes must be at least 128 bytes (ncclUniqueId size)."); - EpCommManager::get().init_from_uid(reinterpret_cast(uid_str.data()), ep_size, - rank_within_group); - NVTEEpGroupConfig cfg{.ep_size = ep_size, - .num_experts = num_experts, - .max_tokens_per_rank = max_tokens_per_rank, - .max_recv_tokens_per_rank = max_recv_tokens_per_rank, - .hidden_dim = hidden_dim, - .max_num_sms = max_num_sms}; - // If common rejects the config (validate_config / ncclEpCreateGroup), roll - // the comm back so the two singletons don't end up in inconsistent states - // and the comm doesn't strand until process exit. - try { - nvte_ep_initialize(static_cast(EpCommManager::get().comm()), cfg); - } catch (...) { - EpCommManager::get().shutdown(); - throw; + std::shared_ptr anchor; + { + std::lock_guard lock(g_ep_mu); + NVTE_CHECK(!g_ep_resources_anchor, + "EP bootstrap already initialized; call release_ep_resources() before re-init."); + std::memcpy(g_ep_params.uid_bytes.data(), uid_str.data(), 128); + g_ep_params.ep_size = ep_size; + g_ep_params.rank_within_group = rank_within_group; + g_ep_params.num_experts = num_experts; + g_ep_params.max_tokens_per_rank = max_tokens_per_rank; + g_ep_params.max_recv_tokens_per_rank = max_recv_tokens_per_rank; + g_ep_params.hidden_dim = hidden_dim; + g_ep_params.max_num_sms = max_num_sms; + g_ep_params_set = true; } + // Acquire outside the lock: EpResources ctor runs ncclCommInitRank which is + // a collective and may block on peer ranks. + anchor = AcquireEpResources(); + std::lock_guard lock(g_ep_mu); + g_ep_resources_anchor = std::move(anchor); } -void EpShutdown() { - // Order matters: ep_group_ in common reads from the comm, so tear it down - // first, then destroy the comm. - nvte_ep_shutdown(); - EpCommManager::get().shutdown(); +// Drops the anchor; comm tears down once the last executable also releases. +void ReleaseEpResources() { + std::shared_ptr to_drop; + { + std::lock_guard lock(g_ep_mu); + to_drop = std::move(g_ep_resources_anchor); + } + // to_drop dtor runs outside the lock. } pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment) { @@ -137,10 +180,35 @@ pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_ali return pybind11::make_tuple(handle_id, handle_mem_size); } +pybind11::capsule GetEpInstanceStateTypeIdCapsule() { + return pybind11::capsule(static_cast(&EpInstanceState::id), "xla.ffi.type_id"); +} + +pybind11::capsule GetEpInstanceStateTypeInfoCapsule() { + return pybind11::capsule(static_cast(&EpInstanceState::info), "xla.ffi.type_info"); +} + +// ── Instantiate handler ───────────────────────────────────────────────────── + +static ::xla::ffi::ErrorOr> EpInstantiateImpl() { + auto state = std::make_unique(); + try { + state->resources = AcquireEpResources(); + } catch (const std::exception& e) { + return ::xla::ffi::Unexpected( + ::xla::ffi::Error::Internal(std::string("EP instantiate failed: ") + e.what())); + } + return state; +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpInstantiateHandler, EpInstantiateImpl, FFI::BindInstantiate()); + // ── ep_prepare ──────────────────────────────────────────────────────────────── -Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts, - Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) { +Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type topk_idx, + Result_Type token_counts, Result_Type handle_mem, Result_Type workspace, + EpPrepareConfig config) { + (void)ep_state; // lifetime only. auto topk_dims = topk_idx.dimensions(); NVTE_CHECK(topk_dims.size() >= 2, "topk_idx must be at least 2D [..., top_k], got ndim=", topk_dims.size()); @@ -178,20 +246,22 @@ Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type t XLA_FFI_DEFINE_HANDLER_SYMBOL(EpPrepareHandler, EpPrepareFFI, FFI::Bind() - .Ctx() // stream - .Arg() // topk_idx - .Ret() // token_counts - .Ret() // handle_mem - .Ret() // workspace (FFI scratch) + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // topk_idx + .Ret() // token_counts + .Ret() // handle_mem + .Ret() // workspace (FFI scratch) .Attrs(), FFI_CudaGraph_Traits); // ── ep_dispatch ─────────────────────────────────────────────────────────────── -Error_Type EpDispatchFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type topk_idx, - Buffer_Type tokens, Buffer_Type topk_weights, Result_Type recv_tokens, - Result_Type recv_topk_weights, Result_Type workspace, - EpDispatchConfig config) { +Error_Type EpDispatchFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type topk_idx, Buffer_Type tokens, Buffer_Type topk_weights, + Result_Type recv_tokens, Result_Type recv_topk_weights, + Result_Type workspace, EpDispatchConfig config) { + (void)ep_state; auto token_dims = tokens.dimensions(); NVTE_CHECK(token_dims.size() >= 2, "tokens must be at least 2D [..., H], got ndim=", token_dims.size()); @@ -264,21 +334,23 @@ Error_Type EpDispatchFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Typ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchHandler, EpDispatchFFI, FFI::Bind() - .Ctx() // stream - .Arg() // handle_mem - .Arg() // topk_idx - .Arg() // tokens - .Arg() // topk_weights - .Ret() // recv_tokens - .Ret() // recv_topk_weights - .Ret() // workspace (FFI scratch) + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // topk_idx + .Arg() // tokens + .Arg() // topk_weights + .Ret() // recv_tokens + .Ret() // recv_topk_weights + .Ret() // workspace (FFI scratch) .Attrs(), FFI_CudaGraph_Traits); // ── ep_combine ──────────────────────────────────────────────────────────────── -Error_Type EpCombineFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type expert_out, - Result_Type result, EpCombineConfig config) { +Error_Type EpCombineFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type expert_out, Result_Type result, EpCombineConfig config) { + (void)ep_state; auto eo_dims = expert_out.dimensions(); NVTE_CHECK(eo_dims.size() >= 2, "expert_out must be at least 2D [..., recv_pr, H]; got ndim=", eo_dims.size()); @@ -311,18 +383,21 @@ Error_Type EpCombineFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineHandler, EpCombineFFI, FFI::Bind() - .Ctx() // stream - .Arg() // handle_mem - .Arg() // expert_out - .Ret() // result + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // expert_out + .Ret() // result .Attrs(), FFI_CudaGraph_Traits); // ── ep_dispatch_bwd ─────────────────────────────────────────────────────────── -Error_Type EpDispatchBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type grad, - Buffer_Type g_recv_topk_weights, Result_Type grad_tokens, - Result_Type grad_topk_weights, EpDispatchBwdConfig config) { +Error_Type EpDispatchBwdFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type grad, Buffer_Type g_recv_topk_weights, + Result_Type grad_tokens, Result_Type grad_topk_weights, + EpDispatchBwdConfig config) { + (void)ep_state; auto grad_dims = grad.dimensions(); NVTE_CHECK(grad_dims.size() >= 2, "grad must be at least 2D [..., recv_pr, H]; got ndim=", grad_dims.size()); @@ -380,19 +455,22 @@ Error_Type EpDispatchBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchBwdHandler, EpDispatchBwdFFI, FFI::Bind() - .Ctx() // stream - .Arg() // handle_mem - .Arg() // grad (w.r.t. recv_tokens) - .Arg() // g_recv_topk_weights - .Ret() // grad_tokens - .Ret() // grad_topk_weights + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // grad (w.r.t. recv_tokens) + .Arg() // g_recv_topk_weights + .Ret() // grad_tokens + .Ret() // grad_topk_weights .Attrs(), FFI_CudaGraph_Traits); // ── ep_combine_bwd ──────────────────────────────────────────────────────────── -Error_Type EpCombineBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type grad, - Result_Type grad_expert_out, EpCombineBwdConfig config) { +Error_Type EpCombineBwdFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type grad, Result_Type grad_expert_out, + EpCombineBwdConfig config) { + (void)ep_state; auto grad_dims = grad.dimensions(); NVTE_CHECK(grad_dims.size() >= 2, "grad must be at least 2D [..., H], got ndim=", grad_dims.size()); @@ -424,10 +502,11 @@ Error_Type EpCombineBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_T XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineBwdHandler, EpCombineBwdFFI, FFI::Bind() - .Ctx() // stream - .Arg() // handle_mem - .Arg() // grad (w.r.t. result) - .Ret() // grad_expert_out + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // grad (w.r.t. result) + .Ret() // grad_expert_out .Attrs(), FFI_CudaGraph_Traits); diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index b34f8739ee..0304f37691 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -102,12 +102,22 @@ pybind11::dict Registrations() { dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); #ifdef NVTE_WITH_NCCL_EP - // Expert Parallelism - dict["te_ep_prepare_ffi"] = EncapsulateFFI(EpPrepareHandler); - dict["te_ep_dispatch_ffi"] = EncapsulateFFI(EpDispatchHandler); - dict["te_ep_combine_ffi"] = EncapsulateFFI(EpCombineHandler); - dict["te_ep_dispatch_bwd_ffi"] = EncapsulateFFI(EpDispatchBwdHandler); - dict["te_ep_combine_bwd_ffi"] = EncapsulateFFI(EpCombineBwdHandler); + // Expert Parallelism (instantiate handler pins NCCL comm to executable lifetime). + dict["te_ep_prepare_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpPrepareHandler)); + dict["te_ep_dispatch_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpDispatchHandler)); + dict["te_ep_combine_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpCombineHandler)); + dict["te_ep_dispatch_bwd_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpDispatchBwdHandler)); + dict["te_ep_combine_bwd_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpCombineBwdHandler)); #endif // NVTE_WITH_NCCL_EP // TopK @@ -137,13 +147,15 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); m.def("get_grouped_gemm_setup_workspace_size", &nvte_get_grouped_gemm_setup_workspace_size); #ifdef NVTE_WITH_NCCL_EP - m.def("initialize_ep_communicator", &EpInitialize, pybind11::arg("unique_id_bytes"), + m.def("set_ep_bootstrap_params", &SetEpBootstrapParams, pybind11::arg("unique_id_bytes"), pybind11::arg("ep_size"), pybind11::arg("rank_within_group"), pybind11::arg("num_experts"), pybind11::arg("max_tokens_per_rank"), pybind11::arg("max_recv_tokens_per_rank"), pybind11::arg("hidden_dim"), pybind11::arg("max_num_sms") = 0); - m.def("shutdown_ep_communicator", &EpShutdown); + m.def("release_ep_resources", &ReleaseEpResources); m.def("ep_register_layer", &EpRegisterLayer, pybind11::arg("top_k"), pybind11::arg("dispatch_output_per_expert_alignment") = 0); + m.def("get_ep_instance_state_type_id", &GetEpInstanceStateTypeIdCapsule); + m.def("get_ep_instance_state_type_info", &GetEpInstanceStateTypeInfoCapsule); #endif // NVTE_WITH_NCCL_EP pybind11::enum_(m, "DType", pybind11::module_local()) diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index 40d07bc3d4..d2850defaf 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -100,7 +100,8 @@ def ep_bootstrap( f" '{ep_resource}' size ({mesh_ep_size})." ) - transformer_engine_jax.initialize_ep_communicator( + # Eager NCCL init while ranks are barrier-synced by the UID broadcast above. + transformer_engine_jax.set_ep_bootstrap_params( uid_bytes, ep_size, rank_within_group, @@ -111,18 +112,10 @@ def ep_bootstrap( max_num_sms=int(max_num_sms), ) - # Shutdown ordering: - # - Python atexit is LIFO. ep_bootstrap runs jmu.process_allgather first, - # which assumes jax.distributed.initialize() ran earlier, so JAX's - # distributed atexit hooks are already registered before this one. Ours - # therefore fires first at exit — fine, because EpShutdown only touches - # NCCL (ncclEpGroupDestroy + ncclCommDestroy) and does not depend on - # JAX's coordination service. Do not add JAX calls to EpShutdown. - # - Running before C++ static destructors avoids the cudartUnloading - # hazard; the C++ destructors are intentionally no-ops. + # Release the C++ anchor at interpreter shutdown so RAII can tear down NCCL. global _atexit_registered if not _atexit_registered: - atexit.register(transformer_engine_jax.shutdown_ep_communicator) + atexit.register(transformer_engine_jax.release_ep_resources) _atexit_registered = True tex.ep.set_ep_config( From cb44374b9f329718eab090f75dfd61284eb1b3d8 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Sat, 23 May 2026 20:02:43 +0000 Subject: [PATCH 11/29] JAX EP: expose allow_handle_mem_reloc as opt-in ep_bootstrap parameter Signed-off-by: Phuong Nguyen --- examples/jax/ep/ep_moe.py | 2 ++ tests/jax/test_multi_process_ep.py | 2 ++ transformer_engine/jax/csrc/extensions.h | 2 +- transformer_engine/jax/csrc/extensions/ep.cpp | 7 +++++-- transformer_engine/jax/csrc/extensions/pybind.cpp | 3 ++- transformer_engine/jax/ep.py | 7 +++++++ 6 files changed, 19 insertions(+), 4 deletions(-) diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py index 8dcac02a04..b2f48a6ad3 100644 --- a/examples/jax/ep/ep_moe.py +++ b/examples/jax/ep/ep_moe.py @@ -288,6 +288,8 @@ def main(): max_tokens_per_rank=args.num_tokens, recv_capacity_per_rank=args.recv_capacity_per_rank, hidden_dim=args.hidden, + # XLA reallocates handle_mem between JIT executables. + allow_handle_mem_reloc=True, ) ( diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py index 0658ad9750..7d070fb353 100644 --- a/tests/jax/test_multi_process_ep.py +++ b/tests/jax/test_multi_process_ep.py @@ -122,6 +122,8 @@ def setUpClass(cls): max_tokens_per_rank=TOKENS_PER_DP_SHARD, recv_capacity_per_rank=cls.recv_capacity_per_rank, hidden_dim=HIDDEN_DIM, + # XLA reallocates handle_mem between JIT executables. + allow_handle_mem_reloc=True, ) # ── Bootstrap precondition ──────────────────────────────────────────── diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 62e762a5bb..9e64cf4d73 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -203,7 +203,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); // Bootstrap EP (eager NCCL comm init); anchor released by ReleaseEpResources. void SetEpBootstrapParams(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group, int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, - int hidden_dim, int max_num_sms); + int hidden_dim, int max_num_sms, int allow_handle_mem_reloc); void ReleaseEpResources(); // Register an EP layer; returns (handle_id, handle_mem_size). pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment); diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index 5dc05de0ae..39e2d8be3f 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -34,6 +34,7 @@ struct EpBootstrapParams { int max_recv_tokens_per_rank = 0; int hidden_dim = 0; int max_num_sms = 0; + int allow_handle_mem_reloc = 0; }; class EpResources { @@ -47,7 +48,8 @@ class EpResources { .max_tokens_per_rank = p.max_tokens_per_rank, .max_recv_tokens_per_rank = p.max_recv_tokens_per_rank, .hidden_dim = p.hidden_dim, - .max_num_sms = p.max_num_sms}; + .max_num_sms = p.max_num_sms, + .allow_handle_mem_reloc = p.allow_handle_mem_reloc}; try { nvte_ep_initialize(static_cast(comm_), cfg); } catch (...) { @@ -137,7 +139,7 @@ struct EpCombineBwdConfig { // synchronize via the UID broadcast). void SetEpBootstrapParams(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group, int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, - int hidden_dim, int max_num_sms) { + int hidden_dim, int max_num_sms, int allow_handle_mem_reloc) { std::string uid_str = unique_id_bytes_obj; NVTE_CHECK(static_cast(uid_str.size()) >= 128, "unique_id_bytes must be at least 128 bytes (ncclUniqueId size)."); @@ -154,6 +156,7 @@ void SetEpBootstrapParams(pybind11::bytes unique_id_bytes_obj, int ep_size, int g_ep_params.max_recv_tokens_per_rank = max_recv_tokens_per_rank; g_ep_params.hidden_dim = hidden_dim; g_ep_params.max_num_sms = max_num_sms; + g_ep_params.allow_handle_mem_reloc = allow_handle_mem_reloc; g_ep_params_set = true; } // Acquire outside the lock: EpResources ctor runs ncclCommInitRank which is diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 0304f37691..aeca99510a 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -150,7 +150,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("set_ep_bootstrap_params", &SetEpBootstrapParams, pybind11::arg("unique_id_bytes"), pybind11::arg("ep_size"), pybind11::arg("rank_within_group"), pybind11::arg("num_experts"), pybind11::arg("max_tokens_per_rank"), pybind11::arg("max_recv_tokens_per_rank"), - pybind11::arg("hidden_dim"), pybind11::arg("max_num_sms") = 0); + pybind11::arg("hidden_dim"), pybind11::arg("max_num_sms") = 0, + pybind11::arg("allow_handle_mem_reloc") = 0); m.def("release_ep_resources", &ReleaseEpResources); m.def("ep_register_layer", &EpRegisterLayer, pybind11::arg("top_k"), pybind11::arg("dispatch_output_per_expert_alignment") = 0); diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index d2850defaf..55b4ebec6c 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -42,10 +42,16 @@ def ep_bootstrap( recv_capacity_per_rank, hidden_dim, max_num_sms=0, + allow_handle_mem_reloc=False, ): """Initialize the EP communicator. Call once per process before any EP op. max_num_sms caps the SMs allotted to EP kernels (0 = auto). + + Set ``allow_handle_mem_reloc=True`` only if the caller cannot guarantee a + stable ``handle_mem`` device pointer across calls (e.g. XLA-managed + buffers reallocated between JIT executables). Default raises on + relocation so callers detect handle-aliasing bugs. """ if world_size < 2: raise ValueError( @@ -110,6 +116,7 @@ def ep_bootstrap( recv_capacity_per_rank, hidden_dim, max_num_sms=int(max_num_sms), + allow_handle_mem_reloc=int(bool(allow_handle_mem_reloc)), ) # Release the C++ anchor at interpreter shutdown so RAII can tear down NCCL. From 2012b0ad716e1a37e10252220bbc302275a619b4 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 28 May 2026 16:30:19 -0700 Subject: [PATCH 12/29] jax/ep: decorate EP ops with @compute_on("gpu_stream:collective") Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/ep.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py index 7d112ad5f4..26d0291124 100644 --- a/transformer_engine/jax/cpp_extensions/ep.py +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -19,6 +19,7 @@ import jax import jax.numpy as jnp from jax import dtypes, ffi +from jax.experimental.compute_on import compute_on from jax.sharding import NamedSharding, PartitionSpec import jax.tree_util as jtu @@ -876,6 +877,7 @@ def shardy_sharding_rule(*args): _HANDLE_ID_CALLSITE_CACHE = {} +@compute_on("gpu_stream:collective") def ep_prepare(topk_idx, dispatch_output_per_expert_alignment=0): """Exchange routing metadata; return ``(token_counts, EpHandle)``.""" import sys as _sys @@ -900,6 +902,7 @@ def ep_prepare(topk_idx, dispatch_output_per_expert_alignment=0): return token_counts, EpHandle(handle_mem, handle_id) +@compute_on("gpu_stream:collective") def ep_dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank): """Scatter tokens and weights to expert ranks; returns (recv_tokens, recv_topk_weights, handle).""" top_k = int(topk_weights.shape[-1]) @@ -916,6 +919,7 @@ def ep_dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_ra return recv_tokens, recv_topk_weights, handle +@compute_on("gpu_stream:collective") def ep_combine_fwd(handle, expert_out, num_local_tokens, out_partition_spec=None): """Gather expert outputs back to home ranks. expert_out is pre-weighted.""" out_leading = _normalize_leading_shape(num_local_tokens) @@ -928,6 +932,7 @@ def ep_combine_fwd(handle, expert_out, num_local_tokens, out_partition_spec=None ) +@compute_on("gpu_stream:collective") def ep_dispatch_bwd( handle, grad, g_recv_topk_weights, top_k, num_local_tokens, out_partition_spec=None ): @@ -944,6 +949,7 @@ def ep_dispatch_bwd( ) +@compute_on("gpu_stream:collective") def ep_combine_bwd(handle, grad, recv_capacity_per_rank): """Backward of combine; returns grad_expert_out [num_procs, recv_capacity_per_rank, H].""" return EpCombineBwdPrimitive.outer_primitive.bind( From c04bebb09ae3cd3e1446e251352c4dd89c75f795 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 27 May 2026 18:18:51 -0700 Subject: [PATCH 13/29] ep_bootstrap: add XLA-collective fallback for UID allgather Signed-off-by: Phuong Nguyen --- transformer_engine/jax/ep.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index 55b4ebec6c..b0e404972e 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -30,6 +30,35 @@ _atexit_registered = False +def _allgather_uid(uid_arr, world_size, uid_size): + """Allgather UID bytes across all processes. + + Tries ``jax.experimental.multihost_utils.process_allgather`` first; + falls back to an XLA collective (process-local sharded global array + replicated via ``jax.jit``) when the multihost helper returns a + short buffer, which has been observed under some launchers. + """ + try: + gathered = jmu.process_allgather(uid_arr, tiled=True) + if gathered.size == world_size * uid_size: + return np.asarray(gathered).reshape(world_size, uid_size) + except Exception: # pylint: disable=broad-except + pass + devices = np.asarray(jax.devices()) + if devices.size != world_size: + raise RuntimeError( + f"_allgather_uid fallback expected {world_size} global devices," + f" got {devices.size}." + ) + mesh = jax.sharding.Mesh(devices, ("_uid_all",)) + sharded = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("_uid_all", None)) + replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + local = np.asarray(uid_arr).reshape(1, uid_size) + g_in = jax.make_array_from_process_local_data(sharded, local, (world_size, uid_size)) + g_out = jax.jit(lambda x: x, out_shardings=replicated)(g_in) + return np.asarray(g_out).reshape(world_size, uid_size) + + # ── Bootstrap ──────────────────────────────────────────────────────────────── @@ -90,7 +119,7 @@ def ep_bootstrap( uid_bytes = bytes(UID_SIZE) uid_arr = jnp.frombuffer(uid_bytes, dtype=jnp.uint8) - all_uids = jmu.process_allgather(uid_arr).reshape(world_size, UID_SIZE) + all_uids = _allgather_uid(uid_arr, world_size, UID_SIZE) uid_bytes = bytes(np.asarray(all_uids[dp_color * ep_size]).tolist()) ep_resource = global_mesh_resource().ep_resource From 141558051ebaac7819fba9e6d7ccb4245e883f9d Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 29 May 2026 12:25:06 -0700 Subject: [PATCH 14/29] jax/ep: introduce per-layer EpHandle, drop callsite-frame handle_id cache Signed-off-by: Phuong Nguyen --- examples/jax/ep/ep_moe.py | 15 +-- tests/jax/test_multi_process_ep.py | 122 ++++++++++++++------ transformer_engine/jax/cpp_extensions/ep.py | 107 ++++++++--------- transformer_engine/jax/ep.py | 107 +++++++---------- 4 files changed, 184 insertions(+), 167 deletions(-) diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py index b2f48a6ad3..dae3710526 100644 --- a/examples/jax/ep/ep_moe.py +++ b/examples/jax/ep/ep_moe.py @@ -14,7 +14,7 @@ import numpy as np from jax.sharding import Mesh, NamedSharding, PartitionSpec -from transformer_engine.jax.ep import ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.ep import ep_bootstrap, ep_make_handle, ep_dispatch, ep_combine from transformer_engine.jax.sharding import MeshResource, global_shard_guard @@ -199,6 +199,7 @@ def _moe_step(args, topk_idx, tokens, topk_w, kernels): kernel_spec = PartitionSpec("ep", None, None, None) kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:]) + ep_handle = ep_make_handle(args.top_k, dispatch_output_per_expert_alignment=16) @jax.jit def step(topk_idx, tokens, topk_w, local_kernels): @@ -208,20 +209,16 @@ def step(topk_idx, tokens, topk_w, local_kernels): local_kernels = jax.lax.with_sharding_constraint( local_kernels, NamedSharding(mesh, kernel_spec) ) - slots_per_expert = args.recv_capacity_per_rank // NLE - recv_tokens, recv_topk_w, handle, _tc = ep_dispatch( - topk_idx, - tokens, - topk_w, - args.recv_capacity_per_rank, - dispatch_output_per_expert_alignment=slots_per_expert, + recv_tokens, recv_topk_w, handle_mem, _tc = ep_dispatch( + ep_handle, topk_idx, tokens, topk_w, args.recv_capacity_per_rank ) recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3)) recv_topk_w = jax.lax.with_sharding_constraint(recv_topk_w, NamedSharding(mesh, ep2)) expert_out = _batched_expert_linear(recv_tokens, local_kernels, NLE, dp_size, ep_size) expert_out = jax.lax.with_sharding_constraint(expert_out, NamedSharding(mesh, ep3)) return ep_combine( - handle, + ep_handle, + handle_mem, _tc, expert_out, recv_topk_w, diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py index 7d070fb353..abdbcd32ec 100644 --- a/tests/jax/test_multi_process_ep.py +++ b/tests/jax/test_multi_process_ep.py @@ -29,7 +29,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from transformer_engine.jax.sharding import MeshResource, global_shard_guard -from transformer_engine.jax.ep import ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.ep import ep_bootstrap, ep_make_handle, ep_dispatch, ep_combine from transformer_engine.jax.cpp_extensions.ep import ( ep_prepare, ep_dispatch_fwd, @@ -125,6 +125,8 @@ def setUpClass(cls): # XLA reallocates handle_mem between JIT executables. allow_handle_mem_reloc=True, ) + # One handle key shared by all single-layer tests below. + cls.hk = ep_make_handle(TOP_K) # ── Bootstrap precondition ──────────────────────────────────────────── @@ -204,29 +206,76 @@ def _make_random_inputs(self, seed=42, nonuniform=True): # ── Individual primitives (cpp_extensions level) ────────────────────── - def test_two_prepares_distinct_handle_ids(self): - """Two ep_prepare sites with matching (top_k, alignment) must produce - distinct handle_ids — distinct logical layers cannot share a - HandleEntry. Verified by tracing through jit so the primitive's - outer_primitive.bind path is exercised.""" + def test_two_handles_distinct_ids(self): + """Two ``ep_make_handle`` calls must yield distinct ``handle_id``s; + distinct logical layers cannot share a HandleEntry. Verified through a + jit so each ``ep_prepare`` bind path is exercised.""" _T, topk_idx, _tokens, _w = self._make_identity_inputs() - captured: list = [] + ka, kb = ep_make_handle(TOP_K), ep_make_handle(TOP_K) dp_spec = PartitionSpec(("dp", "ep"), None) with self.mesh, global_shard_guard(self.mr): idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) @jax.jit def run(idx): - _tc_a, ha = ep_prepare(idx) - _tc_b, hb = ep_prepare(idx) - captured.append((ha.handle_id, hb.handle_id)) - return ha.handle_mem, hb.handle_mem + _tc_a, ha = ep_prepare(idx, ka) + _tc_b, hb = ep_prepare(idx, kb) + return ha, hb hm_a, hm_b = run(idx_s) hm_a.block_until_ready() hm_b.block_until_ready() - id_a, id_b = captured[0] - self.assertNotEqual(id_a, id_b, "two ep_prepare calls returned the same handle_id") + self.assertNotEqual(ka.handle_id, kb.handle_id) + + def test_two_layer_dispatch_no_handle_aliasing(self): + """Two ep_dispatch calls in one jit with distinct ``EpHandle``s must + not clobber each other's routing state. Different inputs per layer with + identity routing + uniform weights => both recv buffers must independently + identity-round-trip via ep_combine.""" + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + tokens_b = (tokens.astype(jnp.float32) * -1.0 + 0.25).astype(tokens.dtype) + ka, kb = ep_make_handle(TOP_K), ep_make_handle(TOP_K) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + ta = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) + tb = jax.lax.with_sharding_constraint(tokens_b, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + + def one_layer(hk, idx, toks, w_): + recv_t, recv_w, hm, tc = ep_dispatch( + hk, idx, toks, w_, self.recv_capacity_per_rank + ) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_spec_3d)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_spec_2d)) + return ep_combine( + hk, hm, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None) + ) + + @jax.jit + def run(idx, ta_, tb_, w_): + return one_layer(ka, idx, ta_, w_), one_layer(kb, idx, tb_, w_) + + out_a, out_b = run(idx_s, ta, tb, w) + out_a.block_until_ready() + out_b.block_until_ready() + out_a_g = jmu.process_allgather(out_a, tiled=True) + out_b_g = jmu.process_allgather(out_b, tiled=True) + + self.assertNotEqual(ka.handle_id, kb.handle_id) + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(out_a_g.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, rtol=5e-2, + ) + np.testing.assert_allclose( + np.asarray(out_b_g.astype(jnp.float32)), + np.asarray(tokens_b.astype(jnp.float32)), + atol=5e-2, rtol=5e-2, + ) def test_primitive_prepare(self): """ep_prepare returns the expected shapes and a valid handle id.""" @@ -238,8 +287,8 @@ def test_primitive_prepare(self): @jax.jit def run(idx): - tc, handle = ep_prepare(idx) - return tc, handle.handle_mem + tc, hm = ep_prepare(idx, self.hk) + return tc, hm tc, hm = run(idx_s) tc.block_until_ready() @@ -260,9 +309,9 @@ def _run_identity_round_trip(self, nonuniform): @jax.jit def run(idx, toks, w): - _tc, handle = ep_prepare(idx) - recv_t, recv_w, handle = ep_dispatch_fwd( - handle, idx, toks, w, self.recv_capacity_per_rank + _tc, hm = ep_prepare(idx, self.hk) + recv_t, recv_w = ep_dispatch_fwd( + self.hk, hm, idx, toks, w, self.recv_capacity_per_rank ) recv_t = jax.lax.with_sharding_constraint( recv_t, NamedSharding(self.mesh, ep_spec_3d) @@ -279,7 +328,8 @@ def run(idx, toks, w): weighted, NamedSharding(self.mesh, ep_spec_3d) ) out = ep_combine_fwd( - handle, weighted, T_global, out_partition_spec=(("dp", "ep"), None) + self.hk, hm, weighted, T_global, + out_partition_spec=(("dp", "ep"), None), ) return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) @@ -322,7 +372,7 @@ def loss_fn(toks): toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) - recv_t, recv_w, handle, tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t, recv_w, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) recv_t = jax.lax.with_sharding_constraint( recv_t, NamedSharding(self.mesh, ep_spec_3d) ) @@ -330,7 +380,7 @@ def loss_fn(toks): recv_w, NamedSharding(self.mesh, ep_spec_2d) ) out = ep_combine( - handle, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None) + self.hk, hm, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None) ) return 0.5 * (out.astype(jnp.float32) ** 2).sum() @@ -370,11 +420,12 @@ def test_dispatch_combine_3d_input_output(self): @jax.jit def run(idx, toks, w): - recv_t, recv_w, handle, _tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t, recv_w, hm, _tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) out = ep_combine( - handle, + self.hk, + hm, _tc, recv_t, recv_w, @@ -412,11 +463,12 @@ def test_dispatch_combine_dp_only_first_dim(self): @jax.jit def run(idx, toks, w): - recv_t, recv_w, handle, _tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t, recv_w, hm, _tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) out = ep_combine( - handle, + self.hk, + hm, _tc, recv_t, recv_w, @@ -457,8 +509,8 @@ def loss_fn(toks): toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) - recv_tokens, _recv_w, _handle, _tc = ep_dispatch( - idx, toks, w, self.recv_capacity_per_rank + recv_tokens, _recv_w, _hm, _tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank ) recv_tokens = jax.lax.with_sharding_constraint( recv_tokens, NamedSharding(self.mesh, ep_spec_3d) @@ -503,13 +555,13 @@ def loss_fn(eo): toks = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) - _recv_tokens, recv_w, handle, tc = ep_dispatch( - idx, toks, w, self.recv_capacity_per_rank + _recv_tokens, recv_w, hm, tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank ) recv_w = jax.lax.with_sharding_constraint( recv_w, NamedSharding(self.mesh, PartitionSpec(("dp", "ep"), None)) ) - combined = ep_combine(handle, tc, eo, recv_w, T_global) + combined = ep_combine(self.hk, hm, tc, eo, recv_w, T_global) # Pin combined to dp-sharded so autodiff transpose feeds # ep_combine_bwd a per-shard cotangent. combined = jax.lax.with_sharding_constraint( @@ -549,7 +601,7 @@ def loss_fn(idx_in, tok_in, w_in): tok_in = jax.lax.with_sharding_constraint(tok_in, NamedSharding(self.mesh, dp_spec)) w_in = jax.lax.with_sharding_constraint(w_in, NamedSharding(self.mesh, dp_spec)) _recv_t, recv_w, _h, _tc = ep_dispatch( - idx_in, tok_in, w_in, self.recv_capacity_per_rank + self.hk, idx_in, tok_in, w_in, self.recv_capacity_per_rank ) # Per-slot index scale ⇒ each slot's contribution differs. scale = jnp.asarray( @@ -589,7 +641,7 @@ def run(idx, toks, w): idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) - recv_t, recv_w, handle, tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t, recv_w, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) recv_t = jax.lax.with_sharding_constraint( recv_t, NamedSharding(self.mesh, ep_spec_3d) ) @@ -597,7 +649,7 @@ def run(idx, toks, w): recv_w, NamedSharding(self.mesh, ep_spec_2d) ) out = ep_combine( - handle, tc, recv_t, recv_w, T_dp, out_sharding=(("dp", "ep"), None) + self.hk, hm, tc, recv_t, recv_w, T_dp, out_sharding=(("dp", "ep"), None) ) return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) @@ -634,9 +686,9 @@ def fwd(eo, toks, idx, w): toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) - _rt, rw, handle, tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + _rt, rw, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) rw = jax.lax.with_sharding_constraint(rw, NamedSharding(self.mesh, ep_spec_2d)) - combined = ep_combine(handle, tc, eo, rw, T_dp, out_sharding=(("dp", "ep"), None)) + combined = ep_combine(self.hk, hm, tc, eo, rw, T_dp, out_sharding=(("dp", "ep"), None)) return jax.lax.with_sharding_constraint(combined, NamedSharding(self.mesh, dp_spec)) # jax.vjp + pinned cotangent feeds ep_combine_bwd/ep_dispatch_bwd diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py index 26d0291124..8fb0d90f8a 100644 --- a/transformer_engine/jax/cpp_extensions/ep.py +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -8,7 +8,7 @@ Sharded compound ``(dp_resource, ep_resource)`` when DP is set, else ``ep_resource`` alone. - EpDispatch inputs are 2D ``[T, H]`` or 3D ``[B, S, H]``; only the first - dim may be sharded, with axis ∈ {ep, (dp, ep), dp, None}. Trailing dims + dim may be sharded, with axis in {ep, (dp, ep), dp, None}. Trailing dims must be replicated. ``dp`` alone gets ``ep`` folded in locally. - EpCombine output sharding comes from ``out_sharding`` or defaults to the compound ``(dp, ep)`` axis on the leading dim. @@ -21,7 +21,6 @@ from jax import dtypes, ffi from jax.experimental.compute_on import compute_on from jax.sharding import NamedSharding, PartitionSpec -import jax.tree_util as jtu import transformer_engine_jax from .base import BasePrimitive, register_primitive @@ -34,6 +33,7 @@ "get_ep_config", "get_ep_num_local_experts", "ep_allocate_handle_id", + "ep_make_handle", "ep_prepare", "ep_dispatch_fwd", "ep_combine_fwd", @@ -42,24 +42,6 @@ ] -# Routing-state container threaded through dispatch/combine/*_bwd. -@jtu.register_pytree_node_class -class EpHandle: - def __init__(self, handle_mem, handle_id): - self.handle_mem = handle_mem - self.handle_id = int(handle_id) - - def tree_flatten(self): - return (self.handle_mem,), (self.handle_id,) - - @classmethod - def tree_unflatten(cls, aux, children): - return cls(children[0], aux[0]) - - def __repr__(self): - return f"EpHandle(handle_id={self.handle_id})" - - # ── Module-level EP config ────────────────────────────────────────────────── @@ -101,11 +83,7 @@ def get_ep_num_local_experts() -> int: def ep_allocate_handle_id(top_k: int, dispatch_output_per_expert_alignment: int = 0) -> int: - """Reserve a fresh handle_id for an EP layer. - - Distinct logical layers must each call this — sharing a handle_id across - layers corrupts the routing state, even when (top_k, alignment) match. - """ + """Low-level: reserve a fresh handle_id. Prefer ``ep_make_handle``.""" handle_id, handle_mem_size = transformer_engine_jax.ep_register_layer( int(top_k), int(dispatch_output_per_expert_alignment) ) @@ -114,6 +92,36 @@ def ep_allocate_handle_id(top_k: int, dispatch_output_per_expert_alignment: int return handle_id +@dataclass(frozen=True) +class EpHandle: + """Per-layer EP config + routing-slot identity. + + Carries static layer config and a ``handle_id`` that pins the C++ routing + slot across re-traces. Allocate via ``ep_make_handle``; distinct layers + must hold distinct handles. + """ + + handle_id: int + top_k: int + dispatch_output_per_expert_alignment: int = 0 + + +def ep_make_handle(top_k: int, dispatch_output_per_expert_alignment: int = 0) -> EpHandle: + """Allocate a per-layer EP handle. + + Call once per logical MoE layer at model init (outside ``jax.jit``), then + pass the same handle into every ``ep_dispatch`` / ``ep_combine`` for that + layer. The handle's ``handle_id`` survives re-traces, ``jax.checkpoint`` + rematerialization, and separate inference/training compilations. + """ + handle_id = ep_allocate_handle_id(top_k, dispatch_output_per_expert_alignment) + return EpHandle( + handle_id=handle_id, + top_k=int(top_k), + dispatch_output_per_expert_alignment=int(dispatch_output_per_expert_alignment), + ) + + def _ep_handle_mem_size(handle_id: int) -> int: """Return the handle_mem byte size for an id from ep_allocate_handle_id.""" try: @@ -874,40 +882,23 @@ def shardy_sharding_rule(*args): # ── Public-ish helpers (used by jax/ep.py) ────────────────────────────────── -_HANDLE_ID_CALLSITE_CACHE = {} - - @compute_on("gpu_stream:collective") -def ep_prepare(topk_idx, dispatch_output_per_expert_alignment=0): - """Exchange routing metadata; return ``(token_counts, EpHandle)``.""" - import sys as _sys - - top_k = int(topk_idx.shape[-1]) - alignment = int(dispatch_output_per_expert_alignment) - # Cache handle_id by caller (file:lineno, top_k, alignment): JAX re-traces - # the same call site (e.g. custom_vjp fwd vs primal) and the resulting - # EpHandles must share the same id to compare equal in pytree aux. - f = _sys._getframe(1) - cache_key = (f.f_code.co_filename, f.f_lineno, top_k, alignment) - handle_id = _HANDLE_ID_CALLSITE_CACHE.get(cache_key) - if handle_id is None: - handle_id = ep_allocate_handle_id(top_k, alignment) - _HANDLE_ID_CALLSITE_CACHE[cache_key] = handle_id - token_counts, handle_mem = EpPreparePrimitive.outer_primitive.bind( +def ep_prepare(topk_idx, handle): + """Exchange routing metadata for ``handle``; return ``(token_counts, handle_mem)``.""" + return EpPreparePrimitive.outer_primitive.bind( topk_idx, - handle_id=handle_id, - dispatch_output_per_expert_alignment=alignment, + handle_id=handle.handle_id, + dispatch_output_per_expert_alignment=handle.dispatch_output_per_expert_alignment, is_outer=True, ) - return token_counts, EpHandle(handle_mem, handle_id) @compute_on("gpu_stream:collective") -def ep_dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank): - """Scatter tokens and weights to expert ranks; returns (recv_tokens, recv_topk_weights, handle).""" +def ep_dispatch_fwd(handle, handle_mem, topk_idx, tokens, topk_weights, recv_capacity_per_rank): + """Scatter tokens and weights to expert ranks; returns (recv_tokens, recv_topk_weights).""" top_k = int(topk_weights.shape[-1]) - recv_tokens, recv_topk_weights = EpDispatchPrimitive.outer_primitive.bind( - handle.handle_mem, + return EpDispatchPrimitive.outer_primitive.bind( + handle_mem, topk_idx, tokens, topk_weights, @@ -916,15 +907,14 @@ def ep_dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_ra top_k=top_k, is_outer=True, ) - return recv_tokens, recv_topk_weights, handle @compute_on("gpu_stream:collective") -def ep_combine_fwd(handle, expert_out, num_local_tokens, out_partition_spec=None): +def ep_combine_fwd(handle, handle_mem, expert_out, num_local_tokens, out_partition_spec=None): """Gather expert outputs back to home ranks. expert_out is pre-weighted.""" out_leading = _normalize_leading_shape(num_local_tokens) return EpCombinePrimitive.outer_primitive.bind( - handle.handle_mem, + handle_mem, expert_out, handle_id=handle.handle_id, out_leading_shape=out_leading, @@ -934,12 +924,13 @@ def ep_combine_fwd(handle, expert_out, num_local_tokens, out_partition_spec=None @compute_on("gpu_stream:collective") def ep_dispatch_bwd( - handle, grad, g_recv_topk_weights, top_k, num_local_tokens, out_partition_spec=None + handle, handle_mem, grad, g_recv_topk_weights, top_k, num_local_tokens, + out_partition_spec=None, ): """Backward of dispatch; returns (grad_tokens, grad_topk_weights).""" out_leading = _normalize_leading_shape(num_local_tokens) return EpDispatchBwdPrimitive.outer_primitive.bind( - handle.handle_mem, + handle_mem, grad, g_recv_topk_weights, handle_id=handle.handle_id, @@ -950,10 +941,10 @@ def ep_dispatch_bwd( @compute_on("gpu_stream:collective") -def ep_combine_bwd(handle, grad, recv_capacity_per_rank): +def ep_combine_bwd(handle, handle_mem, grad, recv_capacity_per_rank): """Backward of combine; returns grad_expert_out [num_procs, recv_capacity_per_rank, H].""" return EpCombineBwdPrimitive.outer_primitive.bind( - handle.handle_mem, + handle_mem, grad, handle_id=handle.handle_id, recv_capacity_per_rank=recv_capacity_per_rank, diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index b0e404972e..62bd6691fd 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -14,14 +14,16 @@ import transformer_engine_jax import transformer_engine.jax.cpp_extensions as tex -from transformer_engine.jax.cpp_extensions.ep import EpHandle from transformer_engine.jax.sharding import global_mesh_resource, get_mesh_axis_size ep_prepare = tex.ep_prepare +ep_make_handle = tex.ep_make_handle +EpHandle = tex.EpHandle __all__ = [ "EpHandle", "ep_bootstrap", + "ep_make_handle", "ep_prepare", "ep_dispatch", "ep_combine", @@ -171,64 +173,34 @@ def ep_bootstrap( # ── ep_dispatch (custom_vjp) ───────────────────────────────────────────────── -@partial(jax.custom_vjp, nondiff_argnums=(3, 4)) -def ep_dispatch( - topk_idx, - tokens, - topk_weights, - recv_capacity_per_rank, - dispatch_output_per_expert_alignment=0, -): +@partial(jax.custom_vjp, nondiff_argnums=(0, 4)) +def ep_dispatch(handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank): """Scatter tokens and weights to expert ranks. - Inputs are 2D ``[T, H]`` or 3D ``[B, S, H]``. Only the leading dim may - be sharded — axis ∈ {ep, (dp, ep), dp, None}; trailing dims replicated. - - Args: - topk_idx: ``[..., top_k]`` int32/int64 routing indices. - tokens: ``[..., H]`` activations (matching leading dims). - topk_weights: ``[..., top_k]`` float32 routing weights. - recv_capacity_per_rank: STATIC int. Per-rank recv slot count. - dispatch_output_per_expert_alignment: STATIC int. Per-expert slot - alignment; 0 disables. - - Returns: - ``(recv_tokens, recv_topk_weights, handle, token_counts)`` where - ``recv_tokens`` is 3D ``[num_procs, recv_capacity_per_rank, H]`` - sharded ``(("dp","ep"), None, None)`` (or ``("ep", None, None)`` if - DP is unset), and ``recv_topk_weights`` is 2D - ``[num_procs, recv_capacity_per_rank]`` similarly sharded. Pass - ``handle`` to the matching ``ep_combine``. + ``handle`` is a per-layer ``EpHandle`` from ``ep_make_handle``; distinct + layers must hold distinct handles. Inputs are 2D ``[T, H]`` or 3D + ``[B, S, H]`` with only the leading dim sharded + (axis in {ep, (dp, ep), dp, None}). Returns + ``(recv_tokens, recv_topk_weights, handle_mem, token_counts)``; pass + ``handle_mem`` and ``token_counts`` to the matching ``ep_combine``. """ - return _dispatch_fwd( - topk_idx, - tokens, - topk_weights, - recv_capacity_per_rank, - dispatch_output_per_expert_alignment, - )[0] + return _dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank)[0] -def _dispatch_fwd( - topk_idx, - tokens, - topk_weights, - recv_capacity_per_rank, - dispatch_output_per_expert_alignment, -): +def _dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank): top_k = int(topk_weights.shape[-1]) - token_counts, handle = tex.ep_prepare(topk_idx, dispatch_output_per_expert_alignment) - recv_tokens, recv_topk_weights, handle = tex.ep_dispatch_fwd( - handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank + token_counts, handle_mem = tex.ep_prepare(topk_idx, handle) + recv_tokens, recv_topk_weights = tex.ep_dispatch_fwd( + handle, handle_mem, topk_idx, tokens, topk_weights, recv_capacity_per_rank ) out_leading = tuple(tokens.shape[:-1]) - primal = (recv_tokens, recv_topk_weights, handle, token_counts) - return primal, (handle, out_leading, top_k) + primal = (recv_tokens, recv_topk_weights, handle_mem, token_counts) + return primal, (handle_mem, out_leading, top_k) -def _dispatch_bwd(recv_capacity_per_rank, dispatch_output_per_expert_alignment, res, g_outputs): - del recv_capacity_per_rank, dispatch_output_per_expert_alignment - handle, out_leading, top_k = res +def _dispatch_bwd(handle, recv_capacity_per_rank, res, g_outputs): + del recv_capacity_per_rank + handle_mem, out_leading, top_k = res # Re-pin cotangent sharding: XLA transpose can drop the EP axis on a # single-fwd-output cotangent, landing a global tensor in the FFI. gsr = global_mesh_resource() @@ -242,7 +214,7 @@ def _dispatch_bwd(recv_capacity_per_rank, dispatch_output_per_expert_alignment, g_outputs[1], jax.sharding.PartitionSpec(leading, None) ) grad_tokens, grad_topk_weights = tex.ep_dispatch_bwd( - handle, g_recv_tokens, g_recv_topk_weights, top_k, out_leading + handle, handle_mem, g_recv_tokens, g_recv_topk_weights, top_k, out_leading ) return (None, grad_tokens, grad_topk_weights) @@ -253,31 +225,33 @@ def _dispatch_bwd(recv_capacity_per_rank, dispatch_output_per_expert_alignment, # ── ep_combine (custom_vjp) ────────────────────────────────────────────────── -@partial(jax.custom_vjp, nondiff_argnums=(4, 5)) +@partial(jax.custom_vjp, nondiff_argnums=(0, 5, 6)) def ep_combine( - handle, token_counts, expert_out, recv_topk_weights, num_local_tokens, out_sharding=None + handle, handle_mem, token_counts, expert_out, recv_topk_weights, + num_local_tokens, out_sharding=None, ): """Reduce weighted expert outputs back to source ranks. Args: - handle: ``EpHandle`` from a matching ``ep_dispatch`` call. + handle: ``EpHandle`` matching the ``ep_dispatch`` call. + handle_mem: Routing-state buffer returned by ``ep_dispatch``. token_counts: ``[num_procs, num_local_experts]`` int32 (passed through). expert_out: ``[num_procs, recv_capacity_per_rank, H]`` post-FFN activations. recv_topk_weights: ``[num_procs, recv_capacity_per_rank]`` float32 weights returned by ``ep_dispatch``. - num_local_tokens: STATIC int or tuple. int → 2D output ``[T, H]``; - tuple → N-D output ``[*tuple, H]``. + num_local_tokens: STATIC int or tuple. int -> 2D output ``[T, H]``; + tuple -> N-D output ``[*tuple, H]``. out_sharding: STATIC optional ``PartitionSpec`` tuple for the output. Defaults to ``(("dp","ep"), *None)`` when - DP is set, else ``("ep", *None)``. Pass a custom - spec to override; only the leading dim may be - sharded. + DP is set, else ``("ep", *None)``. Only the leading + dim may be sharded. Returns: ``[..., H]`` combined output shaped per ``num_local_tokens``. """ return _combine_fwd( - handle, token_counts, expert_out, recv_topk_weights, num_local_tokens, out_sharding + handle, handle_mem, token_counts, expert_out, recv_topk_weights, + num_local_tokens, out_sharding, )[0] @@ -287,18 +261,21 @@ def _make_valid_mask(recv_topk_weights, dtype): def _combine_fwd( - handle, token_counts, expert_out, recv_topk_weights, num_local_tokens, out_sharding + handle, handle_mem, token_counts, expert_out, recv_topk_weights, + num_local_tokens, out_sharding, ): del token_counts w = recv_topk_weights[..., None] mask = _make_valid_mask(recv_topk_weights, jnp.float32) weighted = (expert_out.astype(jnp.float32) * w * mask).astype(expert_out.dtype) - result = tex.ep_combine_fwd(handle, weighted, num_local_tokens, out_partition_spec=out_sharding) - return result, (handle, recv_topk_weights, expert_out) + result = tex.ep_combine_fwd( + handle, handle_mem, weighted, num_local_tokens, out_partition_spec=out_sharding + ) + return result, (handle_mem, recv_topk_weights, expert_out) -def _combine_bwd(_num_local_tokens, _out_sharding, res, g_result): - handle, recv_topk_weights, expert_out = res +def _combine_bwd(handle, _num_local_tokens, _out_sharding, res, g_result): + handle_mem, recv_topk_weights, expert_out = res # expert_out is [..., recv_pr, H]; pull recv_pr from the second-to-last dim. recv_capacity_per_rank = expert_out.shape[-2] # Re-pin cotangent sharding: same XLA-transpose workaround as _dispatch_bwd. @@ -316,7 +293,7 @@ def _combine_bwd(_num_local_tokens, _out_sharding, res, g_result): ) if spec is not None: g_result = jax.lax.with_sharding_constraint(g_result, spec) - grad_weighted = tex.ep_combine_bwd(handle, g_result, recv_capacity_per_rank) + grad_weighted = tex.ep_combine_bwd(handle, handle_mem, g_result, recv_capacity_per_rank) w = recv_topk_weights[..., None] mask = _make_valid_mask(recv_topk_weights, jnp.float32) grad_weighted_f32 = grad_weighted.astype(jnp.float32) From 0eee8b8fc82a2b50a8542477f008478a3dfa5e88 Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 3 Jun 2026 16:26:29 -0700 Subject: [PATCH 15/29] [JAX] EP: wire NVTEEpGroupConfig.max_token_dtype through bootstrap PR #3034 commit 9b225cbe added a required NVTEEpGroupConfig.max_token_dtype field. The C++ backend (ep_backend.cpp:349) enforces typeToSize(tok_dtype) <= typeToSize(max_token_dtype) at every dispatch, and the field is also used at group create to size the NCCL EP staging buffers (ep_backend.cpp:221-222). PR #3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written before this field existed and never set it, so any JAX EP group landed with the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from JAX then failed immediately with: tokens dtype (6) wider than group max_token_dtype (0) This commit threads max_token_dtype end-to-end: - transformer_engine/jax/csrc/extensions.h update SetEpBootstrapParams declaration to match the new arity. - transformer_engine/jax/csrc/extensions/ep.cpp add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams; forward it into NVTEEpGroupConfig in the EpResources ctor. - transformer_engine/jax/csrc/extensions/pybind.cpp add the matching pybind11::arg("max_token_dtype") = 0. - transformer_engine/jax/ep.py add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to NVTEDType int, forward to the C++ setter. Carried on the te-ep-fixes branch until PR #3036 exposes the field upstream. See PR #3034 (commit 9b225cbe, ep.h:43) for the field definition. --- transformer_engine/jax/csrc/extensions.h | 3 +- transformer_engine/jax/csrc/extensions/ep.cpp | 8 +++-- .../jax/csrc/extensions/pybind.cpp | 2 +- transformer_engine/jax/ep.py | 31 +++++++++++++++++++ 4 files changed, 40 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 9e64cf4d73..d6392819c0 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -203,7 +203,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); // Bootstrap EP (eager NCCL comm init); anchor released by ReleaseEpResources. void SetEpBootstrapParams(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group, int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, - int hidden_dim, int max_num_sms, int allow_handle_mem_reloc); + int hidden_dim, int max_num_sms, int allow_handle_mem_reloc, + int max_token_dtype); void ReleaseEpResources(); // Register an EP layer; returns (handle_id, handle_mem_size). pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment); diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index 39e2d8be3f..84f24d75bf 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -35,6 +35,7 @@ struct EpBootstrapParams { int hidden_dim = 0; int max_num_sms = 0; int allow_handle_mem_reloc = 0; + int max_token_dtype = 0; }; class EpResources { @@ -49,7 +50,8 @@ class EpResources { .max_recv_tokens_per_rank = p.max_recv_tokens_per_rank, .hidden_dim = p.hidden_dim, .max_num_sms = p.max_num_sms, - .allow_handle_mem_reloc = p.allow_handle_mem_reloc}; + .allow_handle_mem_reloc = p.allow_handle_mem_reloc, + .max_token_dtype = static_cast(p.max_token_dtype)}; try { nvte_ep_initialize(static_cast(comm_), cfg); } catch (...) { @@ -139,7 +141,8 @@ struct EpCombineBwdConfig { // synchronize via the UID broadcast). void SetEpBootstrapParams(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group, int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, - int hidden_dim, int max_num_sms, int allow_handle_mem_reloc) { + int hidden_dim, int max_num_sms, int allow_handle_mem_reloc, + int max_token_dtype) { std::string uid_str = unique_id_bytes_obj; NVTE_CHECK(static_cast(uid_str.size()) >= 128, "unique_id_bytes must be at least 128 bytes (ncclUniqueId size)."); @@ -157,6 +160,7 @@ void SetEpBootstrapParams(pybind11::bytes unique_id_bytes_obj, int ep_size, int g_ep_params.hidden_dim = hidden_dim; g_ep_params.max_num_sms = max_num_sms; g_ep_params.allow_handle_mem_reloc = allow_handle_mem_reloc; + g_ep_params.max_token_dtype = max_token_dtype; g_ep_params_set = true; } // Acquire outside the lock: EpResources ctor runs ncclCommInitRank which is diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index aeca99510a..6020a228e3 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -151,7 +151,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { pybind11::arg("ep_size"), pybind11::arg("rank_within_group"), pybind11::arg("num_experts"), pybind11::arg("max_tokens_per_rank"), pybind11::arg("max_recv_tokens_per_rank"), pybind11::arg("hidden_dim"), pybind11::arg("max_num_sms") = 0, - pybind11::arg("allow_handle_mem_reloc") = 0); + pybind11::arg("allow_handle_mem_reloc") = 0, pybind11::arg("max_token_dtype") = 0); m.def("release_ep_resources", &ReleaseEpResources); m.def("ep_register_layer", &EpRegisterLayer, pybind11::arg("top_k"), pybind11::arg("dispatch_output_per_expert_alignment") = 0); diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index 62bd6691fd..17d00bef87 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -64,6 +64,30 @@ def _allgather_uid(uid_arr, world_size, uid_size): # ── Bootstrap ──────────────────────────────────────────────────────────────── +_TE_DTYPE_FOR_NUMPY = { + np.dtype(np.uint8): transformer_engine_jax.DType.kByte, + np.dtype(np.int32): transformer_engine_jax.DType.kInt32, + np.dtype(np.int64): transformer_engine_jax.DType.kInt64, + np.dtype(np.float32): transformer_engine_jax.DType.kFloat32, + np.dtype(np.float16): transformer_engine_jax.DType.kFloat16, +} + + +def _to_te_dtype_int(dtype): + """Map jax/numpy dtype -> NVTEDType int. bf16 / fp8 / fp4 handled explicitly.""" + if dtype is None: + return int(transformer_engine_jax.DType.kByte) + if dtype == jnp.bfloat16: + return int(transformer_engine_jax.DType.kBFloat16) + np_dtype = np.dtype(dtype) + if np_dtype in _TE_DTYPE_FOR_NUMPY: + return int(_TE_DTYPE_FOR_NUMPY[np_dtype]) + raise ValueError( + f"ep_bootstrap: unsupported max_token_dtype={dtype!r}; supported = " + "uint8 / int32 / int64 / float32 / float16 / bfloat16." + ) + + def ep_bootstrap( world_size, rank, @@ -74,6 +98,7 @@ def ep_bootstrap( hidden_dim, max_num_sms=0, allow_handle_mem_reloc=False, + max_token_dtype=None, ): """Initialize the EP communicator. Call once per process before any EP op. @@ -83,6 +108,11 @@ def ep_bootstrap( stable ``handle_mem`` device pointer across calls (e.g. XLA-managed buffers reallocated between JIT executables). Default raises on relocation so callers detect handle-aliasing bugs. + + ``max_token_dtype`` is the widest token dtype the group will dispatch + (sizes NCCL EP staging buffers at group create). Pass a jax/numpy + dtype, e.g. ``jnp.bfloat16``. Default ``None`` keeps the legacy ``kByte`` + behavior, which only accepts 1-byte tensors. """ if world_size < 2: raise ValueError( @@ -148,6 +178,7 @@ def ep_bootstrap( hidden_dim, max_num_sms=int(max_num_sms), allow_handle_mem_reloc=int(bool(allow_handle_mem_reloc)), + max_token_dtype=_to_te_dtype_int(max_token_dtype), ) # Release the C++ anchor at interpreter shutdown so RAII can tear down NCCL. From 10f4b1c7d833f975b5ba2f9bf47a521a8f9e77a8 Mon Sep 17 00:00:00 2001 From: tdophung Date: Tue, 2 Jun 2026 16:19:50 -0700 Subject: [PATCH 16/29] [JAX] MoE: enforce (outer_dp, ep) ordering for TE EP compatibility [JAX] MoE: soft re-pin inbound activations sharding at moe() entry [JAX] MoE: scope gate_logits 2D reshape to topk primitive call [JAX] MoE: add apply_topk_weights_early flag (TE EP backend only) [JAX] MoE: stack wi_0/wi_1 on new axis (4D) instead of concat Signed-off-by: tdophung --- transformer_engine/jax/flax/moe.py | 3 + transformer_engine/jax/moe.py | 150 ++++++++++++++++++++--------- 2 files changed, 105 insertions(+), 48 deletions(-) diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 91346a7a48..b5a4afc2ad 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -147,6 +147,8 @@ class _MoEBlock(TransformerEngineBase): permutation_backend: PermutationBackend = PermutationBackend.PURE_JAX _align_size: int = 0 + apply_topk_weights_early: bool = False + # Dtypes / init / misc dtype: DType = jnp.float32 kernel_init: Optional[Initializer] = None @@ -273,6 +275,7 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: permutation_backend=self.permutation_backend, align_size=self._align_size, gate_inside_vjp=True, + apply_topk_weights_early=self.apply_topk_weights_early, ep_axis=ep_axis, data_parallelism_axes=self.data_parallelism_axes, input_axes=self.input_axes, diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index 2a1c818cb3..4479b9f176 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -53,11 +53,12 @@ from enum import Enum from functools import partial from typing import Any, NewType, Optional, Tuple, Union +import warnings import jax import jax.numpy as jnp from flax import struct as flax_struct -from jax.sharding import PartitionSpec as P +from jax.sharding import NamedSharding, PartitionSpec as P from . import cpp_extensions as tex from .permutation import ( @@ -212,7 +213,7 @@ class _BodyCtx: routing_map: Any dispatch: Any # _DispatchState casted_sorted_x_lhs_trans: Any - casted_wi_rhs_trans: Any # combined [E, H, 2M] residual for fused wi_0|wi_1 bwd + casted_wi_rhs_trans: Any # stacked [E, H, 2, M] residual for fused wi_0|wi_1 bwd gate_proj_out: Any up_proj_out: Any casted_intermediate_lhs_trans: Any @@ -966,12 +967,20 @@ def _body_fwd( # pylint: disable=unused-argument num_ep: int, num_experts_local: int, recv_buffer_rows: int, + apply_topk_weights_early: bool = False, ) -> Tuple[jnp.ndarray, jnp.ndarray, dict]: """Per-shard forward body. Returns ``(output, aux_loss, ctx_dict)``. ``aux_loss`` is always materialized (zeros scalar when disabled) so the ``shard_map``'s ``out_specs`` has a static structure. """ + if apply_topk_weights_early: + # Requires row-aligned per-token weights at the FFN intermediate; + # only available on the TE EP (tex.ep_dispatch) path. + raise NotImplementedError( + "apply_topk_weights_early=True is supported only with the TE EP " + "(tex.ep_dispatch / tex.ep_combine) backend." + ) if not gate_inside_vjp: raise NotImplementedError( "gate_inside_vjp=False is deferred to a follow-up PR; for now" @@ -992,7 +1001,8 @@ def _body_fwd( # pylint: disable=unused-argument # ---------------- Stage 1: gate ---------------- gate_kernel_cast = gate_kernel.astype(x.dtype) - gate_logits = jnp.einsum("bsh,he->bse", x, gate_kernel_cast) + gate_logits = jnp.einsum("bsh,he->bse", x, gate_kernel_cast) # [B, S, E] + # tex.fused_topk_with_score_function_* requires rank-2 input. logits_2d = gate_logits.reshape(-1, num_experts) inputs_2d = x.reshape(-1, hidden) @@ -1025,7 +1035,9 @@ def _body_fwd( # pylint: disable=unused-argument if aux_loss_coeff > 0.0: if ep_active: collective_axes: Any = ( - ep_axis if not data_parallelism_axes else (ep_axis, *data_parallelism_axes) + ep_axis + if not data_parallelism_axes + else (*data_parallelism_axes, ep_axis) ) global_logits_2d = jax.lax.all_gather( logits_2d, axis_name=collective_axes, axis=0, tiled=True @@ -1100,22 +1112,17 @@ def _body_fwd( # pylint: disable=unused-argument if q_set_wo == noop_quantizer_set: wo = wo.astype(sorted_x.dtype) - # GEMM 1+2 (fused): up_proj_combined = sorted_x @ wi where - # wi := concat([wi_0, wi_1], axis=-1) -> shape [E, H, 2M] - # combined_out := sorted_x @ wi -> shape [T, 2M] - # Splitting the output back into ``gate_proj_out`` / ``up_proj_out`` - # is free (it's a slicing reshape). This collapses two grouped - # GEMMs and two grouped quantizes of ``sorted_x`` (one per kernel) - # into one of each. Bias is concatenated the same way. + # Fused gate+up projection: stack wi_0 / wi_1 on a new axis-(-2) so the + # downstream split is a slice on the (unsharded) stack axis. concat on + # axis=-1 would cross the M axis and force a reshard when M is TP-sharded. # - # FP8/MXFP8 caveat: per-expert amax is now computed over [H, 2M] - # rather than [H, M] for each of wi_0 / wi_1 separately, so the - # representable range for one of the two halves may shift slightly - # vs. the pre-fusion code. Numerics tests cover this. + # FP8/MXFP8 caveat: per-expert amax is computed over [H, 2, M] rather than + # [H, M] for each of wi_0 / wi_1 separately, so the representable range for + # one half may shift slightly vs. an unfused pair of casts. inter_M = wi_0.shape[-1] - wi_combined = jnp.concatenate([wi_0, wi_1], axis=-1) + wi_combined = jnp.stack([wi_0, wi_1], axis=-2) wi_combined_bias = ( - jnp.concatenate([wi_0_bias, wi_1_bias], axis=-1) if wi_0_bias is not None else None + jnp.stack([wi_0_bias, wi_1_bias], axis=-2) if wi_0_bias is not None else None ) casted_sorted_x = tex.grouped_quantize(sorted_x, q_set_w0.x, local_group_sizes, flatten_axis=-1) casted_wi = tex.grouped_quantize(wi_combined, q_set_w0.kernel, flatten_axis=-1) @@ -1125,8 +1132,8 @@ def _body_fwd( # pylint: disable=unused-argument contracting_dims=((1,), (1,)), bias=wi_combined_bias, ) - gate_proj_out = combined_out[..., :inter_M] - up_proj_out = combined_out[..., inter_M:] + gate_proj_out = combined_out[..., 0, :] + up_proj_out = combined_out[..., 1, :] casted_sorted_x_lhs_trans = casted_sorted_x.get_tensor(usage=TensorUsage.LHS_TRANS) casted_wi_rhs_trans = casted_wi.get_tensor(usage=TensorUsage.RHS_TRANS) if isinstance(casted_sorted_x_lhs_trans, ScaledTensor): @@ -1253,15 +1260,21 @@ def _body_bwd( # pylint: disable=unused-argument has_wo_bias: bool, has_expert_bias: bool, x_shape: Tuple[int, ...], + apply_topk_weights_early: bool = False, ) -> dict: """Per-shard backward body. Returns a dict of grads keyed identically to the ``captured`` dict consumed by :func:`_body_fwd`.""" + if apply_topk_weights_early: + raise NotImplementedError( + "apply_topk_weights_early=True is supported only with the TE EP " + "(tex.ep_dispatch / tex.ep_combine) backend." + ) if not gate_inside_vjp: raise NotImplementedError("gate_inside_vjp=False is deferred to a follow-up PR.") d_output, d_aux_loss = dy_pair # The fused FFN bwd quantizes via ``q_set_w0`` only (one quantize for - # the [E, H, 2M] fused wi tensor and one for the [T, 2M] fused dgrad), + # the [E, H, 2, M] stacked wi tensor and one for the [T, 2, M] stacked dgrad), # so ``q_set_w1`` is intentionally unused here. q_set_w0, _q_set_w1, q_set_wo = quantizer_sets batch_size, sequence_length, hidden = x_shape @@ -1347,33 +1360,37 @@ def _body_bwd( # pylint: disable=unused-argument (d_gate_proj_out,) = dact_gate_proj_pullback(d_intermediate * ctx.up_proj_out) # ---------------- FFN bwd: GEMM 1+2 fused (wi_0 | wi_1) ---------------- - # Concat the two upstream grads along the output (M) axis, do one - # grouped quantize + one dgrad GEMM + one wgrad GEMM, then split. - # ``ctx.casted_wi_rhs_trans`` has shape [E, H, 2M] from the fwd - # fused quantize, so the dgrad math is: + # Mirror of the fwd stack: combine d_gate / d_up on a new axis=-2, + # run one dgrad + one wgrad GEMM, then split on axis=-2. # d_sorted_x = [d_gate | d_up] @ wi_rhs_trans # = d_gate @ wi_0^T + d_up @ wi_1^T inter_M = d_gate_proj_out.shape[-1] - d_combined = jnp.concatenate([d_gate_proj_out, d_up_proj_out], axis=-1) + d_combined = jnp.stack([d_gate_proj_out, d_up_proj_out], axis=-2) casted_d_combined = tex.grouped_quantize( d_combined, q_set_w0.dgrad, ctx.local_group_sizes, flatten_axis=-1 ) d_sorted_x = tex.grouped_gemm( casted_d_combined.get_tensor(usage=TensorUsage.LHS), ctx.casted_wi_rhs_trans, - contracting_dims=((1,), (2,)), + contracting_dims=((1, 2), (2, 3)), ) d_wi_combined = tex.grouped_gemm( ctx.casted_sorted_x_lhs_trans, casted_d_combined.get_tensor(usage=TensorUsage.RHS), contracting_dims=((0,), (0,)), ) - d_wi_0 = d_wi_combined[..., :inter_M] - d_wi_1 = d_wi_combined[..., inter_M:] + d_wi_0 = d_wi_combined[..., 0, :] + d_wi_1 = d_wi_combined[..., 1, :] if has_wi_bias: - d_wi_combined_bias = tex.grouped_dbias(d_combined, ctx.local_group_sizes) - d_wi_0_bias = d_wi_combined_bias[..., :inter_M] - d_wi_1_bias = d_wi_combined_bias[..., inter_M:] + # grouped_dbias requires rank-2 input; reshape around the call. + # M is not TP-sharded on the bias path, so the reshape is free. + d_combined_2d = d_combined.reshape(d_combined.shape[0], -1) + d_wi_combined_bias_2d = tex.grouped_dbias(d_combined_2d, ctx.local_group_sizes) + d_wi_combined_bias = d_wi_combined_bias_2d.reshape( + *d_wi_combined_bias_2d.shape[:-1], 2, inter_M + ) + d_wi_0_bias = d_wi_combined_bias[..., 0, :] + d_wi_1_bias = d_wi_combined_bias[..., 1, :] else: d_wi_0_bias = None d_wi_1_bias = None @@ -1458,23 +1475,18 @@ def _body_bwd( # pylint: disable=unused-argument score_function=score_function, compute_aux_scores=True, ) - # Step 3: under EP the aux logits were all_gathered along - # ``(ep_axis, *data_parallelism_axes)`` (the latter being FSDP - # axes that shard the batch). The bwd is the inverse of that - # multi-axis tiled all_gather: ``dynamic_slice`` to pick out - # this shard's local rows from the global cotangent. - # - # JAX's convention for tiled ``all_gather(axis_name=(a, b, ...))`` - # is row-major over the tuple: the shard at mesh position - # ``(i_a, i_b, ...)`` writes to rows - # ``[(i_a * size_b * ... + i_b * ... + ...) * local_T : - # + local_T)``. We invert that by computing the same flat - # index here and slicing. + # Inverse of the fwd tiled all_gather along + # ``(*data_parallelism_axes, ep_axis)``: pick out this shard's + # local rows from the global cotangent. JAX's tiled all_gather + # is row-major over the axis-name tuple, so the shard at mesh + # position (i_a, i_b, ...) writes to a contiguous row block + # starting at flat_index * local_T. if ep_active: local_T_aux = ctx.logits_2d.shape[0] - flat_shard = shard_id # ep is the outermost axis in the gather tuple + flat_shard = 0 for ax, sz in zip(data_parallelism_axes, fsdp_sizes): flat_shard = flat_shard * sz + jax.lax.axis_index(ax) + flat_shard = flat_shard * num_ep + shard_id d_aux_logits_local = jax.lax.dynamic_slice( d_aux_logits.astype(ctx.logits_2d.dtype), start_indices=(flat_shard * local_T_aux, 0), @@ -1698,6 +1710,7 @@ def _moe_fwd_rule( # pylint: disable=unused-argument wo_kernel_axes, quantizer_sets, dtype, + apply_topk_weights_early, ): x = with_sharding_constraint_by_logical_axes(x, input_axes) ep_active = ep_axis is not None @@ -1718,6 +1731,7 @@ def _moe_fwd_rule( # pylint: disable=unused-argument "dtype": dtype, "ep_axis": ep_axis, "data_parallelism_axes": data_parallelism_axes, + "apply_topk_weights_early": apply_topk_weights_early, } captured: dict = { "inputs": x, @@ -1792,7 +1806,10 @@ def _moe_fwd_rule( # pylint: disable=unused-argument if not data_parallelism_axes: batch_pspec_axis: Any = ep_axis else: - batch_pspec_axis = (ep_axis, *data_parallelism_axes) + # ep must be innermost: ep_bootstrap forms NCCL EP comms from + # consecutive global ranks (dp_color = rank // ep_size), so the + # comm only stays within one model replica under (outer_dp, ep). + batch_pspec_axis = (*data_parallelism_axes, ep_axis) dp_size = 1 for ax in data_parallelism_axes: dp_size *= mesh.shape[ax] @@ -1876,6 +1893,7 @@ def _moe_bwd_rule( wo_kernel_axes, quantizer_sets, dtype, + apply_topk_weights_early, ctx, dy_pair, ): @@ -1917,6 +1935,7 @@ def _moe_bwd_rule( "has_wo_bias": has_wo_bias, "has_expert_bias": has_expert_bias, "x_shape": x_shape, + "apply_topk_weights_early": apply_topk_weights_early, } if not ep_active: @@ -1936,7 +1955,10 @@ def _moe_bwd_rule( if not data_parallelism_axes: batch_pspec_axis: Any = ep_axis else: - batch_pspec_axis = (ep_axis, *data_parallelism_axes) + # ep must be innermost: ep_bootstrap forms NCCL EP comms from + # consecutive global ranks (dp_color = rank // ep_size), so the + # comm only stays within one model replica under (outer_dp, ep). + batch_pspec_axis = (*data_parallelism_axes, ep_axis) ctx_spec = _build_ctx_specs( ep_axis, batch_pspec_axis, @@ -1995,7 +2017,7 @@ def _grads_dict_to_tuple( # ============================================================================= -@partial(jax.custom_vjp, nondiff_argnums=tuple(range(9, 29))) +@partial(jax.custom_vjp, nondiff_argnums=tuple(range(9, 30))) def _moe( x, gate_kernel, @@ -2026,6 +2048,7 @@ def _moe( wo_kernel_axes, quantizer_sets, dtype, + apply_topk_weights_early, ): # Call in `_moe`'s own signature order to match what JAX will pass # the fwd rule via ``_argnums_partial``. See the comment block at @@ -2061,6 +2084,7 @@ def _moe( wo_kernel_axes, quantizer_sets, dtype, + apply_topk_weights_early, ) return output_pair @@ -2093,8 +2117,14 @@ def moe( # Permutation permutation_backend: PermutationBackend = PermutationBackend.PURE_JAX, align_size: int = 0, - # Gate placement (Phuong: "perhaps as an option") + # Gate placement gate_inside_vjp: bool = True, + # When True, fold per-token top-k weights into the FFN intermediate + # (next to act(gate)*up) instead of into the post-down-projection + # combine. Both placements are mathematically equivalent (down-proj is + # linear); the early placement gives XLA a chance to fuse the multiply + # with the activation. Off by default. + apply_topk_weights_early: bool = False, # Parallelism (resolved by caller from MeshResource) ep_axis: Optional[str] = None, data_parallelism_axes: Tuple[str, ...] = (), @@ -2129,6 +2159,29 @@ def moe( # we bypass also normalizes here. score_function = _validate_score_function(score_function) + # Enforce ((outer_dp..., ep), None, None) on inbound activations. The + # EP comm groups consecutive global ranks (dp_color = rank // ep_size), + # so ep MUST be innermost in the partition spec. Soft re-pin: free if + # upstream already matches, single reshard otherwise. + if ep_axis is not None: + mesh = _get_mesh() + if mesh is None or mesh.empty: + raise ValueError("moe(...) requires an active jax.sharding.Mesh when ep_axis is set.") + expected_leading: Any = ( + (*data_parallelism_axes, ep_axis) if data_parallelism_axes else ep_axis + ) + expected_spec = P(expected_leading, None, None) + actual_spec = getattr(getattr(x, "sharding", None), "spec", None) + if actual_spec is not None and tuple(actual_spec) != tuple(expected_spec): + warnings.warn( + f"moe(...): inbound x sharding {actual_spec} does not match expected " + f"{expected_spec}; inserting a reshard. Apply " + "jax.lax.with_sharding_constraint upstream to avoid this overhead.", + UserWarning, + stacklevel=2, + ) + x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, expected_spec)) + output, aux_loss = _moe( x, gate_kernel, @@ -2159,6 +2212,7 @@ def moe( wo_kernel_axes=wo_kernel_axes, quantizer_sets=quantizer_sets, dtype=dtype, + apply_topk_weights_early=apply_topk_weights_early, ) if aux_loss_coeff <= 0.0: aux_loss = None From 9194fe3e0d7991f7f9f991e01970223fa2aaa48f Mon Sep 17 00:00:00 2001 From: tdophung Date: Tue, 2 Jun 2026 17:59:03 -0700 Subject: [PATCH 17/29] integrate tex.* calls, remove all ragged-a2a + triton/pure jax step by step paths. change tests to collapse in 1 bigger one with different parameters instead of smaller meaningless dtypes/shapes/finite chhecks Signed-off-by: tdophung --- tests/jax/run_te_ep_moe.sh | 122 ++ tests/jax/test_te_ep_moe.py | 762 ++++++++ transformer_engine/jax/flax/moe.py | 109 +- transformer_engine/jax/moe.py | 2620 ++++++++-------------------- 4 files changed, 1678 insertions(+), 1935 deletions(-) create mode 100755 tests/jax/run_te_ep_moe.sh create mode 100644 tests/jax/test_te_ep_moe.py diff --git a/tests/jax/run_te_ep_moe.sh b/tests/jax/run_te_ep_moe.sh new file mode 100755 index 0000000000..32d5f21956 --- /dev/null +++ b/tests/jax/run_te_ep_moe.sh @@ -0,0 +1,122 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Multiprocess (one-GPU-per-process) launcher for the TE-EP MoE custom_vjp +# test suite. Forks one pytest invocation per visible GPU, passing each +# its own --num-process=N --process-id=i, and waits for all of them. Each +# child calls jax.distributed.initialize(..., local_device_ids=process_id) +# so each Python process only sees its one GPU as a local device and the +# participating processes form a global (ep, fsdp) mesh. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +TEST_FILE="$TE_ROOT/tests/jax/test_te_ep_moe.py" +PYTEST_INI="$TE_ROOT/tests/jax/pytest.ini" + +NUM_GPUS="${NUM_GPUS:-$(nvidia-smi -L | wc -l)}" +if [ "$NUM_GPUS" -lt 4 ]; then + echo "[run_te_ep_moe.sh] need >=4 GPUs (got $NUM_GPUS); aborting" >&2 + exit 1 +fi + +export XLA_PYTHON_CLIENT_PREALLOCATE="${XLA_PYTHON_CLIENT_PREALLOCATE:-false}" +export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.5}" +export TE_EP_MOE_COORDINATOR_ADDRESS="${TE_EP_MOE_COORDINATOR_ADDRESS:-127.0.0.1:13457}" + +echo "============================================================" +echo "TE-EP MoE MULTIPROCESS test (one process per GPU, ${NUM_GPUS} GPUs)" +echo " test file : $TEST_FILE" +echo " coordinator : $TE_EP_MOE_COORDINATOR_ADDRESS" +echo " XLA_PYTHON_CLIENT_PREALLOCATE: $XLA_PYTHON_CLIENT_PREALLOCATE" +echo " XLA_PYTHON_CLIENT_MEM_FRACTION: $XLA_PYTHON_CLIENT_MEM_FRACTION" +echo "============================================================" + +if [ -n "${TE_EP_MOE_MP_LOG_DIR:-}" ]; then + LOG_DIR="$TE_EP_MOE_MP_LOG_DIR" + mkdir -p "$LOG_DIR" +else + LOG_DIR=$(mktemp -d -t te_ep_moe_mp_XXXXXX) +fi +echo "Per-process logs: $LOG_DIR" + +PIDS=() + +cleanup() { + for pid in "${PIDS[@]:-}"; do + if kill -0 "$pid" 2>/dev/null; then + kill -TERM "$pid" 2>/dev/null || true + fi + done + sleep 1 + for pid in "${PIDS[@]:-}"; do + if kill -0 "$pid" 2>/dev/null; then + kill -KILL "$pid" 2>/dev/null || true + fi + done +} +trap cleanup EXIT INT TERM + +for i in $(seq 0 $((NUM_GPUS - 1))); do + LOG_FILE="$LOG_DIR/proc_${i}.log" + PYTEST_CMD=( + python3 -m pytest -c "$PYTEST_INI" + "$TEST_FILE" + -p no:typeguard + -v -s + --num-process="$NUM_GPUS" + --process-id="$i" + ) + if [ "$i" -eq 0 ]; then + echo "=== Live output from process 0 ===" + "${PYTEST_CMD[@]}" 2>&1 | tee "$LOG_FILE" & + else + "${PYTEST_CMD[@]}" > "$LOG_FILE" 2>&1 & + fi + PIDS+=("$!") +done + +EXITS=() +for pid in "${PIDS[@]}"; do + if wait "$pid"; then + EXITS+=("0") + else + EXITS+=("$?") + fi +done + +echo +echo "============================================================" +echo "Per-process exit codes:" +for i in "${!EXITS[@]}"; do + echo " proc $i -> ${EXITS[$i]}" +done + +# Treat exit 0 (pass) and exit 5 (pytest "no tests collected", which the +# file emits via pytest.skip(allow_module_level=True) on pre-Blackwell +# GPUs) as success. +FAILED=0 +for e in "${EXITS[@]}"; do + if [ "$e" != "0" ] && [ "$e" != "5" ]; then + FAILED=1 + break + fi +done + +echo +if [ "$FAILED" -eq 0 ]; then + echo "[run_te_ep_moe.sh] all processes PASSED" + if [ -z "${TE_EP_MOE_MP_LOG_DIR:-}" ]; then + rm -rf "$LOG_DIR" + fi + exit 0 +fi + +echo "[run_te_ep_moe.sh] at least one process FAILED" +echo " retaining logs at $LOG_DIR for diagnosis" +echo " process 0 tail:" +tail -20 "$LOG_DIR/proc_0.log" 2>/dev/null || true +exit 1 diff --git a/tests/jax/test_te_ep_moe.py b/tests/jax/test_te_ep_moe.py new file mode 100644 index 0000000000..cc878e0bd1 --- /dev/null +++ b/tests/jax/test_te_ep_moe.py @@ -0,0 +1,762 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Multi-process (one-GPU-per-process) tests for the TE-EP MoE custom_vjp. + +The launcher ``tests/jax/run_te_ep_moe.sh`` forks one pytest process per +visible GPU (mirroring ``run_multiprocess_moe_vjp.sh``). Each process binds +to exactly one device via +``jax.distributed.initialize(..., local_device_ids=process_id)``; the +participating processes form a global ``(ep, fsdp)`` mesh through JAX's +distributed runtime. + +How to run +---------- + +You typically do NOT invoke pytest on this file directly -- use the +launcher, which passes ``--num-process=N --process-id=i`` to each +forked process. Driving it directly with only one process will skip +every test because :func:`jax.distributed.initialize` requires +multiple participants, and the TE EP NCCL primitives require at +least four ranks. + + bash tests/jax/run_te_ep_moe.sh + +What this suite covers +---------------------- + +This file is the TE-EP-only successor to ``test_moe_vjp.py`` and +``test_multiprocess_moe_vjp.py``. Each test exercises one MoE-block +run and bundles every check that single run supports — shape, dtype, +finiteness AND numerical parity vs a pure-JAX reference. Variations +on the block are pytest parametrize values rather than separate test +classes: + +* ``test_forward`` covers the forward across a curated set of + configurations (apply_topk_weights_early on/off, align_size=0/128, + softmax/sigmoid scoring, optional expert_bias). Each config asserts + shape, dtype, finiteness and numerical parity vs the reference in + one run. +* ``test_backward`` mirrors that for gradients. +* ``TestTeEpMoeAuxLoss`` covers the second return value end-to-end + (returned + parity + aux-only grad propagates to gate + combined + main+aux grads stay finite) in two consolidated tests. +* ``TestTeEpMoEBlockFlax`` exercises the Flax wrapper with the same + parity reference. +* ``TestZZZTeEpMoeBootstrap`` verifies the per-process NCCL bootstrap + rejects a mismatched signature. + +FP8 / MXFP8 recipes are deferred — the ``quantizer_sets`` plumbing +has not yet been re-wired across the TE-EP ``shard_map`` boundary +(see ``.pr3036-review/INTEGRATION_DESIGN.md``). +""" + +import os + +os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") +os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.5") + +import sys +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from jax.experimental import mesh_utils +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from flax.linen import partitioning as nn_partitioning + + +def _init_distributed(num_process: int, process_id: int) -> bool: + """Initialize jax.distributed for this pytest process. + + Returns True on a real multi-process launch, False otherwise so + the module can fast-skip when pytest collects it without the + launcher. + """ + if num_process <= 1: + return False + coord = os.environ.get("TE_EP_MOE_COORDINATOR_ADDRESS", "127.0.0.1:13457") + jax.distributed.initialize( + coordinator_address=coord, + num_processes=num_process, + process_id=process_id, + local_device_ids=process_id, + ) + assert jax.local_device_count() == 1, "one GPU per process is required for TE EP" + assert ( + jax.device_count() == num_process + ), f"global device_count {jax.device_count()} != num_process {num_process}" + return True + + +def _read_mp_options(): + num = int(os.environ.get("MP_NUM_PROCESS", "0") or "0") + pid = int(os.environ.get("MP_PROCESS_ID", "0") or "0") + for i, a in enumerate(sys.argv): + if a.startswith("--num-process="): + num = int(a.split("=", 1)[1]) + elif a == "--num-process" and i + 1 < len(sys.argv): + num = int(sys.argv[i + 1]) + elif a.startswith("--process-id="): + pid = int(a.split("=", 1)[1]) + elif a == "--process-id" and i + 1 < len(sys.argv): + pid = int(sys.argv[i + 1]) + return num, pid + + +_MP_NUM_PROCESS, _MP_PROCESS_ID = _read_mp_options() +_MP_ACTIVE = _init_distributed(_MP_NUM_PROCESS, _MP_PROCESS_ID) + +if not _MP_ACTIVE: + pytest.skip( + "test_te_ep_moe.py requires the multiprocess launcher " + "(run_te_ep_moe.sh). Skipping.", + allow_module_level=True, + ) + +from transformer_engine_jax import get_device_compute_capability + +# Grouped GEMM in the MoE custom_vjp requires Blackwell (sm_100+). The +# TE EP NCCL primitives themselves need SM>=90, but the FFN body uses +# grouped_gemm, so the file as a whole gates on sm_100+. +if get_device_compute_capability(0) < 100: + pytest.skip( + "MoE TE EP tests require Blackwell (sm_100+) for grouped GEMM", + allow_module_level=True, + ) + +from transformer_engine.jax.flax import _MoEBlock as MoEBlock +from transformer_engine.jax.moe import moe +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +# ----------------------------------------------------------------------------- +# Mesh / shape config +# ----------------------------------------------------------------------------- + +EP_AXIS = "ep" +FSDP_AXIS = "fsdp" +EP_SIZE = 2 +assert ( + jax.device_count() % EP_SIZE == 0 +), f"device_count {jax.device_count()} must be divisible by EP_SIZE={EP_SIZE}" +FSDP_SIZE = jax.device_count() // EP_SIZE +NUM_DEVICES_REQUIRED = EP_SIZE * FSDP_SIZE + +LOGICAL_AXIS_RULES = ( + ("exp", EP_AXIS), + ("embed", FSDP_AXIS), + ("mlp", None), + ("batch", (EP_AXIS, FSDP_AXIS)), +) + +# Small shapes so the parity tests stay tight on bf16. The block still +# has all four ranks participating in dispatch/combine. +DTYPE = jnp.bfloat16 +BATCH = EP_SIZE * FSDP_SIZE * 2 # 8 on 4-GPU, 16 on 8-GPU +SEQ = 32 +HIDDEN = 64 +INTER = 128 +NUM_EXPERTS = 8 +TOPK = 2 + +# bf16 grouped_gemm + softmax-topk + ep all-to-all stack drifts ~1e-1 vs a +# fp32 numpy reference. Keep these tight enough to catch real bugs but +# loose enough to absorb expected bf16 rounding. +FWD_ATOL = 5e-2 +FWD_RTOL = 5e-2 +GRAD_FFN_ATOL = 1e-1 +GRAD_FFN_RTOL = 1e-1 +GRAD_GATE_ATOL = 5e-1 +GRAD_GATE_RTOL = 5e-1 + +# Two TE EP runs that should be bitwise-equal modulo XLA fusion order +# (align_size rounding, etc.). +TE_TO_TE_ATOL = 5e-3 +TE_TO_TE_RTOL = 5e-3 + +# Aux loss is computed in float32 from the SAME logits as the routing +# path. Numerical drift between TE-EP and the reference is dominated by +# the bf16-rounded softmax inside the topk kernel. +AUX_ATOL = 1e-3 +AUX_RTOL = 1e-3 + + +# ----------------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def mesh(): + if jax.device_count() < NUM_DEVICES_REQUIRED: + pytest.skip( + f"Need >={NUM_DEVICES_REQUIRED} devices for ep={EP_SIZE} x fsdp={FSDP_SIZE};" + f" have {jax.device_count()}" + ) + # ``ep`` must be the inner axis: ``ep_bootstrap`` forms NCCL EP groups + # from consecutive global ranks via ``dp_color = rank // ep_size``, so + # only an (outer_fsdp, inner_ep) device layout groups ranks correctly. + devices = mesh_utils.create_device_mesh((FSDP_SIZE, EP_SIZE)) + return Mesh(devices, axis_names=(FSDP_AXIS, EP_AXIS)) + + +# ----------------------------------------------------------------------------- +# Pure-JAX reference MoE (no EP). Mirrors the exact math of TE's fused +# router primitive (see tests/jax/test_fused_router.py for the same +# reference applied to the standalone router kernel): +# +# softmax + post-softmax (use_pre_softmax=False, the default): +# 1. top_k by raw logits +# 2. softmax over just the K selected logits (so weights sum to 1) +# +# sigmoid + optional expert_bias: +# 1. scores = sigmoid(logits) +# 2. top_k by (scores + expert_bias) [bias only steers selection] +# 3. weights = scores at top_k positions, normalized when K > 1 +# +# Then for both: +# * weights *= scaling_factor (we leave scaling_factor=1.0 in this +# suite, matching _make_block's default). +# * per-expert FFN: silu(layer_w0) * layer_w1 → wo. +# ----------------------------------------------------------------------------- + + +@partial( + jax.jit, + static_argnames=( + "num_experts", + "num_experts_per_tok", + "aux_loss_coeff", + "score_function", + ), +) +def _pure_jax_moe_reference( + x, + gate_kernel, + wi_0, + wi_1, + wo, + expert_bias=None, + *, + num_experts, + num_experts_per_tok, + aux_loss_coeff: float = 0.0, + score_function: str = "softmax", +): + B, S, H = x.shape + T = B * S + K = num_experts_per_tok + x_2d = x.reshape(T, H) + + gate_kernel_cast = gate_kernel.astype(x.dtype) + logits = (x_2d @ gate_kernel_cast).astype(jnp.float32) # [T, E] + + if score_function == "softmax": + # use_pre_softmax=False: topk on raw logits, then softmax over K. + top_logits, top_indices = jax.lax.top_k(logits, k=K) + weights = jax.nn.softmax(top_logits, axis=-1) # [T, K], sums to 1 + elif score_function == "sigmoid": + scores = jax.nn.sigmoid(logits) # [T, E] + if expert_bias is not None and expert_bias.shape != (0,): + scores_for_routing = scores + expert_bias.astype(jnp.float32)[None, :] + _, top_indices = jax.lax.top_k(scores_for_routing, k=K) + weights = jnp.take_along_axis(scores, top_indices, axis=-1) + else: + weights, top_indices = jax.lax.top_k(scores, k=K) + # Sigmoid weights are normalized when K > 1 (matches the kernel). + if K > 1: + weights = weights / (weights.sum(axis=-1, keepdims=True) + 1e-20) + else: + raise ValueError(f"Unsupported score_function={score_function!r}") + + routing_weights_full = jnp.zeros((T, num_experts), dtype=jnp.float32) + routing_weights_full = routing_weights_full.at[ + jnp.arange(T)[:, None], top_indices + ].set(weights) + + # FFN. ``apply_topk_weights_early`` is a fusion knob that doesn't + # change the math (wo is linear), so the reference is identical for + # both placements. + layer_w0 = jnp.einsum("th,ehm->tem", x_2d, wi_0) + layer_w1 = jnp.einsum("th,ehm->tem", x_2d, wi_1) + intermediate = jax.nn.silu(layer_w0.astype(jnp.float32)) * layer_w1.astype(jnp.float32) + intermediate = intermediate.astype(x.dtype) + expert_out = jnp.einsum("tem,emh->teh", intermediate, wo) # [T, E, H] + output_2d = jnp.einsum( + "te,teh->th", routing_weights_full.astype(x.dtype), expert_out + ) + output = output_2d.reshape(B, S, H).astype(x.dtype) + + if aux_loss_coeff > 0.0: + # tex.fused_moe_aux_loss formula (matches the same + # reference_aux_loss helper from test_fused_router.py). The + # "aux scores" use the same score_function but always with + # K-normalised sigmoid (when sigmoid) / plain softmax (when + # softmax) — see tex.fused_topk_with_score_function_fwd with + # compute_aux_scores=True. + if score_function == "softmax": + aux_scores = jax.nn.softmax(logits, axis=-1) + else: # sigmoid + aux_scores = jax.nn.sigmoid(logits) + if K > 1: + aux_scores = aux_scores / ( + aux_scores.sum(axis=-1, keepdims=True) + 1e-20 + ) + routing_map = (routing_weights_full > 0).astype(jnp.int32) + tokens_per_expert = jnp.sum(routing_map, axis=0) # [E] + sum_probs_per_expert = jnp.sum(aux_scores, axis=0) # [E] + aux_loss = (num_experts * aux_loss_coeff / (K * (T**2))) * jnp.sum( + sum_probs_per_expert * tokens_per_expert.astype(jnp.float32) + ) + aux_loss = aux_loss.astype(x.dtype) + else: + aux_loss = jnp.zeros((), dtype=x.dtype) + return output, aux_loss + + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + + +def _make_block( + *, + apply_topk_weights_early=False, + align_size=0, + aux_loss_coeff=0.0, + use_expert_bias=False, + score_function="softmax", + bias_init=None, +): + kwargs = dict( + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + intermediate_size=INTER, + data_parallelism_axes=(FSDP_AXIS,), + apply_topk_weights_early=apply_topk_weights_early, + align_size=align_size, + aux_loss_coeff=aux_loss_coeff, + use_expert_bias=use_expert_bias, + score_function=score_function, + dtype=DTYPE, + ) + # Custom bias_init lets tests inject a non-zero expert_bias without + # poking variables['params'] post-init. + if bias_init is not None: + kwargs["bias_init"] = bias_init + return MoEBlock(**kwargs) + + +def _strong_expert_bias_init(key, shape, dtype): + """Half +5, half -5 — large enough to force topk onto the +ve half.""" + del key + n = shape[0] + return jnp.concatenate( + [ + jnp.full((n // 2,), 5.0, dtype=dtype), + jnp.full((n - n // 2,), -5.0, dtype=dtype), + ] + ) + + +def _shard_inputs(x, mesh): + # Match the layout moe.py re-pins to: outer dp axes, then ep innermost. + return jax.lax.with_sharding_constraint( + x, NamedSharding(mesh, P((FSDP_AXIS, EP_AXIS), None, None)) + ) + + +def _ctx(mesh): + """Combined mesh + global_shard_guard + axis_rules context.""" + + class _Combo: + def __enter__(self_inner): + self_inner._m = mesh.__enter__() + self_inner._gs = global_shard_guard( + MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS) + ) + self_inner._gs.__enter__() + self_inner._ar = nn_partitioning.axis_rules(LOGICAL_AXIS_RULES) + self_inner._ar.__enter__() + return self_inner._m + + def __exit__(self_inner, *args): + self_inner._ar.__exit__(*args) + self_inner._gs.__exit__(*args) + mesh.__exit__(*args) + + return _Combo() + + +def _init_apply(block, mesh, x, key): + with _ctx(mesh): + x_sh = _shard_inputs(x, mesh) + variables = jax.jit(block.init)(key, x_sh) + jax.block_until_ready(jax.tree_util.tree_leaves(variables)[0]) + output, aux = jax.jit(block.apply)(variables, x_sh) + jax.block_until_ready(output) + return variables, output, aux + + +def _grad_step(block, variables, mesh, x, *, include_aux=False): + """Run jax.grad of mean(out^2) [+ aux if include_aux] vs params.""" + with _ctx(mesh): + x_sh = _shard_inputs(x, mesh) + + def loss_fn(variables, x): + output, aux = block.apply(variables, x) + loss = jnp.mean(output.astype(jnp.float32) ** 2) + if include_aux and aux is not None: + loss = loss + aux.astype(jnp.float32) + return loss + + grads = jax.jit(jax.grad(loss_fn))(variables, x_sh) + jax.block_until_ready(jax.tree_util.tree_leaves(grads)[0]) + return grads + + +def _grad_aux_only(block, variables, mesh, x): + """Jit'd grad of just the aux loss scalar — proves it reaches the + gate even when no main-output contribution is present.""" + with _ctx(mesh): + x_sh = _shard_inputs(x, mesh) + + def aux_only(variables, x): + _, aux = block.apply(variables, x) + return aux.astype(jnp.float32) + + grads = jax.jit(jax.grad(aux_only))(variables, x_sh) + jax.block_until_ready(jax.tree_util.tree_leaves(grads)[0]) + return grads + + +def _unwrap(x): + return x.value if hasattr(x, "value") else x + + +def _to_global_numpy(arr, mesh): + """Replicate a sharded JAX array onto every rank and return as numpy. + + Triggers an all-gather inside JIT. The resulting addressable_data(0) + contains the full global array on every process, so we can run the + pure-JAX reference and compare against it from any process. + """ + rep = NamedSharding(mesh, P()) + with mesh: + full = jax.jit(lambda a: jax.lax.with_sharding_constraint(a, rep))(arr) + full.block_until_ready() + return np.asarray(jax.device_get(full.addressable_data(0))) + + +def _params_global_numpy(variables, mesh): + """Pull every entry of variables['params'] to a replicated numpy array.""" + params = variables["params"] + return {name: _to_global_numpy(_unwrap(p), mesh) for name, p in params.items()} + + +def _make_inputs(key): + """Generate a globally-identical input tensor on every process.""" + return jax.random.normal(key, (BATCH, SEQ, HIDDEN), dtype=DTYPE) + + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + + +# ----------------------------------------------------------------------------- +# Parametrize variants exercised by both the forward and the backward +# parity tests. Each config is one MoE-block configuration the suite +# wants covered; the test body checks shape, dtype, finiteness AND +# numerical parity vs the same pure-JAX reference (which understands +# the same set of knobs). +# ----------------------------------------------------------------------------- + +_CONFIGS = [ + pytest.param( + dict(score_function="softmax"), + id="softmax", + ), + pytest.param( + dict(score_function="softmax", apply_topk_weights_early=True), + id="softmax-topk-early", + ), + pytest.param( + dict(score_function="softmax", align_size=128), + id="softmax-align128", + ), + pytest.param( + dict(score_function="sigmoid"), + id="sigmoid", + ), + pytest.param( + dict(score_function="sigmoid", use_expert_bias=True), + id="sigmoid-bias-zero", + ), + pytest.param( + dict( + score_function="sigmoid", + use_expert_bias=True, + bias_init=_strong_expert_bias_init, + ), + id="sigmoid-bias-strong", + ), +] + + +def _reference_kwargs_from_config(config, params_np): + """Pick out the reference-relevant pieces of a parametrize config.""" + return dict( + score_function=config.get("score_function", "softmax"), + expert_bias=( + jnp.asarray(params_np["expert_bias"]) + if config.get("use_expert_bias", False) + else None + ), + ) + + +class TestTeEpMoeForward: + """Per-config forward correctness in a single run: shape, dtype, + finiteness AND numerical parity vs the pure-JAX reference.""" + + @pytest.mark.parametrize("config", _CONFIGS) + def test_forward(self, mesh, config): + block = _make_block(**config) + x = _make_inputs(jax.random.PRNGKey(0)) + variables, output, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(1)) + + # Shape / dtype / finiteness (cheap; on the local shard). + assert output.shape == x.shape + assert output.dtype == x.dtype + out_local = np.asarray(jax.device_get(output.addressable_data(0))) + assert np.all(np.isfinite(out_local)), "output has NaN/Inf" + assert aux is None, "aux_loss should be None when aux_loss_coeff == 0" + + # Numerical parity (replicated global view -> single rank's numpy). + params_np = _params_global_numpy(variables, mesh) + x_np = np.asarray(jax.device_get(x)) + out_te_np = _to_global_numpy(output, mesh) + + out_ref, _ = _pure_jax_moe_reference( + jnp.asarray(x_np), + jnp.asarray(params_np["gate_kernel"]), + jnp.asarray(params_np["wi_0"]), + jnp.asarray(params_np["wi_1"]), + jnp.asarray(params_np["wo"]), + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + **_reference_kwargs_from_config(config, params_np), + ) + np.testing.assert_allclose( + out_te_np.astype(np.float32), + np.asarray(jax.device_get(out_ref)).astype(np.float32), + atol=FWD_ATOL, + rtol=FWD_RTOL, + err_msg=f"forward parity breach for config={config}", + ) + + +class TestTeEpMoeBackward: + """Per-config backward correctness in a single run: per-tensor + grads finite, non-zero AND parity vs the pure-JAX reference.""" + + @pytest.mark.parametrize("config", _CONFIGS) + def test_backward(self, mesh, config): + block = _make_block(**config) + x = _make_inputs(jax.random.PRNGKey(2)) + variables, _, _ = _init_apply(block, mesh, x, jax.random.PRNGKey(3)) + grads_te = _grad_step(block, variables, mesh, x) + + # Reference grads via jax.grad over the pure-JAX MoE with the + # same config. + params_np = _params_global_numpy(variables, mesh) + x_np = np.asarray(jax.device_get(x)) + ref_kwargs = _reference_kwargs_from_config(config, params_np) + ref_expert_bias = ref_kwargs.pop("expert_bias") + + def loss_fn(params, x): + out, _ = _pure_jax_moe_reference( + x, + params["gate_kernel"], + params["wi_0"], + params["wi_1"], + params["wo"], + ref_expert_bias, + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + **ref_kwargs, + ) + return jnp.mean(out.astype(jnp.float32) ** 2) + + grads_ref = jax.jit(jax.grad(loss_fn))( + {k: jnp.asarray(v) for k, v in params_np.items() if k != "expert_bias"}, + jnp.asarray(x_np), + ) + grads_ref_np = {k: np.asarray(jax.device_get(v)) for k, v in grads_ref.items()} + + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + # Per-tensor: finite + non-zero + parity in one pass. + g_te = _to_global_numpy(_unwrap(grads_te["params"][name]), mesh) + assert np.all(np.isfinite(g_te)), f"{name} grad has NaN/Inf [config={config}]" + assert np.any(g_te != 0.0), f"{name} grad identically zero [config={config}]" + atol, rtol = ( + (GRAD_GATE_ATOL, GRAD_GATE_RTOL) + if name == "gate_kernel" + else (GRAD_FFN_ATOL, GRAD_FFN_RTOL) + ) + np.testing.assert_allclose( + g_te.astype(np.float32), + grads_ref_np[name].astype(np.float32), + atol=atol, + rtol=rtol, + err_msg=f"grad parity breach on {name} [config={config}]", + ) + + +class TestTeEpMoeAuxLoss: + """Aux-loss path. Consolidated into: + * ``test_aux_loss``: one run that checks the returned scalar's + shape / dtype / finiteness / magnitude AND numerical parity vs the + reference AND that the aux-only bwd propagates to gate_kernel. + * ``test_combined_loss_grads``: one run for joint main+aux bwd + finite + non-zero per tensor. + """ + + def test_aux_loss(self, mesh): + coeff = 1e-2 + block = _make_block(aux_loss_coeff=coeff) + x = _make_inputs(jax.random.PRNGKey(20)) + variables, _, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(21)) + + # Shape / dtype / finiteness / magnitude. + assert aux is not None, "aux_loss should be returned when coeff > 0" + assert aux.shape == (), f"aux_loss must be 0-d scalar, got {aux.shape}" + assert aux.dtype == DTYPE, f"aux_loss dtype {aux.dtype} != {DTYPE}" + aux_np = _to_global_numpy(aux, mesh) + assert np.isfinite(aux_np), "aux_loss is NaN/Inf" + assert abs(float(aux_np)) < 1e2, f"aux_loss looks unreasonable: {aux_np}" + + # Numerical parity vs the reference. + params_np = _params_global_numpy(variables, mesh) + x_np = np.asarray(jax.device_get(x)) + _, aux_ref = _pure_jax_moe_reference( + jnp.asarray(x_np), + jnp.asarray(params_np["gate_kernel"]), + jnp.asarray(params_np["wi_0"]), + jnp.asarray(params_np["wi_1"]), + jnp.asarray(params_np["wo"]), + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + aux_loss_coeff=coeff, + ) + np.testing.assert_allclose( + float(aux_np), + float(jax.device_get(aux_ref)), + atol=AUX_ATOL, + rtol=AUX_RTOL, + ) + + # Aux-only bwd must propagate to gate_kernel — proves the + # fused_moe_aux_loss_bwd → topk(compute_aux_scores)_bwd chain is + # wired. + aux_grads = _grad_aux_only(block, variables, mesh, x) + g_gate = np.asarray( + jax.device_get( + _unwrap(aux_grads["params"]["gate_kernel"]).addressable_data(0) + ) + ) + assert np.all(np.isfinite(g_gate)), "gate grad NaN/Inf under aux-only loss" + assert np.any(g_gate != 0.0), "aux bwd should propagate to gate_kernel" + + def test_combined_loss_grads(self, mesh): + """Joint main + aux loss bwd: per-tensor finite + non-zero in + one pass.""" + block = _make_block(aux_loss_coeff=1e-2) + x = _make_inputs(jax.random.PRNGKey(22)) + variables, _, _ = _init_apply(block, mesh, x, jax.random.PRNGKey(23)) + grads = _grad_step(block, variables, mesh, x, include_aux=True) + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g_local = np.asarray( + jax.device_get(_unwrap(grads["params"][name]).addressable_data(0)) + ) + assert np.all(np.isfinite(g_local)), f"{name} grad NaN/Inf under main+aux" + assert np.any(g_local != 0.0), f"{name} grad zero under main+aux" + + +class TestTeEpMoEBlockFlax: + """Flax wrapper end-to-end in one run: shape/dtype/finiteness on the + forward, numerical parity vs the same reference, and per-tensor + grad finiteness + non-zeroness.""" + + def test_init_apply_parity(self, mesh): + block = _make_block() + x = _make_inputs(jax.random.PRNGKey(12)) + variables, output, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(13)) + + assert aux is None + assert output.shape == x.shape + assert output.dtype == x.dtype + out_local = np.asarray(jax.device_get(output.addressable_data(0))) + assert np.all(np.isfinite(out_local)) + + params_np = _params_global_numpy(variables, mesh) + x_np = np.asarray(jax.device_get(x)) + out_te_np = _to_global_numpy(output, mesh) + out_ref, _ = _pure_jax_moe_reference( + jnp.asarray(x_np), + jnp.asarray(params_np["gate_kernel"]), + jnp.asarray(params_np["wi_0"]), + jnp.asarray(params_np["wi_1"]), + jnp.asarray(params_np["wo"]), + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + ) + np.testing.assert_allclose( + out_te_np.astype(np.float32), + np.asarray(jax.device_get(out_ref)).astype(np.float32), + atol=FWD_ATOL, + rtol=FWD_RTOL, + ) + + grads = _grad_step(block, variables, mesh, x) + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g_local = np.asarray( + jax.device_get(_unwrap(grads["params"][name]).addressable_data(0)) + ) + assert np.all(np.isfinite(g_local)), f"{name} grad NaN/Inf" + assert np.any(g_local != 0.0), f"{name} grad zero" + + +# Keep the bootstrap-signature test last in the module (the "ZZZ" prefix +# ensures pytest's alphabetic class ordering picks it last): it +# intentionally mismatches the NCCL EP bootstrap signature, which +# permanently taints the per-process bootstrap cache for the rest of +# the file. +class TestZZZTeEpMoeBootstrap: + """Per-process NCCL bootstrap re-bootstrap rejection.""" + + def test_bootstrap_signature_mismatch_raises(self, mesh): + block_a = _make_block() + x_a = _make_inputs(jax.random.PRNGKey(14)) + _init_apply(block_a, mesh, x_a, jax.random.PRNGKey(15)) + + # Different hidden dim → different bootstrap signature. + bigger_hidden = HIDDEN * 2 + x_b = jax.random.normal( + jax.random.PRNGKey(16), (BATCH, SEQ, bigger_hidden), dtype=DTYPE + ) + block_b = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + intermediate_size=INTER, + data_parallelism_axes=(FSDP_AXIS,), + dtype=DTYPE, + ) + with pytest.raises(ValueError, match="bootstrapped"): + _init_apply(block_b, mesh, x_b, jax.random.PRNGKey(17)) diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index b5a4afc2ad..67b2f5dfdd 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -5,27 +5,20 @@ """Flax Linen MoE block for TransformerEngine JAX. This module exposes :class:`_MoEBlock`, an experimental Flax Linen layer -that is a thin wrapper around the framework-agnostic functional MoE entry -point :func:`transformer_engine.jax.moe.moe`. The wrapper's only job is -to: +that wraps the framework-agnostic functional MoE entry point +:func:`transformer_engine.jax.moe.moe`. The wrapper's only job is to: -1. Register the gate kernel, per-expert FFN kernels, and optional biases - as ``self.param`` slots (with the right +1. Register the gate kernel, per-expert FFN kernels, and optional FFN + biases as ``self.param`` slots (with the right :func:`flax.linen.with_logical_partitioning` annotations so JAX's sharding layer FSDPs the params correctly). 2. Resolve the EP axis name from the active :class:`transformer_engine.jax.sharding.MeshResource`. 3. Forward all knobs to :func:`moe`. -All routing, dispatch, FFN, combine, and aux-loss logic lives in -``moe.py`` under a *single* ``jax.custom_vjp`` so future fusions -(FP8-on-the-wire EP, fused ``ragged_all_to_all + grouped_gemm``, gate + -route + dispatch fusion) can land without touching this wrapper. - The class is intentionally underscore-prefixed; the public ``MoEBlock`` alias will be introduced once TE's NCCL-backed EP component (and the -recipe-driven alignment follow-up) stabilises (target: the TE release -following the 2.16 code freeze). +recipe-driven alignment follow-up) stabilises. """ from typing import Any, Callable, NewType, Optional, Tuple, Union @@ -37,8 +30,7 @@ # import P`` without a second jax.sharding import. from jax.sharding import PartitionSpec as P # noqa: F401 # pylint: disable=unused-import -from ..moe import PermutationBackend, moe -from ..quantize import noop_quantizer_set +from ..moe import moe from ..router import ScoreFunction from ..sharding import get_active_resource_axis from .module import TransformerEngineBase @@ -50,22 +42,19 @@ Initializer = Callable[[PRNGKey, Shape, DType], Array] -__all__ = ["PermutationBackend", "_MoEBlock"] +__all__ = ["_MoEBlock"] class _MoEBlock(TransformerEngineBase): """Experimental Flax MoE layer over TransformerEngine. See module docstring for the design (this class is a thin Flax - wrapper around :func:`transformer_engine.jax.moe.moe`). Constructor - knob set kept compatible with the previous bespoke implementation so - existing call sites need no changes. + wrapper around :func:`transformer_engine.jax.moe.moe`). Parameters ---------- num_experts : int - Total number of experts. Under EP this must be divisible by the - EP mesh axis size. + Total number of experts. Must be divisible by the EP mesh axis size. num_experts_per_tok : int Top-k value for routing. intermediate_size : int @@ -82,41 +71,30 @@ class _MoEBlock(TransformerEngineBase): Grouped top-k knobs (DeepSeek-style). ``None`` disables grouping. scaling_factor : float Multiplier on the routing weights. - use_expert_bias : bool - If ``True``, registers a per-expert routing bias (shape ``[E]``). - Only meaningful with ``score_function="sigmoid"``; the underlying - primitive validates the pairing. - aux_loss_coeff : float - If ``> 0``, return the MoE auxiliary load-balancing loss scalar - in addition to the main output. + + apply_topk_weights_early : bool + When True, fold per-token top-k weights into the FFN intermediate + (next to ``act(gate) * up``) instead of into the post-down-projection + combine. Both placements are mathematically equivalent (the down + projection is linear); the early placement gives XLA a chance to + fuse the multiply with the activation. gate_kernel_axes, wi_kernel_axes, wo_kernel_axes, input_axes : Logical sharding axis tuples (consumed by Flax's :func:`with_logical_partitioning` and our internal :func:`with_sharding_constraint_by_logical_axes`). data_parallelism_axes : tuple[str, ...] - FSDP axes over which the input *batch* dim is sharded IN - ADDITION to the EP axis. Empty (default) means activations are - replicated across non-EP axes within an EP group; set e.g. - ``("fsdp",)`` for true FSDP-of-batch where each device owns a - unique slice of the batch. - permutation_backend : PermutationBackend - ``PURE_JAX`` (default) or ``TRITON``. - _align_size : int - Per-expert group-size alignment (``0`` disables; required > 0 - for quantized grouped GEMM). Internal knob; will be inferred - from the active quantization recipe in a follow-up PR. + FSDP axes over which the input *batch* dim is sharded IN ADDITION + to the EP axis. Empty (default) means activations are replicated + across non-EP axes within an EP group; set e.g. ``("fsdp",)`` for + true FSDP-of-batch where each device owns a unique slice of the + batch. dtype : jnp.dtype Compute / parameter dtype. - kernel_init, bias_init, expert_bias_init : Initializers. + kernel_init, bias_init : Initializers. use_bias : bool Register per-expert FFN biases. - - Quantization is currently configured via the standard TE autocast - context (``fp8_autocast``/``with_quantizer_set``); per-call - quantizer sets can also be passed through ``__call__``'s - ``quantizer_sets`` keyword once we stabilise the recipe pipeline. """ # Architecture @@ -131,8 +109,6 @@ class _MoEBlock(TransformerEngineBase): num_groups: Optional[int] = None group_topk: Optional[int] = None scaling_factor: float = 1.0 - use_expert_bias: bool = False - aux_loss_coeff: float = 0.0 # Sharding (logical axes) gate_kernel_axes: Tuple[Optional[str], ...] = () @@ -143,18 +119,27 @@ class _MoEBlock(TransformerEngineBase): # Parallelism data_parallelism_axes: Tuple[str, ...] = () - # Permutation - permutation_backend: PermutationBackend = PermutationBackend.PURE_JAX - _align_size: int = 0 + # Aux loss (global expert-load balancing). 0.0 disables; non-zero + # enables the second return value and routes its gradient back to + # the gate. + aux_loss_coeff: float = 0.0 + # Fusion knob apply_topk_weights_early: bool = False + # Minimum per-expert slot alignment fed to ``tex.ep_prepare``. Default 0 + # uses the natural slot count; set to e.g. 128 to satisfy FP8 grouped-GEMM + # tile alignment. + align_size: int = 0 + # Dtypes / init / misc dtype: DType = jnp.float32 kernel_init: Optional[Initializer] = None bias_init: Initializer = nn.initializers.zeros - expert_bias_init: Initializer = nn.initializers.zeros use_bias: bool = False + # Per-expert router bias added before the top-k. Only meaningful when + # score_function='sigmoid'. + use_expert_bias: bool = False def __post_init__(self): if self.kernel_init is None: @@ -165,11 +150,6 @@ def __post_init__(self): 1.0, "fan_in", "truncated_normal", dtype=self.dtype ), ) - if not isinstance(self.permutation_backend, PermutationBackend): - raise TypeError( - "permutation_backend must be a PermutationBackend, got" - f" {self.permutation_backend!r}" - ) super().__post_init__() @nn.compact @@ -186,18 +166,17 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: output : jnp.ndarray ``[batch, sequence, hidden]``. aux_loss : Optional[jnp.ndarray] - Scalar load-balancing loss when ``aux_loss_coeff > 0``, - else ``None``. + 0-d scalar when ``aux_loss_coeff > 0``, ``None`` otherwise. """ assert ( inputs.ndim == 3 ), f"_MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" _, _, hidden_size = inputs.shape - # Param registrations -- must run OUTSIDE any JAX transform that + # Param registrations must run OUTSIDE any JAX transform that # alters the variable scope (e.g. shard_map). The functional - # ``moe(...)`` opens its own shard_map internally for the EP - # path, so registering params here is correct. + # ``moe(...)`` opens its own shard_map internally for the FFN + # body, so registering params here is correct. gate_kernel = self.param( "gate_kernel", nn.with_logical_partitioning(self.kernel_init, self.gate_kernel_axes), @@ -242,13 +221,14 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: (self.num_experts, hidden_size), self.dtype, ) + expert_bias = None if self.use_expert_bias: expert_bias = self.param( "expert_bias", - nn.with_logical_partitioning(self.expert_bias_init, ("exp",)), + nn.with_logical_partitioning(self.bias_init, ("exp",)), (self.num_experts,), - self.dtype, + jnp.float32, ) ep_axis = get_active_resource_axis("ep_resource") @@ -272,16 +252,13 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: group_topk=self.group_topk, scaling_factor=self.scaling_factor, aux_loss_coeff=self.aux_loss_coeff, - permutation_backend=self.permutation_backend, - align_size=self._align_size, - gate_inside_vjp=True, apply_topk_weights_early=self.apply_topk_weights_early, + align_size=self.align_size, ep_axis=ep_axis, data_parallelism_axes=self.data_parallelism_axes, input_axes=self.input_axes, gate_kernel_axes=self.gate_kernel_axes, wi_kernel_axes=self.wi_kernel_axes, wo_kernel_axes=self.wo_kernel_axes, - quantizer_sets=(noop_quantizer_set, noop_quantizer_set, noop_quantizer_set), dtype=self.dtype, ) diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index 4479b9f176..162ea8f7e5 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -1,77 +1,51 @@ # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""Functional Mixture-of-Experts (MoE) entry point with a single fused VJP. - -This module exposes :func:`moe`, the framework-agnostic flat function that -implements an entire MoE block (gate -> top-k routing -> token dispatch -> -per-expert FFN -> token combine, plus optional expert parallelism via a -shard_map / ragged_all_to_all collective) under a *single* -``jax.custom_vjp``. It is the moral analog of -:func:`transformer_engine.jax.layernorm_mlp.layernorm_mlp` for MoE: one -custom_vjp boundary covers the whole block so future fusions (FP8 over the -EP wire, fused ``ragged_all_to_all + grouped_gemm``, gate+route+dispatch -fusion) can land without re-architecting the call site. - -Design rationale ----------------- - -The earlier MoE block (:class:`transformer_engine.jax.flax.moe._MoEBlock`) -composed many narrower custom_vjps -- one per :func:`grouped_dense`, one -per :func:`token_dispatch`, etc. Every nested custom_vjp is a place where -a quantized :class:`ScaledTensor` cannot survive (JAX requires custom_vjp -inputs / outputs to be plain ``jnp.ndarray`` ish pytrees). To enable -end-to-end FP8 flow -- in particular FP8 carried over the EP -ragged_all_to_all -- the dispatch's quantize, the a2a, the per-expert -FFN, the inverse a2a, and the combine all have to live inside the same -VJP. This file collapses them into one. - -Implementation conventions --------------------------- - -* No nested ``custom_vjp``. Every primitive's ``_fwd`` and ``_bwd`` is - called directly (e.g. :func:`tex.fused_topk_with_score_function_fwd` / - ``_bwd``, :func:`unpermute_with_mask_map`, - :func:`unpermute_bwd_with_merging_probs`, - :func:`sort_chunks_by_map(is_forward=False)`, - forward + reverse :func:`jax.lax.ragged_all_to_all`) so the outer - ``_moe_bwd_rule`` controls the bwd graph end-to-end without invoking - ``jax.vjp`` for re-linearization. -* The fwd/bwd context (``ctx``) is a plain ``dict`` whose keys depend on - the static configuration (permutation backend, EP active or not, - presence of biases, aux loss enabled). The ``_moe_fwd_rule`` builds a - matching ``ctx_specs`` dict in lockstep when opening the EP shard_map - so ``out_specs`` structurally matches the body's return. -* :func:`_dispatch` is the helper that wraps - ``permute -> a2a -> local_permute`` (forward); :func:`_combine` is its - inverse. Their ``_bwd`` siblings drive the inverse collectives in the - bwd rule. None of these helpers form a custom_vjp boundary. +"""Mixture-of-Experts (MoE) layer for TransformerEngine JAX. + +This module exposes :func:`moe`, a single fused MoE forward pass + bwd +built on top of TE's NCCL-backed Expert Parallelism primitives +(``tex.ep_dispatch`` / ``tex.ep_combine``). The block runs:: + + gate -> topk -> ep_dispatch -> per-expert FFN (grouped GEMMs) + -> ep_combine -> output + +under a single ``jax.custom_vjp`` so the routing, dispatch, FFN and +combine steps fuse cleanly under XLA without leaking intermediate +residuals into the user-facing autograd graph. + +Sharding model +-------------- +* Inbound activations are 3D ``[B, S, H]`` sharded + ``((*data_parallelism_axes, ep_axis), None, None)``. The public + :func:`moe` soft-repins this on entry and warns when a reshard is + inserted. +* The EP primitives operate at global view (their custom_partitioning + rules handle per-shard execution). The FFN GEMMs run per-shard inside + a small ``shard_map`` whose ``in_specs`` and ``out_specs`` mirror the + same ``((dp, ep), ...)`` layout. + +Out-of-scope (for now) +---------------------- +FP8 / MXFP8 quantizer sets are not yet wired on this path; turning +them on requires recipe-aware residual specs and ``ScaledTensor`` +leaves across the ``shard_map`` boundary. ``aux_loss_coeff`` and +``expert_bias`` are supported (the former forces a per-step +all-gather over the routing-side logits, which lives off the critical +path and overlaps with the dispatch collective). """ -import math from dataclasses import dataclass -from enum import Enum from functools import partial -from typing import Any, NewType, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union import warnings import jax import jax.numpy as jnp -from flax import struct as flax_struct from jax.sharding import NamedSharding, PartitionSpec as P from . import cpp_extensions as tex -from .permutation import ( - PureJaxPermState, - compute_ragged_all_to_all_params, - compute_reverse_ragged_all_to_all_params, - pure_jax_token_combine, - pure_jax_token_dispatch, - routing_map_to_selected_experts, -) from .quantize import ( - QuantizerSet, - ScaledTensor, TensorUsage, noop_quantizer_set, with_sharding_constraint_by_logical_axes, @@ -80,1052 +54,132 @@ from .router import ScoreFunction, _validate_score_function from .sharding import _get_mesh -# Triton-backed primitives are imported lazily: callers on the PURE_JAX -# permutation backend should not need ``triton`` installed. The TRITON -# branches in this module call ``_require_triton()`` first to raise a -# clear error if the import failed. -try: - from .triton_extensions.permutation import ( - make_chunk_sort_map, - make_row_id_map, - permute_with_mask_map, - permute_with_mask_map_and_pad, - sort_chunks_by_map, - unpermute_bwd_with_merging_probs, - unpermute_bwd_with_merging_probs_and_unpad, - unpermute_with_mask_map, - unpermute_with_mask_map_and_unpad, - ) - - _TRITON_AVAILABLE = True -except ImportError: - _TRITON_AVAILABLE = False - make_chunk_sort_map = None - make_row_id_map = None - permute_with_mask_map = None - permute_with_mask_map_and_pad = None - sort_chunks_by_map = None - unpermute_bwd_with_merging_probs = None - unpermute_bwd_with_merging_probs_and_unpad = None - unpermute_with_mask_map = None - unpermute_with_mask_map_and_unpad = None - - -def _require_triton(): - """Raise a clear error if Triton permutation kernels are unavailable.""" - if not _TRITON_AVAILABLE: - raise ImportError( - "PermutationBackend.TRITON requires" - " ``transformer_engine.jax.triton_extensions`` (and ``triton``)." - " Install Triton or pass PermutationBackend.PURE_JAX." - ) - - -PRNGKey = Any -Shape = Tuple[int, ...] -DType = NewType("DType", jnp.dtype) -Array = NewType("Array", jnp.ndarray) - - -__all__ = ["moe", "PermutationBackend"] +__all__ = ["moe"] # ============================================================================= -# Enums +# Process-level NCCL EP bootstrap # ============================================================================= +# +# ``tex.ep_bootstrap`` initialises the NCCL EP communicator exactly once per +# process and stashes its state in a C++ singleton. Subsequent calls with the +# same signature are a no-op; calls with a different signature raise. +_te_ep_bootstrap_signature: Optional[Tuple[int, int, int, int, int]] = None -class PermutationBackend(Enum): - """Token-dispatch / combine backend used by :func:`moe`. - * ``TRITON``: TE's fused Triton kernels. Faster than ``PURE_JAX`` - on current hardware and the recommended default. - * ``PURE_JAX``: ``jnp.argsort`` + gather paths compiled as plain - XLA; useful as a numerical reference and on builds without - Triton available. - """ +def _te_ep_bootstrap_if_needed( + num_experts: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + ep_size: int, +) -> None: + """Bootstrap the NCCL EP communicator on first use within a process.""" + global _te_ep_bootstrap_signature + sig = (num_experts, max_tokens_per_rank, recv_capacity_per_rank, hidden_dim, ep_size) + if _te_ep_bootstrap_signature == sig: + return + if _te_ep_bootstrap_signature is not None: + raise ValueError( + "TE EP was already bootstrapped with signature " + f"{_te_ep_bootstrap_signature}; got {sig}. Re-bootstrap with" + " different params is not supported within a single process." + ) + from transformer_engine.jax.ep import ep_bootstrap # local: avoids import cycle - PURE_JAX = "pure_jax" - TRITON = "triton" + ep_bootstrap( + world_size=jax.process_count(), + rank=jax.process_index(), + ep_size=ep_size, + num_experts=num_experts, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=hidden_dim, + # XLA may relocate the C++ handle buffer between JIT executables; + # allow it rather than asserting on handle aliasing. + allow_handle_mem_reloc=True, + ) + _te_ep_bootstrap_signature = sig # ============================================================================= -# Dispatch-state records (carried _dispatch -> _combine / *_bwd) +# Residual container threaded fwd -> bwd # ============================================================================= -# -# Two NamedTuples (one per permutation backend) so we get type -# discrimination at the consumer side via ``isinstance``. The backend- -# specific residuals are required fields; the EP-only residuals are -# Optional and are populated only when the run is EP-active. Each field -# is either an ``ndarray`` or ``None`` -- nothing static, since these -# values cross the shard_map pytree boundary and would otherwise be -# coerced into JitTracers. - - -@flax_struct.dataclass -class _PureJaxDispatchState: - """Residuals saved by :func:`_dispatch` on the PURE_JAX path. - - Registered as a JAX pytree via ``flax.struct.dataclass``: each - annotated field is a leaf, ``None`` is a non-leaf sentinel. The - matching spec built by :func:`_build_dispatch_specs` mirrors this - layout so shard_map's value and spec trees line up. - """ - - group_sizes: jnp.ndarray - sorted_indices: jnp.ndarray - routing_weights: jnp.ndarray - # EP-only: - all_shards_tokens_per_expert: Optional[jnp.ndarray] = None - local_perm_row_id_map: Optional[jnp.ndarray] = None - - -@flax_struct.dataclass -class _TritonDispatchState: - """Residuals saved by :func:`_dispatch` on the TRITON path.""" - group_sizes: jnp.ndarray - row_id_map: jnp.ndarray - pad_offsets: Optional[jnp.ndarray] # populated only when align_size > 0 - merging_probs: jnp.ndarray - # EP-only: - all_shards_tokens_per_expert: Optional[jnp.ndarray] = None - local_perm_row_id_map: Optional[jnp.ndarray] = None +@dataclass +class _Ctx: + """Residuals carried from the fwd rule into the bwd rule.""" -_DispatchState = Union[_PureJaxDispatchState, _TritonDispatchState] - - -@flax_struct.dataclass -class _BodyCtx: - """Residuals carried fwd_rule -> bwd_rule by :func:`_body_fwd`. - - Optional fields (``expert_bias``, ``aux_*``) are ``None`` when the - matching feature is disabled. :func:`_build_ctx_specs` mirrors that - layout so the shard_map spec and value trees match leaf-for-leaf. - """ - - # Always present. - x: Any - gate_kernel: Any - logits_2d: Any - saved_scores: Any - routing_map: Any - dispatch: Any # _DispatchState + x: jnp.ndarray + gate_kernel: jnp.ndarray + expert_bias: jnp.ndarray + logits_2d: jnp.ndarray + saved_scores: jnp.ndarray + routing_map: jnp.ndarray + handle: Any + token_counts: jnp.ndarray + recv_topk_weights: jnp.ndarray casted_sorted_x_lhs_trans: Any - casted_wi_rhs_trans: Any # stacked [E, H, 2, M] residual for fused wi_0|wi_1 bwd - gate_proj_out: Any - up_proj_out: Any + casted_wi_rhs_trans: Any + gate_proj_out: jnp.ndarray + up_proj_out: jnp.ndarray casted_intermediate_lhs_trans: Any casted_wo_rhs_trans: Any - expert_outputs: Any - local_group_sizes: Any - # Feature-gated. - expert_bias: Any = None + expert_outputs: jnp.ndarray + local_group_sizes: jnp.ndarray + # Aux-loss residuals; None when aux_loss_coeff == 0. aux_const_buf: Any = None aux_tokens_per_expert: Any = None - aux_logits_for_score: Any = None aux_saved_scores: Any = None # ============================================================================= -# ctx / dispatch-state key conventions -# ============================================================================= -# -# Both ``ctx`` (carried fwd_rule -> bwd_rule) and the dispatch state -# (carried _dispatch -> _combine / _dispatch_bwd / _combine_bwd) are plain -# python dicts. Using a dict (rather than a flax_struct.dataclass) lets us -# vary the populated keys with the static config without breaking -# ``shard_map``'s ``out_specs`` structural match: the spec dict and the -# value dict are built with the SAME keys via :func:`_build_ctx_specs`. -# -# Below is the key glossary so the rest of the file reads cleanly. -# -# DispatchState (dict): values are jnp.ndarray unless noted -# Always present: -# "group_sizes" [n_groups] per-expert token counts -# (n_groups = E for no-EP, -# E_local for EP) -# "ep_active" bool (carried as a Python flag, -# not in the dict; passed -# alongside) -# PURE_JAX backend: -# "sorted_indices" [num_real + padding] argsort indices -# "routing_weights" [num_tokens, topk] per-token-per-expert weights -# TRITON backend: -# "row_id_map" [num_tokens, 2*E + 1] -# "pad_offsets" [E] or None -# "merging_probs" [num_tokens, E] -# EP-only: -# "all_shards_tokens_per_expert" [num_ep, E] -# "local_perm_row_id_map" [recv_buffer_rows] -# "local_perm_inv_row_id_map" [recv_buffer_rows] -# -# NOTE: per-shard compile-time-constant shapes (num_real_tokens, -# padding_size, pre/post_a2a_buffer_shape) are NOT stored in this -# dict; they are recomputed in _body_fwd/_body_bwd via -# _compute_static_shape_info and passed as Python ints / int tuples to -# the dispatch/combine helpers. Storing them in the dict would cause -# JAX's pytree-flatten across the shard_map boundary to coerce them -# into JitTracer 0-d arrays, which breaks Python-level control flow -# (e.g. ``if padding > 0``) and ``jnp.zeros(shape)`` in the bwd. -# -# See :class:`_BodyCtx` (NamedTuple) for the ctx layout and field -# documentation. :func:`_build_ctx_specs` returns a matching ``_BodyCtx`` -# of ``P(...)`` specs so shard_map's value/spec trees line up -# leaf-for-leaf. - - -# ============================================================================= -# Static shape helper -# ============================================================================= -# -# A set of per-shard shape/size values that the dispatch and combine -# helpers (both fwd and bwd) need. They're all derivable from existing -# static args, so we recompute them in both ``_body_fwd`` and -# ``_body_bwd`` and pass them as Python ints / int-tuples through -# explicit kwargs. We MUST NOT stash them inside the dynamic -# ``state`` / ``ctx`` dict: when the dict crosses the EP shard_map's -# out_specs/in_specs boundary, JAX's pytree-flatten coerces any Python -# int leaves into traced 0-d arrays, which then breaks dependent Python -# code in the bwd (e.g. ``if padding > 0`` and ``jnp.zeros(shape)``). - - -@dataclass(frozen=True) -class _StaticShapeInfo: - """Per-shard compile-time-constant shape info used by dispatch / - combine fwd and bwd. Fields are Python ints / int tuples (NOT jnp - arrays) so they can be passed as ordinary static keyword args. - - Attributes - ---------- - num_real_tokens : int - Per-shard count of real (non-padding) permuted tokens, - i.e. ``per_shard_num_tokens * num_experts_per_tok``. - padding_size : int - Per-shard number of alignment-padding tokens appended to the - sort buffer (``num_experts * (align_size - 1)`` when - ``align_size > 0``, else ``0``). - pre_a2a_buffer_shape : tuple[int, int] - ``(num_real_tokens + padding_size, hidden)`` -- the per-shard - shape of the sorted-inputs buffer sent over the EP - ragged_all_to_all in the fwd direction. - post_a2a_buffer_shape : Optional[tuple[int, int]] - ``(recv_buffer_rows, hidden)`` when EP is active, ``None`` - otherwise. - """ - - num_real_tokens: int - padding_size: int - pre_a2a_buffer_shape: Tuple[int, int] - post_a2a_buffer_shape: Optional[Tuple[int, int]] - - -def _compute_static_shape_info( - *, - batch_size: int, - sequence_length: int, - hidden: int, - num_experts: int, - num_experts_per_tok: int, - align_size: int, - ep_active: bool, - num_ep: int = 1, - fsdp_sizes: Tuple[int, ...] = (), - recv_buffer_rows: int = 0, - batch_is_per_shard: bool = True, -) -> _StaticShapeInfo: - """Build a :class:`_StaticShapeInfo` for the current rank. - - ``batch_is_per_shard`` controls whether ``batch_size`` is already - sharded (True -- e.g. when this is called from inside a shard_map - body, where ``x.shape[0]`` reports the per-shard batch size) or - global (False -- e.g. when computing from x.shape outside the - shard_map body). - """ - if ep_active and not batch_is_per_shard: - dp_size = math.prod(fsdp_sizes) if fsdp_sizes else 1 - per_shard_batch = batch_size // (num_ep * dp_size) - else: - per_shard_batch = batch_size - per_shard_num_tokens = per_shard_batch * sequence_length - num_real_tokens = per_shard_num_tokens * num_experts_per_tok - padding_size = num_experts * (align_size - 1) if align_size > 0 else 0 - pre_a2a_buffer_shape = (num_real_tokens + padding_size, hidden) - post_a2a_buffer_shape = (recv_buffer_rows, hidden) if ep_active else None - return _StaticShapeInfo( - num_real_tokens=num_real_tokens, - padding_size=padding_size, - pre_a2a_buffer_shape=pre_a2a_buffer_shape, - post_a2a_buffer_shape=post_a2a_buffer_shape, - ) - - -# ============================================================================= -# Dispatch / combine helpers (no VJP boundary -- pure Python) -# ============================================================================= - - -def _dispatch( - inputs_2d: jnp.ndarray, - sparse_probs: jnp.ndarray, - routing_map: jnp.ndarray, - *, - backend: PermutationBackend, - num_experts: int, - num_experts_per_tok: int, - align_size: int, - # EP-only: - ep_active: bool, - ep_axis: Optional[str], - num_ep: int, - recv_buffer_rows: int, - shard_id: Optional[jnp.ndarray] = None, -) -> Tuple[jnp.ndarray, dict]: - """``permute -> (a2a -> local_permute) iff ep_active``. - - Returns ``(sorted_x, state)`` where ``sorted_x`` has shape - ``[buffer_rows, hidden]`` -- ``E`` groups (no-EP) or ``E_local`` groups - (EP) -- and ``state`` is a dict carrying everything :func:`_combine` - and the bwd helpers need to reverse the operation. - - Bypasses the ``custom_vjp``-wrapped public ``token_dispatch`` / - ``pure_jax_token_dispatch`` wrappers (well, mostly: PURE_JAX still - composes through ``pure_jax_token_dispatch`` because that helper has - no ``custom_vjp`` itself -- only its inner ``_sort_activations`` does, - which is fine since we never auto-diff through it from this layer). - For TRITON we call the underlying ``permute_with_mask_map`` / - ``permute_with_mask_map_and_pad`` primitives directly. - """ - num_tokens, hidden = inputs_2d.shape - topk = num_experts_per_tok - - # Backend-specific residuals collected here, then packaged into the - # appropriate _*DispatchState below. - sorted_indices = None - routing_weights_kept = None - row_id_map = None - pad_offsets = None - merging_probs = None - - # ------------------------------------------------------------------ - # Step 1: global permute (every shard routes its own tokens over the - # full expert axis). Backend-specific. - # ------------------------------------------------------------------ - if backend is PermutationBackend.PURE_JAX: - selected_experts, routing_weights = routing_map_to_selected_experts( - sparse_probs, routing_map, topk - ) - sorted_inputs, perm_state, group_sizes = pure_jax_token_dispatch( - inputs_2d, - selected_experts, - num_experts=num_experts, - num_experts_per_tok=topk, - align_size=align_size, - ) - # NOTE: ``perm_state.num_real_tokens`` and ``perm_state.padding_size`` - # are compile-time Python ints; intentionally NOT stored in the - # returned state (would be coerced to JitTracer 0-d arrays under - # the EP shard_map's pytree flatten). Recompute via - # ``_compute_static_shape_info`` in the bwd / EP-combine - # call sites that need them. - sorted_indices = perm_state.sorted_indices - routing_weights_kept = routing_weights - else: - # TRITON backend -- inline the underlying primitive sequence - # (mirrors ``_token_dispatch_fwd_rule`` but exposes the residuals - # to our ctx instead of saving them inside another custom_vjp). - num_out_tokens = num_tokens * topk - row_id_map = make_row_id_map(routing_map, num_tokens, num_experts) - tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32) - if align_size > 0: - target_tokens_per_expert = ( - jnp.ceil(tokens_per_expert / align_size) * align_size - ).astype(jnp.int32) - pad_lengths = target_tokens_per_expert - tokens_per_expert - cum_pad = jnp.cumsum(pad_lengths) - pad_offsets = jnp.concatenate([jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]]) - worst_case_out_tokens = ( - (num_out_tokens + num_experts * (align_size - 1)) // align_size - ) * align_size - sorted_inputs, _ = permute_with_mask_map_and_pad( - inputs_2d, - row_id_map, - None, - pad_offsets, - num_tokens, - num_experts, - worst_case_out_tokens, - hidden, - align_size=align_size, - ) - group_sizes = target_tokens_per_expert - else: - sorted_inputs, _ = permute_with_mask_map( - inputs_2d, - row_id_map, - None, - num_tokens, - num_experts, - num_out_tokens, - hidden, - ) - pad_offsets = None - group_sizes = tokens_per_expert - merging_probs = sparse_probs - - def _build_state(group_sizes_val, ep_all=None, ep_local=None): - if backend is PermutationBackend.PURE_JAX: - return _PureJaxDispatchState( - group_sizes=group_sizes_val, - sorted_indices=sorted_indices, - routing_weights=routing_weights_kept, - all_shards_tokens_per_expert=ep_all, - local_perm_row_id_map=ep_local, - ) - return _TritonDispatchState( - group_sizes=group_sizes_val, - row_id_map=row_id_map, - pad_offsets=pad_offsets, - merging_probs=merging_probs, - all_shards_tokens_per_expert=ep_all, - local_perm_row_id_map=ep_local, - ) - - if not ep_active: - return sorted_inputs, _build_state(group_sizes) - - # ------------------------------------------------------------------ - # Step 2 (EP only): all_gather per-expert counts so every shard knows - # the [num_ep, num_experts] token-count matrix. - # ------------------------------------------------------------------ - all_shards_tokens_per_expert = jax.lax.all_gather( - group_sizes[None, :], - axis_name=ep_axis, - axis=0, - tiled=True, - ) - - # ------------------------------------------------------------------ - # Step 3 (EP only): forward ragged_all_to_all over the EP axis. - # ------------------------------------------------------------------ - in_off, send_sz, out_off, recv_sz = compute_ragged_all_to_all_params( - all_shards_tokens_per_expert, shard_id, num_ep - ) - post_a2a_buffer_shape = (recv_buffer_rows, hidden) - recv_buf = jnp.zeros(post_a2a_buffer_shape, dtype=sorted_inputs.dtype) - x_recv = jax.lax.ragged_all_to_all( - sorted_inputs, recv_buf, in_off, send_sz, out_off, recv_sz, axis_name=ep_axis - ) - - # ------------------------------------------------------------------ - # Step 4 (EP only): local permute -- (source_shard, expert) -> - # (expert, shard). Inlined ``local_permute_after_a2a`` so we control - # both the row_id_map and its inverse for the bwd. - # ------------------------------------------------------------------ - num_experts_local = num_experts // num_ep - local_expert_start = shard_id * num_experts_local - local_expert_columns = jax.lax.dynamic_slice( - all_shards_tokens_per_expert, - start_indices=(0, local_expert_start), - slice_sizes=(num_ep, num_experts_local), - ) - split_sizes = local_expert_columns.reshape(-1) # source-major - indices_matrix = jnp.arange(num_ep * num_experts_local, dtype=jnp.int32).reshape( - num_ep, num_experts_local - ) - sorted_chunk_indices = indices_matrix.T.reshape(-1) # source-major -> expert-major - num_chunks = num_ep * num_experts_local - # Build a SINGLE row_id_map. ``is_forward=True`` permutes - # source-major -> expert-major; ``is_forward=False`` is the exact - # inverse (this is exactly what ``_sort_chunks_by_index_bwd_rule`` - # uses on the saved residual). _MoEBlock builds two row_id_maps - # only because it calls ``sort_chunks_by_index`` twice -- once in - # ``local_permute_after_a2a`` and again in ``local_unpermute_before_a2a``; - # each of those wrappers calls ``make_chunk_sort_map`` internally. - # Here we share one map across (fwd permute, fwd inverse-permute, - # bwd permute, bwd inverse-permute). - local_perm_row_id_map = make_chunk_sort_map( - split_sizes, sorted_chunk_indices, recv_buffer_rows, num_chunks - ) - sorted_x, _ = sort_chunks_by_map( - x_recv, local_perm_row_id_map, None, recv_buffer_rows, hidden, is_forward=True - ) - local_group_sizes = jnp.sum(local_expert_columns, axis=0) - - # NOTE: pre_a2a_buffer_shape and post_a2a_buffer_shape are compile- - # time int tuples; intentionally NOT stored in the returned state - # (would be coerced to JitTracer 0-d arrays under the EP shard_map's - # pytree flatten). Recompute via ``_compute_static_shape_info`` in - # the bwd call sites that need them. For EP, ``group_sizes`` here is - # the per-local-expert count (the FFN runs over E_local groups, not - # E). The global ``group_sizes`` lives inside - # ``all_shards_tokens_per_expert`` if anyone needs it for - # diagnostics. - return sorted_x, _build_state( - local_group_sizes, - ep_all=all_shards_tokens_per_expert, - ep_local=local_perm_row_id_map, - ) - - -def _combine( - expert_outputs: jnp.ndarray, - state: _DispatchState, - *, - backend: PermutationBackend, - ep_active: bool, - batch_size: int, - sequence_length: int, - dtype: jnp.dtype, - num_experts_per_tok: int, - # Per-shard compile-time-constant shape info (Python ints / int tuples). - # Computed by _compute_static_shape_info in the caller, passed here - # rather than stored in ``state`` to survive shard_map crossings. - num_real_tokens: int, - padding_size: int, - pre_a2a_buffer_shape: Tuple[int, int], - # EP-only: - ep_axis: Optional[str], - shard_id: Optional[jnp.ndarray] = None, - num_ep: int = 1, -) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Inverse of :func:`_dispatch`. - - Returns ``(output, expert_outputs_post_ep)``. ``output`` is the - ``[B, S, H]`` combined activations. ``expert_outputs_post_ep`` is - the FFN-output tensor in the shape that Step 3 of the combine - actually consumed (i.e. after the reverse ragged_all_to_all on EP - runs, or the original input on non-EP). The caller stashes this as - the bwd residual so that ``_combine_bwd``'s Step-3 inverse sees - the same tensor the forward Step 3 used. - """ - if ep_active: - # Step 1 (EP): inverse local permute. Reuse the SAME row_id_map - # built in _dispatch by setting is_forward=False (this is the - # exact inverse, identical to what - # ``_sort_chunks_by_index_bwd_rule`` does with the saved residual). - recv_buffer_rows, hidden = expert_outputs.shape - x_send_back, _ = sort_chunks_by_map( - expert_outputs, - state.local_perm_row_id_map, - None, - recv_buffer_rows, - hidden, - is_forward=False, - ) - # Step 2 (EP): reverse ragged_all_to_all. - in_off_r, send_sz_r, out_off_r, recv_sz_r = compute_reverse_ragged_all_to_all_params( - state.all_shards_tokens_per_expert, shard_id, num_ep - ) - send_back_buf = jnp.zeros(pre_a2a_buffer_shape, dtype=expert_outputs.dtype) - expert_outputs = jax.lax.ragged_all_to_all( - x_send_back, - send_back_buf, - in_off_r, - send_sz_r, - out_off_r, - recv_sz_r, - axis_name=ep_axis, - ) - - # Step 3: global combine. ``expert_outputs`` here is the post-A2A - # tensor under EP, or the original input under non-EP -- whichever - # value Step 3 actually consumes. Returned as the second tuple - # element so the caller can stash it as the bwd residual. - if backend is PermutationBackend.PURE_JAX: - # Reuse the reference pure-jax implementation; it has no - # custom_vjp on its outer surface so we can call it freely. - perm_state = PureJaxPermState( - sorted_indices=state.sorted_indices, - num_real_tokens=num_real_tokens, - padding_size=padding_size, - ) - output = pure_jax_token_combine( - expert_outputs, - perm_state, - state.routing_weights, - num_experts_per_tok=num_experts_per_tok, - batch_size=batch_size, - sequence_length=sequence_length, - ) - return output, expert_outputs - # TRITON - num_tokens = state.row_id_map.shape[0] - num_experts = (state.row_id_map.shape[1] - 1) // 2 - hidden = expert_outputs.shape[-1] - if state.pad_offsets is not None: - out_2d, _ = unpermute_with_mask_map_and_unpad( - expert_outputs, - state.row_id_map, - state.merging_probs, - None, - state.pad_offsets, - num_tokens, - num_experts, - hidden, - ) - else: - out_2d, _ = unpermute_with_mask_map( - expert_outputs, - state.row_id_map, - state.merging_probs, - None, - num_tokens, - num_experts, - hidden, - ) - return out_2d.reshape(batch_size, sequence_length, hidden).astype(dtype), expert_outputs - - -def _combine_bwd( # pylint: disable=unused-argument - d_output: jnp.ndarray, - state: _DispatchState, - expert_outputs: jnp.ndarray, - *, - backend: PermutationBackend, - ep_active: bool, - batch_size: int, - sequence_length: int, - dtype: jnp.dtype, - num_experts: int, - num_experts_per_tok: int, - # Per-shard compile-time-constant shape info (Python ints / int tuples). - # See ``_compute_static_shape_info`` and the note in ``_dispatch`` - # for why these are kwargs rather than state-dict entries. - num_real_tokens: int, - padding_size: int, - post_a2a_buffer_shape: Optional[Tuple[int, int]], - # EP-only: - ep_axis: Optional[str], - shard_id: Optional[jnp.ndarray] = None, - num_ep: int = 1, -) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Inverse of :func:`_combine` on the cotangent. - - Returns ``(d_expert_outputs, d_routing_weights_or_merging_probs)``. - - ``expert_outputs`` is the *forward* output of the FFN (same value the - fwd handed to :func:`_combine`). It's required by the TRITON - combine_bwd kernel; for PURE_JAX we don't need it but accept it for - a symmetric signature. - """ - # Step 3 inverse: global combine bwd. - d_output_2d = d_output.reshape(-1, d_output.shape[-1]) - if backend is PermutationBackend.PURE_JAX: - # The pure-jax combine is: - # unsort = _sort_activations(expert_outputs, argsort(sorted_indices)) - # if pad: unsort = unsort[:num_real] - # reshape -> einsum BKE,BK -> BE -> reshape to BSE - # Hand-derive the bwd in plain JAX (no custom_vjp involved): - unsort_indices = jnp.argsort(state.sorted_indices) - topk = num_experts_per_tok - num_real = num_real_tokens - padding = padding_size - # Recover the unsorted intermediate that the fwd produced (we - # need it for the d_routing_weights pullback). Apply the same - # gather the fwd did. - unsort_intermediate = expert_outputs[unsort_indices] - if padding > 0: - unsort_intermediate = unsort_intermediate[:num_real] - # Bwd of einsum/reshape: - # output[B, E] = sum_K intermediate[B, K, E] * weights[B, K] - # d_intermediate[B, K, E] = d_output[B, E] * weights[B, K] - # d_weights[B, K] = sum_E d_output[B, E] * intermediate[B, K, E] - rw = state.routing_weights.reshape(-1, topk) - intermediate_3d = unsort_intermediate.reshape(rw.shape[0], topk, -1) - rw_cast = rw.astype(intermediate_3d.dtype) - d_intermediate_3d = jnp.einsum("BE,BK -> BKE", d_output_2d, rw_cast) - d_routing_weights = jnp.einsum("BE,BKE -> BK", d_output_2d, intermediate_3d).astype( - state.routing_weights.dtype - ) - d_routing_weights = d_routing_weights.reshape(state.routing_weights.shape) - d_unsort_intermediate = d_intermediate_3d.reshape(num_real, -1) - # Pad back with zeros if the fwd stripped padding. - if padding > 0: - d_unsort_intermediate = jnp.concatenate( - [ - d_unsort_intermediate, - jnp.zeros( - (padding, d_unsort_intermediate.shape[-1]), - dtype=d_unsort_intermediate.dtype, - ), - ], - axis=0, - ) - # Bwd of the gather is gather-by-original-indices: - # sorted = unsort[argsort(sorted_indices)] - # d_sorted = scatter d_unsort via argsort(sorted_indices) - # = d_unsort[sorted_indices] (gather by original sorted_indices, - # which is the inverse of argsort(sorted_indices)). - d_expert_outputs_global = d_unsort_intermediate[state.sorted_indices] - else: - # TRITON combine bwd: requires fwd_input (expert_outputs). - num_tokens = state.row_id_map.shape[0] - n_experts = (state.row_id_map.shape[1] - 1) // 2 - hidden = d_output_2d.shape[-1] - num_out_tokens = expert_outputs.shape[0] - if state.pad_offsets is not None: - d_expert_outputs_global, d_merging_probs = unpermute_bwd_with_merging_probs_and_unpad( - d_output_2d, - state.row_id_map, - expert_outputs, - state.merging_probs, - state.pad_offsets, - num_tokens, - n_experts, - num_out_tokens, - hidden, - ) - # The kernel only writes positions tokens map to; padded - # positions may contain NaN. Replace with zeros (matches - # ``_token_combine_bwd_rule``). - d_expert_outputs_global = jnp.where( - jnp.isnan(d_expert_outputs_global), 0.0, d_expert_outputs_global - ) - else: - d_expert_outputs_global, d_merging_probs = unpermute_bwd_with_merging_probs( - d_output_2d, - state.row_id_map, - expert_outputs, - state.merging_probs, - num_tokens, - n_experts, - num_out_tokens, - hidden, - ) - d_routing_weights = d_merging_probs - - if not ep_active: - return d_expert_outputs_global, d_routing_weights - - # Step 2 (EP) inverse: bwd of reverse ragged_all_to_all is a forward - # ragged_all_to_all using the SAME forward parameters (sender / - # receiver roles swap from the reverse direction back to forward). - in_off_f, send_sz_f, out_off_f, recv_sz_f = compute_ragged_all_to_all_params( - state.all_shards_tokens_per_expert, shard_id, num_ep - ) - recv_buf_for_bwd = jnp.zeros(post_a2a_buffer_shape, dtype=d_expert_outputs_global.dtype) - d_x_send_back = jax.lax.ragged_all_to_all( - d_expert_outputs_global, - recv_buf_for_bwd, - in_off_f, - send_sz_f, - out_off_f, - recv_sz_f, - axis_name=ep_axis, - ) - # Step 1 (EP) inverse: combine fwd applied is_forward=False; the - # bwd is is_forward=True with the SAME row_id_map. - recv_buffer_rows, hidden = d_x_send_back.shape - d_expert_outputs, _ = sort_chunks_by_map( - d_x_send_back, - state.local_perm_row_id_map, - None, - recv_buffer_rows, - hidden, - is_forward=True, - ) - return d_expert_outputs, d_routing_weights - - -def _dispatch_bwd( - d_sorted_x: jnp.ndarray, - state: _DispatchState, - inputs_2d_shape: Tuple[int, ...], - *, - backend: PermutationBackend, - ep_active: bool, - num_experts: int, - num_experts_per_tok: int, - # Per-shard compile-time-constant shape info (Python ints / int tuples). - # See ``_compute_static_shape_info`` and the note in ``_dispatch`` - # for why these are kwargs rather than state-dict entries. - num_real_tokens: int, - padding_size: int, - pre_a2a_buffer_shape: Tuple[int, int], - # EP-only: - ep_axis: Optional[str], - shard_id: Optional[jnp.ndarray] = None, - num_ep: int = 1, -) -> jnp.ndarray: - """Inverse of :func:`_dispatch` on the cotangent. Returns ``d_inputs_2d``. - - The probs path through dispatch is always discarded (PURE_JAX never - threads probs through dispatch; TRITON technically does but the - caller drops ``permuted_probs``, so its cotangent is structurally - zero). The probs gradient instead flows back through - :func:`_combine_bwd`. - """ - if ep_active: - # Step 4 inverse: dispatch fwd applied is_forward=True; bwd is - # is_forward=False with the SAME row_id_map. - recv_buffer_rows, hidden = d_sorted_x.shape - d_x_recv, _ = sort_chunks_by_map( - d_sorted_x, - state.local_perm_row_id_map, - None, - recv_buffer_rows, - hidden, - is_forward=False, - ) - # Step 3 inverse: bwd of forward ragged_a2a is the reverse-direction - # ragged_a2a using the SAME params with sender/receiver swapped. - in_off_r, send_sz_r, out_off_r, recv_sz_r = compute_reverse_ragged_all_to_all_params( - state.all_shards_tokens_per_expert, shard_id, num_ep - ) - recv_buf_pre = jnp.zeros(pre_a2a_buffer_shape, dtype=d_x_recv.dtype) - d_sorted_x = jax.lax.ragged_all_to_all( - d_x_recv, - recv_buf_pre, - in_off_r, - send_sz_r, - out_off_r, - recv_sz_r, - axis_name=ep_axis, - ) - - # Step 1 inverse: global permute bwd. - if backend is PermutationBackend.PURE_JAX: - # Fwd was: replicated = repeat(inputs_2d, topk, axis=0) - # padded = pad(replicated, (0, padding_size)) - # sorted = padded[sorted_indices] - # Bwd: d_padded = scatter via sorted_indices - # = d_sorted[argsort(sorted_indices)] - # d_replicated = d_padded[:num_real] - # d_inputs_2d = d_replicated.reshape(T, topk, H).sum(axis=1) - sorted_indices = state.sorted_indices - num_real = num_real_tokens - padding = padding_size - topk = num_experts_per_tok - unsort_indices = jnp.argsort(sorted_indices) - d_padded = d_sorted_x[unsort_indices] - if padding > 0: - d_replicated = d_padded[:num_real] - else: - d_replicated = d_padded - num_tokens = inputs_2d_shape[0] - hidden = inputs_2d_shape[-1] - d_inputs_2d = d_replicated.reshape(num_tokens, topk, hidden).sum(axis=1) - return d_inputs_2d - - # TRITON: bwd is unpermute_with_mask_map[_and_unpad]. - num_tokens = inputs_2d_shape[0] - hidden = inputs_2d_shape[-1] - if state.pad_offsets is not None: - d_inputs_2d, _ = unpermute_with_mask_map_and_unpad( - d_sorted_x, - state.row_id_map, - None, - None, - state.pad_offsets, - num_tokens, - num_experts, - hidden, - ) - else: - d_inputs_2d, _ = unpermute_with_mask_map( - d_sorted_x, - state.row_id_map, - None, - None, - num_tokens, - num_experts, - hidden, - ) - return d_inputs_2d - - -# ============================================================================= -# Per-shard body +# Per-shard FFN body (runs inside shard_map) # ============================================================================= -def _body_fwd( # pylint: disable=unused-argument - captured: dict, +def _ffn_fwd_per_shard( + recv_tokens_local: jnp.ndarray, + recv_topk_weights_local: jnp.ndarray, + wi_0: jnp.ndarray, + wi_1: jnp.ndarray, + wo: jnp.ndarray, + wi_0_bias: Optional[jnp.ndarray], + wi_1_bias: Optional[jnp.ndarray], + wo_bias: Optional[jnp.ndarray], *, - # Statics - num_experts: int, - num_experts_per_tok: int, + num_local_experts: int, + slots_per_expert: int, activation_type: str, - score_function: ScoreFunction, - use_pre_softmax: bool, - num_groups: Optional[int], - group_topk: Optional[int], - scaling_factor: float, - aux_loss_coeff: float, - permutation_backend: PermutationBackend, - align_size: int, - gate_inside_vjp: bool, - quantizer_sets: Tuple[QuantizerSet, QuantizerSet, QuantizerSet], - dtype: jnp.dtype, - # EP-only statics - ep_active: bool, - ep_axis: Optional[str], - data_parallelism_axes: Tuple[str, ...], - fsdp_sizes: Tuple[int, ...], - num_ep: int, - num_experts_local: int, - recv_buffer_rows: int, - apply_topk_weights_early: bool = False, -) -> Tuple[jnp.ndarray, jnp.ndarray, dict]: - """Per-shard forward body. Returns ``(output, aux_loss, ctx_dict)``. + apply_topk_weights_early: bool, +): + """Per-shard FFN forward. - ``aux_loss`` is always materialized (zeros scalar when disabled) so - the ``shard_map``'s ``out_specs`` has a static structure. + Operates on the shard-local ``[1, recv_pr, H]`` slice that + ``tex.ep_dispatch`` produces. Returns the expert outputs (shaped + ``[1, recv_pr, H_out]`` so the surrounding ``shard_map`` reassembles + them as ``[num_procs, recv_pr, H_out]``) plus the residuals consumed + by the bwd. """ - if apply_topk_weights_early: - # Requires row-aligned per-token weights at the FFN intermediate; - # only available on the TE EP (tex.ep_dispatch) path. - raise NotImplementedError( - "apply_topk_weights_early=True is supported only with the TE EP " - "(tex.ep_dispatch / tex.ep_combine) backend." - ) - if not gate_inside_vjp: - raise NotImplementedError( - "gate_inside_vjp=False is deferred to a follow-up PR; for now" - " the gate GEMM lives inside the MoE VJP." - ) - - x = captured["inputs"] - gate_kernel = captured["gate_kernel"] - wi_0 = captured["wi_0"] - wi_1 = captured["wi_1"] - wo = captured["wo"] - wi_0_bias = captured.get("wi_0_bias") - wi_1_bias = captured.get("wi_1_bias") - wo_bias = captured.get("wo_bias") - expert_bias = captured.get("expert_bias") - - batch_size, sequence_length, hidden = x.shape - - # ---------------- Stage 1: gate ---------------- - gate_kernel_cast = gate_kernel.astype(x.dtype) - gate_logits = jnp.einsum("bsh,he->bse", x, gate_kernel_cast) # [B, S, E] - # tex.fused_topk_with_score_function_* requires rank-2 input. - logits_2d = gate_logits.reshape(-1, num_experts) - inputs_2d = x.reshape(-1, hidden) - - # ---------------- Stage 2: routing ---------------- - # Under EP, expert_bias is sharded P(ep_axis); the router needs the - # full E-dim view, so all_gather it. - if ep_active and expert_bias is not None: - full_expert_bias = jax.lax.all_gather(expert_bias, axis_name=ep_axis, tiled=True) - else: - full_expert_bias = expert_bias - # Pass an empty array sentinel when expert_bias is unused (the - # underlying primitive expects a real ndarray, not None). - eb_arg = ( - full_expert_bias if full_expert_bias is not None else jnp.zeros((0,), dtype=jnp.float32) - ) - sparse_probs, routing_map, saved_scores = tex.fused_topk_with_score_function_fwd( - logits_2d, - topk=num_experts_per_tok, - use_pre_softmax=use_pre_softmax, - num_groups=-1 if num_groups is None else num_groups, - group_topk=-1 if group_topk is None else group_topk, - scaling_factor=scaling_factor, - score_function=score_function, - expert_bias=eb_arg, - compute_aux_scores=False, - ) - sparse_probs = sparse_probs.astype(dtype) + hidden = recv_tokens_local.shape[-1] + sorted_x = recv_tokens_local.reshape(-1, hidden) + recv_w_flat = recv_topk_weights_local.reshape(-1) + local_group_sizes = jnp.full((num_local_experts,), slots_per_expert, dtype=jnp.int32) - # ---------------- Stage 2b: aux loss ---------------- - if aux_loss_coeff > 0.0: - if ep_active: - collective_axes: Any = ( - ep_axis - if not data_parallelism_axes - else (*data_parallelism_axes, ep_axis) - ) - global_logits_2d = jax.lax.all_gather( - logits_2d, axis_name=collective_axes, axis=0, tiled=True - ) - _, global_routing_map, _ = tex.fused_topk_with_score_function_fwd( - global_logits_2d, - topk=num_experts_per_tok, - use_pre_softmax=use_pre_softmax, - num_groups=-1 if num_groups is None else num_groups, - group_topk=-1 if group_topk is None else group_topk, - scaling_factor=scaling_factor, - score_function=score_function, - expert_bias=eb_arg, - compute_aux_scores=False, - ) - aux_tokens_per_expert = jnp.sum(global_routing_map.astype(jnp.int32), axis=0) - aux_logits_for_score = global_logits_2d - else: - aux_tokens_per_expert = jnp.sum(routing_map.astype(jnp.int32), axis=0) - aux_logits_for_score = logits_2d - # Aux-side scores: clean per-expert scores (no grouped routing, - # no bias). compute_aux_scores=True takes a separate path that - # ignores the grouping knobs. - aux_probs, _aux_routing_map, aux_saved_scores = tex.fused_topk_with_score_function_fwd( - aux_logits_for_score.astype(jnp.float32), - topk=num_experts_per_tok, - use_pre_softmax=False, - num_groups=-1, - group_topk=-1, - scaling_factor=1.0, - score_function=score_function, - expert_bias=jnp.zeros((0,), dtype=jnp.float32), - compute_aux_scores=True, - ) - aux_loss, aux_const_buf = tex.fused_moe_aux_loss_fwd( - aux_probs.astype(jnp.float32), - aux_tokens_per_expert.astype(jnp.int32), - topk=num_experts_per_tok, - coeff=aux_loss_coeff, - ) - else: - aux_loss = jnp.zeros((), dtype=dtype) - aux_const_buf = None - aux_tokens_per_expert = None - aux_logits_for_score = None - aux_saved_scores = None + wi_0 = wi_0.astype(sorted_x.dtype) + wi_1 = wi_1.astype(sorted_x.dtype) + wo = wo.astype(sorted_x.dtype) - # ---------------- Stage 3: dispatch ---------------- - shard_id = jax.lax.axis_index(ep_axis) if ep_active else None - sorted_x, dispatch_state = _dispatch( - inputs_2d, - sparse_probs, - routing_map, - backend=permutation_backend, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - align_size=align_size, - ep_active=ep_active, - ep_axis=ep_axis, - num_ep=num_ep, - recv_buffer_rows=recv_buffer_rows, - shard_id=shard_id, - ) - local_group_sizes = dispatch_state.group_sizes - - # ---------------- Stage 4: per-expert FFN (inlined) ---------------- - q_set_w0, q_set_w1, q_set_wo = quantizer_sets - if q_set_w0 == noop_quantizer_set: - wi_0 = wi_0.astype(sorted_x.dtype) - if q_set_w1 == noop_quantizer_set: - wi_1 = wi_1.astype(sorted_x.dtype) - if q_set_wo == noop_quantizer_set: - wo = wo.astype(sorted_x.dtype) - - # Fused gate+up projection: stack wi_0 / wi_1 on a new axis-(-2) so the - # downstream split is a slice on the (unsharded) stack axis. concat on - # axis=-1 would cross the M axis and force a reshard when M is TP-sharded. - # - # FP8/MXFP8 caveat: per-expert amax is computed over [H, 2, M] rather than - # [H, M] for each of wi_0 / wi_1 separately, so the representable range for - # one half may shift slightly vs. an unfused pair of casts. - inter_M = wi_0.shape[-1] wi_combined = jnp.stack([wi_0, wi_1], axis=-2) wi_combined_bias = ( jnp.stack([wi_0_bias, wi_1_bias], axis=-2) if wi_0_bias is not None else None ) - casted_sorted_x = tex.grouped_quantize(sorted_x, q_set_w0.x, local_group_sizes, flatten_axis=-1) - casted_wi = tex.grouped_quantize(wi_combined, q_set_w0.kernel, flatten_axis=-1) + + q_set = noop_quantizer_set + casted_sorted_x = tex.grouped_quantize(sorted_x, q_set.x, local_group_sizes, flatten_axis=-1) + casted_wi = tex.grouped_quantize(wi_combined, q_set.kernel, flatten_axis=-1) combined_out = tex.grouped_gemm( casted_sorted_x.get_tensor(usage=TensorUsage.LHS), casted_wi.get_tensor(usage=TensorUsage.RHS), @@ -1136,20 +190,22 @@ def _body_fwd( # pylint: disable=unused-argument up_proj_out = combined_out[..., 1, :] casted_sorted_x_lhs_trans = casted_sorted_x.get_tensor(usage=TensorUsage.LHS_TRANS) casted_wi_rhs_trans = casted_wi.get_tensor(usage=TensorUsage.RHS_TRANS) - if isinstance(casted_sorted_x_lhs_trans, ScaledTensor): - casted_sorted_x_lhs_trans = casted_sorted_x_lhs_trans.checkpoint(q_set_w0.x) - if isinstance(casted_wi_rhs_trans, ScaledTensor): - casted_wi_rhs_trans = casted_wi_rhs_trans.checkpoint(q_set_w0.kernel) - # Activation: intermediate = act(gate_proj_out) * up_proj_out act_fn = _convert_to_activation_function(activation_type) intermediate = act_fn(gate_proj_out) * up_proj_out - # GEMM 3: expert_outputs = intermediate @ wo + if apply_topk_weights_early: + # Fold the per-token combine weights into the FFN intermediate; + # the downstream wo GEMM is linear so this is equivalent to the + # late-weighting path, modulo elementwise op fusion gains. + w_b = recv_w_flat[:, None] + mask_b = (recv_w_flat != 0).astype(intermediate.dtype)[:, None] + intermediate = intermediate * w_b * mask_b + casted_intermediate = tex.grouped_quantize( - intermediate, q_set_wo.x, local_group_sizes, flatten_axis=-1 + intermediate, q_set.x, local_group_sizes, flatten_axis=-1 ) - casted_wo = tex.grouped_quantize(wo, q_set_wo.kernel, flatten_axis=-1) + casted_wo = tex.grouped_quantize(wo, q_set.kernel, flatten_axis=-1) expert_outputs = tex.grouped_gemm( casted_intermediate.get_tensor(usage=TensorUsage.LHS), casted_wo.get_tensor(usage=TensorUsage.RHS), @@ -1158,234 +214,100 @@ def _body_fwd( # pylint: disable=unused-argument ) casted_intermediate_lhs_trans = casted_intermediate.get_tensor(usage=TensorUsage.LHS_TRANS) casted_wo_rhs_trans = casted_wo.get_tensor(usage=TensorUsage.RHS_TRANS) - if isinstance(casted_intermediate_lhs_trans, ScaledTensor): - casted_intermediate_lhs_trans = casted_intermediate_lhs_trans.checkpoint(q_set_wo.x) - if isinstance(casted_wo_rhs_trans, ScaledTensor): - casted_wo_rhs_trans = casted_wo_rhs_trans.checkpoint(q_set_wo.kernel) - - # ---------------- Stage 5: combine ---------------- - # Compute per-shard static shape info once and pass through both - # _combine and (later) the bwd helpers via kwargs -- never via the - # state dict, which gets pytree-flattened across shard_map and would - # coerce Python ints into JitTracer 0-d arrays. - _static_shape = _compute_static_shape_info( - batch_size=batch_size, - sequence_length=sequence_length, - hidden=hidden, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - align_size=align_size, - ep_active=ep_active, - num_ep=num_ep, - fsdp_sizes=fsdp_sizes, - recv_buffer_rows=recv_buffer_rows, - ) - # ``expert_outputs_residual`` is the post-A2A FFN-output tensor that - # Step 3 of the combine actually consumed. Saving this (rather than - # the pre-A2A shard-local FFN output) is what makes - # ``_combine_bwd``'s Step-3 inverse see the same value the forward - # Step 3 saw -- otherwise EP + TRITON yields wrong d_expert_outputs. - output, expert_outputs_residual = _combine( - expert_outputs, - dispatch_state, - backend=permutation_backend, - ep_active=ep_active, - batch_size=batch_size, - sequence_length=sequence_length, - dtype=dtype, - num_experts_per_tok=num_experts_per_tok, - num_real_tokens=_static_shape.num_real_tokens, - padding_size=_static_shape.padding_size, - pre_a2a_buffer_shape=_static_shape.pre_a2a_buffer_shape, - ep_axis=ep_axis, - shard_id=shard_id, - num_ep=num_ep, - ) - # ---------------- Build ctx ---------------- - aux_enabled = aux_loss_coeff > 0.0 - ctx = _BodyCtx( - x=x, - gate_kernel=gate_kernel, - logits_2d=logits_2d, - saved_scores=saved_scores, - routing_map=routing_map, - dispatch=dispatch_state, - casted_sorted_x_lhs_trans=casted_sorted_x_lhs_trans, - casted_wi_rhs_trans=casted_wi_rhs_trans, - gate_proj_out=gate_proj_out, - up_proj_out=up_proj_out, - casted_intermediate_lhs_trans=casted_intermediate_lhs_trans, - casted_wo_rhs_trans=casted_wo_rhs_trans, - expert_outputs=expert_outputs_residual, - local_group_sizes=local_group_sizes, - expert_bias=expert_bias if expert_bias is not None else None, - aux_const_buf=aux_const_buf if aux_enabled else None, - aux_tokens_per_expert=aux_tokens_per_expert if aux_enabled else None, - aux_logits_for_score=aux_logits_for_score if aux_enabled else None, - aux_saved_scores=aux_saved_scores if aux_enabled else None, + expert_outputs_3d = expert_outputs.reshape(1, expert_outputs.shape[0], expert_outputs.shape[1]) + residuals = ( + casted_sorted_x_lhs_trans, + casted_wi_rhs_trans, + gate_proj_out, + up_proj_out, + casted_intermediate_lhs_trans, + casted_wo_rhs_trans, + local_group_sizes, ) - - return output, aux_loss, ctx - - -def _body_bwd( # pylint: disable=unused-argument - ctx: _BodyCtx, - dy_pair: Tuple[jnp.ndarray, jnp.ndarray], + return expert_outputs_3d, residuals + + +def _ffn_bwd_per_shard( + d_expert_outputs_local: jnp.ndarray, + casted_sorted_x_lhs_trans, + casted_wi_rhs_trans, + gate_proj_out: jnp.ndarray, + up_proj_out: jnp.ndarray, + casted_intermediate_lhs_trans, + casted_wo_rhs_trans, + local_group_sizes: jnp.ndarray, + recv_topk_weights_local: jnp.ndarray, *, - num_experts: int, - num_experts_per_tok: int, activation_type: str, - score_function: ScoreFunction, - use_pre_softmax: bool, - num_groups: Optional[int], - group_topk: Optional[int], - scaling_factor: float, - aux_loss_coeff: float, - permutation_backend: PermutationBackend, - align_size: int, - gate_inside_vjp: bool, - quantizer_sets: Tuple[QuantizerSet, QuantizerSet, QuantizerSet], - dtype: jnp.dtype, - ep_active: bool, - ep_axis: Optional[str], - data_parallelism_axes: Tuple[str, ...], - fsdp_sizes: Tuple[int, ...], - num_ep: int, - num_experts_local: int, - recv_buffer_rows: int, - # Static side info (kept here rather than inside ctx because they're - # python flags / shapes, not array leaves): - has_wi_bias: bool, - has_wo_bias: bool, - has_expert_bias: bool, - x_shape: Tuple[int, ...], - apply_topk_weights_early: bool = False, -) -> dict: - """Per-shard backward body. Returns a dict of grads keyed identically - to the ``captured`` dict consumed by :func:`_body_fwd`.""" - if apply_topk_weights_early: - raise NotImplementedError( - "apply_topk_weights_early=True is supported only with the TE EP " - "(tex.ep_dispatch / tex.ep_combine) backend." - ) - if not gate_inside_vjp: - raise NotImplementedError("gate_inside_vjp=False is deferred to a follow-up PR.") - - d_output, d_aux_loss = dy_pair - # The fused FFN bwd quantizes via ``q_set_w0`` only (one quantize for - # the [E, H, 2, M] stacked wi tensor and one for the [T, 2, M] stacked dgrad), - # so ``q_set_w1`` is intentionally unused here. - q_set_w0, _q_set_w1, q_set_wo = quantizer_sets - batch_size, sequence_length, hidden = x_shape - shard_id = jax.lax.axis_index(ep_axis) if ep_active else None - - # Recompute per-shard static shape info from existing statics - # (Python ints / int tuples). Plumbed via kwargs to _combine_bwd - # and _dispatch_bwd -- NOT through the ctx dict, because the - # dict gets pytree-flattened across the bwd shard_map's in_specs - # and Python ints would be coerced into JitTracer 0-d arrays - # (breaking ``if padding > 0`` and ``jnp.zeros(shape)`` callsites). - # ``batch_size`` here is the GLOBAL batch size (captured in - # ``x_shape`` by the outer fwd rule), hence ``batch_is_per_shard=False``. - _static_shape = _compute_static_shape_info( - batch_size=batch_size, - sequence_length=sequence_length, - hidden=hidden, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - align_size=align_size, - ep_active=ep_active, - num_ep=num_ep, - fsdp_sizes=fsdp_sizes, - recv_buffer_rows=recv_buffer_rows, - batch_is_per_shard=False, - ) + apply_topk_weights_early: bool, + has_bias: bool, +): + """Per-shard FFN backward. - # Compute per-shard input shape: under the EP shard_map body, the - # gradient tensors live at per-shard shape, so the dispatch_bwd - # reshape target and ``d_x_from_dispatch.reshape(x_shape)`` below - # must use the per-shard shape rather than the captured global - # ``x_shape``. - if ep_active: - dp_size = math.prod(fsdp_sizes) if fsdp_sizes else 1 - per_shard_batch = batch_size // (num_ep * dp_size) - per_shard_x_shape: Tuple[int, ...] = (per_shard_batch, sequence_length, hidden) - else: - per_shard_x_shape = x_shape - - # ---------------- Combine bwd ---------------- - d_expert_outputs, d_routing_weights = _combine_bwd( - d_output, - ctx.dispatch, - ctx.expert_outputs, - backend=permutation_backend, - ep_active=ep_active, - batch_size=batch_size, - sequence_length=sequence_length, - dtype=dtype, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - num_real_tokens=_static_shape.num_real_tokens, - padding_size=_static_shape.padding_size, - post_a2a_buffer_shape=_static_shape.post_a2a_buffer_shape, - ep_axis=ep_axis, - shard_id=shard_id, - num_ep=num_ep, - ) + Mirrors :func:`_ffn_fwd_per_shard`. Returns + ``(d_sorted_x [1, recv_pr, H], d_recv_w [1, recv_pr], d_wi_0, d_wi_1, d_wo, + d_wi_0_bias, d_wi_1_bias, d_wo_bias)``. + """ + d_eo_2d = d_expert_outputs_local.reshape(-1, d_expert_outputs_local.shape[-1]) + recv_w_flat = recv_topk_weights_local.reshape(-1) + q_set = noop_quantizer_set - # ---------------- FFN bwd: GEMM 3 (wo) ---------------- - casted_d_eo = tex.grouped_quantize( - d_expert_outputs, q_set_wo.dgrad, ctx.local_group_sizes, flatten_axis=-1 - ) + # wo bwd + casted_d_eo = tex.grouped_quantize(d_eo_2d, q_set.dgrad, local_group_sizes, flatten_axis=-1) d_intermediate = tex.grouped_gemm( casted_d_eo.get_tensor(usage=TensorUsage.LHS), - ctx.casted_wo_rhs_trans, + casted_wo_rhs_trans, contracting_dims=((1,), (2,)), ) d_wo = tex.grouped_gemm( - ctx.casted_intermediate_lhs_trans, + casted_intermediate_lhs_trans, casted_d_eo.get_tensor(usage=TensorUsage.RHS), contracting_dims=((0,), (0,)), ) - d_wo_bias = tex.grouped_dbias(d_expert_outputs, ctx.local_group_sizes) if has_wo_bias else None + d_wo_bias = tex.grouped_dbias(d_eo_2d, local_group_sizes) if has_bias else None - # ---------------- Activation bwd ---------------- - # intermediate = act(gate_proj_out) * up_proj_out - # d(gate_proj_out) = vjp(act, gate_proj_out)(d_intermediate * up_proj_out) - # d(up_proj_out) = d_intermediate * act(gate_proj_out) act_fn = _convert_to_activation_function(activation_type) - act_gate_proj_out, dact_gate_proj_pullback = jax.vjp(act_fn, ctx.gate_proj_out) + if apply_topk_weights_early: + # intermediate' = intermediate * w * mask. Split the cotangent + # across both factors before the activation bwd consumes it. + w_b = recv_w_flat[:, None] + mask_b = (recv_w_flat != 0).astype(d_intermediate.dtype)[:, None] + intermediate_unweighted = act_fn(gate_proj_out) * up_proj_out + d_recv_w_from_intermediate = jnp.sum( + d_intermediate * intermediate_unweighted * mask_b, axis=-1 + ).astype(recv_w_flat.dtype) + d_intermediate = d_intermediate * w_b * mask_b + else: + d_recv_w_from_intermediate = jnp.zeros_like(recv_w_flat) + + # Activation bwd + act_gate_proj_out, dact_gate_proj_pullback = jax.vjp(act_fn, gate_proj_out) d_up_proj_out = d_intermediate * act_gate_proj_out - (d_gate_proj_out,) = dact_gate_proj_pullback(d_intermediate * ctx.up_proj_out) + (d_gate_proj_out,) = dact_gate_proj_pullback(d_intermediate * up_proj_out) - # ---------------- FFN bwd: GEMM 1+2 fused (wi_0 | wi_1) ---------------- - # Mirror of the fwd stack: combine d_gate / d_up on a new axis=-2, - # run one dgrad + one wgrad GEMM, then split on axis=-2. - # d_sorted_x = [d_gate | d_up] @ wi_rhs_trans - # = d_gate @ wi_0^T + d_up @ wi_1^T + # wi bwd (fused gate/up) inter_M = d_gate_proj_out.shape[-1] d_combined = jnp.stack([d_gate_proj_out, d_up_proj_out], axis=-2) casted_d_combined = tex.grouped_quantize( - d_combined, q_set_w0.dgrad, ctx.local_group_sizes, flatten_axis=-1 + d_combined, q_set.dgrad, local_group_sizes, flatten_axis=-1 ) d_sorted_x = tex.grouped_gemm( casted_d_combined.get_tensor(usage=TensorUsage.LHS), - ctx.casted_wi_rhs_trans, + casted_wi_rhs_trans, contracting_dims=((1, 2), (2, 3)), ) d_wi_combined = tex.grouped_gemm( - ctx.casted_sorted_x_lhs_trans, + casted_sorted_x_lhs_trans, casted_d_combined.get_tensor(usage=TensorUsage.RHS), contracting_dims=((0,), (0,)), ) d_wi_0 = d_wi_combined[..., 0, :] d_wi_1 = d_wi_combined[..., 1, :] - if has_wi_bias: - # grouped_dbias requires rank-2 input; reshape around the call. - # M is not TP-sharded on the bias path, so the reshape is free. + if has_bias: + # tex.grouped_dbias takes a rank-2 input; reshape around the call. d_combined_2d = d_combined.reshape(d_combined.shape[0], -1) - d_wi_combined_bias_2d = tex.grouped_dbias(d_combined_2d, ctx.local_group_sizes) + d_wi_combined_bias_2d = tex.grouped_dbias(d_combined_2d, local_group_sizes) d_wi_combined_bias = d_wi_combined_bias_2d.reshape( *d_wi_combined_bias_2d.shape[:-1], 2, inter_M ) @@ -1395,292 +317,26 @@ def _body_bwd( # pylint: disable=unused-argument d_wi_0_bias = None d_wi_1_bias = None - # ---------------- Dispatch bwd ---------------- - inputs_2d_shape = (per_shard_x_shape[0] * per_shard_x_shape[1], hidden) - d_inputs_2d = _dispatch_bwd( - d_sorted_x, - ctx.dispatch, - inputs_2d_shape=inputs_2d_shape, - backend=permutation_backend, - ep_active=ep_active, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - num_real_tokens=_static_shape.num_real_tokens, - padding_size=_static_shape.padding_size, - pre_a2a_buffer_shape=_static_shape.pre_a2a_buffer_shape, - ep_axis=ep_axis, - shard_id=shard_id, - num_ep=num_ep, - ) - d_x_from_dispatch = d_inputs_2d.reshape(per_shard_x_shape) - - # ---------------- Routing bwd ---------------- - # The probs cotangent comes from _combine_bwd. For PURE_JAX it's the - # cotangent of routing_weights (post-routing_map_to_selected_experts); - # we need to bridge back to sparse_probs. For TRITON it's already the - # cotangent of merging_probs == sparse_probs. - if d_routing_weights is not None: - if permutation_backend is PermutationBackend.PURE_JAX: - # routing_map_to_selected_experts: - # selected_experts = argsort(routing_map)[..., -topk:] - # weights = take_along_axis(sparse_probs, selected_experts, axis=-1) - # routing_map is bool (non-diff); the gradient of weights - # w.r.t. sparse_probs is a scatter-into-zero along the - # selected_experts indices. - selected_experts = jnp.argsort(ctx.routing_map, axis=-1)[..., -num_experts_per_tok:] - d_sparse_probs = jnp.zeros_like(ctx.saved_scores).astype(d_routing_weights.dtype) - d_sparse_probs = jnp.take_along_axis(d_sparse_probs, selected_experts, axis=-1) - # Actually scatter: build via jnp.zeros + .at[].set - d_sparse_probs = jnp.zeros(ctx.routing_map.shape, dtype=d_routing_weights.dtype) - d_sparse_probs = d_sparse_probs.at[ - jnp.arange(ctx.routing_map.shape[0])[:, None], selected_experts - ].set(d_routing_weights) - else: - d_sparse_probs = d_routing_weights.astype(jnp.float32) - else: - d_sparse_probs = jnp.zeros(ctx.routing_map.shape, dtype=jnp.float32) - - # Topk bwd primitive: returns d_logits (no d_expert_bias). - d_logits_2d_main = tex.fused_topk_with_score_function_bwd( - ctx.routing_map, - ctx.saved_scores, - d_sparse_probs.astype(ctx.saved_scores.dtype), - topk=num_experts_per_tok, - use_pre_softmax=use_pre_softmax, - scaling_factor=scaling_factor, - score_function=score_function, - compute_aux_scores=False, - ) - - # ---------------- Aux loss bwd ---------------- - if aux_loss_coeff > 0.0: - # Step 1: aux_loss bwd -> d_aux_probs - aux_num_tokens = ctx.aux_logits_for_score.shape[0] - d_aux_probs = tex.fused_moe_aux_loss_bwd( - ctx.aux_const_buf, - ctx.aux_tokens_per_expert.astype(jnp.int32), - d_aux_loss.reshape(()), - num_tokens=aux_num_tokens, - ) - # Step 2: aux-side topk bwd (compute_aux_scores=True path). - # The routing_map argument is ignored in this branch (the kernel - # uses saved_scores); pass any shape-correct integer tensor. - d_aux_logits = tex.fused_topk_with_score_function_bwd( - jnp.zeros(ctx.aux_logits_for_score.shape, dtype=jnp.bool_), - ctx.aux_saved_scores, - d_aux_probs.astype(ctx.aux_saved_scores.dtype), - topk=num_experts_per_tok, - use_pre_softmax=False, - scaling_factor=1.0, - score_function=score_function, - compute_aux_scores=True, - ) - # Inverse of the fwd tiled all_gather along - # ``(*data_parallelism_axes, ep_axis)``: pick out this shard's - # local rows from the global cotangent. JAX's tiled all_gather - # is row-major over the axis-name tuple, so the shard at mesh - # position (i_a, i_b, ...) writes to a contiguous row block - # starting at flat_index * local_T. - if ep_active: - local_T_aux = ctx.logits_2d.shape[0] - flat_shard = 0 - for ax, sz in zip(data_parallelism_axes, fsdp_sizes): - flat_shard = flat_shard * sz + jax.lax.axis_index(ax) - flat_shard = flat_shard * num_ep + shard_id - d_aux_logits_local = jax.lax.dynamic_slice( - d_aux_logits.astype(ctx.logits_2d.dtype), - start_indices=(flat_shard * local_T_aux, 0), - slice_sizes=(local_T_aux, num_experts), - ) - else: - d_aux_logits_local = d_aux_logits.astype(d_logits_2d_main.dtype) - d_logits_2d = d_logits_2d_main + d_aux_logits_local.astype(d_logits_2d_main.dtype) - else: - d_logits_2d = d_logits_2d_main - - # ---------------- Gate bwd ---------------- - d_gate_logits = d_logits_2d.reshape(per_shard_x_shape[0], per_shard_x_shape[1], num_experts) - gate_kernel_cast = ctx.gate_kernel.astype(ctx.x.dtype) - d_x_from_gate = jnp.einsum("bse,he->bsh", d_gate_logits, gate_kernel_cast) - d_gate_kernel = jnp.einsum("bsh,bse->he", ctx.x, d_gate_logits).astype(ctx.gate_kernel.dtype) - d_x = d_x_from_gate + d_x_from_dispatch - - # Reduce per-rank partial contributions to match the out_specs - # declared by _build_grads_specs: - # gate_kernel : P() -> psum across (ep, *fsdp) - # wi_0/wi_1/wo : P(ep_axis, ...) -> psum across (*fsdp) only - # inputs : P((ep, fsdp), ...) -> already shard-local, no reduction - if ep_active: - replicate_all = (ep_axis,) + tuple(data_parallelism_axes) - d_gate_kernel = jax.lax.psum(d_gate_kernel, axis_name=replicate_all) - if data_parallelism_axes: - replicate_fsdp = tuple(data_parallelism_axes) - d_wi_0 = jax.lax.psum(d_wi_0, axis_name=replicate_fsdp) - d_wi_1 = jax.lax.psum(d_wi_1, axis_name=replicate_fsdp) - d_wo = jax.lax.psum(d_wo, axis_name=replicate_fsdp) - if has_wi_bias: - d_wi_0_bias = jax.lax.psum(d_wi_0_bias, axis_name=replicate_fsdp) - d_wi_1_bias = jax.lax.psum(d_wi_1_bias, axis_name=replicate_fsdp) - if has_wo_bias: - d_wo_bias = jax.lax.psum(d_wo_bias, axis_name=replicate_fsdp) - - grads: dict = { - "inputs": d_x, - "gate_kernel": d_gate_kernel, - "wi_0": d_wi_0, - "wi_1": d_wi_1, - "wo": d_wo, - } - if has_wi_bias: - grads["wi_0_bias"] = d_wi_0_bias - grads["wi_1_bias"] = d_wi_1_bias - if has_wo_bias: - grads["wo_bias"] = d_wo_bias - if has_expert_bias: - # expert_bias has no gradient through topk (the topk bwd returns - # None for it). Emit a structural zero so the outer rule has - # something to package. - grads["expert_bias"] = jnp.zeros_like(ctx.expert_bias) - return grads - - -# ============================================================================= -# Spec builders for shard_map (lockstep with ctx_dict / captured_dict) -# ============================================================================= - - -def _build_in_specs( - ep_axis: str, - batch_pspec_axis: Any, - *, - has_bias: bool, - has_expert_bias: bool, -) -> dict: - """Build the ``in_specs`` dict for the EP fwd shard_map.""" - specs: dict = { - "inputs": P(batch_pspec_axis, None, None), - "gate_kernel": P(), - "wi_0": P(ep_axis, None, None), - "wi_1": P(ep_axis, None, None), - "wo": P(ep_axis, None, None), - } - if has_bias: - for name in ("wi_0_bias", "wi_1_bias", "wo_bias"): - specs[name] = P(ep_axis, None) - if has_expert_bias: - specs["expert_bias"] = P(ep_axis) - return specs - - -def _build_dispatch_specs( # pylint: disable=unused-argument - ep_axis: str, - *, - backend: PermutationBackend, - ep_active: bool, - align_size: int, -) -> _DispatchState: - """Build the shard_map ``out_specs`` for the dispatch state. - - Returns a :data:`_DispatchState` (either :class:`_PureJaxDispatchState` - or :class:`_TritonDispatchState`) whose fields are - :class:`PartitionSpec` placeholders. Optional fields are set to - ``P()`` when populated by :func:`_dispatch` and to ``None`` when - intentionally omitted, so the spec's pytree structure mirrors the - value's structure leaf-for-leaf. - """ - ep_all = P() if ep_active else None - ep_local = P() if ep_active else None - if backend is PermutationBackend.PURE_JAX: - return _PureJaxDispatchState( - group_sizes=P(), - sorted_indices=P(), - routing_weights=P(), - all_shards_tokens_per_expert=ep_all, - local_perm_row_id_map=ep_local, - ) - return _TritonDispatchState( - group_sizes=P(), - row_id_map=P(), - pad_offsets=P() if align_size > 0 else None, - merging_probs=P(), - all_shards_tokens_per_expert=ep_all, - local_perm_row_id_map=ep_local, - ) - - -def _build_ctx_specs( # pylint: disable=unused-argument - ep_axis: str, - batch_pspec_axis: Any, - *, - backend: PermutationBackend, - ep_active: bool, - has_bias: bool, - has_expert_bias: bool, - aux_loss_enabled: bool, - align_size: int, -) -> _BodyCtx: - """Build the spec :class:`_BodyCtx` mirroring :func:`_body_fwd`'s ctx. - - Fields gated off by the static config (``expert_bias``, ``aux_*``) - are ``None`` here so the spec pytree matches the value pytree - leaf-for-leaf. - """ - return _BodyCtx( - # Per-shard local activations along the batch axis. - x=P(batch_pspec_axis, None, None), - gate_kernel=P(), - logits_2d=P(batch_pspec_axis, None), - saved_scores=P(batch_pspec_axis, None), - routing_map=P(batch_pspec_axis, None), - dispatch=_build_dispatch_specs( - ep_axis, backend=backend, ep_active=ep_active, align_size=align_size - ), - # FFN residuals: the LHS_TRANS / RHS_TRANS variants of - # grouped_quantize have leading "rows"/"experts" dims that are - # already shard-local (post-dispatch). Use P(ep_axis,...) on - # leading dim; that works whether the leaf is a plain ndarray - # or a ScaledTensor (shard_map applies the spec leaf-wise to - # the registered ScaledTensor pytree). - casted_sorted_x_lhs_trans=P(), - casted_wi_rhs_trans=P(ep_axis, None, None), - gate_proj_out=P(), - up_proj_out=P(), - casted_intermediate_lhs_trans=P(), - casted_wo_rhs_trans=P(ep_axis, None, None), - expert_outputs=P(), - local_group_sizes=P(), - expert_bias=P(ep_axis) if has_expert_bias else None, - aux_const_buf=P() if aux_loss_enabled else None, - aux_tokens_per_expert=P() if aux_loss_enabled else None, - aux_logits_for_score=P() if aux_loss_enabled else None, - aux_saved_scores=P() if aux_loss_enabled else None, - ) - - -def _build_grads_specs( - ep_axis: str, - batch_pspec_axis: Any, - *, - has_bias: bool, - has_expert_bias: bool, -) -> dict: - """Spec dict for the grads dict returned by :func:`_body_bwd`.""" - return _build_in_specs( - ep_axis, - batch_pspec_axis, - has_bias=has_bias, - has_expert_bias=has_expert_bias, + d_sorted_x_3d = d_sorted_x.reshape(1, d_sorted_x.shape[0], d_sorted_x.shape[1]) + d_recv_w_3d = d_recv_w_from_intermediate.reshape(1, -1) + return ( + d_sorted_x_3d, + d_recv_w_3d, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias, + d_wi_1_bias, + d_wo_bias, ) # ============================================================================= -# Top-level VJP rules +# Full fwd / bwd rules (custom_vjp halves) # ============================================================================= -def _moe_fwd_rule( # pylint: disable=unused-argument - # Args MUST match the positional order of ``_moe`` (diff first, - # then nondiff). See ``_moe_bwd_rule`` for the opposite convention. +def _moe_fwd_rule( x, gate_kernel, wi_0, @@ -1699,109 +355,71 @@ def _moe_fwd_rule( # pylint: disable=unused-argument group_topk, scaling_factor, aux_loss_coeff, - permutation_backend, - align_size, - gate_inside_vjp, ep_axis, data_parallelism_axes, input_axes, gate_kernel_axes, wi_kernel_axes, wo_kernel_axes, - quantizer_sets, dtype, apply_topk_weights_early, + align_size, ): - x = with_sharding_constraint_by_logical_axes(x, input_axes) - ep_active = ep_axis is not None - body_kwargs = { - "num_experts": num_experts, - "num_experts_per_tok": num_experts_per_tok, - "activation_type": activation_type, - "score_function": score_function, - "use_pre_softmax": use_pre_softmax, - "num_groups": num_groups, - "group_topk": group_topk, - "scaling_factor": scaling_factor, - "aux_loss_coeff": aux_loss_coeff, - "permutation_backend": permutation_backend, - "align_size": align_size, - "gate_inside_vjp": gate_inside_vjp, - "quantizer_sets": quantizer_sets, - "dtype": dtype, - "ep_axis": ep_axis, - "data_parallelism_axes": data_parallelism_axes, - "apply_topk_weights_early": apply_topk_weights_early, - } - captured: dict = { - "inputs": x, - "gate_kernel": gate_kernel, - "wi_0": wi_0, - "wi_1": wi_1, - "wo": wo, - } - has_bias = wi_0_bias is not None - has_expert_bias = expert_bias is not None - if has_bias: - captured["wi_0_bias"] = wi_0_bias - captured["wi_1_bias"] = wi_1_bias - captured["wo_bias"] = wo_bias - if has_expert_bias: - captured["expert_bias"] = expert_bias - - if not ep_active: - output, aux_loss, ctx = _body_fwd( - captured, - **body_kwargs, - ep_active=False, - fsdp_sizes=(), - num_ep=1, - num_experts_local=num_experts, - recv_buffer_rows=0, - ) - # Carry static side info to the bwd rule alongside ctx. These - # are Python ints/bools/tuples (NOT pytree leaves), so we - # bundle them as a plain dict rather than putting them on the - # ``_BodyCtx`` NamedTuple where shard_map would try to flatten - # them into JitTracers. - static = { - "has_wi_bias": has_bias, - "has_wo_bias": has_bias, - "has_expert_bias": has_expert_bias, - "x_shape": x.shape, - "num_experts_local": num_experts, - "recv_buffer_rows": 0, - } - return (output, aux_loss), (ctx, static) - - # ---------------- EP path ---------------- + """Forward: gate -> topk -> ep_dispatch -> shard_map(FFN) -> ep_combine. + + Returns ``(output, aux_loss)``. ``aux_loss`` is a zero scalar when + ``aux_loss_coeff == 0``. + """ + del gate_kernel_axes, wi_kernel_axes, wo_kernel_axes # used in bwd only from jax.experimental.shard_map import shard_map + x = with_sharding_constraint_by_logical_axes(x, input_axes) + mesh = _get_mesh() if mesh is None or mesh.empty: - raise ValueError("moe(...) requires an active jax.sharding.Mesh when ep_axis is set.") + raise ValueError("moe(...) requires an active jax.sharding.Mesh.") + if ep_axis is None: + raise ValueError("moe(...) requires ep_axis to be set (TE EP backend).") num_ep = mesh.shape[ep_axis] if num_experts % num_ep != 0: raise ValueError(f"num_experts={num_experts} must be divisible by EP size={num_ep}") - num_experts_local = num_experts // num_ep + num_local_experts = num_experts // num_ep - # Reject overlapping EP / FSDP axes. Listing ep_axis in - # data_parallelism_axes would produce a duplicate-axis PartitionSpec - # ((ep, ep, ...)) which JAX rejects, and would also double-count - # num_ep in dp_size (under-sizing recv_buffer_rows by a factor of - # num_ep). Catch it up front with a clear error. + dp_size = 1 for ax in data_parallelism_axes: - if ax not in mesh.shape: - raise ValueError( - f"data_parallelism_axes contains {ax!r} but mesh has" - f" axes {tuple(mesh.shape.keys())}" - ) - if ax == ep_axis: - raise ValueError( - f"data_parallelism_axes={data_parallelism_axes!r} contains the EP" - f" axis {ep_axis!r}; EP is implicit in the batch sharding and must" - " not also be listed as a data-parallel axis." - ) + dp_size *= mesh.shape[ax] + num_procs = num_ep * dp_size + + B, S, H = x.shape + K = num_experts_per_tok + if B % num_procs != 0: + raise ValueError(f"batch={B} not divisible by ep*dp={num_procs}") + + # Per-rank receive capacity (dropless): every rank may receive all of one + # replica's K-expanded tokens. ``slots_per_expert`` is rounded up to a + # multiple of ``align_size`` (FP8 recipes typically need 128 here); the + # rounded value is what we feed to ``tex.ep_prepare`` as the + # ``dispatch_output_per_expert_alignment`` so each local expert's slot + # block starts on the alignment boundary that grouped_gemm expects. + natural_recv_pr = (B // dp_size) * S * K + natural_spe = (natural_recv_pr + num_local_experts - 1) // num_local_experts + if align_size > 0: + slots_per_expert = ((natural_spe + align_size - 1) // align_size) * align_size + else: + slots_per_expert = natural_spe + recv_pr = num_local_experts * slots_per_expert + # Per-rank input token count: B/num_procs rows x S tokens. The bootstrap + # uses this to size the dispatch send buffer; recv_pr above sizes the + # per-rank receive buffer. + max_tokens_per_rank = (B // num_procs) * S + + _te_ep_bootstrap_if_needed( + num_experts=num_experts, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_pr, + hidden_dim=H, + ep_size=num_ep, + ) if not data_parallelism_axes: batch_pspec_axis: Any = ep_axis @@ -1810,64 +428,211 @@ def _moe_fwd_rule( # pylint: disable=unused-argument # consecutive global ranks (dp_color = rank // ep_size), so the # comm only stays within one model replica under (outer_dp, ep). batch_pspec_axis = (*data_parallelism_axes, ep_axis) - dp_size = 1 - for ax in data_parallelism_axes: - dp_size *= mesh.shape[ax] + ep3_spec = P(batch_pspec_axis, None, None) + ep2_spec = P(batch_pspec_axis, None) + x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, ep3_spec)) - global_batch_size, sequence_length, _hidden = x.shape - topk = num_experts_per_tok - if global_batch_size % (num_ep * dp_size) != 0: - raise ValueError(f"batch={global_batch_size} not divisible by ep*dp={num_ep * dp_size}") - recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk - if align_size > 0: - recv_buffer_rows += num_experts * (align_size - 1) + # ---------------- Gate (global view) ---------------- + gate_kernel_cast = gate_kernel.astype(x.dtype) + gate_logits = jnp.einsum("bsh,he->bse", x, gate_kernel_cast) + logits_2d = gate_logits.reshape(-1, num_experts) - in_specs = _build_in_specs( - ep_axis, - batch_pspec_axis, - has_bias=has_bias, - has_expert_bias=has_expert_bias, + # ---------------- Routing (global view) ---------------- + # expert_bias is an empty (shape-(0,)) sentinel when the caller did + # not enable it; the primitive treats that as "no bias". + eb_arg = expert_bias if expert_bias.shape != (0,) else jnp.zeros((0,), dtype=jnp.float32) + sparse_probs, routing_map, saved_scores = tex.fused_topk_with_score_function_fwd( + logits_2d, + topk=K, + use_pre_softmax=use_pre_softmax, + num_groups=-1 if num_groups is None else num_groups, + group_topk=-1 if group_topk is None else group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=eb_arg, + compute_aux_scores=False, ) - output_spec = P(batch_pspec_axis, None, None) - aux_spec = P() - ctx_spec = _build_ctx_specs( - ep_axis, - batch_pspec_axis, - backend=permutation_backend, - ep_active=True, - has_bias=has_bias, - has_expert_bias=has_expert_bias, - aux_loss_enabled=(aux_loss_coeff > 0.0), - align_size=align_size, + sparse_probs = sparse_probs.astype(dtype) + + # ---------------- Aux loss (global view, replicated) ---------------- + # ``fused_moe_aux_loss_fwd`` sums probs and tokens_per_expert across + # all tokens, which is wrong when T is sharded. Force-replicate the + # gate logits and recompute the routing map at global view so the + # kernel sees a complete [T_global, E] tensor. The replication is a + # single all-gather over (*dp, ep) and lives off the dispatch + # critical path. + if aux_loss_coeff > 0.0: + global_logits_2d = jax.lax.with_sharding_constraint( + logits_2d, NamedSharding(mesh, P()) + ) + _, global_routing_map, _ = tex.fused_topk_with_score_function_fwd( + global_logits_2d, + topk=K, + use_pre_softmax=use_pre_softmax, + num_groups=-1 if num_groups is None else num_groups, + group_topk=-1 if group_topk is None else group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=eb_arg, + compute_aux_scores=False, + ) + aux_tokens_per_expert = jnp.sum(global_routing_map.astype(jnp.int32), axis=0) + # compute_aux_scores=True takes a separate kernel path: clean + # per-expert softmax, no grouping / bias / scaling. + aux_probs, _aux_rm, aux_saved_scores = tex.fused_topk_with_score_function_fwd( + global_logits_2d.astype(jnp.float32), + topk=K, + use_pre_softmax=False, + num_groups=-1, + group_topk=-1, + scaling_factor=1.0, + score_function=score_function, + expert_bias=jnp.zeros((0,), dtype=jnp.float32), + compute_aux_scores=True, + ) + aux_loss, aux_const_buf = tex.fused_moe_aux_loss_fwd( + aux_probs.astype(jnp.float32), + aux_tokens_per_expert.astype(jnp.int32), + topk=K, + coeff=aux_loss_coeff, + ) + aux_loss = aux_loss.astype(dtype) + else: + aux_loss = jnp.zeros((), dtype=dtype) + aux_const_buf = None + aux_tokens_per_expert = None + aux_saved_scores = None + + # ---------------- Routing -> (topk_idx, topk_w) at 3D ---------------- + # argsort on a bool tensor places True last (False=0 < True=1), so the + # last K indices are the selected expert IDs. + selected_experts = jnp.argsort(routing_map, axis=-1)[..., -K:] + routing_weights = jnp.take_along_axis(sparse_probs, selected_experts, axis=-1) + topk_idx_3d = selected_experts.reshape(B, S, K).astype(jnp.int32) + topk_w_3d = routing_weights.reshape(B, S, K).astype(jnp.float32) + + # ---------------- TE EP dispatch (global view) ---------------- + token_counts, handle = tex.ep_prepare(topk_idx_3d, slots_per_expert) + recv_tokens, recv_topk_weights, handle = tex.ep_dispatch_fwd( + handle, topk_idx_3d, x, topk_w_3d, recv_pr + ) + recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3_spec)) + recv_topk_weights = jax.lax.with_sharding_constraint( + recv_topk_weights, NamedSharding(mesh, ep2_spec) + ) + + # ---------------- FFN (per-shard via shard_map) ---------------- + has_bias = wi_0_bias is not None + kernel_spec = P(ep_axis, None, None) + bias_spec = P(ep_axis, None) if has_bias else None + ffn_in_specs = (ep3_spec, ep2_spec, kernel_spec, kernel_spec, kernel_spec) + ffn_in_args = [recv_tokens, recv_topk_weights, wi_0, wi_1, wo] + if has_bias: + ffn_in_specs = ffn_in_specs + (bias_spec, bias_spec, bias_spec) + ffn_in_args.extend([wi_0_bias, wi_1_bias, wo_bias]) + + # FFN residuals live entirely on the local ep rank, so the leading + # "experts" / "rows" dims map to P() (already shard-local). + residuals_spec = ( + P(), # casted_sorted_x_lhs_trans + P(ep_axis, None, None), # casted_wi_rhs_trans + P(), # gate_proj_out + P(), # up_proj_out + P(), # casted_intermediate_lhs_trans + P(ep_axis, None, None), # casted_wo_rhs_trans + P(), # local_group_sizes ) + out_specs = (ep3_spec, residuals_spec) - _fsdp_sizes: Tuple[int, ...] = tuple(mesh.shape[ax] for ax in data_parallelism_axes) - - def _shardmap_body(captured_local): - return _body_fwd( - captured_local, - **body_kwargs, - ep_active=True, - fsdp_sizes=_fsdp_sizes, - num_ep=num_ep, - num_experts_local=num_experts_local, - recv_buffer_rows=recv_buffer_rows, + def _body(*args): + if has_bias: + (r_tok, r_w, w0, w1, w_o, w0b, w1b, wob) = args + else: + (r_tok, r_w, w0, w1, w_o) = args + w0b = w1b = wob = None + return _ffn_fwd_per_shard( + r_tok, + r_w, + w0, + w1, + w_o, + w0b, + w1b, + wob, + num_local_experts=num_local_experts, + slots_per_expert=slots_per_expert, + activation_type=activation_type, + apply_topk_weights_early=apply_topk_weights_early, ) - output, aux_loss, ctx = shard_map( - _shardmap_body, + expert_outputs, ffn_residuals = shard_map( + _body, mesh=mesh, - in_specs=(in_specs,), - out_specs=(output_spec, aux_spec, ctx_spec), + in_specs=ffn_in_specs, + out_specs=out_specs, check_rep=False, - )(captured) + )(*ffn_in_args) + expert_outputs = jax.lax.with_sharding_constraint( + expert_outputs, NamedSharding(mesh, ep3_spec) + ) + + # ---------------- TE EP combine (global view) ---------------- + out_partition_spec = (batch_pspec_axis, None, None) + if apply_topk_weights_early: + # expert_outputs is already weighted upstream. + output = tex.ep_combine_fwd( + handle, + expert_outputs, + num_local_tokens=(B, S), + out_partition_spec=out_partition_spec, + ) + else: + w = recv_topk_weights[..., None] + mask = (recv_topk_weights != 0).astype(expert_outputs.dtype)[..., None] + weighted = expert_outputs * w * mask + output = tex.ep_combine_fwd( + handle, + weighted, + num_local_tokens=(B, S), + out_partition_spec=out_partition_spec, + ) + + ( + casted_sorted_x_lhs_trans, + casted_wi_rhs_trans, + gate_proj_out, + up_proj_out, + casted_intermediate_lhs_trans, + casted_wo_rhs_trans, + local_group_sizes, + ) = ffn_residuals + + ctx = _Ctx( + x=x, + gate_kernel=gate_kernel, + expert_bias=expert_bias, + logits_2d=logits_2d, + saved_scores=saved_scores, + routing_map=routing_map, + handle=handle, + token_counts=token_counts, + recv_topk_weights=recv_topk_weights, + casted_sorted_x_lhs_trans=casted_sorted_x_lhs_trans, + casted_wi_rhs_trans=casted_wi_rhs_trans, + gate_proj_out=gate_proj_out, + up_proj_out=up_proj_out, + casted_intermediate_lhs_trans=casted_intermediate_lhs_trans, + casted_wo_rhs_trans=casted_wo_rhs_trans, + expert_outputs=expert_outputs, + local_group_sizes=local_group_sizes, + aux_const_buf=aux_const_buf, + aux_tokens_per_expert=aux_tokens_per_expert, + aux_saved_scores=aux_saved_scores, + ) static = { - "has_wi_bias": has_bias, - "has_wo_bias": has_bias, - "has_expert_bias": has_expert_bias, + "has_bias": has_bias, "x_shape": x.shape, - "num_experts_local": num_experts_local, - "recv_buffer_rows": recv_buffer_rows, + "recv_pr": recv_pr, } return (output, aux_loss), (ctx, static) @@ -1882,133 +647,260 @@ def _moe_bwd_rule( group_topk, scaling_factor, aux_loss_coeff, - permutation_backend, - align_size, - gate_inside_vjp, ep_axis, data_parallelism_axes, input_axes, gate_kernel_axes, wi_kernel_axes, wo_kernel_axes, - quantizer_sets, dtype, apply_topk_weights_early, - ctx, - dy_pair, + align_size, + residuals, + cotangents, ): - ctx, static = ctx # split tensor residuals from static side info - has_wi_bias = static["has_wi_bias"] - has_wo_bias = static["has_wo_bias"] - has_expert_bias = static["has_expert_bias"] - x_shape = static["x_shape"] - num_experts_local = static["num_experts_local"] - recv_buffer_rows = static["recv_buffer_rows"] + """Backward mirror of :func:`_moe_fwd_rule`.""" + del num_groups, group_topk, dtype, align_size # captured in residuals / unused in bwd + from jax.experimental.shard_map import shard_map - ep_active = ep_axis is not None - mesh = _get_mesh() if ep_active else None - fsdp_sizes: Tuple[int, ...] = ( - tuple(mesh.shape[ax] for ax in data_parallelism_axes) if ep_active else () - ) - body_kwargs = { - "num_experts": num_experts, - "num_experts_per_tok": num_experts_per_tok, - "activation_type": activation_type, - "score_function": score_function, - "use_pre_softmax": use_pre_softmax, - "num_groups": num_groups, - "group_topk": group_topk, - "scaling_factor": scaling_factor, - "aux_loss_coeff": aux_loss_coeff, - "permutation_backend": permutation_backend, - "align_size": align_size, - "gate_inside_vjp": gate_inside_vjp, - "quantizer_sets": quantizer_sets, - "dtype": dtype, - "ep_axis": ep_axis, - "data_parallelism_axes": data_parallelism_axes, - "fsdp_sizes": fsdp_sizes, - "num_ep": 1 if not ep_active else mesh.shape[ep_axis], - "num_experts_local": num_experts_local, - "recv_buffer_rows": recv_buffer_rows, - "has_wi_bias": has_wi_bias, - "has_wo_bias": has_wo_bias, - "has_expert_bias": has_expert_bias, - "x_shape": x_shape, - "apply_topk_weights_early": apply_topk_weights_early, - } + d_output, d_aux_loss = cotangents - if not ep_active: - grads = _body_bwd(ctx, dy_pair, ep_active=False, **body_kwargs) - # Apply sharding constraints on grads. - grads["gate_kernel"] = with_sharding_constraint_by_logical_axes( - grads["gate_kernel"], gate_kernel_axes - ) - grads["wi_0"] = with_sharding_constraint_by_logical_axes(grads["wi_0"], wi_kernel_axes) - grads["wi_1"] = with_sharding_constraint_by_logical_axes(grads["wi_1"], wi_kernel_axes) - grads["wo"] = with_sharding_constraint_by_logical_axes(grads["wo"], wo_kernel_axes) - grads["inputs"] = with_sharding_constraint_by_logical_axes(grads["inputs"], input_axes) - return _grads_dict_to_tuple(grads, has_wi_bias, has_wo_bias, has_expert_bias) + ctx, static = residuals + has_bias = static["has_bias"] + x_shape = static["x_shape"] + recv_pr = static["recv_pr"] - from jax.experimental.shard_map import shard_map + mesh = _get_mesh() + if mesh is None or mesh.empty: + raise ValueError("moe(...) requires an active jax.sharding.Mesh.") + num_ep = mesh.shape[ep_axis] + dp_size = 1 + for ax in data_parallelism_axes: + dp_size *= mesh.shape[ax] + B, S, _ = x_shape + K = num_experts_per_tok if not data_parallelism_axes: batch_pspec_axis: Any = ep_axis else: - # ep must be innermost: ep_bootstrap forms NCCL EP comms from - # consecutive global ranks (dp_color = rank // ep_size), so the - # comm only stays within one model replica under (outer_dp, ep). batch_pspec_axis = (*data_parallelism_axes, ep_axis) - ctx_spec = _build_ctx_specs( - ep_axis, - batch_pspec_axis, - backend=permutation_backend, - ep_active=True, - has_bias=has_wi_bias, - has_expert_bias=has_expert_bias, - aux_loss_enabled=(aux_loss_coeff > 0.0), - align_size=align_size, + ep3_spec = P(batch_pspec_axis, None, None) + ep2_spec = P(batch_pspec_axis, None) + out_partition_spec = (batch_pspec_axis, None, None) + + # ---------------- Combine bwd (global view) ---------------- + d_output = jax.lax.with_sharding_constraint(d_output, NamedSharding(mesh, ep3_spec)) + grad_pre_combine = tex.ep_combine_bwd(ctx.handle, d_output, recv_pr) + grad_pre_combine = jax.lax.with_sharding_constraint( + grad_pre_combine, NamedSharding(mesh, ep3_spec) ) - dy_specs = (P(batch_pspec_axis, None, None), P()) - grads_spec = _build_grads_specs( - ep_axis, batch_pspec_axis, has_bias=has_wi_bias, has_expert_bias=has_expert_bias + + if apply_topk_weights_early: + # combine_fwd consumed already-weighted expert_outputs; the recv_w + # cotangent flows through the early-weighting step inside the FFN bwd. + d_expert_outputs = grad_pre_combine + d_recv_w_from_combine = jnp.zeros_like(ctx.recv_topk_weights) + else: + # combine_fwd consumed weighted = expert_out * w * mask; + # split the cotangent across both factors. + w = ctx.recv_topk_weights[..., None] + mask = (ctx.recv_topk_weights != 0).astype(grad_pre_combine.dtype)[..., None] + d_expert_outputs = grad_pre_combine * w * mask + d_recv_w_from_combine = (grad_pre_combine * ctx.expert_outputs * mask).sum(axis=-1) + d_recv_w_from_combine = d_recv_w_from_combine.astype(ctx.recv_topk_weights.dtype) + + # ---------------- FFN bwd (per-shard via shard_map) ---------------- + kernel_spec = P(ep_axis, None, None) + bias_spec = P(ep_axis, None) if has_bias else None + + bwd_in_specs = ( + ep3_spec, # d_expert_outputs + P(), # casted_sorted_x_lhs_trans + P(ep_axis, None, None), # casted_wi_rhs_trans + P(), # gate_proj_out + P(), # up_proj_out + P(), # casted_intermediate_lhs_trans + P(ep_axis, None, None), # casted_wo_rhs_trans + P(), # local_group_sizes + ep2_spec, # recv_topk_weights + ) + bwd_in_args = [ + d_expert_outputs, + ctx.casted_sorted_x_lhs_trans, + ctx.casted_wi_rhs_trans, + ctx.gate_proj_out, + ctx.up_proj_out, + ctx.casted_intermediate_lhs_trans, + ctx.casted_wo_rhs_trans, + ctx.local_group_sizes, + ctx.recv_topk_weights, + ] + bwd_out_specs = ( + ep3_spec, # d_sorted_x + ep2_spec, # d_recv_w_from_intermediate + kernel_spec, # d_wi_0 + kernel_spec, # d_wi_1 + kernel_spec, # d_wo + bias_spec if has_bias else None, # d_wi_0_bias + bias_spec if has_bias else None, # d_wi_1_bias + bias_spec if has_bias else None, # d_wo_bias ) - def _bwd_body(ctx_local, dy_local): - return _body_bwd(ctx_local, dy_local, ep_active=True, **body_kwargs) + def _bwd_body(*args): + ( + d_sorted_x_3d, + d_recv_w_3d, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias, + d_wi_1_bias, + d_wo_bias, + ) = _ffn_bwd_per_shard( + *args, + activation_type=activation_type, + apply_topk_weights_early=apply_topk_weights_early, + has_bias=has_bias, + ) + # Weight grads accumulate per-DP-shard inside the body; psum across + # DP axes so each replica sees the full sum (matches out_specs + # P(ep_axis, ...) which is DP-replicated). + if data_parallelism_axes: + dp = tuple(data_parallelism_axes) + d_wi_0 = jax.lax.psum(d_wi_0, axis_name=dp) + d_wi_1 = jax.lax.psum(d_wi_1, axis_name=dp) + d_wo = jax.lax.psum(d_wo, axis_name=dp) + if has_bias: + d_wi_0_bias = jax.lax.psum(d_wi_0_bias, axis_name=dp) + d_wi_1_bias = jax.lax.psum(d_wi_1_bias, axis_name=dp) + d_wo_bias = jax.lax.psum(d_wo_bias, axis_name=dp) + return ( + d_sorted_x_3d, + d_recv_w_3d, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias, + d_wi_1_bias, + d_wo_bias, + ) - grads = shard_map( + ( + d_sorted_x, + d_recv_w_from_intermediate, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias, + d_wi_1_bias, + d_wo_bias, + ) = shard_map( _bwd_body, mesh=mesh, - in_specs=(ctx_spec, dy_specs), - out_specs=grads_spec, + in_specs=bwd_in_specs, + out_specs=bwd_out_specs, check_rep=False, - )(ctx, dy_pair) + )(*bwd_in_args) + + d_recv_w_total = d_recv_w_from_combine + d_recv_w_from_intermediate - grads["gate_kernel"] = with_sharding_constraint_by_logical_axes( - grads["gate_kernel"], gate_kernel_axes + # ---------------- Dispatch bwd (global view) ---------------- + d_sorted_x = jax.lax.with_sharding_constraint(d_sorted_x, NamedSharding(mesh, ep3_spec)) + d_recv_w_total = jax.lax.with_sharding_constraint( + d_recv_w_total, NamedSharding(mesh, ep2_spec) + ) + d_x_from_dispatch, d_topk_w = tex.ep_dispatch_bwd( + ctx.handle, + d_sorted_x, + d_recv_w_total, + top_k=K, + num_local_tokens=(B, S), + out_partition_spec=out_partition_spec, ) - grads["wi_0"] = with_sharding_constraint_by_logical_axes(grads["wi_0"], wi_kernel_axes) - grads["wi_1"] = with_sharding_constraint_by_logical_axes(grads["wi_1"], wi_kernel_axes) - grads["wo"] = with_sharding_constraint_by_logical_axes(grads["wo"], wo_kernel_axes) - grads["inputs"] = with_sharding_constraint_by_logical_axes(grads["inputs"], input_axes) - return _grads_dict_to_tuple(grads, has_wi_bias, has_wo_bias, has_expert_bias) + # ---------------- Routing bwd (global view) ---------------- + # The cotangent on routing_weights is a sparse scatter into sparse_probs + # at the selected_experts indices. + selected_experts = jnp.argsort(ctx.routing_map, axis=-1)[..., -K:] + d_topk_w_flat = d_topk_w.reshape(-1, K) + d_sparse_probs = jnp.zeros(ctx.routing_map.shape, dtype=d_topk_w_flat.dtype) + d_sparse_probs = d_sparse_probs.at[ + jnp.arange(ctx.routing_map.shape[0])[:, None], selected_experts + ].set(d_topk_w_flat) + + d_logits_2d = tex.fused_topk_with_score_function_bwd( + ctx.routing_map, + ctx.saved_scores, + d_sparse_probs.astype(ctx.saved_scores.dtype), + topk=K, + use_pre_softmax=use_pre_softmax, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=False, + ) + + # ---------------- Aux loss bwd (global view, replicated) ---------------- + # Reverse the fwd's all-gather/aux pipeline: aux_loss_bwd produces + # d_aux_probs, then topk_bwd(compute_aux_scores=True) produces the + # extra d_logits contribution. The replicated tensor adds into the + # T-sharded routing-side d_logits via JAX's normal broadcast. + if aux_loss_coeff > 0.0: + T_global = ctx.logits_2d.shape[0] + d_aux_loss_scalar = d_aux_loss.reshape(()).astype(jnp.float32) + d_aux_probs = tex.fused_moe_aux_loss_bwd( + ctx.aux_const_buf, + ctx.aux_tokens_per_expert.astype(jnp.int32), + d_aux_loss_scalar, + num_tokens=int(T_global), + ) + # routing_map is ignored by the kernel when compute_aux_scores=True, + # so pass a zero placeholder of the right shape/dtype. + zero_routing_map = jnp.zeros( + ctx.aux_saved_scores.shape, dtype=ctx.routing_map.dtype + ) + d_logits_aux = tex.fused_topk_with_score_function_bwd( + zero_routing_map, + ctx.aux_saved_scores, + d_aux_probs.astype(ctx.aux_saved_scores.dtype), + topk=K, + use_pre_softmax=False, + scaling_factor=1.0, + score_function=score_function, + compute_aux_scores=True, + ) + d_logits_2d = d_logits_2d + d_logits_aux.astype(d_logits_2d.dtype) + + # ---------------- Gate bwd (global view) ---------------- + d_gate_logits = d_logits_2d.reshape(B, S, num_experts) + gate_kernel_cast = ctx.gate_kernel.astype(ctx.x.dtype) + d_x_from_gate = jnp.einsum("bse,he->bsh", d_gate_logits, gate_kernel_cast) + d_gate_kernel = jnp.einsum("bsh,bse->he", ctx.x, d_gate_logits).astype(ctx.gate_kernel.dtype) + d_x = d_x_from_gate + d_x_from_dispatch + + # Pin output grads to the declared logical axes so downstream + # optimizers see consistent shardings. + d_x = with_sharding_constraint_by_logical_axes(d_x, input_axes) + d_gate_kernel = with_sharding_constraint_by_logical_axes(d_gate_kernel, gate_kernel_axes) + d_wi_0 = with_sharding_constraint_by_logical_axes(d_wi_0, wi_kernel_axes) + d_wi_1 = with_sharding_constraint_by_logical_axes(d_wi_1, wi_kernel_axes) + d_wo = with_sharding_constraint_by_logical_axes(d_wo, wo_kernel_axes) + + # expert_bias has no learnable bwd path through fused_topk: the + # primitive's bwd returns None for the bias slot. Match that with a + # zero cotangent of the right shape so custom_vjp's arity check + # passes. + d_expert_bias = jnp.zeros_like(ctx.expert_bias) -def _grads_dict_to_tuple( - grads: dict, has_wi_bias: bool, has_wo_bias: bool, has_expert_bias: bool -) -> Tuple: - """Pack the body_bwd's grads dict into the positional tuple JAX expects.""" return ( - grads["inputs"], - grads["gate_kernel"], - grads["wi_0"], - grads["wi_1"], - grads["wo"], - grads.get("wi_0_bias") if has_wi_bias else None, - grads.get("wi_1_bias") if has_wi_bias else None, - grads.get("wo_bias") if has_wo_bias else None, - grads.get("expert_bias") if has_expert_bias else None, + d_x, + d_gate_kernel, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias if has_bias else None, + d_wi_1_bias if has_bias else None, + d_wo_bias if has_bias else None, + d_expert_bias, ) @@ -2017,7 +909,7 @@ def _grads_dict_to_tuple( # ============================================================================= -@partial(jax.custom_vjp, nondiff_argnums=tuple(range(9, 30))) +@partial(jax.custom_vjp, nondiff_argnums=tuple(range(9, 27))) def _moe( x, gate_kernel, @@ -2037,24 +929,17 @@ def _moe( group_topk, scaling_factor, aux_loss_coeff, - permutation_backend, - align_size, - gate_inside_vjp, ep_axis, data_parallelism_axes, input_axes, gate_kernel_axes, wi_kernel_axes, wo_kernel_axes, - quantizer_sets, dtype, apply_topk_weights_early, + align_size, ): - # Call in `_moe`'s own signature order to match what JAX will pass - # the fwd rule via ``_argnums_partial``. See the comment block at - # the top of ``_moe_fwd_rule`` for why this differs from - # ``_moe_bwd_rule``'s convention. - output_pair, _ = _moe_fwd_rule( + primal, _ = _moe_fwd_rule( x, gate_kernel, wi_0, @@ -2073,20 +958,17 @@ def _moe( group_topk, scaling_factor, aux_loss_coeff, - permutation_backend, - align_size, - gate_inside_vjp, ep_axis, data_parallelism_axes, input_axes, gate_kernel_axes, wi_kernel_axes, wo_kernel_axes, - quantizer_sets, dtype, apply_topk_weights_early, + align_size, ) - return output_pair + return primal _moe.defvjp(_moe_fwd_rule, _moe_bwd_rule) @@ -2103,84 +985,87 @@ def moe( wo_bias: Optional[jnp.ndarray] = None, expert_bias: Optional[jnp.ndarray] = None, *, - # Architecture num_experts: int, num_experts_per_tok: int, activation_type: str = "silu", - # Routing score_function: Union[str, ScoreFunction] = "softmax", use_pre_softmax: bool = False, num_groups: Optional[int] = None, group_topk: Optional[int] = None, scaling_factor: float = 1.0, aux_loss_coeff: float = 0.0, - # Permutation - permutation_backend: PermutationBackend = PermutationBackend.PURE_JAX, - align_size: int = 0, - # Gate placement - gate_inside_vjp: bool = True, - # When True, fold per-token top-k weights into the FFN intermediate - # (next to act(gate)*up) instead of into the post-down-projection - # combine. Both placements are mathematically equivalent (down-proj is - # linear); the early placement gives XLA a chance to fuse the multiply - # with the activation. Off by default. apply_topk_weights_early: bool = False, - # Parallelism (resolved by caller from MeshResource) - ep_axis: Optional[str] = None, + align_size: int = 0, + ep_axis: str, data_parallelism_axes: Tuple[str, ...] = (), - # Logical axes for sharding constraints input_axes: Tuple[Optional[str], ...] = (), gate_kernel_axes: Tuple[Optional[str], ...] = (), wi_kernel_axes: Tuple[Optional[str], ...] = ("exp", "embed", "mlp"), wo_kernel_axes: Tuple[Optional[str], ...] = ("exp", "mlp", "embed"), - # Quantization - quantizer_sets: Tuple[QuantizerSet, QuantizerSet, QuantizerSet] = ( - noop_quantizer_set, - noop_quantizer_set, - noop_quantizer_set, - ), dtype: jnp.dtype = jnp.float32, ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Run a full MoE block under a single fused custom_vjp. + """Run a full MoE block under a single fused custom_vjp on the TE EP path. + + Returns ``(output, aux_loss)``. ``aux_loss`` is ``None`` when + ``aux_loss_coeff == 0`` and a 0-d scalar otherwise. - Parameters and return are documented at the call site of - ``_MoEBlock.__call__``. See module docstring for design rationale. + Parameters + ---------- + expert_bias : Optional[jnp.ndarray] + ``[num_experts]`` learnable router bias added before the top-k + when ``score_function='sigmoid'``. Pass ``None`` to disable. + The bias has no gradient through the top-k primitive itself (it + only steers expert selection); a zero cotangent is returned for + it. + aux_loss_coeff : float + Per-step expert-load-balance loss coefficient. ``0.0`` (default) + disables the aux loss entirely. When non-zero, an extra + all-gather over the routing-side logits is inserted so the + ``fused_moe_aux_loss`` kernel sees a global ``[T_global, E]`` + view; this lives off the dispatch critical path. + align_size : int + Minimum per-expert slot alignment passed to ``tex.ep_prepare`` + as ``dispatch_output_per_expert_alignment``. ``0`` (default) + means use the natural slot count + ``ceil((B/dp)*S*K / num_local_experts)``. Any positive value + rounds that count up to the nearest multiple, growing the + per-rank receive buffer accordingly. Set to ``128`` for FP8 + recipes that require 128-aligned grouped-GEMM tiles. + + See module docstring for the rest of the parameter semantics and the + surrounding design rationale. """ - if not isinstance(permutation_backend, PermutationBackend): - raise TypeError( - f"permutation_backend must be a PermutationBackend, got {permutation_backend!r}" - ) - if permutation_backend is PermutationBackend.TRITON: - _require_triton() - # Normalize string score_function ("softmax" / "sigmoid") to the - # ScoreFunction enum once here. The underlying primitive - # ``tex.fused_topk_with_score_function_fwd`` expects an int-coercible - # value (the enum has integer .value), and the public router wrapper - # we bypass also normalizes here. score_function = _validate_score_function(score_function) # Enforce ((outer_dp..., ep), None, None) on inbound activations. The # EP comm groups consecutive global ranks (dp_color = rank // ep_size), # so ep MUST be innermost in the partition spec. Soft re-pin: free if # upstream already matches, single reshard otherwise. - if ep_axis is not None: - mesh = _get_mesh() - if mesh is None or mesh.empty: - raise ValueError("moe(...) requires an active jax.sharding.Mesh when ep_axis is set.") - expected_leading: Any = ( - (*data_parallelism_axes, ep_axis) if data_parallelism_axes else ep_axis + mesh = _get_mesh() + if mesh is None or mesh.empty: + raise ValueError("moe(...) requires an active jax.sharding.Mesh.") + expected_leading: Any = ( + (*data_parallelism_axes, ep_axis) if data_parallelism_axes else ep_axis + ) + expected_spec = P(expected_leading, None, None) + actual_spec = getattr(getattr(x, "sharding", None), "spec", None) + if actual_spec is not None and tuple(actual_spec) != tuple(expected_spec): + warnings.warn( + f"moe(...): inbound x sharding {actual_spec} does not match expected " + f"{expected_spec}; inserting a reshard. Apply " + "jax.lax.with_sharding_constraint upstream to avoid this overhead.", + UserWarning, + stacklevel=2, ) - expected_spec = P(expected_leading, None, None) - actual_spec = getattr(getattr(x, "sharding", None), "spec", None) - if actual_spec is not None and tuple(actual_spec) != tuple(expected_spec): - warnings.warn( - f"moe(...): inbound x sharding {actual_spec} does not match expected " - f"{expected_spec}; inserting a reshard. Apply " - "jax.lax.with_sharding_constraint upstream to avoid this overhead.", - UserWarning, - stacklevel=2, - ) - x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, expected_spec)) + x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, expected_spec)) + + # custom_vjp can't trace through None args; lower expert_bias to an + # empty shape-(0,) tensor that fused_topk_with_score_function treats + # as "no bias". + if expert_bias is None: + expert_bias_arg = jnp.zeros((0,), dtype=jnp.float32) + else: + expert_bias_arg = expert_bias output, aux_loss = _moe( x, @@ -2191,28 +1076,25 @@ def moe( wi_0_bias, wi_1_bias, wo_bias, - expert_bias, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - activation_type=activation_type, - score_function=score_function, - use_pre_softmax=use_pre_softmax, - num_groups=num_groups, - group_topk=group_topk, - scaling_factor=scaling_factor, - aux_loss_coeff=aux_loss_coeff, - permutation_backend=permutation_backend, - align_size=align_size, - gate_inside_vjp=gate_inside_vjp, - ep_axis=ep_axis, - data_parallelism_axes=data_parallelism_axes, - input_axes=input_axes, - gate_kernel_axes=gate_kernel_axes, - wi_kernel_axes=wi_kernel_axes, - wo_kernel_axes=wo_kernel_axes, - quantizer_sets=quantizer_sets, - dtype=dtype, - apply_topk_weights_early=apply_topk_weights_early, + expert_bias_arg, + num_experts, + num_experts_per_tok, + activation_type, + score_function, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + float(aux_loss_coeff), + ep_axis, + data_parallelism_axes, + input_axes, + gate_kernel_axes, + wi_kernel_axes, + wo_kernel_axes, + dtype, + apply_topk_weights_early, + align_size, ) if aux_loss_coeff <= 0.0: aux_loss = None From 776c5effd8238b029064ba83577b931fa2b5c8dd Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 3 Jun 2026 15:57:37 -0700 Subject: [PATCH 18/29] [JAX] MoE: bootstrap TE EP eagerly outside jit; assert compatibility per-call ep_bootstrap allgathers a NCCL UID via the JAX runtime, which traces under jax.jit and fails with TracerArrayConversionError. Move the bootstrap to the test fixture (matching the test_multi_process_ep.py pattern from the TE EP JAX PR): caller invokes ep_bootstrap once per process, then calls record_ep_bootstrap_signature_for_moe with the same params. _moe_fwd_rule now only asserts that the recorded bootstrap signature is wide enough (num_experts/hidden_dim/ep_size exact match; per-call max_tokens_per_rank and recv_capacity_per_rank <= bootstrap values). Test mesh fixture bootstraps with the worst-case recv_pr across _CONFIGS so every parametrized config is compatible with a single per-process bootstrap. --- tests/jax/test_te_ep_moe.py | 51 ++++++++++++++++++++- transformer_engine/jax/moe.py | 84 +++++++++++++++++++++++------------ 2 files changed, 104 insertions(+), 31 deletions(-) diff --git a/tests/jax/test_te_ep_moe.py b/tests/jax/test_te_ep_moe.py index cc878e0bd1..9373873bec 100644 --- a/tests/jax/test_te_ep_moe.py +++ b/tests/jax/test_te_ep_moe.py @@ -130,7 +130,8 @@ def _read_mp_options(): ) from transformer_engine.jax.flax import _MoEBlock as MoEBlock -from transformer_engine.jax.moe import moe +from transformer_engine.jax.moe import moe, record_ep_bootstrap_signature_for_moe +from transformer_engine.jax.ep import ep_bootstrap from transformer_engine.jax.sharding import MeshResource, global_shard_guard @@ -191,6 +192,23 @@ def _read_mp_options(): # ----------------------------------------------------------------------------- +def _compute_worst_case_recv_pr(): + """Worst-case per-rank recv buffer across every config in _CONFIGS. + + Bootstrap reserves NCCL EP buffers; per-call recv_pr <= bootstrap + recv_pr is fine. We size with the largest align_size in _CONFIGS so + the align128 config still fits the same singleton bootstrap. + """ + num_procs = jax.device_count() + dp_size = num_procs // EP_SIZE + num_local_experts = NUM_EXPERTS // EP_SIZE + natural_recv_pr = (BATCH // dp_size) * SEQ * TOPK + natural_spe = (natural_recv_pr + num_local_experts - 1) // num_local_experts + worst_align = 128 + worst_spe = ((natural_spe + worst_align - 1) // worst_align) * worst_align + return num_local_experts * worst_spe + + @pytest.fixture(scope="module") def mesh(): if jax.device_count() < NUM_DEVICES_REQUIRED: @@ -202,7 +220,36 @@ def mesh(): # from consecutive global ranks via ``dp_color = rank // ep_size``, so # only an (outer_fsdp, inner_ep) device layout groups ranks correctly. devices = mesh_utils.create_device_mesh((FSDP_SIZE, EP_SIZE)) - return Mesh(devices, axis_names=(FSDP_AXIS, EP_AXIS)) + mesh_obj = Mesh(devices, axis_names=(FSDP_AXIS, EP_AXIS)) + + num_procs = jax.process_count() + max_tokens_per_rank = (BATCH // num_procs) * SEQ + recv_capacity_per_rank = _compute_worst_case_recv_pr() + + # Eager bootstrap: ep_bootstrap does a host-side NCCL UID allgather + # and cannot run from inside jax.jit. Sized to the worst-case recv_pr + # across _CONFIGS so every parametrized config is bootstrap-compatible. + with mesh_obj, global_shard_guard( + MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS) + ): + ep_bootstrap( + world_size=num_procs, + rank=jax.process_index(), + ep_size=EP_SIZE, + num_experts=NUM_EXPERTS, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=HIDDEN, + allow_handle_mem_reloc=True, + ) + record_ep_bootstrap_signature_for_moe( + num_experts=NUM_EXPERTS, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=HIDDEN, + ep_size=EP_SIZE, + ) + return mesh_obj # ----------------------------------------------------------------------------- diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index 162ea8f7e5..da6f0ac1e4 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -58,49 +58,75 @@ # ============================================================================= -# Process-level NCCL EP bootstrap +# Process-level NCCL EP bootstrap (must run eagerly, outside jax.jit) # ============================================================================= # -# ``tex.ep_bootstrap`` initialises the NCCL EP communicator exactly once per -# process and stashes its state in a C++ singleton. Subsequent calls with the -# same signature are a no-op; calls with a different signature raise. +# ``tex.ep_bootstrap`` does a NCCL UID allgather over the JAX runtime, which +# cannot run from inside a jit-traced function. The caller must bootstrap +# eagerly once per process before any jitted MoE call, then record the +# bootstrap signature via ``record_ep_bootstrap_signature_for_moe``. The +# per-call check below verifies the recorded signature is wide enough for +# the current MoE invocation (smaller per-call usage is fine since the C++ +# backend reserves worst-case buffers at bootstrap time). _te_ep_bootstrap_signature: Optional[Tuple[int, int, int, int, int]] = None -def _te_ep_bootstrap_if_needed( +def record_ep_bootstrap_signature_for_moe( num_experts: int, max_tokens_per_rank: int, recv_capacity_per_rank: int, hidden_dim: int, ep_size: int, ) -> None: - """Bootstrap the NCCL EP communicator on first use within a process.""" + """Record the params passed to ``ep_bootstrap`` so the per-call check + in ``_moe_fwd_rule`` can verify compatibility. Call this once per + process immediately after ``ep_bootstrap``. + """ global _te_ep_bootstrap_signature - sig = (num_experts, max_tokens_per_rank, recv_capacity_per_rank, hidden_dim, ep_size) - if _te_ep_bootstrap_signature == sig: - return - if _te_ep_bootstrap_signature is not None: + _te_ep_bootstrap_signature = ( + num_experts, + max_tokens_per_rank, + recv_capacity_per_rank, + hidden_dim, + ep_size, + ) + + +def _te_ep_assert_compatible_bootstrap( + num_experts: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + ep_size: int, +) -> None: + """Verify a prior eager ``ep_bootstrap`` is wide enough for this call.""" + if _te_ep_bootstrap_signature is None: + raise RuntimeError( + "TE EP was not bootstrapped. Call" + " transformer_engine.jax.ep.ep_bootstrap(...) eagerly (outside" + " any jax.jit) once per process, then" + " transformer_engine.jax.moe.record_ep_bootstrap_signature_for_moe(...)" + " with the same params, before invoking moe()." + ) + b_num_experts, b_max_tpr, b_recv_pr, b_hidden, b_ep_size = _te_ep_bootstrap_signature + if ( + num_experts != b_num_experts + or hidden_dim != b_hidden + or ep_size != b_ep_size + or max_tokens_per_rank > b_max_tpr + or recv_capacity_per_rank > b_recv_pr + ): raise ValueError( - "TE EP was already bootstrapped with signature " - f"{_te_ep_bootstrap_signature}; got {sig}. Re-bootstrap with" - " different params is not supported within a single process." + "TE EP was already bootstrapped with signature" + f" (num_experts={b_num_experts}, max_tokens_per_rank={b_max_tpr}," + f" recv_capacity_per_rank={b_recv_pr}, hidden_dim={b_hidden}," + f" ep_size={b_ep_size}); this moe() call needs" + f" (num_experts={num_experts}, max_tokens_per_rank={max_tokens_per_rank}," + f" recv_capacity_per_rank={recv_capacity_per_rank}, hidden_dim={hidden_dim}," + f" ep_size={ep_size}). Re-bootstrap with wider params (or matching exact" + " sizes) is required." ) - from transformer_engine.jax.ep import ep_bootstrap # local: avoids import cycle - - ep_bootstrap( - world_size=jax.process_count(), - rank=jax.process_index(), - ep_size=ep_size, - num_experts=num_experts, - max_tokens_per_rank=max_tokens_per_rank, - recv_capacity_per_rank=recv_capacity_per_rank, - hidden_dim=hidden_dim, - # XLA may relocate the C++ handle buffer between JIT executables; - # allow it rather than asserting on handle aliasing. - allow_handle_mem_reloc=True, - ) - _te_ep_bootstrap_signature = sig # ============================================================================= @@ -413,7 +439,7 @@ def _moe_fwd_rule( # per-rank receive buffer. max_tokens_per_rank = (B // num_procs) * S - _te_ep_bootstrap_if_needed( + _te_ep_assert_compatible_bootstrap( num_experts=num_experts, max_tokens_per_rank=max_tokens_per_rank, recv_capacity_per_rank=recv_pr, From acb610ff1028f85614a803efc675e2135088d29e Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 3 Jun 2026 16:15:10 -0700 Subject: [PATCH 19/29] [JAX] MoE: thread EpHandle + handle_mem through dispatch / combine The cpp_extensions/ep.py API (post the per-layer EpHandle refactor in e927903c) expects an EpHandle object plus a separate handle_mem buffer for every dispatch/combine call. The MoE wrapper was still passing the raw slots_per_expert int as the second positional and unpacking ep_dispatch_fwd as a 3-tuple, which now blows up with "AttributeError: 'int' object has no attribute 'handle_id'". Changes: - Cache one EpHandle per (top_k, alignment) at module scope so repeated jit traces don't burn the NVTE_EP_HANDLE_CACHE_SIZE pool. - _moe_fwd_rule: mint/lookup the handle, call ep_prepare(topk_idx, handle) -> (token_counts, handle_mem), and pass (handle, handle_mem) into the fwd dispatch/combine calls. ep_dispatch_fwd now returns a 2-tuple. - _Ctx: stash handle_mem alongside handle so the bwd can hand both back to ep_combine_bwd and ep_dispatch_bwd. - _moe_bwd_rule: thread ctx.handle_mem into the bwd dispatch/combine calls. --- transformer_engine/jax/moe.py | 37 ++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index da6f0ac1e4..067be5a60b 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -37,7 +37,7 @@ from dataclasses import dataclass from functools import partial -from typing import Any, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import warnings import jax @@ -93,6 +93,25 @@ def record_ep_bootstrap_signature_for_moe( ) +# Per-(top_k, alignment) EpHandle cache. ``tex.ep_make_handle`` mints a +# fresh handle_id from a singleton pool capped at NVTE_EP_HANDLE_CACHE_SIZE +# (default 8192); caching here keeps the pool steady across many jit traces +# of the same MoE block configuration. +_te_ep_handle_cache: Dict[Tuple[int, int], Any] = {} + + +def _get_or_make_ep_handle(top_k: int, dispatch_output_per_expert_alignment: int): + key = (int(top_k), int(dispatch_output_per_expert_alignment)) + h = _te_ep_handle_cache.get(key) + if h is None: + h = tex.ep_make_handle( + top_k=key[0], + dispatch_output_per_expert_alignment=key[1], + ) + _te_ep_handle_cache[key] = h + return h + + def _te_ep_assert_compatible_bootstrap( num_experts: int, max_tokens_per_rank: int, @@ -145,6 +164,7 @@ class _Ctx: saved_scores: jnp.ndarray routing_map: jnp.ndarray handle: Any + handle_mem: Any token_counts: jnp.ndarray recv_topk_weights: jnp.ndarray casted_sorted_x_lhs_trans: Any @@ -538,9 +558,12 @@ def _moe_fwd_rule( topk_w_3d = routing_weights.reshape(B, S, K).astype(jnp.float32) # ---------------- TE EP dispatch (global view) ---------------- - token_counts, handle = tex.ep_prepare(topk_idx_3d, slots_per_expert) - recv_tokens, recv_topk_weights, handle = tex.ep_dispatch_fwd( - handle, topk_idx_3d, x, topk_w_3d, recv_pr + handle = _get_or_make_ep_handle( + top_k=K, dispatch_output_per_expert_alignment=slots_per_expert + ) + token_counts, handle_mem = tex.ep_prepare(topk_idx_3d, handle) + recv_tokens, recv_topk_weights = tex.ep_dispatch_fwd( + handle, handle_mem, topk_idx_3d, x, topk_w_3d, recv_pr ) recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3_spec)) recv_topk_weights = jax.lax.with_sharding_constraint( @@ -608,6 +631,7 @@ def _body(*args): # expert_outputs is already weighted upstream. output = tex.ep_combine_fwd( handle, + handle_mem, expert_outputs, num_local_tokens=(B, S), out_partition_spec=out_partition_spec, @@ -618,6 +642,7 @@ def _body(*args): weighted = expert_outputs * w * mask output = tex.ep_combine_fwd( handle, + handle_mem, weighted, num_local_tokens=(B, S), out_partition_spec=out_partition_spec, @@ -641,6 +666,7 @@ def _body(*args): saved_scores=saved_scores, routing_map=routing_map, handle=handle, + handle_mem=handle_mem, token_counts=token_counts, recv_topk_weights=recv_topk_weights, casted_sorted_x_lhs_trans=casted_sorted_x_lhs_trans, @@ -716,7 +742,7 @@ def _moe_bwd_rule( # ---------------- Combine bwd (global view) ---------------- d_output = jax.lax.with_sharding_constraint(d_output, NamedSharding(mesh, ep3_spec)) - grad_pre_combine = tex.ep_combine_bwd(ctx.handle, d_output, recv_pr) + grad_pre_combine = tex.ep_combine_bwd(ctx.handle, ctx.handle_mem, d_output, recv_pr) grad_pre_combine = jax.lax.with_sharding_constraint( grad_pre_combine, NamedSharding(mesh, ep3_spec) ) @@ -837,6 +863,7 @@ def _bwd_body(*args): ) d_x_from_dispatch, d_topk_w = tex.ep_dispatch_bwd( ctx.handle, + ctx.handle_mem, d_sorted_x, d_recv_w_total, top_k=K, From 458d1c4cf26fb20a714f0f49284b6efe23a883f7 Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 3 Jun 2026 16:56:16 -0700 Subject: [PATCH 20/29] [JAX] MoE: pass bf16 as max_token_dtype to test fixture's ep_bootstrap te-ep-fixes plumbs NVTEEpGroupConfig.max_token_dtype through ep_bootstrap. Tests dispatch bf16 tokens; without this arg the group lands with the legacy kByte default (1 byte) and every dispatch aborts at the ep_backend.cpp:349 dtype check. --- tests/jax/test_te_ep_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/jax/test_te_ep_moe.py b/tests/jax/test_te_ep_moe.py index 9373873bec..d56dad2c92 100644 --- a/tests/jax/test_te_ep_moe.py +++ b/tests/jax/test_te_ep_moe.py @@ -241,6 +241,7 @@ def mesh(): recv_capacity_per_rank=recv_capacity_per_rank, hidden_dim=HIDDEN, allow_handle_mem_reloc=True, + max_token_dtype=DTYPE, ) record_ep_bootstrap_signature_for_moe( num_experts=NUM_EXPERTS, From 3d6825c7093a82a6ecc4fbfa618dc81c4fd1f6b4 Mon Sep 17 00:00:00 2001 From: tdophung Date: Thu, 4 Jun 2026 10:39:15 -0700 Subject: [PATCH 21/29] patching the sharding stripped by flattening logits input to topk, will fix for real in later commits Signed-off-by: tdophung --- transformer_engine/jax/moe.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index 067be5a60b..a96a77b991 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -243,8 +243,10 @@ def _ffn_fwd_per_shard( if apply_topk_weights_early: # Fold the per-token combine weights into the FFN intermediate; # the downstream wo GEMM is linear so this is equivalent to the - # late-weighting path, modulo elementwise op fusion gains. - w_b = recv_w_flat[:, None] + # late-weighting path, modulo elementwise op fusion gains. w_b is + # cast to intermediate.dtype so the multiply doesn't promote + # expert_outputs to f32 (NCCL EP combine hard-asserts bf16). + w_b = recv_w_flat[:, None].astype(intermediate.dtype) mask_b = (recv_w_flat != 0).astype(intermediate.dtype)[:, None] intermediate = intermediate * w_b * mask_b @@ -317,7 +319,9 @@ def _ffn_bwd_per_shard( if apply_topk_weights_early: # intermediate' = intermediate * w * mask. Split the cotangent # across both factors before the activation bwd consumes it. - w_b = recv_w_flat[:, None] + # Cast w_b so the multiply stays in d_intermediate.dtype and + # d_sorted_x (downstream into ep_dispatch_bwd) stays bf16. + w_b = recv_w_flat[:, None].astype(d_intermediate.dtype) mask_b = (recv_w_flat != 0).astype(d_intermediate.dtype)[:, None] intermediate_unweighted = act_fn(gate_proj_out) * up_proj_out d_recv_w_from_intermediate = jnp.sum( @@ -556,6 +560,17 @@ def _moe_fwd_rule( routing_weights = jnp.take_along_axis(sparse_probs, selected_experts, axis=-1) topk_idx_3d = selected_experts.reshape(B, S, K).astype(jnp.int32) topk_w_3d = routing_weights.reshape(B, S, K).astype(jnp.float32) + # tex.ep_prepare/dispatch's partition only folds ep_axis into a replicated + # leading dim, not the outer dp/fsdp axes, so a replicated topk_idx makes + # each rank see B/ep rows (not B/num_procs) and overrun the bootstrap-sized + # send buffer. Pin both routing tensors to the (outer, ep) leading sharding + # so per-rank token counts match max_tokens_per_rank. + topk_idx_3d = jax.lax.with_sharding_constraint( + topk_idx_3d, NamedSharding(mesh, ep3_spec) + ) + topk_w_3d = jax.lax.with_sharding_constraint( + topk_w_3d, NamedSharding(mesh, ep3_spec) + ) # ---------------- TE EP dispatch (global view) ---------------- handle = _get_or_make_ep_handle( @@ -637,7 +652,7 @@ def _body(*args): out_partition_spec=out_partition_spec, ) else: - w = recv_topk_weights[..., None] + w = recv_topk_weights[..., None].astype(expert_outputs.dtype) mask = (recv_topk_weights != 0).astype(expert_outputs.dtype)[..., None] weighted = expert_outputs * w * mask output = tex.ep_combine_fwd( @@ -754,8 +769,10 @@ def _moe_bwd_rule( d_recv_w_from_combine = jnp.zeros_like(ctx.recv_topk_weights) else: # combine_fwd consumed weighted = expert_out * w * mask; - # split the cotangent across both factors. - w = ctx.recv_topk_weights[..., None] + # split the cotangent across both factors. w is cast to + # grad_pre_combine.dtype so the multiply stays bf16 and + # d_sorted_x (downstream into ep_dispatch_bwd) stays bf16. + w = ctx.recv_topk_weights[..., None].astype(grad_pre_combine.dtype) mask = (ctx.recv_topk_weights != 0).astype(grad_pre_combine.dtype)[..., None] d_expert_outputs = grad_pre_combine * w * mask d_recv_w_from_combine = (grad_pre_combine * ctx.expert_outputs * mask).sum(axis=-1) From 0236467805a39c04a32554b53046133a87ad9ce2 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 4 Jun 2026 14:36:55 -0700 Subject: [PATCH 22/29] MoEBlock tutorial Signed-off-by: Jeremy Berchtold --- docs/examples/jax/moe.out | 11 + docs/examples/jax/moe.py | 312 ++++++++++++++++++++ docs/examples/jax/moe.rst | 163 +++++++++++ docs/examples/jax/moe_native.py | 422 +++++++++++++++++++++++++++ docs/examples/jax/test_moe.py | 129 ++++++++ docs/examples/te_jax_integration.rst | 4 + 6 files changed, 1041 insertions(+) create mode 100644 docs/examples/jax/moe.out create mode 100644 docs/examples/jax/moe.py create mode 100644 docs/examples/jax/moe.rst create mode 100644 docs/examples/jax/moe_native.py create mode 100644 docs/examples/jax/test_moe.py diff --git a/docs/examples/jax/moe.out b/docs/examples/jax/moe.out new file mode 100644 index 0000000000..27aaacb677 --- /dev/null +++ b/docs/examples/jax/moe.out @@ -0,0 +1,11 @@ +# Numbers below were captured on 4x NVIDIA GB200. Regenerate with: +# python3 docs/examples/jax/moe.py > moe.out + +# MOE_OUTPUT_START +max |native BF16 - TE BF16|: 0.0604 +native JAX BF16: +Mean time: 19.545 ms + +TE _MoEBlock BF16: +Mean time: 13.632 ms +# MOE_OUTPUT_END diff --git a/docs/examples/jax/moe.py b/docs/examples/jax/moe.py new file mode 100644 index 0000000000..f7e21a1acf --- /dev/null +++ b/docs/examples/jax/moe.py @@ -0,0 +1,312 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""JAX: BF16 Mixture-of-Experts with TransformerEngine. + +Companion source for ``moe.rst``. Code blocks between ``# MOE_*_START`` / +``# MOE_*_END`` markers are pulled into the RST via ``literalinclude``. + +Run as a script to exercise the example end-to-end: + + python docs/examples/jax/moe.py + +The example uses a 2x2 expert-parallel/FSDP mesh and therefore requires four +visible GPUs. Both the native baseline and TransformerEngine path run in BF16; +the current ``_MoEBlock`` wrapper uses no-op quantizer sets. +""" + +# MOE_IMPORTS_START +from dataclasses import dataclass +from typing import Any + +import jax +import jax.numpy as jnp +from flax.linen import partitioning as nn_partitioning +from jax.experimental import mesh_utils +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P + +from moe_native import NativeMoEBlock + +# MOE_IMPORTS_END + + +# MOE_CONFIG_START +EP_AXIS = "ep" +FSDP_AXIS = "fsdp" +EP_SIZE = 2 +FSDP_SIZE = 2 + +NUM_EXPERTS = 8 +TOPK = 2 +BATCH = 8 +SEQ = 2048 +HIDDEN = 1024 +INTERMEDIATE = 4096 +DTYPE = jnp.bfloat16 + +LOGICAL_AXIS_RULES = ( + ("exp", EP_AXIS), + ("embed", FSDP_AXIS), + ("mlp", None), + ("batch", (EP_AXIS, FSDP_AXIS)), +) +# MOE_CONFIG_END + + +@dataclass +class DemoState: + mesh: Mesh + mesh_resource: Any + native_model: NativeMoEBlock + te_model: Any + variables: Any + x: jax.Array + dy: jax.Array + + +def _ensure_writable_triton_cache(): + import os + import tempfile + + os.environ.setdefault( + "TRITON_CACHE_DIR", + os.path.join(tempfile.gettempdir(), "transformer_engine_triton_cache"), + ) + + +# MOE_MESH_SETUP_START +def build_ep_fsdp_mesh(): + from transformer_engine.jax.sharding import MeshResource + + required_devices = EP_SIZE * FSDP_SIZE + if len(jax.devices()) < required_devices: + raise RuntimeError( + f"MoE tutorial requires {required_devices} GPUs; " + f"only {len(jax.devices())} visible" + ) + + devices = mesh_utils.create_device_mesh( + (EP_SIZE, FSDP_SIZE), + devices=jax.devices()[:required_devices], + ) + mesh = Mesh(devices, axis_names=(EP_AXIS, FSDP_AXIS)) + mesh_resource = MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS) + return mesh, mesh_resource + + +# MOE_MESH_SETUP_END + + +# MOE_MODEL_SETUP_START +def build_models(mesh, *, hidden=HIDDEN, intermediate=INTERMEDIATE): + _ensure_writable_triton_cache() + + from transformer_engine.jax.flax import _MoEBlock as TEMoEBlock + from transformer_engine.jax.moe import PermutationBackend + + native_model = NativeMoEBlock( + mesh=mesh, + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + intermediate_size=intermediate, + ep_axis=EP_AXIS, + data_parallelism_axes=(FSDP_AXIS,), + dtype=DTYPE, + ) + te_model = TEMoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + intermediate_size=intermediate, + data_parallelism_axes=(FSDP_AXIS,), + permutation_backend=PermutationBackend.PURE_JAX, + dtype=DTYPE, + ) + return native_model, te_model + + +# MOE_MODEL_SETUP_END + + +# MOE_INPUTS_SETUP_START +def make_inputs(*, batch=BATCH, seq=SEQ, hidden=HIDDEN): + key = jax.random.PRNGKey(0) + k_init, k_x, k_dy = jax.random.split(key, 3) + x = jax.random.normal(k_x, (batch, seq, hidden), dtype=DTYPE) + dy = jax.random.normal(k_dy, (batch, seq, hidden), dtype=DTYPE) + return k_init, x, dy + + +def shard_inputs_and_variables(mesh, variables, x, dy): + input_sharding = NamedSharding(mesh, P((EP_AXIS, FSDP_AXIS), None, None)) + gate_sharding = NamedSharding(mesh, P()) + expert_sharding = NamedSharding(mesh, P(EP_AXIS, None, None)) + + params = variables["params"] + sharded_params = { + "gate_kernel": jax.device_put(params["gate_kernel"], gate_sharding), + "wi_0": jax.device_put(params["wi_0"], expert_sharding), + "wi_1": jax.device_put(params["wi_1"], expert_sharding), + "wo": jax.device_put(params["wo"], expert_sharding), + } + return { + "variables": {**variables, "params": sharded_params}, + "x": jax.device_put(x, input_sharding), + "dy": jax.device_put(dy, input_sharding), + } + + +# MOE_INPUTS_SETUP_END + + +def _te_apply(te_model): + def apply_fn(variables, x, **kwargs): + out, _ = te_model.apply(variables, x, **kwargs) + return out + + return apply_fn + + +def setup_demo(*, batch=BATCH, seq=SEQ, hidden=HIDDEN, intermediate=INTERMEDIATE): + from transformer_engine.jax.sharding import global_shard_guard + + mesh, mesh_resource = build_ep_fsdp_mesh() + native_model, te_model = build_models(mesh, hidden=hidden, intermediate=intermediate) + k_init, x, dy = make_inputs(batch=batch, seq=seq, hidden=hidden) + + with jax.set_mesh(mesh), global_shard_guard(mesh_resource), nn_partitioning.axis_rules( + LOGICAL_AXIS_RULES + ): + variables = jax.jit(native_model.init)(k_init, x) + jax.block_until_ready(jax.tree_util.tree_leaves(variables)[0]) + + sharded = shard_inputs_and_variables(mesh, variables, x, dy) + return DemoState( + mesh=mesh, + mesh_resource=mesh_resource, + native_model=native_model, + te_model=te_model, + variables=sharded["variables"], + x=sharded["x"], + dy=sharded["dy"], + ) + + +def te_moe_supported(): + try: + import importlib + import sys + + _ensure_writable_triton_cache() + + import transformer_engine.jax # noqa: F401 + + transformer_engine_jax = sys.modules["transformer_engine_jax"] + flax_mod = importlib.import_module("transformer_engine.jax.flax") + getattr(flax_mod, "_MoEBlock") + if transformer_engine_jax.get_device_compute_capability(0) < 100: + return False, "TE MoE grouped GEMM currently requires Blackwell (sm_100+)" + except Exception as exc: # pylint: disable=broad-exception-caught + return False, str(exc) + return True, "" + + +# MOE_CORRECTNESS_START +def compare_forward(demo): + from transformer_engine.jax.sharding import global_shard_guard + + te_apply = _te_apply(demo.te_model) + with jax.set_mesh(demo.mesh), global_shard_guard( + demo.mesh_resource + ), nn_partitioning.axis_rules(LOGICAL_AXIS_RULES): + native_out = jax.jit(demo.native_model.apply)(demo.variables, demo.x) + te_out = jax.jit(te_apply)(demo.variables, demo.x) + native_out, te_out = jax.block_until_ready((native_out, te_out)) + + max_abs = jnp.max( + jnp.abs(native_out.astype(jnp.float32) - te_out.astype(jnp.float32)) + ) + print(f"max |native BF16 - TE BF16|: {float(max_abs):.4f}") + return native_out, te_out + + +# MOE_CORRECTNESS_END + + +# MOE_BENCH_START +def _block_until_ready_tree(tree): + leaves = jax.tree_util.tree_leaves(tree) + if leaves: + jax.block_until_ready(leaves[0]) + + +def _time_fwd_bwd(apply_fn, demo, *, warmup_iters, timing_iters): + import time + + autocast_kwargs = {"enabled": False, "mesh_resource": demo.mesh_resource} + + def loss_fn(variables, inp, grad_target): + import transformer_engine.jax as te + + with te.autocast(**autocast_kwargs): + out = apply_fn(variables, inp) + return jnp.vdot(out, grad_target) + + train_step = jax.jit(jax.value_and_grad(loss_fn, argnums=(0, 1))) + + for _ in range(warmup_iters): + _block_until_ready_tree(train_step(demo.variables, demo.x, demo.dy)) + + start = time.perf_counter() + for _ in range(timing_iters): + _block_until_ready_tree(train_step(demo.variables, demo.x, demo.dy)) + return (time.perf_counter() - start) * 1000.0 / timing_iters + + +def run_benchmarks(demo, *, warmup_iters=5, timing_iters=10): + from transformer_engine.jax.sharding import global_shard_guard + + te_apply = _te_apply(demo.te_model) + with jax.set_mesh(demo.mesh), global_shard_guard( + demo.mesh_resource + ), nn_partitioning.axis_rules(LOGICAL_AXIS_RULES): + print("native JAX BF16:") + native_ms = _time_fwd_bwd( + demo.native_model.apply, + demo, + warmup_iters=warmup_iters, + timing_iters=timing_iters, + ) + print(f"Mean time: {native_ms:.3f} ms") + + print("\nTE _MoEBlock BF16:") + te_ms = _time_fwd_bwd( + te_apply, + demo, + warmup_iters=warmup_iters, + timing_iters=timing_iters, + ) + print(f"Mean time: {te_ms:.3f} ms") + return native_ms, te_ms + + +# MOE_BENCH_END + + +def main(): + if len(jax.devices()) < EP_SIZE * FSDP_SIZE: + print(f"[skipped: need {EP_SIZE * FSDP_SIZE} GPUs for EP=2/FSDP=2]") + return + + te_supported, te_reason = te_moe_supported() + if not te_supported: + print(f"[skipped TE comparison: {te_reason}]") + return + + demo = setup_demo() + compare_forward(demo) + run_benchmarks(demo) + + +if __name__ == "__main__": + main() diff --git a/docs/examples/jax/moe.rst b/docs/examples/jax/moe.rst new file mode 100644 index 0000000000..43deeee7f7 --- /dev/null +++ b/docs/examples/jax/moe.rst @@ -0,0 +1,163 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +JAX: BF16 Mixture-of-Experts with TransformerEngine +=================================================== + +This document walks through replacing a native JAX/Flax expert-parallel MoE +block with TransformerEngine's experimental Flax ``_MoEBlock``. + +**Baseline.** The reference path is pure JAX/Flax BF16. It uses +``jax.lax.ragged_all_to_all`` for expert-parallel token exchange and +``jax.lax.ragged_dot`` for the grouped expert FFNs. The low-level ragged +all-to-all setup lives in ``moe_native.py`` so the snippets below stay focused +on model-level code. + +**TransformerEngine path.** This tutorial uses ``_MoEBlock`` in BF16 with the +wrapper's current no-op quantizer sets. Quantized MoE recipes are intentionally +out of scope here. + +`<- Back to the JAX integration overview <../te_jax_integration.html>`_ + +1. Baseline: native JAX BF16 EP MoE +----------------------------------- + +The example uses a 2x2 mesh: expert parallelism on ``ep`` and FSDP-style batch +parallelism on ``fsdp``. The batch dimension is sharded over both axes, and +expert weights are sharded over ``ep``. + +.. literalinclude:: moe.py + :language: python + :start-after: # MOE_IMPORTS_START + :end-before: # MOE_IMPORTS_END + +.. literalinclude:: moe.py + :language: python + :start-after: # MOE_CONFIG_START + :end-before: # MOE_CONFIG_END + +.. literalinclude:: moe.py + :language: python + :start-after: # MOE_MESH_SETUP_START + :end-before: # MOE_MESH_SETUP_END + +The native baseline is exposed as a normal Flax module. Its implementation in +``moe_native.py`` performs softmax top-k routing, forward +``ragged_all_to_all`` over ``ep``, local source-major to expert-major chunk +reordering, three ``ragged_dot`` expert GEMMs, reverse ``ragged_all_to_all``, +and weighted token combine. + +2. TransformerEngine ``_MoEBlock`` +---------------------------------- + +The TE replacement registers the same gate and expert parameter names as the +baseline, then delegates routing, dispatch, grouped FFN, combine, +expert-parallel collectives, and VJP to ``transformer_engine.jax.moe.moe``. + +``_MoEBlock`` is intentionally underscore-prefixed while the API stabilizes. Use +it as an experimental integration point. + +.. literalinclude:: moe.py + :language: python + :start-after: # MOE_MODEL_SETUP_START + :end-before: # MOE_MODEL_SETUP_END + +.. literalinclude:: moe.py + :language: python + :start-after: # MOE_INPUTS_SETUP_START + :end-before: # MOE_INPUTS_SETUP_END + +3. Correctness check +-------------------- + +Both models use the same Flax variable dictionary, so the gate and expert +weights are identical. The comparison checks the BF16 forward result on the same +sharded input. + +.. literalinclude:: moe.py + :language: python + :start-after: # MOE_CORRECTNESS_START + :end-before: # MOE_CORRECTNESS_END + +The two paths may not be bit-identical because the router and grouped matmul +implementations differ, but they should stay within ordinary BF16 tolerance for +the default no-bias, softmax top-k, ``silu`` path. + +4. Performance comparison +------------------------- + +``run_benchmarks`` runs a blocking JIT-compiled forward+backward loop with +warmup. The same sharded input, output gradient, and variables are used for +native BF16 and TE BF16. Even though quantization is disabled, the benchmark +passes the active ``MeshResource`` through TE's autocast context so +``_MoEBlock`` can resolve the ``ep`` axis. + +.. literalinclude:: moe.py + :language: python + :start-after: # MOE_BENCH_START + :end-before: # MOE_BENCH_END + +Run the full example with: + +.. code-block:: bash + + python docs/examples/jax/moe.py + +Measured on four NVIDIA GB200 GPUs with the default tutorial shape +``batch=8``, ``seq=2048``, ``hidden=1024``, ``intermediate=4096``, +``num_experts=8``, and ``topk=2``: + +.. csv-table:: + :header: "Path", "Mean fwd+bwd time", "Relative time" + :widths: 35, 25, 25 + + "Native JAX BF16", "19.545 ms", "1.00x" + "TE ``_MoEBlock`` BF16", "13.632 ms", "0.70x" + +The same run reported ``max |native BF16 - TE BF16| = 0.0604`` for the forward +correctness check. For this no-op-quantizer BF16 configuration, TE measured +``1.43x`` the native baseline throughput on this tutorial shape. + +A larger-shape sweep with the same blocking timing loop found TE ahead for each +shape tried: + +.. csv-table:: + :header: "Batch", "Seq", "Hidden", "Intermediate", "Native BF16", "TE BF16", "TE speedup" + :widths: 10, 10, 12, 16, 16, 16, 14 + + "8", "1024", "1024", "4096", "9.173 ms", "7.377 ms", "1.24x" + "8", "2048", "1024", "4096", "19.545 ms", "13.632 ms", "1.43x" + "8", "4096", "1024", "4096", "39.179 ms", "33.570 ms", "1.17x" + "16", "2048", "1024", "4096", "39.211 ms", "33.595 ms", "1.17x" + "8", "1024", "2048", "8192", "19.313 ms", "14.846 ms", "1.30x" + "8", "2048", "2048", "8192", "42.629 ms", "32.657 ms", "1.31x" + "16", "2048", "2048", "8192", "86.957 ms", "68.643 ms", "1.27x" + +Across the sweep, the forward max-absolute difference stayed between +``0.0598`` and ``0.0704``. The result depends on token distribution, hidden +size, intermediate size, and the target stack. Keep +``PermutationBackend.PURE_JAX`` until correctness is established; then compare +it with ``PermutationBackend.TRITON`` separately. + +.. raw:: html + +
+ Output: +
+ +.. container:: program-output + + .. literalinclude:: moe.out + :language: text + :start-after: # MOE_OUTPUT_START + :end-before: # MOE_OUTPUT_END + +Next steps +---------- + +* `Dense GEMMs `_: quantizing a single ``flax.linen.Dense`` GEMM. +* `Collective GEMM `_: further speedups by communicating + between devices inside the GEMM. +* `<- Hub <../te_jax_integration.html>`_ diff --git a/docs/examples/jax/moe_native.py b/docs/examples/jax/moe_native.py new file mode 100644 index 0000000000..85ab48b13c --- /dev/null +++ b/docs/examples/jax/moe_native.py @@ -0,0 +1,422 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Native JAX/Flax MoE baseline used by ``moe.rst``. + +This file intentionally contains the lower-level reference mechanics so the +tutorial can focus on model-level code. It does not import TransformerEngine: +the router, expert-parallel ragged all-to-all, local ragged chunk reorder, and +ragged expert matmuls are implemented with JAX and Flax only. +""" + +import inspect +from functools import partial +from typing import Any, Callable, Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn +from jax.sharding import PartitionSpec as P + + +def _forward_a2a_params( + all_tokens_per_expert: jnp.ndarray, + shard_id: jnp.ndarray, + num_ep: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Build ``ragged_all_to_all`` offsets/sizes for dispatch.""" + num_experts = all_tokens_per_expert.shape[1] + experts_per_shard = num_experts // num_ep + + local_tokens_per_expert = jax.lax.dynamic_slice( + all_tokens_per_expert, + start_indices=(shard_id, 0), + slice_sizes=(1, num_experts), + ).squeeze(0) + local_by_destination = local_tokens_per_expert.reshape(num_ep, experts_per_shard) + send_sizes = jnp.sum(local_by_destination, axis=1).astype(jnp.int32) + input_offsets = jnp.concatenate( + [jnp.array([0], dtype=jnp.int32), jnp.cumsum(send_sizes)[:-1]] + ) + + local_expert_start = shard_id * experts_per_shard + local_expert_columns = jax.lax.dynamic_slice( + all_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_ep, experts_per_shard), + ) + recv_sizes = jnp.sum(local_expert_columns, axis=1).astype(jnp.int32) + + sends_to_destination = jnp.sum( + all_tokens_per_expert.reshape(num_ep, num_ep, experts_per_shard), + axis=2, + ).astype(jnp.int32) + cumulative = jnp.cumsum( + jnp.concatenate( + [jnp.zeros((1, num_ep), dtype=jnp.int32), sends_to_destination], + axis=0, + ), + axis=0, + ) + output_offsets = jax.lax.dynamic_slice( + cumulative, + start_indices=(shard_id, 0), + slice_sizes=(1, num_ep), + ).squeeze(0) + + return input_offsets, send_sizes, output_offsets, recv_sizes + + +def _reverse_a2a_params( + all_tokens_per_expert: jnp.ndarray, + shard_id: jnp.ndarray, + num_ep: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Build ``ragged_all_to_all`` offsets/sizes for combine.""" + num_experts = all_tokens_per_expert.shape[1] + experts_per_shard = num_experts // num_ep + local_expert_start = shard_id * experts_per_shard + + local_expert_columns = jax.lax.dynamic_slice( + all_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_ep, experts_per_shard), + ) + send_sizes = jnp.sum(local_expert_columns, axis=1).astype(jnp.int32) + input_offsets = jnp.concatenate( + [jnp.array([0], dtype=jnp.int32), jnp.cumsum(send_sizes)[:-1]] + ) + + local_tokens_per_expert = jax.lax.dynamic_slice( + all_tokens_per_expert, + start_indices=(shard_id, 0), + slice_sizes=(1, num_experts), + ).squeeze(0) + local_by_destination = local_tokens_per_expert.reshape(num_ep, experts_per_shard) + recv_sizes = jnp.sum(local_by_destination, axis=1).astype(jnp.int32) + + forward_sends_to = jnp.sum( + all_tokens_per_expert.reshape(num_ep, num_ep, experts_per_shard), + axis=2, + ).astype(jnp.int32) + reverse_sends_to = jnp.transpose(forward_sends_to) + cumulative = jnp.cumsum( + jnp.concatenate( + [jnp.zeros((1, num_ep), dtype=jnp.int32), reverse_sends_to], + axis=0, + ), + axis=0, + ) + output_offsets = jax.lax.dynamic_slice( + cumulative, + start_indices=(shard_id, 0), + slice_sizes=(1, num_ep), + ).squeeze(0) + + return input_offsets, send_sizes, output_offsets, recv_sizes + + +def _reorder_ragged_chunks( + x: jnp.ndarray, + chunk_sizes: jnp.ndarray, + source_order: jnp.ndarray, + target_order: jnp.ndarray, +) -> jnp.ndarray: + """Reorder a fixed-size ragged buffer from one chunk order to another.""" + source_sizes = chunk_sizes[source_order] + source_starts = jnp.concatenate( + [jnp.array([0], dtype=jnp.int32), jnp.cumsum(source_sizes)[:-1]] + ) + source_ends = source_starts + source_sizes + + target_sizes = chunk_sizes[target_order] + target_starts_by_position = jnp.concatenate( + [jnp.array([0], dtype=jnp.int32), jnp.cumsum(target_sizes)[:-1]] + ) + target_position_by_chunk = jnp.argsort(target_order) + target_start_by_chunk = target_starts_by_position[target_position_by_chunk] + + rows = jnp.arange(x.shape[0], dtype=jnp.int32) + in_source_chunk = (rows[:, None] >= source_starts[None, :]) & ( + rows[:, None] < source_ends[None, :] + ) + valid = jnp.any(in_source_chunk, axis=1) + source_position = jnp.argmax(in_source_chunk, axis=1) + chunk_id = source_order[source_position] + row_in_chunk = rows - source_starts[source_position] + target_rows = target_start_by_chunk[chunk_id] + row_in_chunk + target_rows = jnp.where(valid, target_rows, 0) + + updates = jnp.where(valid[:, None], x, jnp.zeros_like(x)) + return jnp.zeros_like(x).at[target_rows].add(updates) + + +def _route_tokens( + x_2d: jnp.ndarray, + gate_kernel: jnp.ndarray, + num_experts_per_tok: int, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Softmax top-k router matching the tutorial's default TE path.""" + logits = x_2d.astype(jnp.float32) @ gate_kernel.astype(jnp.float32) + probs = jax.nn.softmax(logits, axis=-1) + weights, experts = jax.lax.top_k(probs, num_experts_per_tok) + weights = weights / jnp.sum(weights, axis=-1, keepdims=True) + return experts.astype(jnp.int32), weights.astype(x_2d.dtype) + + +def _native_moe_local( + captured: dict, + *, + ep_axis: str, + num_experts: int, + num_experts_per_tok: int, + recv_buffer_rows: int, + dtype: jnp.dtype, +) -> jnp.ndarray: + """One shard of the native EP MoE forward pass.""" + x = captured["x"] + gate_kernel = captured["gate_kernel"] + wi_0 = captured["wi_0"] + wi_1 = captured["wi_1"] + wo = captured["wo"] + + batch, sequence, hidden = x.shape + tokens = batch * sequence + x_2d = x.reshape(tokens, hidden) + + selected_experts, routing_weights = _route_tokens( + x_2d, gate_kernel, num_experts_per_tok + ) + flat_experts = selected_experts.reshape(-1) + flat_token_ids = jnp.repeat( + jnp.arange(tokens, dtype=jnp.int32), num_experts_per_tok + ) + flat_weights = routing_weights.reshape(-1) + + sort_order = jnp.argsort(flat_experts, stable=True) + sorted_experts = flat_experts[sort_order] + sorted_x = x_2d[flat_token_ids][sort_order] + tokens_per_expert = jnp.bincount( + sorted_experts, + length=num_experts, + minlength=num_experts, + ).astype(jnp.int32) + + shard_id = jax.lax.axis_index(ep_axis) + num_ep = jax.lax.psum(1, ep_axis) + experts_per_shard = num_experts // num_ep + + all_tokens_per_expert = jax.lax.all_gather( + tokens_per_expert[None, :], + axis_name=ep_axis, + axis=0, + tiled=True, + ) + + in_off, send_sz, out_off, recv_sz = _forward_a2a_params( + all_tokens_per_expert, shard_id, num_ep + ) + x_recv = jax.lax.ragged_all_to_all( + sorted_x, + jnp.zeros((recv_buffer_rows, hidden), dtype=sorted_x.dtype), + in_off, + send_sz, + out_off, + recv_sz, + axis_name=ep_axis, + ) + + local_expert_start = shard_id * experts_per_shard + local_counts_by_source = jax.lax.dynamic_slice( + all_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_ep, experts_per_shard), + ).astype(jnp.int32) + local_chunk_sizes = local_counts_by_source.reshape(-1) + source_major_order = jnp.arange(num_ep * experts_per_shard, dtype=jnp.int32) + expert_major_order = source_major_order.reshape( + num_ep, experts_per_shard + ).T.reshape(-1) + local_group_sizes = jnp.sum(local_counts_by_source, axis=0).astype(jnp.int32) + + x_expert_major = _reorder_ragged_chunks( + x_recv, + local_chunk_sizes, + source_major_order, + expert_major_order, + ) + hidden_0 = jax.lax.ragged_dot(x_expert_major, wi_0, local_group_sizes) + hidden_1 = jax.lax.ragged_dot(x_expert_major, wi_1, local_group_sizes) + activated = jax.nn.silu(hidden_0) * hidden_1 + expert_output = jax.lax.ragged_dot(activated, wo, local_group_sizes).astype(dtype) + + source_major_output = _reorder_ragged_chunks( + expert_output, + local_chunk_sizes, + expert_major_order, + source_major_order, + ) + in_off, send_sz, out_off, recv_sz = _reverse_a2a_params( + all_tokens_per_expert, shard_id, num_ep + ) + returned = jax.lax.ragged_all_to_all( + source_major_output, + jnp.zeros_like(sorted_x), + in_off, + send_sz, + out_off, + recv_sz, + axis_name=ep_axis, + ) + + unsorted = jnp.zeros_like(returned).at[sort_order].set(returned) + token_outputs = unsorted.reshape(tokens, num_experts_per_tok, hidden) + weighted = token_outputs * flat_weights.reshape(tokens, num_experts_per_tok, 1) + return jnp.sum(weighted, axis=1).reshape(batch, sequence, hidden).astype(dtype) + + +def native_moe_ep( + x: jnp.ndarray, + gate_kernel: jnp.ndarray, + wi_0: jnp.ndarray, + wi_1: jnp.ndarray, + wo: jnp.ndarray, + *, + mesh: Any, + ep_axis: str, + data_parallelism_axes: Tuple[str, ...], + num_experts: int, + num_experts_per_tok: int, + dtype: jnp.dtype, +) -> jnp.ndarray: + """Run the native BF16 EP MoE baseline on an active JAX mesh.""" + if num_experts % mesh.shape[ep_axis] != 0: + raise ValueError( + f"num_experts={num_experts} must be divisible by " + f"EP size={mesh.shape[ep_axis]}" + ) + + if data_parallelism_axes: + batch_axis = (ep_axis, *data_parallelism_axes) + else: + batch_axis = ep_axis + + dp_size = 1 + for axis in data_parallelism_axes: + dp_size *= mesh.shape[axis] + + batch, sequence, _ = x.shape + required_batch_multiple = mesh.shape[ep_axis] * dp_size + if batch % required_batch_multiple != 0: + raise ValueError( + f"batch={batch} must be divisible by ep*dp=" + f"{required_batch_multiple}" + ) + + recv_buffer_rows = (batch // dp_size) * sequence * num_experts_per_tok + captured = { + "x": x, + "gate_kernel": gate_kernel, + "wi_0": wi_0, + "wi_1": wi_1, + "wo": wo, + } + in_specs = ( + { + "x": P(batch_axis, None, None), + "gate_kernel": P(), + "wi_0": P(ep_axis, None, None), + "wi_1": P(ep_axis, None, None), + "wo": P(ep_axis, None, None), + }, + ) + + body = partial( + _native_moe_local, + ep_axis=ep_axis, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + recv_buffer_rows=recv_buffer_rows, + dtype=dtype, + ) + shard_map_kwargs = { + "mesh": mesh, + "in_specs": in_specs, + "out_specs": P(batch_axis, None, None), + } + shard_map_params = inspect.signature(jax.shard_map).parameters + if "check_rep" in shard_map_params: + shard_map_kwargs["check_rep"] = False + elif "check_vma" in shard_map_params: + shard_map_kwargs["check_vma"] = False + + return jax.shard_map(body, **shard_map_kwargs)(captured) + + +class NativeMoEBlock(nn.Module): + """Native JAX/Flax BF16 EP MoE block used as the tutorial baseline.""" + + mesh: Any + num_experts: int = 8 + num_experts_per_tok: int = 2 + intermediate_size: int = 2048 + ep_axis: str = "ep" + data_parallelism_axes: Tuple[str, ...] = ("fsdp",) + dtype: jnp.dtype = jnp.bfloat16 + kernel_init: Optional[Callable] = None + + def __post_init__(self): + if self.kernel_init is None: + object.__setattr__( + self, + "kernel_init", + nn.initializers.variance_scaling( + 1.0, + "fan_in", + "truncated_normal", + dtype=self.dtype, + ), + ) + super().__post_init__() + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + hidden = x.shape[-1] + gate_kernel = self.param( + "gate_kernel", + self.kernel_init, + (hidden, self.num_experts), + self.dtype, + ) + wi_0 = self.param( + "wi_0", + self.kernel_init, + (self.num_experts, hidden, self.intermediate_size), + self.dtype, + ) + wi_1 = self.param( + "wi_1", + self.kernel_init, + (self.num_experts, hidden, self.intermediate_size), + self.dtype, + ) + wo = self.param( + "wo", + self.kernel_init, + (self.num_experts, self.intermediate_size, hidden), + self.dtype, + ) + return native_moe_ep( + x, + gate_kernel, + wi_0, + wi_1, + wo, + mesh=self.mesh, + ep_axis=self.ep_axis, + data_parallelism_axes=self.data_parallelism_axes, + num_experts=self.num_experts, + num_experts_per_tok=self.num_experts_per_tok, + dtype=self.dtype, + ) diff --git a/docs/examples/jax/test_moe.py b/docs/examples/jax/test_moe.py new file mode 100644 index 0000000000..85b5604ac1 --- /dev/null +++ b/docs/examples/jax/test_moe.py @@ -0,0 +1,129 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Pytest entry points for ``moe.py``. + +Run with: + + pytest -v docs/examples/jax/test_moe.py + +The tutorial uses a 2x2 EP/FSDP mesh, so tests skip when fewer than four GPUs +are visible. TransformerEngine MoE tests also skip when the installed TE build +does not expose the experimental ``_MoEBlock`` or when hardware support is +missing. +""" + +import importlib +import os +import sys +import tempfile + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + + +requires_4gpu = pytest.mark.skipif(len(jax.devices()) < 4, reason="needs 4 GPUs") + + +os.environ.setdefault( + "TRITON_CACHE_DIR", + os.path.join(tempfile.gettempdir(), "transformer_engine_triton_cache"), +) + + +def _te_moe_available(): + try: + import transformer_engine.jax # noqa: F401 + + mod = importlib.import_module("transformer_engine.jax.flax") + getattr(mod, "_MoEBlock") + transformer_engine_jax = sys.modules["transformer_engine_jax"] + + if transformer_engine_jax.get_device_compute_capability(0) < 100: + return False, "TE MoE grouped GEMM requires Blackwell (sm_100+)" + except Exception as exc: # pylint: disable=broad-exception-caught + return False, str(exc) + return True, "" + + +_te_supported, _te_reason = _te_moe_available() +requires_te_moe = pytest.mark.skipif(not _te_supported, reason=_te_reason) + + +def _small_native_state(): + from jax.experimental import mesh_utils + from jax.sharding import Mesh + from moe import EP_AXIS, FSDP_AXIS, NativeMoEBlock + + mesh = Mesh( + mesh_utils.create_device_mesh((2, 2), devices=jax.devices()[:4]), + (EP_AXIS, FSDP_AXIS), + ) + model = NativeMoEBlock( + mesh=mesh, + num_experts=8, + num_experts_per_tok=2, + intermediate_size=64, + ep_axis=EP_AXIS, + data_parallelism_axes=(FSDP_AXIS,), + dtype=jnp.bfloat16, + ) + x = jax.random.normal(jax.random.PRNGKey(1), (4, 16, 32), dtype=jnp.bfloat16) + dy = jax.random.normal(jax.random.PRNGKey(2), x.shape, dtype=jnp.bfloat16) + return mesh, model, x, dy + + +@requires_4gpu +def test_native_baseline_runs(): + mesh, model, x, _ = _small_native_state() + with jax.set_mesh(mesh): + variables = jax.jit(model.init)(jax.random.PRNGKey(0), x) + out = jax.jit(model.apply)(variables, x) + out.block_until_ready() + + assert out.shape == x.shape + assert out.dtype == x.dtype + assert np.all(np.isfinite(np.asarray(out))) + + +@requires_4gpu +def test_native_baseline_grads_are_finite(): + mesh, model, x, dy = _small_native_state() + + def loss_fn(variables, x): + return jnp.vdot(model.apply(variables, x), dy) + + with jax.set_mesh(mesh): + variables = jax.jit(model.init)(jax.random.PRNGKey(0), x) + grads = jax.jit(jax.grad(loss_fn))(variables, x) + jax.block_until_ready(jax.tree_util.tree_leaves(grads)[0]) + + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + grad = np.asarray(grads["params"][name]) + assert np.all(np.isfinite(grad)), f"{name} grad has NaN/Inf" + assert np.any(grad != 0.0), f"{name} grad is identically zero" + + +@requires_4gpu +@requires_te_moe +def test_te_moe_matches_native_shape_and_dtype(): + import moe + + demo = moe.setup_demo(batch=4, seq=16, hidden=32, intermediate=64) + native_out, te_out = moe.compare_forward(demo) + + assert native_out.shape == te_out.shape == demo.x.shape + assert native_out.dtype == te_out.dtype == demo.x.dtype + assert np.all(np.isfinite(np.asarray(te_out))) + + +@requires_4gpu +@requires_te_moe +def test_benchmark_entrypoint_runs(): + import moe + + demo = moe.setup_demo(batch=4, seq=16, hidden=32, intermediate=64) + moe.run_benchmarks(demo, warmup_iters=1, timing_iters=1) diff --git a/docs/examples/te_jax_integration.rst b/docs/examples/te_jax_integration.rst index a15a10e0b3..492d1c21b5 100644 --- a/docs/examples/te_jax_integration.rst +++ b/docs/examples/te_jax_integration.rst @@ -24,6 +24,9 @@ Pick a topic * - `Dense GEMMs `_ - **Available** - ``nn.Dense`` → quantized GEMM; single-GPU speedup; multi-GPU speedup; + * - `Mixture-of-Experts `_ + - **Available** + - Native BF16 EP MoE → experimental ``_MoEBlock``; BF16 performance; * - `Collective GEMMs `_ - *Coming soon* - @@ -90,6 +93,7 @@ Conventions used across these documents :hidden: jax/dense + jax/moe jax/collective_gemm jax/attention jax/expert_parallelism From e748567630ff68054948c39382c62fab7e047268 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Jun 2026 21:38:56 +0000 Subject: [PATCH 23/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/examples/jax/moe.py | 7 ++----- docs/examples/jax/moe_native.py | 36 +++++++++------------------------ 2 files changed, 11 insertions(+), 32 deletions(-) diff --git a/docs/examples/jax/moe.py b/docs/examples/jax/moe.py index f7e21a1acf..df4fe69c8b 100644 --- a/docs/examples/jax/moe.py +++ b/docs/examples/jax/moe.py @@ -82,8 +82,7 @@ def build_ep_fsdp_mesh(): required_devices = EP_SIZE * FSDP_SIZE if len(jax.devices()) < required_devices: raise RuntimeError( - f"MoE tutorial requires {required_devices} GPUs; " - f"only {len(jax.devices())} visible" + f"MoE tutorial requires {required_devices} GPUs; only {len(jax.devices())} visible" ) devices = mesh_utils.create_device_mesh( @@ -223,9 +222,7 @@ def compare_forward(demo): te_out = jax.jit(te_apply)(demo.variables, demo.x) native_out, te_out = jax.block_until_ready((native_out, te_out)) - max_abs = jnp.max( - jnp.abs(native_out.astype(jnp.float32) - te_out.astype(jnp.float32)) - ) + max_abs = jnp.max(jnp.abs(native_out.astype(jnp.float32) - te_out.astype(jnp.float32))) print(f"max |native BF16 - TE BF16|: {float(max_abs):.4f}") return native_out, te_out diff --git a/docs/examples/jax/moe_native.py b/docs/examples/jax/moe_native.py index 85ab48b13c..0bd884f304 100644 --- a/docs/examples/jax/moe_native.py +++ b/docs/examples/jax/moe_native.py @@ -36,9 +36,7 @@ def _forward_a2a_params( ).squeeze(0) local_by_destination = local_tokens_per_expert.reshape(num_ep, experts_per_shard) send_sizes = jnp.sum(local_by_destination, axis=1).astype(jnp.int32) - input_offsets = jnp.concatenate( - [jnp.array([0], dtype=jnp.int32), jnp.cumsum(send_sizes)[:-1]] - ) + input_offsets = jnp.concatenate([jnp.array([0], dtype=jnp.int32), jnp.cumsum(send_sizes)[:-1]]) local_expert_start = shard_id * experts_per_shard local_expert_columns = jax.lax.dynamic_slice( @@ -84,9 +82,7 @@ def _reverse_a2a_params( slice_sizes=(num_ep, experts_per_shard), ) send_sizes = jnp.sum(local_expert_columns, axis=1).astype(jnp.int32) - input_offsets = jnp.concatenate( - [jnp.array([0], dtype=jnp.int32), jnp.cumsum(send_sizes)[:-1]] - ) + input_offsets = jnp.concatenate([jnp.array([0], dtype=jnp.int32), jnp.cumsum(send_sizes)[:-1]]) local_tokens_per_expert = jax.lax.dynamic_slice( all_tokens_per_expert, @@ -185,13 +181,9 @@ def _native_moe_local( tokens = batch * sequence x_2d = x.reshape(tokens, hidden) - selected_experts, routing_weights = _route_tokens( - x_2d, gate_kernel, num_experts_per_tok - ) + selected_experts, routing_weights = _route_tokens(x_2d, gate_kernel, num_experts_per_tok) flat_experts = selected_experts.reshape(-1) - flat_token_ids = jnp.repeat( - jnp.arange(tokens, dtype=jnp.int32), num_experts_per_tok - ) + flat_token_ids = jnp.repeat(jnp.arange(tokens, dtype=jnp.int32), num_experts_per_tok) flat_weights = routing_weights.reshape(-1) sort_order = jnp.argsort(flat_experts, stable=True) @@ -214,9 +206,7 @@ def _native_moe_local( tiled=True, ) - in_off, send_sz, out_off, recv_sz = _forward_a2a_params( - all_tokens_per_expert, shard_id, num_ep - ) + in_off, send_sz, out_off, recv_sz = _forward_a2a_params(all_tokens_per_expert, shard_id, num_ep) x_recv = jax.lax.ragged_all_to_all( sorted_x, jnp.zeros((recv_buffer_rows, hidden), dtype=sorted_x.dtype), @@ -235,9 +225,7 @@ def _native_moe_local( ).astype(jnp.int32) local_chunk_sizes = local_counts_by_source.reshape(-1) source_major_order = jnp.arange(num_ep * experts_per_shard, dtype=jnp.int32) - expert_major_order = source_major_order.reshape( - num_ep, experts_per_shard - ).T.reshape(-1) + expert_major_order = source_major_order.reshape(num_ep, experts_per_shard).T.reshape(-1) local_group_sizes = jnp.sum(local_counts_by_source, axis=0).astype(jnp.int32) x_expert_major = _reorder_ragged_chunks( @@ -257,9 +245,7 @@ def _native_moe_local( expert_major_order, source_major_order, ) - in_off, send_sz, out_off, recv_sz = _reverse_a2a_params( - all_tokens_per_expert, shard_id, num_ep - ) + in_off, send_sz, out_off, recv_sz = _reverse_a2a_params(all_tokens_per_expert, shard_id, num_ep) returned = jax.lax.ragged_all_to_all( source_major_output, jnp.zeros_like(sorted_x), @@ -293,8 +279,7 @@ def native_moe_ep( """Run the native BF16 EP MoE baseline on an active JAX mesh.""" if num_experts % mesh.shape[ep_axis] != 0: raise ValueError( - f"num_experts={num_experts} must be divisible by " - f"EP size={mesh.shape[ep_axis]}" + f"num_experts={num_experts} must be divisible by EP size={mesh.shape[ep_axis]}" ) if data_parallelism_axes: @@ -309,10 +294,7 @@ def native_moe_ep( batch, sequence, _ = x.shape required_batch_multiple = mesh.shape[ep_axis] * dp_size if batch % required_batch_multiple != 0: - raise ValueError( - f"batch={batch} must be divisible by ep*dp=" - f"{required_batch_multiple}" - ) + raise ValueError(f"batch={batch} must be divisible by ep*dp={required_batch_multiple}") recv_buffer_rows = (batch // dp_size) * sequence * num_experts_per_tok captured = { From 3e6f958599b4c323d45fe6d5b9d735dc305e6ac9 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 4 Jun 2026 16:09:23 -0700 Subject: [PATCH 24/29] Update docs/examples/jax/moe.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> --- docs/examples/jax/moe.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/examples/jax/moe.py b/docs/examples/jax/moe.py index df4fe69c8b..06b0301745 100644 --- a/docs/examples/jax/moe.py +++ b/docs/examples/jax/moe.py @@ -234,10 +234,10 @@ def compare_forward(demo): def _block_until_ready_tree(tree): leaves = jax.tree_util.tree_leaves(tree) if leaves: - jax.block_until_ready(leaves[0]) - - -def _time_fwd_bwd(apply_fn, demo, *, warmup_iters, timing_iters): +def _block_until_ready_tree(tree): + leaves = jax.tree_util.tree_leaves(tree) + if leaves: + jax.block_until_ready(leaves) import time autocast_kwargs = {"enabled": False, "mesh_resource": demo.mesh_resource} From 566c5fdf8980bc09416491179fcca030f4b055a6 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 4 Jun 2026 16:09:38 -0700 Subject: [PATCH 25/29] Update docs/examples/jax/moe.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> --- docs/examples/jax/moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/examples/jax/moe.py b/docs/examples/jax/moe.py index 06b0301745..de21b6c423 100644 --- a/docs/examples/jax/moe.py +++ b/docs/examples/jax/moe.py @@ -177,8 +177,8 @@ def setup_demo(*, batch=BATCH, seq=SEQ, hidden=HIDDEN, intermediate=INTERMEDIATE LOGICAL_AXIS_RULES ): variables = jax.jit(native_model.init)(k_init, x) - jax.block_until_ready(jax.tree_util.tree_leaves(variables)[0]) - + variables = jax.jit(native_model.init)(k_init, x) + jax.block_until_ready(jax.tree_util.tree_leaves(variables)) sharded = shard_inputs_and_variables(mesh, variables, x, dy) return DemoState( mesh=mesh, From 96831df3e930f9603c536c874c1cc5a8769a9f51 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 4 Jun 2026 16:09:49 -0700 Subject: [PATCH 26/29] Update docs/examples/jax/test_moe.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> --- docs/examples/jax/test_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/examples/jax/test_moe.py b/docs/examples/jax/test_moe.py index 85b5604ac1..0443a38c17 100644 --- a/docs/examples/jax/test_moe.py +++ b/docs/examples/jax/test_moe.py @@ -99,7 +99,7 @@ def loss_fn(variables, x): with jax.set_mesh(mesh): variables = jax.jit(model.init)(jax.random.PRNGKey(0), x) grads = jax.jit(jax.grad(loss_fn))(variables, x) - jax.block_until_ready(jax.tree_util.tree_leaves(grads)[0]) + jax.block_until_ready(jax.tree_util.tree_leaves(grads)) for name in ("gate_kernel", "wi_0", "wi_1", "wo"): grad = np.asarray(grads["params"][name]) From 313bbf4f356eb44262764adc9c37b01a3f2e39a4 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 4 Jun 2026 17:14:03 -0700 Subject: [PATCH 27/29] Diagrams Signed-off-by: Jeremy Berchtold --- .../media/jax_moe_native_vs_te_flow.drawio | 43 +++++++ .../jax/media/jax_moe_native_vs_te_flow.svg | 120 ++++++++++++++++++ docs/examples/jax/moe.out | 4 +- docs/examples/jax/moe.py | 6 +- docs/examples/jax/moe.rst | 41 ++++-- docs/examples/jax/moe_native.py | 5 +- 6 files changed, 199 insertions(+), 20 deletions(-) create mode 100644 docs/examples/jax/media/jax_moe_native_vs_te_flow.drawio create mode 100644 docs/examples/jax/media/jax_moe_native_vs_te_flow.svg diff --git a/docs/examples/jax/media/jax_moe_native_vs_te_flow.drawio b/docs/examples/jax/media/jax_moe_native_vs_te_flow.drawio new file mode 100644 index 0000000000..9f9c9556d9 --- /dev/null +++ b/docs/examples/jax/media/jax_moe_native_vs_te_flow.drawio @@ -0,0 +1,43 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/examples/jax/media/jax_moe_native_vs_te_flow.svg b/docs/examples/jax/media/jax_moe_native_vs_te_flow.svg new file mode 100644 index 0000000000..586e008114 --- /dev/null +++ b/docs/examples/jax/media/jax_moe_native_vs_te_flow.svg @@ -0,0 +1,120 @@ + + Simplified native JAX and TransformerEngine JAX MoE forward data flow + Side-by-side flow chart comparing native JAX and TransformerEngine JAX MoE blocks at router, dispatch, expert compute, and combine granularity. + + + + + + + + + JAX MoE forward data flow, native block vs TE block + Simplified view: router -> dispatch -> expert compute -> combine + + + + Native JAX BF16 EP MoE + JAX router, ragged collectives, fused ragged_dot FFN + TE _MoEBlock BF16 + TE router, dispatch/combine, grouped GEMM FFN + + + single TE MoE custom_vjp boundary + + + + Input shard + x [B,S,H], expert weights sharded over ep + + + Router + gate GEMM, softmax, top-k experts and routing weights + + + Dispatch + sort routes, gather counts, ragged_all_to_all, local reorder + + + Expert FFN + ragged_dot(wi_0|wi_1), activation, ragged_dot(wo) + + + Combine + reverse reorder, reverse ragged_all_to_all, unsort and weight + + + + + Input shard + same x [B,S,H] and same parameter names + + + Router + gate GEMM, tex.fused_topk_with_score_function_fwd + + + Dispatch + token_dispatch, gather counts, ragged_all_to_all, local sort + + + Expert FFN + grouped_gemm(wi_0|wi_1), activation, grouped_gemm(wo) + + + Combine + inverse local sort, reverse ragged_all_to_all, token combine + + + + + + + + + + + + + + + + fused router + + ragged_dot FFN -> grouped GEMM FFN + + + + + router + + dispatch / EP exchange + + expert compute + + TE primitive + + combine + + diff --git a/docs/examples/jax/moe.out b/docs/examples/jax/moe.out index 27aaacb677..e9dffd7099 100644 --- a/docs/examples/jax/moe.out +++ b/docs/examples/jax/moe.out @@ -4,8 +4,8 @@ # MOE_OUTPUT_START max |native BF16 - TE BF16|: 0.0604 native JAX BF16: -Mean time: 19.545 ms +Mean time: 17.320 ms TE _MoEBlock BF16: -Mean time: 13.632 ms +Mean time: 13.601 ms # MOE_OUTPUT_END diff --git a/docs/examples/jax/moe.py b/docs/examples/jax/moe.py index de21b6c423..e6ca9af30a 100644 --- a/docs/examples/jax/moe.py +++ b/docs/examples/jax/moe.py @@ -231,13 +231,13 @@ def compare_forward(demo): # MOE_BENCH_START -def _block_until_ready_tree(tree): - leaves = jax.tree_util.tree_leaves(tree) - if leaves: def _block_until_ready_tree(tree): leaves = jax.tree_util.tree_leaves(tree) if leaves: jax.block_until_ready(leaves) + + +def _time_fwd_bwd(apply_fn, demo, *, warmup_iters=5, timing_iters=10): import time autocast_kwargs = {"enabled": False, "mesh_resource": demo.mesh_resource} diff --git a/docs/examples/jax/moe.rst b/docs/examples/jax/moe.rst index 43deeee7f7..7a74b93eeb 100644 --- a/docs/examples/jax/moe.rst +++ b/docs/examples/jax/moe.rst @@ -21,6 +21,19 @@ out of scope here. `<- Back to the JAX integration overview <../te_jax_integration.html>`_ +The forward path below summarizes the data flow for the native baseline and the +TE replacement. + +.. figure:: media/jax_moe_native_vs_te_flow.svg + :alt: Side-by-side forward data flow for native JAX and TransformerEngine JAX MoE blocks. + :align: center + :width: 100% + + Forward data flow for the tutorial's BF16 MoE block. TE keeps the same + sharded inputs and weights, but routes through TE fused router and grouped + GEMM primitives while keeping dispatch, expert compute, and combine inside + one MoE VJP. + 1. Baseline: native JAX BF16 EP MoE ----------------------------------- @@ -46,8 +59,9 @@ expert weights are sharded over ``ep``. The native baseline is exposed as a normal Flax module. Its implementation in ``moe_native.py`` performs softmax top-k routing, forward ``ragged_all_to_all`` over ``ep``, local source-major to expert-major chunk -reordering, three ``ragged_dot`` expert GEMMs, reverse ``ragged_all_to_all``, -and weighted token combine. +reordering, a concatenated ``wi_0|wi_1`` ``ragged_dot`` input projection, +activation, ``wo`` ``ragged_dot`` output projection, reverse +``ragged_all_to_all``, and weighted token combine. 2. TransformerEngine ``_MoEBlock`` ---------------------------------- @@ -113,27 +127,28 @@ Measured on four NVIDIA GB200 GPUs with the default tutorial shape :header: "Path", "Mean fwd+bwd time", "Relative time" :widths: 35, 25, 25 - "Native JAX BF16", "19.545 ms", "1.00x" - "TE ``_MoEBlock`` BF16", "13.632 ms", "0.70x" + "Native JAX BF16", "17.320 ms", "1.00x" + "TE ``_MoEBlock`` BF16", "13.601 ms", "0.79x" The same run reported ``max |native BF16 - TE BF16| = 0.0604`` for the forward correctness check. For this no-op-quantizer BF16 configuration, TE measured -``1.43x`` the native baseline throughput on this tutorial shape. +``1.27x`` the native baseline throughput on this tutorial shape. A larger-shape sweep with the same blocking timing loop found TE ahead for each -shape tried: +shape tried. The default shape appears in both tables; the values differ +slightly because the standalone tutorial run and sweep were timed separately. .. csv-table:: :header: "Batch", "Seq", "Hidden", "Intermediate", "Native BF16", "TE BF16", "TE speedup" :widths: 10, 10, 12, 16, 16, 16, 14 - "8", "1024", "1024", "4096", "9.173 ms", "7.377 ms", "1.24x" - "8", "2048", "1024", "4096", "19.545 ms", "13.632 ms", "1.43x" - "8", "4096", "1024", "4096", "39.179 ms", "33.570 ms", "1.17x" - "16", "2048", "1024", "4096", "39.211 ms", "33.595 ms", "1.17x" - "8", "1024", "2048", "8192", "19.313 ms", "14.846 ms", "1.30x" - "8", "2048", "2048", "8192", "42.629 ms", "32.657 ms", "1.31x" - "16", "2048", "2048", "8192", "86.957 ms", "68.643 ms", "1.27x" + "8", "1024", "1024", "4096", "8.369 ms", "7.346 ms", "1.14x" + "8", "2048", "1024", "4096", "17.413 ms", "13.554 ms", "1.28x" + "8", "4096", "1024", "4096", "34.809 ms", "32.878 ms", "1.06x" + "16", "2048", "1024", "4096", "35.102 ms", "32.773 ms", "1.07x" + "8", "1024", "2048", "8192", "19.656 ms", "14.566 ms", "1.35x" + "8", "2048", "2048", "8192", "38.630 ms", "32.057 ms", "1.21x" + "16", "2048", "2048", "8192", "85.549 ms", "66.793 ms", "1.28x" Across the sweep, the forward max-absolute difference stayed between ``0.0598`` and ``0.0704``. The result depends on token distribution, hidden diff --git a/docs/examples/jax/moe_native.py b/docs/examples/jax/moe_native.py index 0bd884f304..cf1f6db7d8 100644 --- a/docs/examples/jax/moe_native.py +++ b/docs/examples/jax/moe_native.py @@ -234,8 +234,9 @@ def _native_moe_local( source_major_order, expert_major_order, ) - hidden_0 = jax.lax.ragged_dot(x_expert_major, wi_0, local_group_sizes) - hidden_1 = jax.lax.ragged_dot(x_expert_major, wi_1, local_group_sizes) + wi_combined = jnp.concatenate([wi_0, wi_1], axis=-1) + hidden_combined = jax.lax.ragged_dot(x_expert_major, wi_combined, local_group_sizes) + hidden_0, hidden_1 = jnp.split(hidden_combined, 2, axis=-1) activated = jax.nn.silu(hidden_0) * hidden_1 expert_output = jax.lax.ragged_dot(activated, wo, local_group_sizes).astype(dtype) From d8def791576aec0718e8d8c9e2f131bf7cfb91a5 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 5 Jun 2026 16:04:21 -0700 Subject: [PATCH 28/29] Merge in TE EP and update tutorial accordingly Signed-off-by: Jeremy Berchtold --- .../media/jax_moe_native_vs_te_flow.drawio | 6 +- .../jax/media/jax_moe_native_vs_te_flow.svg | 6 +- docs/examples/jax/moe.out | 15 +- docs/examples/jax/moe.py | 173 ++++++++++++++++-- docs/examples/jax/moe.rst | 95 +++++----- docs/examples/jax/test_moe.py | 4 + .../jax/cpp_extensions/router.py | 2 +- transformer_engine/jax/moe.py | 76 ++++++++ 8 files changed, 307 insertions(+), 70 deletions(-) diff --git a/docs/examples/jax/media/jax_moe_native_vs_te_flow.drawio b/docs/examples/jax/media/jax_moe_native_vs_te_flow.drawio index 9f9c9556d9..446fb340d8 100644 --- a/docs/examples/jax/media/jax_moe_native_vs_te_flow.drawio +++ b/docs/examples/jax/media/jax_moe_native_vs_te_flow.drawio @@ -9,7 +9,7 @@ - + @@ -21,9 +21,9 @@ - + - + diff --git a/docs/examples/jax/media/jax_moe_native_vs_te_flow.svg b/docs/examples/jax/media/jax_moe_native_vs_te_flow.svg index 586e008114..13b5d65183 100644 --- a/docs/examples/jax/media/jax_moe_native_vs_te_flow.svg +++ b/docs/examples/jax/media/jax_moe_native_vs_te_flow.svg @@ -38,7 +38,7 @@ Native JAX BF16 EP MoE JAX router, ragged collectives, fused ragged_dot FFN TE _MoEBlock BF16 - TE router, dispatch/combine, grouped GEMM FFN + TE router, NCCL EP dispatch/combine, grouped GEMM FFN single TE MoE custom_vjp boundary @@ -76,7 +76,7 @@ Dispatch - token_dispatch, gather counts, ragged_all_to_all, local sort + tex.ep_dispatch via NCCL EP, TE handle state Expert FFN @@ -84,7 +84,7 @@ Combine - inverse local sort, reverse ragged_all_to_all, token combine + tex.ep_combine via NCCL EP, output reshard diff --git a/docs/examples/jax/moe.out b/docs/examples/jax/moe.out index e9dffd7099..e79009eedf 100644 --- a/docs/examples/jax/moe.out +++ b/docs/examples/jax/moe.out @@ -1,11 +1,12 @@ -# Numbers below were captured on 4x NVIDIA GB200. Regenerate with: -# python3 docs/examples/jax/moe.py > moe.out +# Numbers below were captured on 4x NVIDIA GB200. +# Native JAX BF16 uses the ragged A2A baseline in single-process 4-GPU mode. +# TE BF16 uses NCCL EP in 4-process mode with one GPU per process. # MOE_OUTPUT_START -max |native BF16 - TE BF16|: 0.0604 -native JAX BF16: -Mean time: 17.320 ms +native JAX BF16 ragged A2A: +Mean time: 17.085 ms -TE _MoEBlock BF16: -Mean time: 13.601 ms +TE _MoEBlock BF16 with NCCL EP: +TE _MoEBlock BF16 output: shape=(8, 2048, 1024), dtype=bfloat16 +Mean time: 3.156 ms # MOE_OUTPUT_END diff --git a/docs/examples/jax/moe.py b/docs/examples/jax/moe.py index e6ca9af30a..0e5fc9a2a2 100644 --- a/docs/examples/jax/moe.py +++ b/docs/examples/jax/moe.py @@ -10,15 +10,21 @@ Run as a script to exercise the example end-to-end: python docs/examples/jax/moe.py + python docs/examples/jax/moe.py --num-process=4 --process-id=0 -The example uses a 2x2 expert-parallel/FSDP mesh and therefore requires four -visible GPUs. Both the native baseline and TransformerEngine path run in BF16; -the current ``_MoEBlock`` wrapper uses no-op quantizer sets. +Launch one process for each ``process-id`` in ``[0, 4)``. + +The TransformerEngine path uses NCCL-backed EP and therefore requires a +multi-process launch with one GPU per process. Both the native baseline and +TransformerEngine path run in BF16; the current ``_MoEBlock`` wrapper uses +no-op quantizer sets. """ # MOE_IMPORTS_START from dataclasses import dataclass from typing import Any +import os +import sys import jax import jax.numpy as jnp @@ -66,7 +72,6 @@ class DemoState: def _ensure_writable_triton_cache(): - import os import tempfile os.environ.setdefault( @@ -75,7 +80,40 @@ def _ensure_writable_triton_cache(): ) +def _register_te_ffi_targets(): + _ensure_writable_triton_cache() + import transformer_engine.jax.cpp_extensions # noqa: F401 + + # MOE_MESH_SETUP_START +def _read_mp_options(): + num_process = int(os.environ.get("MP_NUM_PROCESS", "0") or "0") + process_id = int(os.environ.get("MP_PROCESS_ID", "0") or "0") + for i, arg in enumerate(sys.argv): + if arg.startswith("--num-process="): + num_process = int(arg.split("=", 1)[1]) + elif arg == "--num-process" and i + 1 < len(sys.argv): + num_process = int(sys.argv[i + 1]) + elif arg.startswith("--process-id="): + process_id = int(arg.split("=", 1)[1]) + elif arg == "--process-id" and i + 1 < len(sys.argv): + process_id = int(sys.argv[i + 1]) + return num_process, process_id + + +def maybe_initialize_distributed(): + num_process, process_id = _read_mp_options() + if num_process <= 1: + return + coordinator = os.environ.get("TE_EP_MOE_COORDINATOR_ADDRESS", "127.0.0.1:13457") + jax.distributed.initialize( + coordinator_address=coordinator, + num_processes=num_process, + process_id=process_id, + local_device_ids=process_id, + ) + + def build_ep_fsdp_mesh(): from transformer_engine.jax.sharding import MeshResource @@ -86,10 +124,10 @@ def build_ep_fsdp_mesh(): ) devices = mesh_utils.create_device_mesh( - (EP_SIZE, FSDP_SIZE), + (FSDP_SIZE, EP_SIZE), devices=jax.devices()[:required_devices], ) - mesh = Mesh(devices, axis_names=(EP_AXIS, FSDP_AXIS)) + mesh = Mesh(devices, axis_names=(FSDP_AXIS, EP_AXIS)) mesh_resource = MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS) return mesh, mesh_resource @@ -102,7 +140,6 @@ def build_models(mesh, *, hidden=HIDDEN, intermediate=INTERMEDIATE): _ensure_writable_triton_cache() from transformer_engine.jax.flax import _MoEBlock as TEMoEBlock - from transformer_engine.jax.moe import PermutationBackend native_model = NativeMoEBlock( mesh=mesh, @@ -118,7 +155,7 @@ def build_models(mesh, *, hidden=HIDDEN, intermediate=INTERMEDIATE): num_experts_per_tok=TOPK, intermediate_size=intermediate, data_parallelism_axes=(FSDP_AXIS,), - permutation_backend=PermutationBackend.PURE_JAX, + apply_topk_weights_early=True, dtype=DTYPE, ) return native_model, te_model @@ -137,7 +174,7 @@ def make_inputs(*, batch=BATCH, seq=SEQ, hidden=HIDDEN): def shard_inputs_and_variables(mesh, variables, x, dy): - input_sharding = NamedSharding(mesh, P((EP_AXIS, FSDP_AXIS), None, None)) + input_sharding = NamedSharding(mesh, P((FSDP_AXIS, EP_AXIS), None, None)) gate_sharding = NamedSharding(mesh, P()) expert_sharding = NamedSharding(mesh, P(EP_AXIS, None, None)) @@ -158,6 +195,45 @@ def shard_inputs_and_variables(mesh, variables, x, dy): # MOE_INPUTS_SETUP_END +def _recv_capacity_per_rank(batch, seq): + num_procs = jax.process_count() + dp_size = num_procs // EP_SIZE + num_local_experts = NUM_EXPERTS // EP_SIZE + natural_recv_pr = (batch // dp_size) * seq * TOPK + slots_per_expert = (natural_recv_pr + num_local_experts - 1) // num_local_experts + return num_local_experts * slots_per_expert + + +def bootstrap_te_ep(mesh, mesh_resource, *, batch=BATCH, seq=SEQ, hidden=HIDDEN): + from transformer_engine.jax.ep import ep_bootstrap + from transformer_engine.jax.moe import record_ep_bootstrap_signature_for_moe + from transformer_engine.jax.sharding import global_shard_guard + + world_size = jax.process_count() + max_tokens_per_rank = (batch // world_size) * seq + recv_capacity_per_rank = _recv_capacity_per_rank(batch, seq) + + with jax.set_mesh(mesh), global_shard_guard(mesh_resource): + ep_bootstrap( + world_size=world_size, + rank=jax.process_index(), + ep_size=EP_SIZE, + num_experts=NUM_EXPERTS, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=hidden, + allow_handle_mem_reloc=True, + max_token_dtype=DTYPE, + ) + record_ep_bootstrap_signature_for_moe( + num_experts=NUM_EXPERTS, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=hidden, + ep_size=EP_SIZE, + ) + + def _te_apply(te_model): def apply_fn(variables, x, **kwargs): out, _ = te_model.apply(variables, x, **kwargs) @@ -170,6 +246,7 @@ def setup_demo(*, batch=BATCH, seq=SEQ, hidden=HIDDEN, intermediate=INTERMEDIATE from transformer_engine.jax.sharding import global_shard_guard mesh, mesh_resource = build_ep_fsdp_mesh() + bootstrap_te_ep(mesh, mesh_resource, batch=batch, seq=seq, hidden=hidden) native_model, te_model = build_models(mesh, hidden=hidden, intermediate=intermediate) k_init, x, dy = make_inputs(batch=batch, seq=seq, hidden=hidden) @@ -191,10 +268,34 @@ def setup_demo(*, batch=BATCH, seq=SEQ, hidden=HIDDEN, intermediate=INTERMEDIATE ) +def setup_te_demo(*, batch=BATCH, seq=SEQ, hidden=HIDDEN, intermediate=INTERMEDIATE): + from transformer_engine.jax.sharding import global_shard_guard + + mesh, mesh_resource = build_ep_fsdp_mesh() + bootstrap_te_ep(mesh, mesh_resource, batch=batch, seq=seq, hidden=hidden) + _, te_model = build_models(mesh, hidden=hidden, intermediate=intermediate) + k_init, x, dy = make_inputs(batch=batch, seq=seq, hidden=hidden) + + with jax.set_mesh(mesh), global_shard_guard(mesh_resource), nn_partitioning.axis_rules( + LOGICAL_AXIS_RULES + ): + variables = jax.jit(te_model.init)(k_init, x) + jax.block_until_ready(jax.tree_util.tree_leaves(variables)) + sharded = shard_inputs_and_variables(mesh, variables, x, dy) + return DemoState( + mesh=mesh, + mesh_resource=mesh_resource, + native_model=None, + te_model=te_model, + variables=sharded["variables"], + x=sharded["x"], + dy=sharded["dy"], + ) + + def te_moe_supported(): try: import importlib - import sys _ensure_writable_triton_cache() @@ -203,6 +304,16 @@ def te_moe_supported(): transformer_engine_jax = sys.modules["transformer_engine_jax"] flax_mod = importlib.import_module("transformer_engine.jax.flax") getattr(flax_mod, "_MoEBlock") + if jax.process_count() < EP_SIZE * FSDP_SIZE: + return False, ( + "TE EP requires a multi-process launch with one GPU per process; " + f"got process_count={jax.process_count()}" + ) + if jax.local_device_count() != 1: + return False, ( + "TE EP requires one local GPU per process; " + f"got local_device_count={jax.local_device_count()}" + ) if transformer_engine_jax.get_device_compute_capability(0) < 100: return False, "TE MoE grouped GEMM currently requires Blackwell (sm_100+)" except Exception as exc: # pylint: disable=broad-exception-caught @@ -210,7 +321,6 @@ def te_moe_supported(): return True, "" -# MOE_CORRECTNESS_START def compare_forward(demo): from transformer_engine.jax.sharding import global_shard_guard @@ -227,6 +337,21 @@ def compare_forward(demo): return native_out, te_out +# MOE_CORRECTNESS_START +def run_te_forward(demo): + from transformer_engine.jax.sharding import global_shard_guard + + te_apply = _te_apply(demo.te_model) + with jax.set_mesh(demo.mesh), global_shard_guard( + demo.mesh_resource + ), nn_partitioning.axis_rules(LOGICAL_AXIS_RULES): + te_out = jax.jit(te_apply)(demo.variables, demo.x) + te_out.block_until_ready() + + print(f"TE _MoEBlock BF16 output: shape={te_out.shape}, dtype={te_out.dtype}") + return te_out + + # MOE_CORRECTNESS_END @@ -286,11 +411,31 @@ def run_benchmarks(demo, *, warmup_iters=5, timing_iters=10): print(f"Mean time: {te_ms:.3f} ms") return native_ms, te_ms +def run_te_benchmark(demo, *, warmup_iters=5, timing_iters=10): + from transformer_engine.jax.sharding import global_shard_guard + + te_apply = _te_apply(demo.te_model) + with jax.set_mesh(demo.mesh), global_shard_guard( + demo.mesh_resource + ), nn_partitioning.axis_rules(LOGICAL_AXIS_RULES): + print("TE _MoEBlock BF16:") + te_ms = _time_fwd_bwd( + te_apply, + demo, + warmup_iters=warmup_iters, + timing_iters=timing_iters, + ) + print(f"Mean time: {te_ms:.3f} ms") + return te_ms + # MOE_BENCH_END def main(): + _register_te_ffi_targets() + maybe_initialize_distributed() + if len(jax.devices()) < EP_SIZE * FSDP_SIZE: print(f"[skipped: need {EP_SIZE * FSDP_SIZE} GPUs for EP=2/FSDP=2]") return @@ -300,9 +445,9 @@ def main(): print(f"[skipped TE comparison: {te_reason}]") return - demo = setup_demo() - compare_forward(demo) - run_benchmarks(demo) + demo = setup_te_demo() + run_te_forward(demo) + run_te_benchmark(demo) if __name__ == "__main__": diff --git a/docs/examples/jax/moe.rst b/docs/examples/jax/moe.rst index 7a74b93eeb..33ef33bbbd 100644 --- a/docs/examples/jax/moe.rst +++ b/docs/examples/jax/moe.rst @@ -15,9 +15,11 @@ block with TransformerEngine's experimental Flax ``_MoEBlock``. all-to-all setup lives in ``moe_native.py`` so the snippets below stay focused on model-level code. -**TransformerEngine path.** This tutorial uses ``_MoEBlock`` in BF16 with the -wrapper's current no-op quantizer sets. Quantized MoE recipes are intentionally -out of scope here. +**TransformerEngine path.** This tutorial uses ``_MoEBlock`` in BF16 with +NCCL-backed TE EP and the wrapper's current no-op quantizer sets. TE EP replaces +the tutorial's previous TE-side ragged A2A exchange with ``tex.ep_dispatch`` and +``tex.ep_combine`` over NCCL EP. Quantized MoE recipes are intentionally out of +scope here. `<- Back to the JAX integration overview <../te_jax_integration.html>`_ @@ -29,17 +31,19 @@ TE replacement. :align: center :width: 100% - Forward data flow for the tutorial's BF16 MoE block. TE keeps the same - sharded inputs and weights, but routes through TE fused router and grouped - GEMM primitives while keeping dispatch, expert compute, and combine inside - one MoE VJP. + Forward data flow for the tutorial's BF16 MoE block. The native baseline + keeps JAX ``ragged_all_to_all`` and ``ragged_dot``. TE keeps the same sharded + inputs and weights, but routes through TE fused router, NCCL EP + dispatch/combine, and grouped GEMM primitives while keeping dispatch, expert + compute, and combine inside one MoE VJP. 1. Baseline: native JAX BF16 EP MoE ----------------------------------- The example uses a 2x2 mesh: expert parallelism on ``ep`` and FSDP-style batch parallelism on ``fsdp``. The batch dimension is sharded over both axes, and -expert weights are sharded over ``ep``. +expert weights are sharded over ``ep``. TE EP requires ``ep`` to be the inner +axis and currently runs in multi-process mode with one GPU per process. .. literalinclude:: moe.py :language: python @@ -69,6 +73,10 @@ activation, ``wo`` ``ragged_dot`` output projection, reverse The TE replacement registers the same gate and expert parameter names as the baseline, then delegates routing, dispatch, grouped FFN, combine, expert-parallel collectives, and VJP to ``transformer_engine.jax.moe.moe``. +On this branch, the TE-side expert exchange is NCCL EP: ``_MoEBlock`` calls +``tex.ep_dispatch`` before the grouped FFNs and ``tex.ep_combine`` after them. +The native baseline remains unchanged and continues to use +``jax.lax.ragged_all_to_all`` for the comparison numbers. ``_MoEBlock`` is intentionally underscore-prefixed while the API stabilizes. Use it as an experimental integration point. @@ -83,30 +91,32 @@ it as an experimental integration point. :start-after: # MOE_INPUTS_SETUP_START :end-before: # MOE_INPUTS_SETUP_END -3. Correctness check +3. TE EP smoke check -------------------- -Both models use the same Flax variable dictionary, so the gate and expert -weights are identical. The comparison checks the BF16 forward result on the same -sharded input. +The direct script path initializes the TE EP communicator, creates the +``_MoEBlock`` variables, runs a BF16 forward pass, and reports the output shape +and dtype. .. literalinclude:: moe.py :language: python :start-after: # MOE_CORRECTNESS_START :end-before: # MOE_CORRECTNESS_END -The two paths may not be bit-identical because the router and grouped matmul -implementations differ, but they should stay within ordinary BF16 tolerance for -the default no-bias, softmax top-k, ``silu`` path. +The native ragged A2A baseline remains in ``moe_native.py`` and is used for the +baseline timings below. Because the native ragged A2A path runs in +single-process 4-GPU mode while TE EP runs in one-process-per-GPU mode, the +benchmark sweep times the two paths separately. 4. Performance comparison ------------------------- -``run_benchmarks`` runs a blocking JIT-compiled forward+backward loop with -warmup. The same sharded input, output gradient, and variables are used for -native BF16 and TE BF16. Even though quantization is disabled, the benchmark -passes the active ``MeshResource`` through TE's autocast context so -``_MoEBlock`` can resolve the ``ep`` axis. +``run_te_benchmark`` runs a blocking JIT-compiled forward+backward loop with +warmup. Even though quantization is disabled, the benchmark passes the active +``MeshResource`` through TE's autocast context so ``_MoEBlock`` can resolve the +``ep`` axis. The TE block folds top-k weights into the per-expert FFN +intermediate with ``apply_topk_weights_early=True``; this is mathematically +equivalent for the BF16 path because the down projection is linear. .. literalinclude:: moe.py :language: python @@ -117,7 +127,10 @@ Run the full example with: .. code-block:: bash - python docs/examples/jax/moe.py + for i in 0 1 2 3; do + python docs/examples/jax/moe.py --num-process=4 --process-id=$i > proc_$i.log 2>&1 & + done + wait Measured on four NVIDIA GB200 GPUs with the default tutorial shape ``batch=8``, ``seq=2048``, ``hidden=1024``, ``intermediate=4096``, @@ -127,34 +140,32 @@ Measured on four NVIDIA GB200 GPUs with the default tutorial shape :header: "Path", "Mean fwd+bwd time", "Relative time" :widths: 35, 25, 25 - "Native JAX BF16", "17.320 ms", "1.00x" - "TE ``_MoEBlock`` BF16", "13.601 ms", "0.79x" + "Native JAX BF16 ragged A2A", "17.085 ms", "1.00x" + "TE ``_MoEBlock`` BF16 with NCCL EP", "3.156 ms", "0.18x" -The same run reported ``max |native BF16 - TE BF16| = 0.0604`` for the forward -correctness check. For this no-op-quantizer BF16 configuration, TE measured -``1.27x`` the native baseline throughput on this tutorial shape. +For this no-op-quantizer BF16 configuration, TE EP measured ``5.41x`` the +native ragged A2A baseline throughput on this tutorial shape. -A larger-shape sweep with the same blocking timing loop found TE ahead for each -shape tried. The default shape appears in both tables; the values differ -slightly because the standalone tutorial run and sweep were timed separately. +A larger-shape sweep with the same blocking timing loop found TE EP ahead for +each shape tried. The native column uses the unchanged ragged A2A baseline; the +TE column uses NCCL EP. The default shape appears in both tables; the values +differ slightly because the standalone tutorial run and sweep were timed +separately. .. csv-table:: :header: "Batch", "Seq", "Hidden", "Intermediate", "Native BF16", "TE BF16", "TE speedup" :widths: 10, 10, 12, 16, 16, 16, 14 - "8", "1024", "1024", "4096", "8.369 ms", "7.346 ms", "1.14x" - "8", "2048", "1024", "4096", "17.413 ms", "13.554 ms", "1.28x" - "8", "4096", "1024", "4096", "34.809 ms", "32.878 ms", "1.06x" - "16", "2048", "1024", "4096", "35.102 ms", "32.773 ms", "1.07x" - "8", "1024", "2048", "8192", "19.656 ms", "14.566 ms", "1.35x" - "8", "2048", "2048", "8192", "38.630 ms", "32.057 ms", "1.21x" - "16", "2048", "2048", "8192", "85.549 ms", "66.793 ms", "1.28x" - -Across the sweep, the forward max-absolute difference stayed between -``0.0598`` and ``0.0704``. The result depends on token distribution, hidden -size, intermediate size, and the target stack. Keep -``PermutationBackend.PURE_JAX`` until correctness is established; then compare -it with ``PermutationBackend.TRITON`` separately. + "8", "1024", "1024", "4096", "8.543 ms", "2.075 ms", "4.12x" + "8", "2048", "1024", "4096", "17.085 ms", "3.217 ms", "5.31x" + "8", "4096", "1024", "4096", "38.811 ms", "5.349 ms", "7.26x" + "16", "2048", "1024", "4096", "39.194 ms", "5.355 ms", "7.32x" + "8", "1024", "2048", "8192", "19.329 ms", "4.110 ms", "4.70x" + "8", "2048", "2048", "8192", "42.505 ms", "6.254 ms", "6.80x" + "16", "2048", "2048", "8192", "88.134 ms", "10.542 ms", "8.36x" + +The result depends on token distribution, hidden size, intermediate size, and +the target stack. .. raw:: html diff --git a/docs/examples/jax/test_moe.py b/docs/examples/jax/test_moe.py index 0443a38c17..24840af85e 100644 --- a/docs/examples/jax/test_moe.py +++ b/docs/examples/jax/test_moe.py @@ -44,6 +44,10 @@ def _te_moe_available(): if transformer_engine_jax.get_device_compute_capability(0) < 100: return False, "TE MoE grouped GEMM requires Blackwell (sm_100+)" + if jax.process_count() < 4: + return False, "TE EP requires a multiprocess launch" + if jax.local_device_count() != 1: + return False, "TE EP requires one local GPU per process" except Exception as exc: # pylint: disable=broad-exception-caught return False, str(exc) return True, "" diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py index 3245439689..8cc94fcaaf 100644 --- a/transformer_engine/jax/cpp_extensions/router.py +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -412,7 +412,7 @@ def partition( arg_infos, result_infos, ): - del result_infos, routing_map_format + del result_infos grad_spec = get_padded_spec(arg_infos[2]) out_sharding = NamedSharding(mesh, PartitionSpec(*grad_spec)) arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding, arg_infos[2].sharding) diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index a96a77b991..144fb79629 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -153,6 +153,7 @@ def _te_ep_assert_compatible_bootstrap( # ============================================================================= +@jax.tree_util.register_pytree_node_class @dataclass class _Ctx: """Residuals carried from the fwd rule into the bwd rule.""" @@ -180,6 +181,81 @@ class _Ctx: aux_tokens_per_expert: Any = None aux_saved_scores: Any = None + def tree_flatten(self): + children = ( + self.x, + self.gate_kernel, + self.expert_bias, + self.logits_2d, + self.saved_scores, + self.routing_map, + self.handle_mem, + self.token_counts, + self.recv_topk_weights, + self.casted_sorted_x_lhs_trans, + self.casted_wi_rhs_trans, + self.gate_proj_out, + self.up_proj_out, + self.casted_intermediate_lhs_trans, + self.casted_wo_rhs_trans, + self.expert_outputs, + self.local_group_sizes, + self.aux_const_buf, + self.aux_tokens_per_expert, + self.aux_saved_scores, + ) + aux_data = (self.handle,) + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + (handle,) = aux_data + ( + x, + gate_kernel, + expert_bias, + logits_2d, + saved_scores, + routing_map, + handle_mem, + token_counts, + recv_topk_weights, + casted_sorted_x_lhs_trans, + casted_wi_rhs_trans, + gate_proj_out, + up_proj_out, + casted_intermediate_lhs_trans, + casted_wo_rhs_trans, + expert_outputs, + local_group_sizes, + aux_const_buf, + aux_tokens_per_expert, + aux_saved_scores, + ) = children + return cls( + x=x, + gate_kernel=gate_kernel, + expert_bias=expert_bias, + logits_2d=logits_2d, + saved_scores=saved_scores, + routing_map=routing_map, + handle=handle, + handle_mem=handle_mem, + token_counts=token_counts, + recv_topk_weights=recv_topk_weights, + casted_sorted_x_lhs_trans=casted_sorted_x_lhs_trans, + casted_wi_rhs_trans=casted_wi_rhs_trans, + gate_proj_out=gate_proj_out, + up_proj_out=up_proj_out, + casted_intermediate_lhs_trans=casted_intermediate_lhs_trans, + casted_wo_rhs_trans=casted_wo_rhs_trans, + expert_outputs=expert_outputs, + local_group_sizes=local_group_sizes, + aux_const_buf=aux_const_buf, + aux_tokens_per_expert=aux_tokens_per_expert, + aux_saved_scores=aux_saved_scores, + ) + # ============================================================================= # Per-shard FFN body (runs inside shard_map) From 82826a9cb9b04c64296a7bb96b4f24ab61a92a0c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Jun 2026 23:05:33 +0000 Subject: [PATCH 29/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/examples/jax/moe.py | 19 +++-- tests/jax/test_multi_process_ep.py | 41 +++++++--- tests/jax/test_te_ep_moe.py | 39 +++------- transformer_engine/jax/cpp_extensions/ep.py | 7 +- transformer_engine/jax/ep.py | 30 ++++++-- transformer_engine/jax/moe.py | 84 +++++++++------------ 6 files changed, 114 insertions(+), 106 deletions(-) diff --git a/docs/examples/jax/moe.py b/docs/examples/jax/moe.py index 0e5fc9a2a2..b93080e0ec 100644 --- a/docs/examples/jax/moe.py +++ b/docs/examples/jax/moe.py @@ -305,14 +305,20 @@ def te_moe_supported(): flax_mod = importlib.import_module("transformer_engine.jax.flax") getattr(flax_mod, "_MoEBlock") if jax.process_count() < EP_SIZE * FSDP_SIZE: - return False, ( - "TE EP requires a multi-process launch with one GPU per process; " - f"got process_count={jax.process_count()}" + return ( + False, + ( + "TE EP requires a multi-process launch with one GPU per process; " + f"got process_count={jax.process_count()}" + ), ) if jax.local_device_count() != 1: - return False, ( - "TE EP requires one local GPU per process; " - f"got local_device_count={jax.local_device_count()}" + return ( + False, + ( + "TE EP requires one local GPU per process; " + f"got local_device_count={jax.local_device_count()}" + ), ) if transformer_engine_jax.get_device_compute_capability(0) < 100: return False, "TE MoE grouped GEMM currently requires Blackwell (sm_100+)" @@ -411,6 +417,7 @@ def run_benchmarks(demo, *, warmup_iters=5, timing_iters=10): print(f"Mean time: {te_ms:.3f} ms") return native_ms, te_ms + def run_te_benchmark(demo, *, warmup_iters=5, timing_iters=10): from transformer_engine.jax.sharding import global_shard_guard diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py index abdbcd32ec..1472ab49fe 100644 --- a/tests/jax/test_multi_process_ep.py +++ b/tests/jax/test_multi_process_ep.py @@ -245,11 +245,13 @@ def test_two_layer_dispatch_no_handle_aliasing(self): w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) def one_layer(hk, idx, toks, w_): - recv_t, recv_w, hm, tc = ep_dispatch( - hk, idx, toks, w_, self.recv_capacity_per_rank + recv_t, recv_w, hm, tc = ep_dispatch(hk, idx, toks, w_, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) ) - recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_spec_3d)) - recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_spec_2d)) return ep_combine( hk, hm, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None) ) @@ -269,12 +271,14 @@ def run(idx, ta_, tb_, w_): np.testing.assert_allclose( np.asarray(out_a_g.astype(jnp.float32)), np.asarray(tokens.astype(jnp.float32)), - atol=5e-2, rtol=5e-2, + atol=5e-2, + rtol=5e-2, ) np.testing.assert_allclose( np.asarray(out_b_g.astype(jnp.float32)), np.asarray(tokens_b.astype(jnp.float32)), - atol=5e-2, rtol=5e-2, + atol=5e-2, + rtol=5e-2, ) def test_primitive_prepare(self): @@ -328,7 +332,10 @@ def run(idx, toks, w): weighted, NamedSharding(self.mesh, ep_spec_3d) ) out = ep_combine_fwd( - self.hk, hm, weighted, T_global, + self.hk, + hm, + weighted, + T_global, out_partition_spec=(("dp", "ep"), None), ) return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) @@ -372,7 +379,9 @@ def loss_fn(toks): toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) - recv_t, recv_w, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) + recv_t, recv_w, hm, tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank + ) recv_t = jax.lax.with_sharding_constraint( recv_t, NamedSharding(self.mesh, ep_spec_3d) ) @@ -420,7 +429,9 @@ def test_dispatch_combine_3d_input_output(self): @jax.jit def run(idx, toks, w): - recv_t, recv_w, hm, _tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) + recv_t, recv_w, hm, _tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank + ) recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) out = ep_combine( @@ -463,7 +474,9 @@ def test_dispatch_combine_dp_only_first_dim(self): @jax.jit def run(idx, toks, w): - recv_t, recv_w, hm, _tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) + recv_t, recv_w, hm, _tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank + ) recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) out = ep_combine( @@ -641,7 +654,9 @@ def run(idx, toks, w): idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) - recv_t, recv_w, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) + recv_t, recv_w, hm, tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank + ) recv_t = jax.lax.with_sharding_constraint( recv_t, NamedSharding(self.mesh, ep_spec_3d) ) @@ -688,7 +703,9 @@ def fwd(eo, toks, idx, w): w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) _rt, rw, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) rw = jax.lax.with_sharding_constraint(rw, NamedSharding(self.mesh, ep_spec_2d)) - combined = ep_combine(self.hk, hm, tc, eo, rw, T_dp, out_sharding=(("dp", "ep"), None)) + combined = ep_combine( + self.hk, hm, tc, eo, rw, T_dp, out_sharding=(("dp", "ep"), None) + ) return jax.lax.with_sharding_constraint(combined, NamedSharding(self.mesh, dp_spec)) # jax.vjp + pinned cotangent feeds ep_combine_bwd/ep_dispatch_bwd diff --git a/tests/jax/test_te_ep_moe.py b/tests/jax/test_te_ep_moe.py index d56dad2c92..febae4bd2e 100644 --- a/tests/jax/test_te_ep_moe.py +++ b/tests/jax/test_te_ep_moe.py @@ -113,8 +113,7 @@ def _read_mp_options(): if not _MP_ACTIVE: pytest.skip( - "test_te_ep_moe.py requires the multiprocess launcher " - "(run_te_ep_moe.sh). Skipping.", + "test_te_ep_moe.py requires the multiprocess launcher (run_te_ep_moe.sh). Skipping.", allow_module_level=True, ) @@ -229,9 +228,7 @@ def mesh(): # Eager bootstrap: ep_bootstrap does a host-side NCCL UID allgather # and cannot run from inside jax.jit. Sized to the worst-case recv_pr # across _CONFIGS so every parametrized config is bootstrap-compatible. - with mesh_obj, global_shard_guard( - MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS) - ): + with mesh_obj, global_shard_guard(MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS)): ep_bootstrap( world_size=num_procs, rank=jax.process_index(), @@ -323,9 +320,7 @@ def _pure_jax_moe_reference( raise ValueError(f"Unsupported score_function={score_function!r}") routing_weights_full = jnp.zeros((T, num_experts), dtype=jnp.float32) - routing_weights_full = routing_weights_full.at[ - jnp.arange(T)[:, None], top_indices - ].set(weights) + routing_weights_full = routing_weights_full.at[jnp.arange(T)[:, None], top_indices].set(weights) # FFN. ``apply_topk_weights_early`` is a fusion knob that doesn't # change the math (wo is linear), so the reference is identical for @@ -335,9 +330,7 @@ def _pure_jax_moe_reference( intermediate = jax.nn.silu(layer_w0.astype(jnp.float32)) * layer_w1.astype(jnp.float32) intermediate = intermediate.astype(x.dtype) expert_out = jnp.einsum("tem,emh->teh", intermediate, wo) # [T, E, H] - output_2d = jnp.einsum( - "te,teh->th", routing_weights_full.astype(x.dtype), expert_out - ) + output_2d = jnp.einsum("te,teh->th", routing_weights_full.astype(x.dtype), expert_out) output = output_2d.reshape(B, S, H).astype(x.dtype) if aux_loss_coeff > 0.0: @@ -352,9 +345,7 @@ def _pure_jax_moe_reference( else: # sigmoid aux_scores = jax.nn.sigmoid(logits) if K > 1: - aux_scores = aux_scores / ( - aux_scores.sum(axis=-1, keepdims=True) + 1e-20 - ) + aux_scores = aux_scores / (aux_scores.sum(axis=-1, keepdims=True) + 1e-20) routing_map = (routing_weights_full > 0).astype(jnp.int32) tokens_per_expert = jnp.sum(routing_map, axis=0) # [E] sum_probs_per_expert = jnp.sum(aux_scores, axis=0) # [E] @@ -562,9 +553,7 @@ def _reference_kwargs_from_config(config, params_np): return dict( score_function=config.get("score_function", "softmax"), expert_bias=( - jnp.asarray(params_np["expert_bias"]) - if config.get("use_expert_bias", False) - else None + jnp.asarray(params_np["expert_bias"]) if config.get("use_expert_bias", False) else None ), ) @@ -715,9 +704,7 @@ def test_aux_loss(self, mesh): # wired. aux_grads = _grad_aux_only(block, variables, mesh, x) g_gate = np.asarray( - jax.device_get( - _unwrap(aux_grads["params"]["gate_kernel"]).addressable_data(0) - ) + jax.device_get(_unwrap(aux_grads["params"]["gate_kernel"]).addressable_data(0)) ) assert np.all(np.isfinite(g_gate)), "gate grad NaN/Inf under aux-only loss" assert np.any(g_gate != 0.0), "aux bwd should propagate to gate_kernel" @@ -730,9 +717,7 @@ def test_combined_loss_grads(self, mesh): variables, _, _ = _init_apply(block, mesh, x, jax.random.PRNGKey(23)) grads = _grad_step(block, variables, mesh, x, include_aux=True) for name in ("gate_kernel", "wi_0", "wi_1", "wo"): - g_local = np.asarray( - jax.device_get(_unwrap(grads["params"][name]).addressable_data(0)) - ) + g_local = np.asarray(jax.device_get(_unwrap(grads["params"][name]).addressable_data(0))) assert np.all(np.isfinite(g_local)), f"{name} grad NaN/Inf under main+aux" assert np.any(g_local != 0.0), f"{name} grad zero under main+aux" @@ -774,9 +759,7 @@ def test_init_apply_parity(self, mesh): grads = _grad_step(block, variables, mesh, x) for name in ("gate_kernel", "wi_0", "wi_1", "wo"): - g_local = np.asarray( - jax.device_get(_unwrap(grads["params"][name]).addressable_data(0)) - ) + g_local = np.asarray(jax.device_get(_unwrap(grads["params"][name]).addressable_data(0))) assert np.all(np.isfinite(g_local)), f"{name} grad NaN/Inf" assert np.any(g_local != 0.0), f"{name} grad zero" @@ -796,9 +779,7 @@ def test_bootstrap_signature_mismatch_raises(self, mesh): # Different hidden dim → different bootstrap signature. bigger_hidden = HIDDEN * 2 - x_b = jax.random.normal( - jax.random.PRNGKey(16), (BATCH, SEQ, bigger_hidden), dtype=DTYPE - ) + x_b = jax.random.normal(jax.random.PRNGKey(16), (BATCH, SEQ, bigger_hidden), dtype=DTYPE) block_b = MoEBlock( num_experts=NUM_EXPERTS, num_experts_per_tok=TOPK, diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py index 8fb0d90f8a..7f8a05dbb8 100644 --- a/transformer_engine/jax/cpp_extensions/ep.py +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -924,7 +924,12 @@ def ep_combine_fwd(handle, handle_mem, expert_out, num_local_tokens, out_partiti @compute_on("gpu_stream:collective") def ep_dispatch_bwd( - handle, handle_mem, grad, g_recv_topk_weights, top_k, num_local_tokens, + handle, + handle_mem, + grad, + g_recv_topk_weights, + top_k, + num_local_tokens, out_partition_spec=None, ): """Backward of dispatch; returns (grad_tokens, grad_topk_weights).""" diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index 17d00bef87..71cfdbb246 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -49,8 +49,7 @@ def _allgather_uid(uid_arr, world_size, uid_size): devices = np.asarray(jax.devices()) if devices.size != world_size: raise RuntimeError( - f"_allgather_uid fallback expected {world_size} global devices," - f" got {devices.size}." + f"_allgather_uid fallback expected {world_size} global devices, got {devices.size}." ) mesh = jax.sharding.Mesh(devices, ("_uid_all",)) sharded = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("_uid_all", None)) @@ -258,8 +257,13 @@ def _dispatch_bwd(handle, recv_capacity_per_rank, res, g_outputs): @partial(jax.custom_vjp, nondiff_argnums=(0, 5, 6)) def ep_combine( - handle, handle_mem, token_counts, expert_out, recv_topk_weights, - num_local_tokens, out_sharding=None, + handle, + handle_mem, + token_counts, + expert_out, + recv_topk_weights, + num_local_tokens, + out_sharding=None, ): """Reduce weighted expert outputs back to source ranks. @@ -281,8 +285,13 @@ def ep_combine( ``[..., H]`` combined output shaped per ``num_local_tokens``. """ return _combine_fwd( - handle, handle_mem, token_counts, expert_out, recv_topk_weights, - num_local_tokens, out_sharding, + handle, + handle_mem, + token_counts, + expert_out, + recv_topk_weights, + num_local_tokens, + out_sharding, )[0] @@ -292,8 +301,13 @@ def _make_valid_mask(recv_topk_weights, dtype): def _combine_fwd( - handle, handle_mem, token_counts, expert_out, recv_topk_weights, - num_local_tokens, out_sharding, + handle, + handle_mem, + token_counts, + expert_out, + recv_topk_weights, + num_local_tokens, + out_sharding, ): del token_counts w = recv_topk_weights[..., None] diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index 144fb79629..b71c8ccf1b 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -295,9 +295,7 @@ def _ffn_fwd_per_shard( wo = wo.astype(sorted_x.dtype) wi_combined = jnp.stack([wi_0, wi_1], axis=-2) - wi_combined_bias = ( - jnp.stack([wi_0_bias, wi_1_bias], axis=-2) if wi_0_bias is not None else None - ) + wi_combined_bias = jnp.stack([wi_0_bias, wi_1_bias], axis=-2) if wi_0_bias is not None else None q_set = noop_quantizer_set casted_sorted_x = tex.grouped_quantize(sorted_x, q_set.x, local_group_sizes, flatten_axis=-1) @@ -588,9 +586,7 @@ def _moe_fwd_rule( # single all-gather over (*dp, ep) and lives off the dispatch # critical path. if aux_loss_coeff > 0.0: - global_logits_2d = jax.lax.with_sharding_constraint( - logits_2d, NamedSharding(mesh, P()) - ) + global_logits_2d = jax.lax.with_sharding_constraint(logits_2d, NamedSharding(mesh, P())) _, global_routing_map, _ = tex.fused_topk_with_score_function_fwd( global_logits_2d, topk=K, @@ -641,17 +637,11 @@ def _moe_fwd_rule( # each rank see B/ep rows (not B/num_procs) and overrun the bootstrap-sized # send buffer. Pin both routing tensors to the (outer, ep) leading sharding # so per-rank token counts match max_tokens_per_rank. - topk_idx_3d = jax.lax.with_sharding_constraint( - topk_idx_3d, NamedSharding(mesh, ep3_spec) - ) - topk_w_3d = jax.lax.with_sharding_constraint( - topk_w_3d, NamedSharding(mesh, ep3_spec) - ) + topk_idx_3d = jax.lax.with_sharding_constraint(topk_idx_3d, NamedSharding(mesh, ep3_spec)) + topk_w_3d = jax.lax.with_sharding_constraint(topk_w_3d, NamedSharding(mesh, ep3_spec)) # ---------------- TE EP dispatch (global view) ---------------- - handle = _get_or_make_ep_handle( - top_k=K, dispatch_output_per_expert_alignment=slots_per_expert - ) + handle = _get_or_make_ep_handle(top_k=K, dispatch_output_per_expert_alignment=slots_per_expert) token_counts, handle_mem = tex.ep_prepare(topk_idx_3d, handle) recv_tokens, recv_topk_weights = tex.ep_dispatch_fwd( handle, handle_mem, topk_idx_3d, x, topk_w_3d, recv_pr @@ -674,13 +664,13 @@ def _moe_fwd_rule( # FFN residuals live entirely on the local ep rank, so the leading # "experts" / "rows" dims map to P() (already shard-local). residuals_spec = ( - P(), # casted_sorted_x_lhs_trans - P(ep_axis, None, None), # casted_wi_rhs_trans - P(), # gate_proj_out - P(), # up_proj_out - P(), # casted_intermediate_lhs_trans - P(ep_axis, None, None), # casted_wo_rhs_trans - P(), # local_group_sizes + P(), # casted_sorted_x_lhs_trans + P(ep_axis, None, None), # casted_wi_rhs_trans + P(), # gate_proj_out + P(), # up_proj_out + P(), # casted_intermediate_lhs_trans + P(ep_axis, None, None), # casted_wo_rhs_trans + P(), # local_group_sizes ) out_specs = (ep3_spec, residuals_spec) @@ -712,9 +702,7 @@ def _body(*args): out_specs=out_specs, check_rep=False, )(*ffn_in_args) - expert_outputs = jax.lax.with_sharding_constraint( - expert_outputs, NamedSharding(mesh, ep3_spec) - ) + expert_outputs = jax.lax.with_sharding_constraint(expert_outputs, NamedSharding(mesh, ep3_spec)) # ---------------- TE EP combine (global view) ---------------- out_partition_spec = (batch_pspec_axis, None, None) @@ -859,15 +847,15 @@ def _moe_bwd_rule( bias_spec = P(ep_axis, None) if has_bias else None bwd_in_specs = ( - ep3_spec, # d_expert_outputs - P(), # casted_sorted_x_lhs_trans + ep3_spec, # d_expert_outputs + P(), # casted_sorted_x_lhs_trans P(ep_axis, None, None), # casted_wi_rhs_trans - P(), # gate_proj_out - P(), # up_proj_out - P(), # casted_intermediate_lhs_trans + P(), # gate_proj_out + P(), # up_proj_out + P(), # casted_intermediate_lhs_trans P(ep_axis, None, None), # casted_wo_rhs_trans - P(), # local_group_sizes - ep2_spec, # recv_topk_weights + P(), # local_group_sizes + ep2_spec, # recv_topk_weights ) bwd_in_args = [ d_expert_outputs, @@ -881,14 +869,14 @@ def _moe_bwd_rule( ctx.recv_topk_weights, ] bwd_out_specs = ( - ep3_spec, # d_sorted_x - ep2_spec, # d_recv_w_from_intermediate - kernel_spec, # d_wi_0 - kernel_spec, # d_wi_1 - kernel_spec, # d_wo - bias_spec if has_bias else None, # d_wi_0_bias - bias_spec if has_bias else None, # d_wi_1_bias - bias_spec if has_bias else None, # d_wo_bias + ep3_spec, # d_sorted_x + ep2_spec, # d_recv_w_from_intermediate + kernel_spec, # d_wi_0 + kernel_spec, # d_wi_1 + kernel_spec, # d_wo + bias_spec if has_bias else None, # d_wi_0_bias + bias_spec if has_bias else None, # d_wi_1_bias + bias_spec if has_bias else None, # d_wo_bias ) def _bwd_body(*args): @@ -945,15 +933,15 @@ def _bwd_body(*args): in_specs=bwd_in_specs, out_specs=bwd_out_specs, check_rep=False, - )(*bwd_in_args) + )( + *bwd_in_args + ) d_recv_w_total = d_recv_w_from_combine + d_recv_w_from_intermediate # ---------------- Dispatch bwd (global view) ---------------- d_sorted_x = jax.lax.with_sharding_constraint(d_sorted_x, NamedSharding(mesh, ep3_spec)) - d_recv_w_total = jax.lax.with_sharding_constraint( - d_recv_w_total, NamedSharding(mesh, ep2_spec) - ) + d_recv_w_total = jax.lax.with_sharding_constraint(d_recv_w_total, NamedSharding(mesh, ep2_spec)) d_x_from_dispatch, d_topk_w = tex.ep_dispatch_bwd( ctx.handle, ctx.handle_mem, @@ -1001,9 +989,7 @@ def _bwd_body(*args): ) # routing_map is ignored by the kernel when compute_aux_scores=True, # so pass a zero placeholder of the right shape/dtype. - zero_routing_map = jnp.zeros( - ctx.aux_saved_scores.shape, dtype=ctx.routing_map.dtype - ) + zero_routing_map = jnp.zeros(ctx.aux_saved_scores.shape, dtype=ctx.routing_map.dtype) d_logits_aux = tex.fused_topk_with_score_function_bwd( zero_routing_map, ctx.aux_saved_scores, @@ -1190,9 +1176,7 @@ def moe( mesh = _get_mesh() if mesh is None or mesh.empty: raise ValueError("moe(...) requires an active jax.sharding.Mesh.") - expected_leading: Any = ( - (*data_parallelism_axes, ep_axis) if data_parallelism_axes else ep_axis - ) + expected_leading: Any = (*data_parallelism_axes, ep_axis) if data_parallelism_axes else ep_axis expected_spec = P(expected_leading, None, None) actual_spec = getattr(getattr(x, "sharding", None), "spec", None) if actual_spec is not None and tuple(actual_spec) != tuple(expected_spec):