Skip to content

Commit e69cc21

Browse files
authored
[cherry-pick][hybrid performance] optim the grad fuse for pipeline mode by sorting the grad by dtype (#35070) (#35300)
1 parent e931cd1 commit e69cc21

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

python/paddle/fluid/optimizer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5223,6 +5223,9 @@ def _accumulate_gradients_with_fuse(self, main_block, fp16, fused_size):
52235223
if len(grad_param_pairs) == 0:
52245224
return
52255225

5226+
grad_param_pairs = self._sort_grad_param_by_dtype(main_block,
5227+
grad_param_pairs)
5228+
52265229
grad_param_segments = []
52275230
merged_suffix = '@MERGED@FP16' if fp16 else '@MERGED'
52285231
dtype = paddle.float16 if fp16 else paddle.float32
@@ -5416,6 +5419,24 @@ def _accumulate_gradients_with_fuse(self, main_block, fp16, fused_size):
54165419

54175420
return fused_merged_gradients
54185421

5422+
def _sort_grad_param_by_dtype(self, main_block, grad_param_pairs):
5423+
# sort the grad param paris by the dtype
5424+
fp16_pairs = []
5425+
fp32_pairs = []
5426+
other_pairs = []
5427+
for pairs in grad_param_pairs:
5428+
dtype = main_block.var(pairs[0]).dtype
5429+
if dtype == paddle.float32:
5430+
fp32_pairs.append(pairs)
5431+
elif dtype == paddle.float16:
5432+
fp16_pairs.append(pairs)
5433+
else:
5434+
other_pairs.append(pairs)
5435+
sorted_pairs = fp16_pairs
5436+
sorted_pairs.extend(fp32_pairs)
5437+
sorted_pairs.extend(other_pairs)
5438+
return sorted_pairs
5439+
54195440
def _get_var_size(self, var):
54205441
dtype_to_size = {
54215442
core.VarDesc.VarType.FP16: 2,

0 commit comments

Comments
 (0)