Skip to content

Commit 65f5619

Browse files
committed
feat: add Muon optimizer support
1 parent d808ea6 commit 65f5619

File tree

1 file changed

+30
-82
lines changed

1 file changed

+30
-82
lines changed

test/collective/fleet/hybrid_parallel_sharding_mp_model.py

Lines changed: 30 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import paddle
2222
import paddle.distributed as dist
2323
from paddle.distributed import fleet
24-
from paddle.distributed.fleet.utils.log_util import logger
2524
from paddle.distributed.fleet.utils.mix_precision_utils import (
2625
MixPrecisionLayer,
2726
MixPrecisionOptimizer,
@@ -45,6 +44,8 @@
4544

4645

4746
class SimpleMPNet(paddle.nn.Layer):
47+
"""Tensor-parallel model: linear1 is column-split, linear2 is row-split."""
48+
4849
def __init__(
4950
self,
5051
vocab_size,
@@ -70,31 +71,32 @@ def __init__(
7071
),
7172
)
7273

74+
# Each TP rank owns a contiguous column slice of fc1 and row slice of fc2.
7375
inner_per_rank = inner_size // mp_degree
7476
fc1_start = mp_id * inner_per_rank
7577
fc1_end = fc1_start + inner_per_rank
7678

77-
init_fc1_slice = np_fc1[:, fc1_start:fc1_end]
78-
7979
self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
8080
hidden_size,
8181
inner_size,
8282
weight_attr=paddle.framework.ParamAttr(
83-
initializer=paddle.nn.initializer.Assign(init_fc1_slice)
83+
initializer=paddle.nn.initializer.Assign(
84+
np_fc1[:, fc1_start:fc1_end]
85+
)
8486
),
85-
gather_output=False, # 关键:输出保持切分状态,不聚合,直接喂给下一层 RowParallel
87+
gather_output=False, # keep output sharded for RowParallel input
8688
has_bias=True,
8789
)
8890

89-
init_fc2_slice = np_fc2[fc1_start:fc1_end, :]
90-
9191
self.linear2 = fleet.meta_parallel.RowParallelLinear(
9292
inner_size,
9393
hidden_size,
9494
weight_attr=paddle.framework.ParamAttr(
95-
initializer=paddle.nn.initializer.Assign(init_fc2_slice)
95+
initializer=paddle.nn.initializer.Assign(
96+
np_fc2[fc1_start:fc1_end, :]
97+
)
9698
),
97-
input_is_parallel=True, # 关键:告诉这一层,输入已经是切分过的了
99+
input_is_parallel=True, # input already sharded from ColumnParallel
98100
has_bias=True,
99101
)
100102

@@ -119,6 +121,8 @@ def forward(self, x):
119121

120122

121123
class SimpleDPNet(paddle.nn.Layer):
124+
"""Single-process reference model with identical weight initialisation."""
125+
122126
def __init__(
123127
self,
124128
vocab_size,
@@ -209,8 +213,8 @@ def setUp(self):
209213
def train_batch(self, batch, model, optimizer):
210214
output = model(batch)
211215
loss = output.mean()
212-
loss.backward() # do backward
213-
optimizer.step() # update parameters
216+
loss.backward()
217+
optimizer.step()
214218
optimizer.clear_grad()
215219
return loss
216220

@@ -239,7 +243,6 @@ def build_optimizer(self, model, strategy=None, Optimizer="adam"):
239243
return optimizer
240244

241245
def build_model_optimizer(self, Optimizer="adam", amp_level=None):
242-
243246
np_fc1 = np.random.random_sample((hidden_size, inner_size)).astype(
244247
"float32"
245248
)
@@ -253,6 +256,7 @@ def build_model_optimizer(self, Optimizer="adam", amp_level=None):
253256
"float32"
254257
)
255258

259+
# model_a: sharded+TP distributed model; model_b: single-process reference
256260
model_a = SimpleMPNet(
257261
vocab_size,
258262
hidden_size,
@@ -302,79 +306,45 @@ def sharding_model(self, Optimizer, sharded_accumulators, amp_level=None):
302306
)
303307

304308
hcg = fleet.get_hybrid_communicate_group()
305-
# degree
306309
mp_degree = self.strategy.hybrid_configs['mp_degree']
307310
sharding_degree = self.strategy.hybrid_configs['sharding_degree']
308-
# rank
309311
rank = dist.get_rank()
310312
sharding_rank = hcg.get_sharding_parallel_rank()
311313
tp_rank = hcg.get_model_parallel_rank()
312-
# data size
313314
local_batch_size = batch_size // sharding_degree
315+
tp_group = hcg.get_model_parallel_group()
314316

315317
for idx in range(STEPS):
316-
print(f"Step = {idx}")
317-
318+
# Each sharding rank processes its own mini-batch slice.
318319
start_index = sharding_rank * local_batch_size
319320
end_index = (sharding_rank + 1) * local_batch_size
320-
321321
batch_sharding = paddle.to_tensor(
322322
self.data[idx][start_index:end_index]
323323
)
324-
logger.info(
325-
f"rank = {rank}, sharding_rank = {sharding_rank}, tp_rank = {tp_rank} start_index = {start_index}, end_index = {end_index}"
326-
)
327-
328324
batch_single = paddle.to_tensor(self.data[idx])
325+
329326
loss_a = self.train_batch(batch_sharding, model_a, optimizer_a)
330327
loss_b = self.train_batch(batch_single, model_b, optimizer_b)
331328

332-
# Reduce loss
329+
# Average loss across all ranks for a fair global comparison.
333330
loss_a_metric = loss_a.detach().clone()
334331
dist.all_reduce(loss_a_metric, op=dist.ReduceOp.SUM)
335-
world_size = dist.get_world_size()
336-
loss_a_global = loss_a_metric / world_size
337-
338-
abs_err = np.abs(loss_a_global.numpy() - loss_b.numpy())
339-
rel_err = abs_err / (np.abs(loss_b.numpy()) + 1e-9)
332+
loss_a_global = loss_a_metric / dist.get_world_size()
340333

341-
print(
342-
f"step {idx}, loss_a(local)={loss_a.numpy()}, loss_a(global)={loss_a_global.numpy()}, loss_b={loss_b.numpy()}"
343-
)
344-
print(f"abs_error = {abs_err}, rel_error = {rel_err}")
345-
346-
print(f"\n--- Checking Parameters at Step {idx} ---")
347-
348-
# 获取参数列表 (假设顺序是一致的)
349-
params_a = model_a.parameters()
350-
params_b = model_b.parameters()
351-
352-
tp_group = hcg.get_model_parallel_group()
353-
354-
for i, (param_a, param_b) in enumerate(zip(params_a, params_b)):
355-
name = param_a.name
356-
357-
# 1. 获取本地参数值
334+
# Compare each parameter between the distributed and reference model.
335+
for param_a, param_b in zip(
336+
model_a.parameters(), model_b.parameters()
337+
):
358338
val_a_local = param_a.numpy()
359339
val_b = param_b.numpy()
360340

361-
# 2. 判断是否是 TP 参数 (通过形状是否一致判断)
362-
is_tp_param = val_a_local.shape != val_b.shape
363-
364-
if is_tp_param:
365-
# 分布式参数:需要 Gather
341+
# TP parameters are split across TP ranks; gather before comparing.
342+
if val_a_local.shape != val_b.shape:
366343
gathered_list = []
367344
paddle.distributed.all_gather(
368345
gathered_list, param_a, group=tp_group
369346
)
370-
371-
# 3. 拼接
372-
# 关键:ColumnParallelLinear 是按列切分 (axis=1)
373-
# RowParallelLinear 是按行切分 (axis=0)
374-
# VocabParallelEmbedding 是按行切分 (axis=0)
375-
376-
# 简单的启发式判断拼接维度:
377-
# 看看 val_b 的哪个维度是 val_a 的 mp_degree 倍
347+
# Determine the split axis: find which dim is mp_degree times smaller.
378348
concat_axis = -1
379349
for dim in range(len(val_b.shape)):
380350
if (
@@ -383,43 +353,21 @@ def sharding_model(self, Optimizer, sharded_accumulators, amp_level=None):
383353
):
384354
concat_axis = dim
385355
break
386-
387356
if concat_axis == -1:
388-
print(
389-
f"[Warning] Param {name} shape mismatch but axis not found. Skip."
390-
)
391357
continue
392-
393-
# 拼接
394358
val_a_global = np.concatenate(
395359
[t.numpy() for t in gathered_list], axis=concat_axis
396360
)
397361
else:
398-
# 非分布式参数 (或者 Sharding Only 参数),本地即全量
399362
val_a_global = val_a_local
400363

401-
# 4. 计算误差
402-
# 注意:由于 O2 精度问题,对比时可能需要 cast 到 float32
403-
diff = np.abs(val_a_global - val_b)
404-
max_abs_err = diff.max()
405-
max_rel_err = max_abs_err / (np.abs(val_b).max() + 1e-9)
406-
407-
# 5. 打印与断言
408-
# 只在 Rank 0 打印,避免刷屏
409-
if dist.get_rank() == 0:
410-
print(
411-
f"Param: {name} | Type: {'TP' if is_tp_param else 'Global'} | Shape: {val_a_global.shape}"
412-
)
413-
print(f" Max Abs Error: {max_abs_err:.2e}")
414-
print(f" Max Rel Error: {max_rel_err:.2e}")
415-
416-
# 设置相对宽松的阈值 (因为 O2 累积误差)
364+
# Loose tolerance to account for O2 AMP accumulated error.
417365
np.testing.assert_allclose(
418366
val_a_global,
419367
val_b,
420368
rtol=1e-4,
421369
atol=1e-4,
422-
err_msg=f"Param {name} mismatch!",
370+
err_msg=f"Param {param_a.name} mismatch!",
423371
)
424372

425373
def test_sharding_muon(self):

0 commit comments

Comments
 (0)