Skip to content

Add AutoEP#7938

Draft
tohtana wants to merge 11 commits intodeepspeedai:masterfrom
tohtana:tohtana/add_autoep
Draft

Add AutoEP#7938
tohtana wants to merge 11 commits intodeepspeedai:masterfrom
tohtana:tohtana/add_autoep

Conversation

@tohtana
Copy link
Copy Markdown
Collaborator

@tohtana tohtana commented Mar 31, 2026

This PR adds AutoEP (Automatic Expert Parallelism) to DeepSpeed training for HuggingFace MoE models.

AutoEP detects MoE blocks during deepspeed.initialize(), builds the required EP/EDP process groups, and replaces supported MoE blocks with an EP-enabled execution path, so expert parallelism can be enabled with DeepSpeed config only and without model code changes.

Current scope in this PR is the base AutoEP feature:

  • ZeRO stages 0, 1, and 2 support
  • checkpoint save/load support
  • universal checkpoint conversion support

ZeRO-3 extensions are intentionally left as follow-up work (#7928 should be merged for this work)

Supported presets in this PR:

  • Mixtral
  • Qwen3-MoE
  • DeepSeek-V2
  • DeepSeek-V3
  • LLaMA-4

For end-to-end benchmarking and testing, an AutoEP example is available in DeepSpeedExamples:

Attribution

This implementation substantially builds on TorchTitan's MoE / expert-parallel implementation, and we want to explicitly acknowledge that prior work.

The TorchTitan-derived pieces in this PR are primarily:

  • deepspeed/moe/ep_router.py: adapted from TorchTitan's TokenChoiceTopKRouter
  • deepspeed/moe/ep_experts.py: adapted from TorchTitan's GroupedExperts and grouped-GEMM expert execution path
  • deepspeed/moe/ep_kernels.py: adapted from TorchTitan's TokenReorderer, generate_permute_indices, Triton fill-indices kernel, and token-group alignment / padding helpers
  • deepspeed/module_inject/auto_ep_layer.py: adapts the same router -> reorder -> dispatch -> local expert compute -> combine structure used in TorchTitan's MoE / EP flow

Relevant TorchTitan sources:

The DeepSpeed-specific work in this PR is the AutoEP integration layer around those building blocks:

  • HuggingFace MoE detection and structural validation
  • model-family presets and custom-config path
  • weight repacking from HF expert layouts into grouped expert tensors
  • DeepSpeed runtime group setup and module replacement
  • DeepSpeed checkpoint save/load and universal checkpoint support
  • DeepSpeed docs and tests

Design

The implementation is split into a few layers:

  • deepspeed/module_inject/auto_ep_config.py

    • user config parsing
    • built-in model presets
    • validation for EP topology and per-model constraints
  • deepspeed/module_inject/auto_ep.py

    • scans the model for MoE blocks
    • validates the detected structure
    • builds a MoELayerSpec for each supported MoE layer
    • replaces the original HF block with AutoEPMoELayer
  • deepspeed/module_inject/auto_ep_layer.py

    • the drop-in execution wrapper for a detected MoE block
    • implements router execution, token reorder, EP dispatch/combine, local expert compute, and shared-expert merge
  • deepspeed/moe/ep_router.py, deepspeed/moe/ep_experts.py, deepspeed/moe/ep_kernels.py

    • reusable MoE runtime pieces for routing, grouped expert compute, token permutation, and aligned grouped-GEMM execution
  • deepspeed/moe/ep_repack.py

    • converts HF expert weights into the grouped expert layout expected by the runtime
  • deepspeed/runtime/engine.py and checkpoint conversion code

    • wires AutoEP into deepspeed.initialize()
    • handles checkpoint save/load metadata and universal checkpoint integration

At runtime, the execution path is:

  1. detect and replace supported HF MoE blocks during initialization
  2. route tokens with the EP router
  3. reorder tokens by expert assignment
  4. perform all-to-all dispatch across the EP group when autoep_size > 1
  5. run local grouped expert compute
  6. all-to-all combine and restore the original token order
  7. merge shared experts if the model has them

Adding new model support

There are two supported ways to extend AutoEP to a new MoE model family.

  1. Add a preset in PRESET_MODELS.
    This is the preferred path for a model family we want to support out of the box. A preset defines:
  • MoE layer pattern
  • router child name
  • experts child name
  • expert weight names / layout
  • num_experts and top_k config attributes
  • routing defaults
  • optional shared-expert structure
  1. Use the custom config path.
    For models that are not yet built into DeepSpeed, AutoEP can be driven from config with:
  • moe_layer_pattern
  • router_pattern
  • expert_pattern
  • expert_w1, expert_w2, expert_w3
  • num_experts_attr
  • top_k_attr
  • optional shared-expert fields

Once detection can produce a valid MoELayerSpec, the replacement, execution, and checkpoint paths are shared.

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
@tohtana
Copy link
Copy Markdown
Collaborator Author

tohtana commented Mar 31, 2026

This feature is still experimental. The next steps are:

We welcome help testing and validating this on large-scale models.

@jiosephlee
Copy link
Copy Markdown

jiosephlee commented Mar 31, 2026

@tohtana I wish I could be of help, but I haven't written code on this level; if you could clarify on what you mean by a preset for gpt-oss, or if there are other first-issues kind of work I could help with, I would gladly look into it

@tohtana
Copy link
Copy Markdown
Collaborator Author

tohtana commented Mar 31, 2026

Hi @jiosephlee,
Thank you for offering help! It would be great if you can try gpt-oss once it is implemented with this AutoEP work.

tohtana added 4 commits March 31, 2026 18:33
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants