Skip to content

Conversation

@yanfeich
Copy link
Contributor

@yanfeich yanfeich commented Jan 4, 2026

Motivation

enable MoE EP for hpu with loader_v1

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

  • fastdeploy/model_executor/layers/moe/moe.py
    HPU calls forward_normal no matter EP or TP, and won't fall into forward_split_allgather nor forward_chunked_moe

  • fused_moe_hpu_backend.py
    change down_proj_in_scale from list to tensor.

  • hpu_model_runner.py
    list to tensor, add padding dim for 0x80 alignment request.

  • fastdeploy/model_executor/load_weight_utils.py
    needs up_gate_proj.activation_scale for EP in loader v0

  • fastdeploy/model_executor/models/ernie4_5_moe.py‎
    add Attention related activation_scale name conversions

      - self_attn.
    
loaded_weight_name checkpoint_to_fd_key_fn all_param_mapping
qkv_proj.activation_scale qkv_proj.in_scale qkv_proj.act_scale
o_proj.activation_scale o_proj.in_scale o_proj.act_scale
cachek_matmul.activation_scale cachek_matmul.in_scale attn.cache_k_scale
cachev_matmul.activation_scale cachev_matmul.in_scale attn.cache_v_scale
q_matmul.activation_scale q_matmul.in_scale attn.q_scale
s_matmul.activation_scale s_matmul.in_scale attn.s_scale
    - mlp. & mlp.shared_experts.
loaded_weight_name checkpoint_to_fd_key_fn all_param_mapping
down_proj.activation_scale down_proj.in_scale down_proj.act_scale
up_gate_proj.activation_scale up_gate_proj.in_scale up_gate_proj.act_scale
    - mlp.experts. (all experts share same activation_scale)
loaded_weight_name checkpoint_to_fd_key_fn all_param_mapping
up_gate_proj.activation_scale up_gate_proj.in_scale up_gate_proj_in_scale
    - mlp.experts.{exp_id}.
loaded_weight_name checkpoint_to_fd_key_fn all_param_mapping
experts.{exp_id}.down_proj.activation_scale experts.{exp_id}.down_proj.in_scale experts.down_proj_in_scale

Usage or Command

set enable_expert_parallel=True, and disable_sequence_parallel_moe=True, to enable HPU MoE EP.

Accuracy Tests

Checklist

  • [ Done ] Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • [ Done ] Format your code, run pre-commit before commit.
  • [ Done ] Add unit tests. Please write the reason in this PR if no unit tests.
    conducted by local tests
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings January 4, 2026 06:34
@paddle-bot
Copy link

paddle-bot bot commented Jan 4, 2026

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Jan 4, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables MoE (Mixture of Experts) Expert Parallelism (EP) for Intel HPU by modifying the execution path and weight handling to accommodate HPU-specific requirements.

Key changes:

  • Modified MoE forward logic to route HPU through forward_normal regardless of EP/TP configuration
  • Converted down_proj_in_scale from list to tensor and added padding alignment for HPU's 0x80 byte alignment requirement
  • Added up_gate_proj.activation_scale weight loading support for EP mode

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
fastdeploy/model_executor/layers/moe/moe.py Routes HPU platform to use forward_normal path for both EP and TP modes
fastdeploy/model_executor/layers/backends/intel_hpu/moe/fused_moe_hpu_backend.py Changes down_proj_in_scale handling from list to tensor and renames apply_tp to apply
fastdeploy/worker/hpu_model_runner.py Adds alignment padding function for scales and implements early return for EP mode
fastdeploy/model_executor/load_weight_utils.py Adds up_gate_proj_in_scale_key to weight loading for EP support
examples/intel_hpu/offline_demo.py Enables EP configuration in demo script

@codecov-commenter
Copy link

codecov-commenter commented Jan 4, 2026

Codecov Report

❌ Patch coverage is 18.18182% with 9 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@decbbb3). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/layers/moe/moe.py 16.66% 4 Missing and 1 partial ⚠️
fastdeploy/model_executor/load_weight_utils.py 0.00% 2 Missing ⚠️
fastdeploy/model_executor/utils.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #5855   +/-   ##
==========================================
  Coverage           ?   67.04%           
==========================================
  Files              ?      348           
  Lines              ?    44643           
  Branches           ?     6862           
==========================================
  Hits               ?    29932           
  Misses             ?    12507           
  Partials           ?     2204           
Flag Coverage Δ
GPU 67.04% <18.18%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yanfeich
Copy link
Contributor Author

yanfeich commented Jan 8, 2026

add @LeoZhao-Intel @fmiao2372
Please help review this patch, thanks!

tensor_model_parallel_all_reduce_custom(out)
else:
out = tensor_model_parallel_all_reduce(out, self.tp_group)
out = tensor_model_parallel_all_reduce(out, self.tp_group)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个为什么不用tensor_model_parallel_all_reduce_custom ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants