2121import paddle
2222import paddle .distributed as dist
2323from paddle .distributed import fleet
24- from paddle .distributed .fleet .utils .log_util import logger
2524from paddle .distributed .fleet .utils .mix_precision_utils import (
2625 MixPrecisionLayer ,
2726 MixPrecisionOptimizer ,
4544
4645
4746class SimpleMPNet (paddle .nn .Layer ):
47+ """Tensor-parallel model: linear1 is column-split, linear2 is row-split."""
48+
4849 def __init__ (
4950 self ,
5051 vocab_size ,
@@ -70,31 +71,32 @@ def __init__(
7071 ),
7172 )
7273
74+ # Each TP rank owns a contiguous column slice of fc1 and row slice of fc2.
7375 inner_per_rank = inner_size // mp_degree
7476 fc1_start = mp_id * inner_per_rank
7577 fc1_end = fc1_start + inner_per_rank
7678
77- init_fc1_slice = np_fc1 [:, fc1_start :fc1_end ]
78-
7979 self .linear1 = fleet .meta_parallel .ColumnParallelLinear (
8080 hidden_size ,
8181 inner_size ,
8282 weight_attr = paddle .framework .ParamAttr (
83- initializer = paddle .nn .initializer .Assign (init_fc1_slice )
83+ initializer = paddle .nn .initializer .Assign (
84+ np_fc1 [:, fc1_start :fc1_end ]
85+ )
8486 ),
85- gather_output = False , # 关键:输出保持切分状态,不聚合,直接喂给下一层 RowParallel
87+ gather_output = False , # keep output sharded for RowParallel input
8688 has_bias = True ,
8789 )
8890
89- init_fc2_slice = np_fc2 [fc1_start :fc1_end , :]
90-
9191 self .linear2 = fleet .meta_parallel .RowParallelLinear (
9292 inner_size ,
9393 hidden_size ,
9494 weight_attr = paddle .framework .ParamAttr (
95- initializer = paddle .nn .initializer .Assign (init_fc2_slice )
95+ initializer = paddle .nn .initializer .Assign (
96+ np_fc2 [fc1_start :fc1_end , :]
97+ )
9698 ),
97- input_is_parallel = True , # 关键:告诉这一层,输入已经是切分过的了
99+ input_is_parallel = True , # input already sharded from ColumnParallel
98100 has_bias = True ,
99101 )
100102
@@ -119,6 +121,8 @@ def forward(self, x):
119121
120122
121123class SimpleDPNet (paddle .nn .Layer ):
124+ """Single-process reference model with identical weight initialisation."""
125+
122126 def __init__ (
123127 self ,
124128 vocab_size ,
@@ -209,8 +213,8 @@ def setUp(self):
209213 def train_batch (self , batch , model , optimizer ):
210214 output = model (batch )
211215 loss = output .mean ()
212- loss .backward () # do backward
213- optimizer .step () # update parameters
216+ loss .backward ()
217+ optimizer .step ()
214218 optimizer .clear_grad ()
215219 return loss
216220
@@ -239,7 +243,6 @@ def build_optimizer(self, model, strategy=None, Optimizer="adam"):
239243 return optimizer
240244
241245 def build_model_optimizer (self , Optimizer = "adam" , amp_level = None ):
242-
243246 np_fc1 = np .random .random_sample ((hidden_size , inner_size )).astype (
244247 "float32"
245248 )
@@ -253,6 +256,7 @@ def build_model_optimizer(self, Optimizer="adam", amp_level=None):
253256 "float32"
254257 )
255258
259+ # model_a: sharded+TP distributed model; model_b: single-process reference
256260 model_a = SimpleMPNet (
257261 vocab_size ,
258262 hidden_size ,
@@ -302,79 +306,45 @@ def sharding_model(self, Optimizer, sharded_accumulators, amp_level=None):
302306 )
303307
304308 hcg = fleet .get_hybrid_communicate_group ()
305- # degree
306309 mp_degree = self .strategy .hybrid_configs ['mp_degree' ]
307310 sharding_degree = self .strategy .hybrid_configs ['sharding_degree' ]
308- # rank
309311 rank = dist .get_rank ()
310312 sharding_rank = hcg .get_sharding_parallel_rank ()
311313 tp_rank = hcg .get_model_parallel_rank ()
312- # data size
313314 local_batch_size = batch_size // sharding_degree
315+ tp_group = hcg .get_model_parallel_group ()
314316
315317 for idx in range (STEPS ):
316- print (f"Step = { idx } " )
317-
318+ # Each sharding rank processes its own mini-batch slice.
318319 start_index = sharding_rank * local_batch_size
319320 end_index = (sharding_rank + 1 ) * local_batch_size
320-
321321 batch_sharding = paddle .to_tensor (
322322 self .data [idx ][start_index :end_index ]
323323 )
324- logger .info (
325- f"rank = { rank } , sharding_rank = { sharding_rank } , tp_rank = { tp_rank } start_index = { start_index } , end_index = { end_index } "
326- )
327-
328324 batch_single = paddle .to_tensor (self .data [idx ])
325+
329326 loss_a = self .train_batch (batch_sharding , model_a , optimizer_a )
330327 loss_b = self .train_batch (batch_single , model_b , optimizer_b )
331328
332- # Reduce loss
329+ # Average loss across all ranks for a fair global comparison.
333330 loss_a_metric = loss_a .detach ().clone ()
334331 dist .all_reduce (loss_a_metric , op = dist .ReduceOp .SUM )
335- world_size = dist .get_world_size ()
336- loss_a_global = loss_a_metric / world_size
337-
338- abs_err = np .abs (loss_a_global .numpy () - loss_b .numpy ())
339- rel_err = abs_err / (np .abs (loss_b .numpy ()) + 1e-9 )
332+ loss_a_global = loss_a_metric / dist .get_world_size ()
340333
341- print (
342- f"step { idx } , loss_a(local)={ loss_a .numpy ()} , loss_a(global)={ loss_a_global .numpy ()} , loss_b={ loss_b .numpy ()} "
343- )
344- print (f"abs_error = { abs_err } , rel_error = { rel_err } " )
345-
346- print (f"\n --- Checking Parameters at Step { idx } ---" )
347-
348- # 获取参数列表 (假设顺序是一致的)
349- params_a = model_a .parameters ()
350- params_b = model_b .parameters ()
351-
352- tp_group = hcg .get_model_parallel_group ()
353-
354- for i , (param_a , param_b ) in enumerate (zip (params_a , params_b )):
355- name = param_a .name
356-
357- # 1. 获取本地参数值
334+ # Compare each parameter between the distributed and reference model.
335+ for param_a , param_b in zip (
336+ model_a .parameters (), model_b .parameters ()
337+ ):
358338 val_a_local = param_a .numpy ()
359339 val_b = param_b .numpy ()
360340
361- # 2. 判断是否是 TP 参数 (通过形状是否一致判断)
362- is_tp_param = val_a_local .shape != val_b .shape
363-
364- if is_tp_param :
365- # 分布式参数:需要 Gather
341+ # TP parameters are split across TP ranks; gather before comparing.
342+ if val_a_local .shape != val_b .shape :
366343 gathered_list = []
367344 paddle .distributed .all_gather (
368345 gathered_list , param_a , group = tp_group
369346 )
370-
371- # 3. 拼接
372- # 关键:ColumnParallelLinear 是按列切分 (axis=1)
373- # RowParallelLinear 是按行切分 (axis=0)
374- # VocabParallelEmbedding 是按行切分 (axis=0)
375-
376- # 简单的启发式判断拼接维度:
377- # 看看 val_b 的哪个维度是 val_a 的 mp_degree 倍
347+ # Determine the split axis: find which dim is mp_degree times smaller.
378348 concat_axis = - 1
379349 for dim in range (len (val_b .shape )):
380350 if (
@@ -383,43 +353,21 @@ def sharding_model(self, Optimizer, sharded_accumulators, amp_level=None):
383353 ):
384354 concat_axis = dim
385355 break
386-
387356 if concat_axis == - 1 :
388- print (
389- f"[Warning] Param { name } shape mismatch but axis not found. Skip."
390- )
391357 continue
392-
393- # 拼接
394358 val_a_global = np .concatenate (
395359 [t .numpy () for t in gathered_list ], axis = concat_axis
396360 )
397361 else :
398- # 非分布式参数 (或者 Sharding Only 参数),本地即全量
399362 val_a_global = val_a_local
400363
401- # 4. 计算误差
402- # 注意:由于 O2 精度问题,对比时可能需要 cast 到 float32
403- diff = np .abs (val_a_global - val_b )
404- max_abs_err = diff .max ()
405- max_rel_err = max_abs_err / (np .abs (val_b ).max () + 1e-9 )
406-
407- # 5. 打印与断言
408- # 只在 Rank 0 打印,避免刷屏
409- if dist .get_rank () == 0 :
410- print (
411- f"Param: { name } | Type: { 'TP' if is_tp_param else 'Global' } | Shape: { val_a_global .shape } "
412- )
413- print (f" Max Abs Error: { max_abs_err :.2e} " )
414- print (f" Max Rel Error: { max_rel_err :.2e} " )
415-
416- # 设置相对宽松的阈值 (因为 O2 累积误差)
364+ # Loose tolerance to account for O2 AMP accumulated error.
417365 np .testing .assert_allclose (
418366 val_a_global ,
419367 val_b ,
420368 rtol = 1e-4 ,
421369 atol = 1e-4 ,
422- err_msg = f"Param { name } mismatch!" ,
370+ err_msg = f"Param { param_a . name } mismatch!" ,
423371 )
424372
425373 def test_sharding_muon (self ):
0 commit comments