Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 49 additions & 13 deletions python/paddle/fluid/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,46 @@
import warnings

import functools
import paddle
from . import layers
from . import framework
from . import core
from . import name_scope
from .dygraph import base as imperative_base
from .data_feeder import check_variable_and_dtype
from .framework import in_dygraph_mode
from .layer_helper import LayerHelper

__all__ = [
'set_gradient_clip', 'ErrorClipByValue', 'ClipGradByValue',
'ClipGradByNorm', 'ClipGradByGlobalNorm'
]


def _squared_l2_norm(x):
r"""
This OP returns the squared L2 norm of a tensor.
"""

if core.is_compiled_with_npu() or core.is_compiled_with_xpu():
square = layers.square(x)
sum_square = layers.reduce_sum(square)
return sum_square

if in_dygraph_mode():
return core.ops.squared_l2_norm(x)

op_type = 'squared_l2_norm'
check_variable_and_dtype(x, 'x', ['float32'], op_type)
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(x.dtype)

inputs = {"X": x}
outputs = {'Out': out}
helper.append_op(type=op_type, inputs=inputs, outputs=outputs)
return out


class BaseErrorClipAttr(object):
def __str__(self):
raise NotImplementedError()
Expand Down Expand Up @@ -416,8 +444,8 @@ def _dygraph_clip(self, params_grads):
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
square = layers.square(merge_grad)
sum_square = layers.reduce_sum(square)

sum_square = _squared_l2_norm(merge_grad)
sum_square_list.append(sum_square)

# all parameters have been filterd out
Expand All @@ -439,6 +467,7 @@ def _dygraph_clip(self, params_grads):
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
# TODO(wangxi): use inplace elementwise_mul
new_grad = layers.elementwise_mul(x=g, y=clip_var)
params_and_grads.append((p, new_grad))

Expand All @@ -460,8 +489,7 @@ def _static_clip(self, params_grads):
merge_grad = layers.get_tensor_from_selected_rows(
merge_grad)

square = layers.square(merge_grad)
sum_square = layers.reduce_sum(input=square)
sum_square = _squared_l2_norm(merge_grad)
sum_square_list.append(sum_square)

# all parameters have been filterd out
Expand Down Expand Up @@ -489,9 +517,14 @@ def _static_clip(self, params_grads):
continue

with p.block.program._optimized_guard([p, g]):
new_grad = layers.elementwise_mul(x=g, y=scale_var)
param_new_grad_name_dict[p.name] = new_grad.name
params_and_grads.append((p, new_grad))
# inplace
p.block.append_op(
type='elementwise_mul',
inputs={'X': g,
'Y': scale_var},
outputs={'Out': g})
param_new_grad_name_dict[p.name] = g.name
params_and_grads.append((p, g))

_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads
Expand All @@ -513,8 +546,7 @@ def _process_context(self, context, param, grad):
merge_grad = layers.merge_selected_rows(grad)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)

square = layers.square(merge_grad)
local_norm_var = layers.reduce_sum(input=square)
local_norm_var = _squared_l2_norm(merge_grad)
context[self.group_name].append(local_norm_var)

self.context = context
Expand All @@ -532,10 +564,14 @@ def _create_operators(self, param, grad):
assert group_scale_var.shape == (1, )
self.context[group_scale_name] = group_scale_var

new_grad = layers.elementwise_mul(
x=grad, y=self.context[group_scale_name])
# inplace
param.block.append_op(
type='elementwise_mul',
inputs={'X': grad,
'Y': self.context[group_scale_name]},
outputs={'Out': grad})

return param, new_grad
return param, grad


@framework.dygraph_not_support
Expand Down Expand Up @@ -709,7 +745,7 @@ def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict):
continue
block_id_list.append(block_id)
for op in param.block.program.global_block().ops:
if 'op_namescope' in op.all_attrs() and "gradient_clip" in op.attr(
if op.has_attr("op_namescope") and "gradient_clip" in op.attr(
"op_namescope") and op.attr('op_role_var'):
param_name = op.attr('op_role_var')[0]
if param_name in param_new_grad_name_dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def test_sharding_gradient_clip(self):
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'square',
'reduce_sum', 'square', 'reduce_sum', 'square', 'reduce_sum', 'sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm', 'sum',
'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
'elementwise_div', 'elementwise_mul', 'elementwise_mul',
'elementwise_mul', 'momentum', 'momentum', 'momentum'
Expand Down
19 changes: 13 additions & 6 deletions python/paddle/fluid/tests/unittests/test_gradient_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import six
from fake_reader import fake_imdb_reader

paddle.enable_static()


def bow_net(data,
label,
Expand Down Expand Up @@ -149,7 +151,7 @@ def clip_gradient(self, params_grads):
def check_clip_result(self, out, out_clip):
global_norm = 0
for v in out:
global_norm += np.sum(np.power(v, 2))
global_norm += np.sum(np.square(v))
global_norm = np.sqrt(global_norm)
scale = self.clip_norm / np.maximum(self.clip_norm, global_norm)
res = []
Expand All @@ -160,7 +162,8 @@ def check_clip_result(self, out, out_clip):
self.assertTrue(
np.allclose(
a=u, b=v, rtol=1e-5, atol=1e-8),
"gradient clip by global norm has wrong results!")
"gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}".
format(u, v, u - v))

# test whether the ouput is right when use 'set_gradient_clip'
def test_old_gradient_clip(self):
Expand Down Expand Up @@ -210,12 +213,16 @@ def test_none_grad(self):
params_grads = [(x, None), (x, y), (y, x)]
params_grads = clip(params_grads)
self.assertTrue(
len(clip(params_grads)) == 2,
len(params_grads) == 2,
"ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!"
)
self.assertTrue(
params_grads[0][1].name != 'y',
"ClipByGlobalNorm: param_grad (x, y) should be clipped!")

ops = [op.type for op in x.block.ops]
self.assertListEqual(ops, [
'squared_l2_norm', 'squared_l2_norm', 'sum', 'sqrt',
'fill_constant', 'elementwise_max', 'elementwise_div',
'elementwise_mul', 'elementwise_mul'
])

# raise typeError
def test_tpyeError(self):
Expand Down