Skip to content

feat: add Muon optimizer support#78122

Open
xxyux wants to merge 1 commit intoPaddlePaddle:developfrom
xxyux:feature/add-muon-optimizer
Open

feat: add Muon optimizer support#78122
xxyux wants to merge 1 commit intoPaddlePaddle:developfrom
xxyux:feature/add-muon-optimizer

Conversation

@xxyux
Copy link
Contributor

@xxyux xxyux commented Mar 3, 2026

PR Category

Execute Infrastructure

PR Types

New features

Description

Summary

This PR introduces the Muon optimizer and wires it into PaddleFleet's DygraphShardingOptimizerV2 pipeline.

Muon applies orthogonal gradient updates (via Newton-Schulz iteration) to 2-D weight matrices, and falls back to standard AdamW for embeddings, biases, and other non-matrix parameters. It is designed for large-scale distributed training with Sharding Stage1 V2 parameter sharding and Tensor Parallelism.

Key changes

paddle/optimizer/muon.py (new)

  • Implements the Muon optimizer class, inheriting from Optimizer.
  • _muon_update: gathers the full 2-D weight matrix across sharding / TP ranks, applies Newton-Schulz orthogonalisation, then scatters the local shard back.
  • _adamw_update: in-place AdamW fallback for non-matrix parameters.
  • sharded_state_dict: flex-checkpoint compatible save/load, mirroring the AdamW layout.
  • Supports multi_precision (BF16/FP16 training with FP32 master weights) and per-head QKV orthogonalisation.

paddle/distributed/fleet/utils/muon_comm_utils.py (new)

  • gather_varlen: variable-length all-to-one gather that pre-allocates a single contiguous buffer on the destination rank to avoid memory fragmentation.
  • get_sharding_info: computes per-rank element counts and local offset for a sharded parameter stored in a FusedCommBuffer.
  • should_use_muon: predicate that selects 2-D weight matrices while excluding embeddings, biases, and LM-head weights.

DygraphShardingOptimizerV2.step()

  • Detects a Muon inner optimizer by walking the _inner_opt wrapper chain (name comparison avoids a circular import).
  • Annotates each 1-D slice-param with the metadata Muon needs: original_shape, is_sharded_gather, sharding_indices, split_axis, QKV flags.
  • Uses get_sharding_info to compute per-rank element counts for parameters that span multiple sharding ranks.
  • Sorts params_grads so that large, fully-owned parameters are processed before smaller sharded ones, improving allocator locality during the Newton-Schulz pass.

Tests

hybrid_parallel_sharding_mp_model.py + test_hybrid_parallel_sharding_mp_logic: end-to-end accuracy test that runs a 4-GPU sharding=2 × mp=2 job, compares Muon-updated MP model parameters against a single-process reference, and verifies correctness through the full annotation → gather → orthogonalise → scatter pipeline.

是否引起精度变化

@paddle-bot
Copy link

paddle-bot bot commented Mar 3, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@codecov-commenter
Copy link

Codecov Report

❌ Patch coverage is 20.00000% with 28 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@a2ec403). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...rs/dygraph_optimizer/dygraph_sharding_optimizer.py 20.00% 28 Missing ⚠️

❌ Your patch status has failed because the patch coverage (20.00%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #78122   +/-   ##
==========================================
  Coverage           ?   20.00%           
==========================================
  Files              ?        1           
  Lines              ?       35           
  Branches           ?        0           
==========================================
  Hits               ?        7           
  Misses             ?       28           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@xxyux xxyux force-pushed the feature/add-muon-optimizer branch from 65f5619 to 1e75080 Compare March 4, 2026 03:39
@xxyux
Copy link
Contributor Author

xxyux commented Mar 4, 2026

/re-run all-failed

1 similar comment
@xxyux
Copy link
Contributor Author

xxyux commented Mar 4, 2026

/re-run all-failed

@xxyux xxyux force-pushed the feature/add-muon-optimizer branch 2 times, most recently from 50097f4 to e9dfa92 Compare March 4, 2026 11:41
…ding MP parallel support

- Add Muon optimizer (python/paddle/optimizer/muon.py) with Newton-Schulz
  orthogonalisation and per-head QKV split support
- Add _create_accumulators for checkpoint-resume compatibility under AMP O2
- Add muon_comm_utils for distributed Muon gradient communication
- Update DygraphShardingOptimizer to support Muon optimizer integration
- Add sharding MP parallel test with both standard and qkv_split models
- Update CMakeLists.txt with test_parallel_dygraph_sharding_mp_parallel

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@xxyux xxyux force-pushed the feature/add-muon-optimizer branch from e9dfa92 to 7eb8579 Compare March 4, 2026 11:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants