Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -738,14 +738,13 @@ def insert_scale_loss_grad_ops(block, scale=1.0):
'''
for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op):
loss_grad_var = block.vars[op.output_arg_names[0]]
block._insert_op_without_sync(
idx + 1,
type='scale',
inputs={'X': loss_grad_var},
outputs={'Out': loss_grad_var},
attrs={'scale': scale,
OP_ROLE_KEY: OpRole.Backward})
assert op.type == 'fill_constant', \
"loss_grad_op must be fill_constant op, " \
"but this op is {}".format(op.type)
assert op.has_attr('value')
loss_scale = float(op.attr('value'))
loss_scale = loss_scale / scale
op._set_attr('value', loss_scale)
break


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def _insert_loss_grad_scale_op(self):
global_dp_degree = self.sharding_degree * self.dp_degree
assert int(global_dp_degree) == global_dp_degree
if global_dp_degree > 1:
insert_scale_loss_grad_ops(main_block, scale=1.0 / global_dp_degree)
insert_scale_loss_grad_ops(main_block, scale=global_dp_degree)

main_block._sync_with_cpp()

Expand Down
17 changes: 7 additions & 10 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5019,16 +5019,13 @@ def _insert_loss_scale(self, block):
if self._num_microbatches == 1: return
for index, op in reversed(tuple(enumerate(list(block.ops)))):
if self._is_loss_grad_op(op):
loss_grad_var = block.vars[op.output_arg_names[0]]
block._insert_op(
index=index + 1,
type='scale',
inputs={'X': loss_grad_var},
outputs={'Out': loss_grad_var},
attrs={
'scale': 1.0 / self._num_microbatches,
self._op_role_key: self._op_role.Backward
})
assert op.type == 'fill_constant', \
"loss_grad_op must be fill_constant op, " \
"but this op is {}".format(op.type)
assert op.has_attr('value')
loss_scale = float(op.attr('value'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind the potential precision loss here. fill_constant op will cast the value into fp32 and then save as string into its OpDesc. reload & reset this value might cause precision loss when the denominator is odd (3,7, 11, etc)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
image
loss_grad_op(that is fill_constant) use value instead of str_value, which AttrType is float. value will be saved with float32 in protobuf. If we encount precision problem, I this this must be caused by float AttrType, and I think double AttrType is better, which framework does not provide.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
the type of op.attr('value') is already float64 in python, add float(op.attr('value')) is only for explicit.

loss_scale = loss_scale / self._num_microbatches
op._set_attr('value', loss_scale)
break

def _rename_gradient_var_name(self, block):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import unittest

from fleet_meta_optimizer_base import TestFleetMetaOptimizer
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op

paddle.enable_static()

Expand Down Expand Up @@ -77,10 +78,10 @@ def test_opt_sharding_with_pp(self):
'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul',
'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
'fill_constant', 'scale', 'scale', 'mean_grad',
'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad',
'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad',
'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'fill_constant', 'mean_grad', 'cross_entropy_grad2', 'softmax_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2',
'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant',
'sum', 'fill_constant', 'sum', 'fill_constant', 'sum',
Expand Down Expand Up @@ -158,10 +159,10 @@ def test_opt_sharding_with_pp_with_allreduce_fuse(self):
'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul',
'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
'fill_constant', 'scale', 'scale', 'mean_grad',
'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad',
'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad',
'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'fill_constant', 'mean_grad', 'cross_entropy_grad2', 'softmax_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2',
'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant',
'sum', 'fill_constant', 'sum', 'fill_constant', 'sum',
Expand Down Expand Up @@ -220,8 +221,8 @@ def test_opt_sharding_with_pp_amp_gclip(self):
'tanh', 'cast', 'cast', 'mul', 'cast', 'elementwise_add', 'cast',
'tanh', 'cast', 'cast', 'mul', 'cast', 'elementwise_add', 'softmax',
'cast', 'cross_entropy2', 'mean', 'elementwise_mul',
'fill_constant', 'scale', 'scale', 'elementwise_mul_grad',
'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad',
'fill_constant', 'elementwise_mul_grad', 'mean_grad',
'cross_entropy_grad2', 'cast', 'softmax_grad',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
Expand Down Expand Up @@ -293,23 +294,23 @@ def test_opt_sharding_with_pp_amp_gclip_fuse_gm(self):
'tanh', 'cast', 'cast', 'mul', 'cast', 'elementwise_add', 'softmax',
'cast', 'cross_entropy2', 'mean', 'elementwise_mul',
'coalesce_tensor', 'coalesce_tensor', 'coalesce_tensor',
'coalesce_tensor', 'fill_constant', 'scale', 'scale',
'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', 'cast',
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast',
'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast',
'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast',
'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast',
'c_sync_calc_stream', 'send_v2', 'cast', 'sum', 'cast', 'sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
'check_finite_and_unscale', 'cast', 'c_allreduce_max',
'c_allreduce_max', 'cast', 'update_loss_scaling', 'squared_l2_norm',
'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm',
'squared_l2_norm', 'sum', 'c_allreduce_sum', 'c_allreduce_sum',
'sqrt', 'fill_constant', 'elementwise_max', 'elementwise_div',
'coalesce_tensor', 'fill_constant', 'elementwise_mul_grad',
'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'c_sync_calc_stream',
'send_v2', 'cast', 'sum', 'cast', 'sum', 'c_reduce_sum',
'c_reduce_sum', 'c_sync_comm_stream', 'check_finite_and_unscale',
'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast',
'update_loss_scaling', 'squared_l2_norm', 'squared_l2_norm',
'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm', 'sum',
'c_allreduce_sum', 'c_allreduce_sum', 'sqrt', 'fill_constant',
'elementwise_max', 'elementwise_div', 'elementwise_mul',
'elementwise_mul', 'elementwise_mul', 'elementwise_mul',
'elementwise_mul', 'elementwise_mul', 'momentum', 'momentum',
'momentum', 'momentum', 'momentum', 'coalesce_tensor',
'c_broadcast', 'coalesce_tensor', 'c_broadcast'
'elementwise_mul', 'momentum', 'momentum', 'momentum', 'momentum',
'momentum', 'coalesce_tensor', 'c_broadcast', 'coalesce_tensor',
'c_broadcast'
])


Expand All @@ -327,7 +328,10 @@ def setUp(self):
self._debug = False

def test_opt_sharding_with_pp_amp_gclip_boundary(self):
""" test optimizer sharding without parameter """
"""
test optimizer sharding without parameter
test loss grad scale value
"""
train_prog, startup_prog = static.Program(), static.Program()
avg_cost, strategy = self.boundary_net(train_prog, startup_prog)

Expand Down Expand Up @@ -357,6 +361,16 @@ def test_opt_sharding_with_pp_amp_gclip_boundary(self):
startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops]

# check loss scale for hybrid
for op in main_prog_ops:
if is_loss_grad_op(op):
self.assertEqual(op.type, 'fill_constant')
self.assertTrue(op.has_attr('value'))
scale = strategy.pipeline_configs[
'accumulate_steps'] * strategy.sharding_configs['dp_degree']
loss_scale = 1.0 / scale
self.assertAlmostEqual(float(op.attr('value')), loss_scale)

# global, sharding, pp_send, pp_recv
self.assertEqual(startup_prog_op_types, [
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
Expand All @@ -367,14 +381,13 @@ def test_opt_sharding_with_pp_amp_gclip_boundary(self):

self.assertEqual(main_prog_op_types, [
'recv_v2', 'cast', 'matmul', 'cast', 'reduce_mean',
'elementwise_mul', 'fill_constant', 'scale', 'scale',
'elementwise_mul_grad', 'reduce_mean_grad', 'cast', 'matmul_grad',
'c_sync_calc_stream', 'send_v2', 'fill_constant', 'cast', 'sum',
'c_reduce_sum', 'c_sync_comm_stream', 'check_finite_and_unscale',
'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast',
'update_loss_scaling', 'fill_constant', 'c_allreduce_sum',
'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
'elementwise_div', 'c_broadcast'
'elementwise_mul', 'fill_constant', 'elementwise_mul_grad',
'reduce_mean_grad', 'cast', 'matmul_grad', 'c_sync_calc_stream',
'send_v2', 'fill_constant', 'cast', 'sum', 'c_reduce_sum',
'c_sync_comm_stream', 'check_finite_and_unscale', 'cast',
'c_allreduce_max', 'c_allreduce_max', 'cast', 'update_loss_scaling',
'fill_constant', 'c_allreduce_sum', 'c_allreduce_sum', 'sqrt',
'fill_constant', 'elementwise_max', 'elementwise_div', 'c_broadcast'
])

def test_opt_sharding_with_pp_amp_gclip_boundary_card1(self):
Expand Down Expand Up @@ -419,14 +432,14 @@ def test_opt_sharding_with_pp_amp_gclip_boundary_card1(self):

self.assertEqual(main_prog_op_types, [
'recv_v2', 'cast', 'matmul', 'cast', 'reduce_mean',
'elementwise_mul', 'fill_constant', 'scale', 'scale',
'elementwise_mul_grad', 'reduce_mean_grad', 'cast', 'matmul_grad',
'c_sync_calc_stream', 'send_v2', 'fill_constant', 'cast', 'sum',
'c_reduce_sum', 'c_sync_comm_stream', 'check_finite_and_unscale',
'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast',
'update_loss_scaling', 'squared_l2_norm', 'sum', 'c_allreduce_sum',
'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
'elementwise_div', 'elementwise_mul', 'momentum', 'c_broadcast'
'elementwise_mul', 'fill_constant', 'elementwise_mul_grad',
'reduce_mean_grad', 'cast', 'matmul_grad', 'c_sync_calc_stream',
'send_v2', 'fill_constant', 'cast', 'sum', 'c_reduce_sum',
'c_sync_comm_stream', 'check_finite_and_unscale', 'cast',
'c_allreduce_max', 'c_allreduce_max', 'cast', 'update_loss_scaling',
'squared_l2_norm', 'sum', 'c_allreduce_sum', 'c_allreduce_sum',
'sqrt', 'fill_constant', 'elementwise_max', 'elementwise_div',
'elementwise_mul', 'momentum', 'c_broadcast'
])


Expand Down
Loading