Skip to content

Fix NPU Qwen3.5 grpo bugs#9589

Open
addsubmuldiv wants to merge 9 commits into
modelscope:mainfrom
addsubmuldiv:qwen35_grpo_adapt
Open

Fix NPU Qwen3.5 grpo bugs#9589
addsubmuldiv wants to merge 9 commits into
modelscope:mainfrom
addsubmuldiv:qwen35_grpo_adapt

Conversation

@addsubmuldiv

@addsubmuldiv addsubmuldiv commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

This PR enables Qwen3.5 / Qwen3.5-MoE GRPO on NPU, covering both colocate and server vLLM rollout paths, and documents the known limitation of MoE fused expert LoRA under PEFT + Transformers 5.

The changes are grouped into four areas:

1. Qwen3.5 colocate / basic compatibility
2. vLLM-Ascend MoE weight-sync layout compatibility
3. vLLM server-mode weight synchronization compatibility
4. Documentation for PEFT MoE fused expert LoRA limitations

This PR only contains framework code changes and does not include local smoke scripts.

1. Qwen3.5 RoPE config compatibility with vLLM 0.18 + Transformers 5

Files involved:

  • swift/infer_engine/vllm_engine.py

Problem

Qwen3.5 depends on the config / RoPE parameter handling logic in Transformers 5. In the Qwen3.5 config path of vLLM 0.18.x, ignore_keys_at_rope_validation may still be passed in as a list; however, the Transformers 5 side treats it as a set and performs set operations on it, which leads to a type error.

Fix

When creating the vLLM engine, this PR temporarily patches PretrainedConfig.convert_rope_params_to_dict() so that it only converts ignore_keys_at_rope_validation from list to set when needed.

Upstream status

2. MoE weight synchronization layout compatibility for vLLM-Ascend

Files involved:

  • swift/model/npu_patch/vllm_ascend.py
  • swift/model/npu_patch/vllm_ascend_moe.py
  • swift/rlhf_trainers/rollout_mixin.py
  • swift/rlhf_trainers/utils.py

Problem

During GRPO training, training-side weights need to be synchronized to the vLLM rollout side. For Qwen MoE on vLLM-Ascend, MoE expert weights have backend-specific layouts:

  • training-side weights may come from HF / FSDP2 / Megatron layouts;
  • vLLM-Ascend may convert MoE weights into a runtime layout suitable for NPU grouped matmul;
  • FSDP2 can expose Qwen MoE expert weights as fused gate_up_proj / down_proj tensors;
  • directly writing these tensors into vLLM-Ascend w13_weight / w2_weight can leave the runtime weight layout in the wrong direction.

One observed failure was the first rollout hitting npu_grouped_matmul with a hidden-size mismatch, where the activation K dimension was hidden_size but the synced w13_weight dimension still matched the TP-sharded expert intermediate dimension.

Fix

  • Add an NPU-only patch for the vLLM-Ascend MoE weight loader.
  • Treat both qwen3_moe and qwen3_5_moe as the Qwen MoE sync family.
  • Expand FSDP2 fused Qwen MoE expert names into vLLM checkpoint-style gate_proj / up_proj / down_proj names before calling the vLLM loader.
  • For FSDP2 colocate runtime sync, write the processed runtime layout directly:
    • w13_weight: [local_experts, hidden, 2 * intermediate_per_tp]
    • w2_weight: [local_experts, intermediate_per_tp, hidden]
  • Skip process_weights_after_loading() only for this FSDP2 Qwen MoE colocate runtime-sync path, because calling post-load again would transpose the just-synced runtime weights back to the wrong direction.
  • Keep server full-weight reload on the preprocessed-layout path and still run post-load processing after reload.

Upstream status

3. Weight synchronization compatibility for vLLM server mode

Files involved:

  • swift/pipelines/infer/rollout.py
  • swift/rlhf_trainers/vllm_client.py

Problem

In server mode, the training process synchronizes weights to an independent vLLM rollout server through a client. For the Qwen3.5-MoE + vLLM-Ascend path, weight synchronization needs to ensure that:

  • HCCL broadcast uses the correct device context / stream on NPU;
  • full weight reload triggers the required post-load processing in vLLM / vLLM-Ascend;
  • LoRA adapter sync and full weight sync do not share the same post-processing logic.

Fix

  • Consolidate tensor broadcast logic in both the server client and server worker into _broadcast_tensor().
  • Enter communicator.device device context and use the current NPU stream only under NPU runtime.
  • Keep the original generic broadcast behavior unchanged for non-NPU branches to avoid affecting GPU.
  • Add a /process_weights_after_loading/ call on the server worker side to rebuild the vLLM-Ascend MoE layout after full reload.
  • Skip the full-reload post-load logic for LoRA adapter sync.

4. Known limitation: PEFT 0.19 + Transformers 5 fused MoE LoRA

This PR does not attempt to fix the PEFT fused expert LoRA path in SWIFT. Instead, it documents the limitation in the NPU support documentation.

Background

Transformers 5 changes the structure of some MoE experts. Previously, each expert typically exposed independent nn.Linear modules such as gate_proj, up_proj, and down_proj; in the new structure, these expert weights may instead be fused into a single nn.Parameter, for example:

mlp.experts.gate_up_proj
mlp.experts.down_proj

Regular LoRA uses PEFT target_modules to match nn.Modules, while fused MoE experts require PEFT target_parameters to match parameter names directly.

Current assessment

  • target_parameters is an upstream mechanism provided by PEFT, but there are still unresolved edge cases in combinations involving lora_dropout, ZeRO-3 / FSDP, and multi-adapter setups.
  • This is not an NPU-specific issue, nor is it a vLLM / HCCL issue.
  • To avoid maintaining an unstable upstream capability branch inside SWIFT, this PR reverts the temporary workaround for this issue in swift/tuners/peft.py.
  • This PR adds a known-limitation note to:
    • docs/source/BestPractices/NPU-support.md
    • docs/source_en/BestPractices/NPU-support.md

Recommendation

  • By default, Qwen3.5 GRPO scripts should not explicitly pass model_type, so as to avoid triggering the PEFT Transformers 5 MoE conversion path.
  • If the model configuration itself already triggers that path, prefer full-parameter training or disable the affected LoRA combinations.
  • If users really need to train fused expert parameters, they should wait until the upstream PEFT support becomes stable, or use it only when lora_dropout=0 and the target backend has been separately validated.

Related upstream context

Experiment results

Verified:

  1. Python static checks passed:
python -m py_compile \
  swift/infer_engine/vllm_engine.py \
  swift/model/npu_patch/vllm_ascend.py \
  swift/model/npu_patch/vllm_ascend_moe.py \
  swift/pipelines/infer/rollout.py \
  swift/rlhf_trainers/rollout_mixin.py \
  swift/rlhf_trainers/utils.py \
  swift/rlhf_trainers/vllm_client.py \
  swift/tuners/peft.py
  1. Documentation and diff format checks passed:
git diff --check -- \
  docs/source/BestPractices/NPU-support.md \
  docs/source_en/BestPractices/NPU-support.md
  1. NPU Qwen GRPO smoke coverage:
Scenario Model Backend / rollout Training mode Result
colocate Qwen3.5 dense Megatron + vLLM-Ascend full 1 step passed, covering the basic dense + colocate path
colocate Qwen3.5 dense Megatron + vLLM-Ascend LoRA 1 step passed, covering dense + LoRA weight synchronization
colocate Qwen3.5-MoE Megatron + vLLM-Ascend full 1 step passed, covering Megatron MoE processed-layout synchronization
colocate Qwen3.5-MoE Megatron + vLLM-Ascend LoRA 1 step passed, covering the Megatron MoE + LoRA path
colocate Qwen3.5-MoE FSDP2 + vLLM-Ascend full 1 step passed, covering FSDP2 fused expert / pre-process layout synchronization
colocate Qwen3-MoE FSDP2 + vLLM-Ascend, vLLM TP=4 LoRA 5/5 steps passed, covering qwen3_moe fused expert runtime sync and processed-layout writeback
server Qwen3.5-MoE Megatron + standalone vLLM-Ascend server full 1 step passed, covering standalone server-mode HCCL broadcast, full-reload post-processing, and MoE inference

Not covered:

  • GPU regression was not validated locally; however, the NPU / vLLM-Ascend-specific logic is isolated behind runtime branches.
  • PEFT fused expert target_parameters + lora_dropout > 0 is still intentionally unsupported; the current approach is to document the limitation rather than maintain a temporary workaround in SWIFT.

@addsubmuldiv addsubmuldiv changed the title Qwen35 grpo adapt Qwen35 grpo bug fix Jun 17, 2026
@addsubmuldiv addsubmuldiv changed the title Qwen35 grpo bug fix Fix Qwen3.5 grpo bugs Jun 17, 2026
@addsubmuldiv addsubmuldiv changed the title Fix Qwen3.5 grpo bugs Fix NPU Qwen3.5 grpo bugs Jun 17, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces compatibility layers and weight synchronization enhancements for vLLM-Ascend, specifically supporting Qwen3.5 MoE models and handling list-style RoPE validation keys from older configurations. It also patches PEFT to keep regex target modules intact during Transformers-v5 MoE conversion. The review feedback highlights two important improvements: first, safely retrieving the RoPE validation conversion method using getattr to prevent AttributeError on older transformers versions; second, ensuring that PEFT config conversion falls back to the original converter when regex compilation fails or when targeting 'all-linear' modules, while also pre-compiling the regex outside of nested loops to optimize performance.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread swift/infer_engine/vllm_engine.py Outdated
Comment on lines +104 to +115
origin_convert = PretrainedConfig.convert_rope_params_to_dict

def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation=None, **kwargs):
if isinstance(ignore_keys_at_rope_validation, list):
ignore_keys_at_rope_validation = set(ignore_keys_at_rope_validation)
return origin_convert(self, ignore_keys_at_rope_validation=ignore_keys_at_rope_validation, **kwargs)

PretrainedConfig.convert_rope_params_to_dict = convert_rope_params_to_dict
try:
yield
finally:
PretrainedConfig.convert_rope_params_to_dict = origin_convert

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The convert_rope_params_to_dict method was introduced in newer versions of transformers. If an older version of transformers is installed, accessing PretrainedConfig.convert_rope_params_to_dict directly will raise an AttributeError on import or execution. To prevent this, use getattr to safely retrieve the method and yield early if it is not present.

Suggested change
origin_convert = PretrainedConfig.convert_rope_params_to_dict
def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation=None, **kwargs):
if isinstance(ignore_keys_at_rope_validation, list):
ignore_keys_at_rope_validation = set(ignore_keys_at_rope_validation)
return origin_convert(self, ignore_keys_at_rope_validation=ignore_keys_at_rope_validation, **kwargs)
PretrainedConfig.convert_rope_params_to_dict = convert_rope_params_to_dict
try:
yield
finally:
PretrainedConfig.convert_rope_params_to_dict = origin_convert
origin_convert = getattr(PretrainedConfig, 'convert_rope_params_to_dict', None)
if origin_convert is None:
yield
return
def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation=None, **kwargs):
if isinstance(ignore_keys_at_rope_validation, list):
ignore_keys_at_rope_validation = set(ignore_keys_at_rope_validation)
return origin_convert(self, ignore_keys_at_rope_validation=ignore_keys_at_rope_validation, **kwargs)
PretrainedConfig.convert_rope_params_to_dict = convert_rope_params_to_dict
try:
yield
finally:
PretrainedConfig.convert_rope_params_to_dict = origin_convert

Comment thread swift/tuners/peft.py Outdated
Comment on lines +363 to +386
try:
re.compile(target_modules)
except re.error:
return

if target_modules == 'all-linear':
return

target_parameters = _names_to_set(getattr(peft_config, 'target_parameters', None))
original_target_parameters = target_parameters.copy()
old_names_by_new_name = {}
for old_name, new_name in target_module_mapping.items():
old_names_by_new_name.setdefault(new_name, set()).add(old_name)

matched_fused_targets = {new_name: set() for new_name in fused_targets}
for param_name, _ in model.named_parameters():
for new_name, old_names in old_names_by_new_name.items():
if not (param_name == new_name or param_name.endswith(f'.{new_name}')):
continue
prefix = param_name[:-len(new_name)]
matched_old_names = {
old_name
for old_name in old_names if re.fullmatch(target_modules, f'{prefix}{old_name}')
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are two issues in this block:

  1. Correctness/Safety: If re.compile fails or if target_modules == 'all-linear', returning None directly will completely bypass origin_convert, which prevents PEFT from performing its default config conversion. We should delegate these cases back to origin_convert.
  2. Performance: re.fullmatch(target_modules, ...) is called inside nested loops over all model parameters and target modules. This compiles the regex string target_modules repeatedly, which is highly inefficient. Compiling the regex once outside the loops using re.compile and reusing the compiled pattern avoids this overhead.
        try:
            pattern = re.compile(target_modules)
        except re.error:
            return origin_convert(peft_config, model, conversions)

        if target_modules == 'all-linear':
            return origin_convert(peft_config, model, conversions)

        target_parameters = _names_to_set(getattr(peft_config, 'target_parameters', None))
        original_target_parameters = target_parameters.copy()
        old_names_by_new_name = {}
        for old_name, new_name in target_module_mapping.items():
            old_names_by_new_name.setdefault(new_name, set()).add(old_name)

        matched_fused_targets = {new_name: set() for new_name in fused_targets}
        for param_name, _ in model.named_parameters():
            for new_name, old_names in old_names_by_new_name.items():
                if not (param_name == new_name or param_name.endswith(f'.{new_name}')):
                    continue
                prefix = param_name[:-len(new_name)]
                matched_old_names = {
                    old_name
                    for old_name in old_names if pattern.fullmatch(f'{prefix}{old_name}')
                }

@addsubmuldiv addsubmuldiv marked this pull request as ready for review June 18, 2026 06:24
Copilot AI review requested due to automatic review settings June 18, 2026 06:24

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves NPU (vLLM-Ascend) compatibility for running GRPO on Qwen3.5 / Qwen3.5-MoE across both colocate and server rollout modes, and documents a known PEFT + Transformers 5 MoE fused-expert LoRA limitation.

Changes:

  • Add NPU-aware broadcast and post-load processing hooks so server-mode weight sync correctly rebuilds vLLM-Ascend MoE kernel layouts after reload.
  • Extend vLLM-Ascend MoE weight-loader patching to support selecting “processed” vs “preprocessed” expert layouts for different sync paths (Megatron vs FSDP2 reload-style).
  • Add a temporary Transformers 5 RoPE compatibility monkey-patch for vLLM 0.18.x Qwen3.5 configs and document PEFT fused-expert LoRA limitations.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
swift/rlhf_trainers/vllm_client.py Centralizes broadcast behavior to be NPU-stream/device-context aware during server sync.
swift/rlhf_trainers/utils.py Adds Qwen3.5 MoE registry entry and threads MoE layout selection into Ascend expert loader patching.
swift/rlhf_trainers/rollout_mixin.py Configures vLLM-Ascend MoE sync layout for colocate mode and refines fused-expert name/weight handling.
swift/pipelines/infer/rollout.py Adds NPU-aware broadcast helper and ensures full-reload paths trigger Ascend MoE post-load processing.
swift/model/npu_patch/vllm_ascend.py Re-exports the Ascend MoE preprocessed-layout selector to support sync logic.
swift/model/npu_patch/vllm_ascend_moe.py Implements layout-mode tracking and enhanced Ascend MoE expert weight-loader behavior for sync/reload.
swift/infer_engine/vllm_engine.py Adds a scoped patch to accept list-based RoPE validation ignore keys for vLLM 0.18.x + Transformers 5.
docs/source/BestPractices/NPU-support.md Documents PEFT Transformers 5 MoE fused-expert LoRA limitations (ZH).
docs/source_en/BestPractices/NPU-support.md Documents PEFT Transformers 5 MoE fused-expert LoRA limitations (EN).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 197 to 203
weight = torch.empty(shape, dtype=dtype, device=self.communicator.device)

# Use NCCL to broadcast the updated weights from the client (src) to all workers.
self.communicator.broadcast(
weight, src=self.client_rank, stream=getattr(get_torch_device(), 'current_stream', lambda: None)())
_broadcast_tensor(self.communicator, weight, src=self.client_rank)
synchronize()
self.communicator.group.barrier()

@Jintao-Huang

Copy link
Copy Markdown
Collaborator

cc @hjh0119

Comment on lines +139 to +141
def should_keep_fused_moe_expert_for_vllm_ascend(model) -> bool:
"""Return whether fused expert names should be kept for vLLM-Ascend sync."""
return False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this method?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. This helper was a leftover from an earlier branch where we tried both "keep fused expert names" and "expand fused expert names" paths. After the final fix, vLLM-Ascend always needs the expanded checkpoint-style names for this sync path, so the helper had degenerated into a constant False, which was unnecessary and confusing.

I removed the helper and now call the fused MoE expansion path directly. This keeps the intent explicit and avoids carrying a dead switch in the code.

Comment thread swift/rlhf_trainers/vllm_client.py Outdated
Comment on lines +41 to +47
def _broadcast_tensor(communicator, tensor: torch.Tensor, src: int) -> None:
if is_torch_npu_available():
device_module = get_torch_device()
with device_module.device(communicator.device):
communicator.broadcast(tensor, src=src, stream=device_module.current_stream())
else:
communicator.broadcast(tensor, src=src, stream=getattr(get_torch_device(), 'current_stream', lambda: None)())

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider reusing the methods from swift/pipelines/infer/rollout.py or extracting them into utils

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I extracted the shared tensor broadcast logic into broadcast_tensor_for_vllm_weight_sync() in swift/rlhf_trainers/utils.py, and both swift/rlhf_trainers/vllm_client.py and swift/pipelines/infer/rollout.py now reuse it.

Comment thread swift/infer_engine/vllm_engine.py Outdated
"""
from transformers import PretrainedConfig

origin_convert = PretrainedConfig.convert_rope_params_to_dict

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Gemini. Is this attribute safe to get for lower versions of Transformers

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. The patch now uses getattr(PretrainedConfig, "convert_rope_params_to_dict", None) and skips the monkey patch if the method does not exist.

@hjh0119

hjh0119 commented Jun 21, 2026

Copy link
Copy Markdown
Collaborator

Sorry for the late review. I've left a few minor comments

@hjh0119

hjh0119 commented Jun 21, 2026

Copy link
Copy Markdown
Collaborator

Thanks! Is this pr ready to merge?

@addsubmuldiv

Copy link
Copy Markdown
Collaborator Author

Thanks! Is this pr ready to merge?

Yes, I think it’s ready. Please feel free to merge it if there are no further concerns.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants