Skip to content
12 changes: 9 additions & 3 deletions python/paddle/nn/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def clip_by_norm(x, max_norm, name=None):
[0.50000000, 0.50000000]])
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.clip_by_norm(x, max_norm)

helper = LayerHelper("clip_by_norm", **locals())
Expand Down Expand Up @@ -528,8 +528,7 @@ def __init__(self, clip_norm):
def __str__(self):
return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm

@imperative_base.no_grad()
def _dygraph_clip(self, params_grads):
def _clip_gradients(self, params_grads):
params_and_grads = []
for p, g in params_grads:
if g is None:
Expand All @@ -541,6 +540,13 @@ def _dygraph_clip(self, params_grads):
params_and_grads.append((p, new_grad))
return params_and_grads

@imperative_base.no_grad()
def _dygraph_clip(self, params_grads):
return self._clip_gradients(params_grads)

def _pir_clip(self, params_grads):
return self._clip_gradients(params_grads)

def _static_clip(self, params_grads):
params_and_grads = []
with framework.name_scope('gradient_clip'):
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/nn/functional/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def channel_shuffle(x, groups, data_format="NCHW", name=None):
f"But recevie Attr(data_format): {data_format} "
)

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
return _C_ops.channel_shuffle(x, groups, data_format)

helper = LayerHelper("channel_shuffle", **locals())
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/pir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,15 @@ def create_parameter(
**kwargs,
):
regularizer = None
need_clip = None
if 'initializer' not in kwargs:
raise ValueError(
"initializer is None, if you want to create parameter, please pass its initializer."
)
if 'regularizer' in kwargs:
regularizer = kwargs['regularizer']
if 'need_clip' in kwargs:
need_clip = kwargs['need_clip']
if dtype is not None:
if not isinstance(dtype, DataType):
dtype = convert_np_dtype_to_dtype_(dtype)
Expand Down Expand Up @@ -308,6 +311,7 @@ def create_parameter(
param.persistable = True

param.regularizer = regularizer
param.need_clip = need_clip
return param


Expand Down
184 changes: 110 additions & 74 deletions test/legacy_test/test_channel_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

import paddle
import paddle.nn.functional as F
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def channel_shuffle_np(x, groups, data_format="NCHW"):
Expand Down Expand Up @@ -71,10 +71,10 @@ def init_data_format(self):
self.format = "NCHW"

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)


class TestChannelLast(TestChannelShuffleOp):
Expand All @@ -84,84 +84,122 @@ def init_data_format(self):

class TestChannelShuffleAPI(unittest.TestCase):
def setUp(self):
self.x_1_np = np.random.random([2, 9, 4, 4]).astype("float64")
self.x_2_np = np.random.random([2, 4, 4, 9]).astype("float64")
self.out_1_np = channel_shuffle_np(self.x_1_np, 3)
self.out_2_np = channel_shuffle_np(self.x_2_np, 3, "NHWC")
self.x_1_np = np.random.random([2, 9, 4, 4]).astype("float64")
self.out_1_np = channel_shuffle_np(self.x_1_np, 3)

@test_with_pir_api
def test_static_graph_functional(self):
for use_cuda in (
[False, True] if core.is_compiled_with_cuda() else [False]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()

paddle.enable_static()
x_1 = paddle.static.data(
name="x", shape=[2, 9, 4, 4], dtype="float64"
)
x_2 = paddle.static.data(
name="x2", shape=[2, 4, 4, 9], dtype="float64"
)
out_1 = F.channel_shuffle(x_1, 3)
out_2 = F.channel_shuffle(x_2, 3, "NHWC")

exe = paddle.static.Executor(place=place)
res_1 = exe.run(
base.default_main_program(),
feed={"x": self.x_1_np},
fetch_list=out_1,
use_prune=True,
)

res_2 = exe.run(
base.default_main_program(),
feed={"x2": self.x_2_np},
fetch_list=out_2,
use_prune=True,
)
for use_cuda in (
[False, True] if core.is_compiled_with_cuda() else [False]
):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()

paddle.enable_static()
x_1 = paddle.static.data(
name="x", shape=[2, 9, 4, 4], dtype="float64"
)
out_1 = F.channel_shuffle(x_1, 3)

exe = paddle.static.Executor(place=place)
res_1 = exe.run(
paddle.static.default_main_program(),
feed={"x": self.x_1_np},
fetch_list=out_1,
use_prune=True,
)

np.testing.assert_allclose(res_1[0], self.out_1_np)
np.testing.assert_allclose(res_2[0], self.out_2_np)
np.testing.assert_allclose(res_1[0], self.out_1_np)

# same test between layer and functional in this op.
@test_with_pir_api
def test_static_graph_layer(self):
for use_cuda in (
[False, True] if core.is_compiled_with_cuda() else [False]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
for use_cuda in (
[False, True] if core.is_compiled_with_cuda() else [False]
):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()

paddle.enable_static()
x_1 = paddle.static.data(
name="x", shape=[2, 9, 4, 4], dtype="float64"
)
# init instance
ps_1 = paddle.nn.ChannelShuffle(3)
out_1 = ps_1(x_1)
out_1_np = channel_shuffle_np(self.x_1_np, 3)

exe = paddle.static.Executor(place=place)
res_1 = exe.run(
paddle.static.default_main_program(),
feed={"x": self.x_1_np},
fetch_list=out_1,
use_prune=True,
)

paddle.enable_static()
x_1 = paddle.static.data(
name="x", shape=[2, 9, 4, 4], dtype="float64"
)
x_2 = paddle.static.data(
name="x2", shape=[2, 4, 4, 9], dtype="float64"
)
# init instance
ps_1 = paddle.nn.ChannelShuffle(3)
ps_2 = paddle.nn.ChannelShuffle(3, "NHWC")
out_1 = ps_1(x_1)
out_2 = ps_2(x_2)
out_1_np = channel_shuffle_np(self.x_1_np, 3)
out_2_np = channel_shuffle_np(self.x_2_np, 3, "NHWC")

exe = paddle.static.Executor(place=place)
res_1 = exe.run(
base.default_main_program(),
feed={"x": self.x_1_np},
fetch_list=out_1,
use_prune=True,
)
np.testing.assert_allclose(res_1[0], out_1_np)

res_2 = exe.run(
base.default_main_program(),
feed={"x2": self.x_2_np},
fetch_list=out_2,
use_prune=True,
)
@test_with_pir_api
def test_static_graph_functional_new(self):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
for use_cuda in (
[False, True] if core.is_compiled_with_cuda() else [False]
):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()

paddle.enable_static()
x_2 = paddle.static.data(
name="x2", shape=[2, 4, 4, 9], dtype="float64"
)
out_2 = F.channel_shuffle(x_2, 3, "NHWC")

exe = paddle.static.Executor(place=place)
res_2 = exe.run(
paddle.static.default_main_program(),
feed={"x2": self.x_2_np},
fetch_list=out_2,
use_prune=True,
)

np.testing.assert_allclose(res_2[0], self.out_2_np)

@test_with_pir_api
def test_static_graph_layer_new(self):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
for use_cuda in (
[False, True] if core.is_compiled_with_cuda() else [False]
):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()

paddle.enable_static()
x_2 = paddle.static.data(
name="x2", shape=[2, 4, 4, 9], dtype="float64"
)
# init instance
ps_2 = paddle.nn.ChannelShuffle(3, "NHWC")
out_2 = ps_2(x_2)
out_2_np = channel_shuffle_np(self.x_2_np, 3, "NHWC")

exe = paddle.static.Executor(place=place)

res_2 = exe.run(
paddle.static.default_main_program(),
feed={"x2": self.x_2_np},
fetch_list=out_2,
use_prune=True,
)

np.testing.assert_allclose(res_1[0], out_1_np)
np.testing.assert_allclose(res_2[0], out_2_np)
np.testing.assert_allclose(res_2[0], out_2_np)

def run_dygraph(self, groups, data_format):
n, c, h, w = 2, 9, 4, 4
Expand Down Expand Up @@ -209,6 +247,7 @@ def test_dygraph2(self):


class TestChannelShuffleError(unittest.TestCase):
@test_with_pir_api
def test_error_functional(self):
def error_input():
with paddle.base.dygraph.guard():
Expand Down Expand Up @@ -240,6 +279,7 @@ def error_data_format():

self.assertRaises(ValueError, error_data_format)

@test_with_pir_api
def test_error_layer(self):
def error_input_layer():
with paddle.base.dygraph.guard():
Expand Down Expand Up @@ -308,15 +348,11 @@ def init_data_format(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X'],
'Out',
)
self.check_grad_with_place(place, ['X'], 'Out', check_pir=True)


if __name__ == '__main__':
Expand Down
8 changes: 5 additions & 3 deletions test/legacy_test/test_clip_by_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import paddle
from paddle.base import core
from paddle.nn import clip
from paddle.pir_utils import test_with_pir_api


class TestClipByNormOp(OpTest):
Expand All @@ -45,7 +46,7 @@ def setUp(self):
self.outputs = {'Out': output}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def initTestCase(self):
self.shape = (100,)
Expand Down Expand Up @@ -81,7 +82,7 @@ def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=0.001)
self.check_output_with_place(place, atol=0.001, check_pir=True)


class TestClipByNormOpFp16Case1(TestClipByNormOpFp16):
Expand Down Expand Up @@ -133,7 +134,7 @@ def setUp(self):
self.place = core.CUDAPlace(0)

def test_check_output(self):
self.check_output_with_place(self.place)
self.check_output_with_place(self.place, check_pir=True)

def initTestCase(self):
self.shape = (100,)
Expand Down Expand Up @@ -186,6 +187,7 @@ def check_with_place(self, place):
equal_nan=False,
)

@test_with_pir_api
def test_clip_by_norm_with_selected_ros(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
Expand Down
Loading