Skip to content

Commit fb4f0ef

Browse files
authored
PIR API adaptor No.35、40】 Migrate paddle.nn.ChannelShuffle/ClipGradByNorm into pir (#60445)
* fix some bugs * fix bugs * Update clip.py * Update test_channel_shuffle.py * Update test_clip_by_norm_op.py * Update test_clip_by_norm_op.py
1 parent 33cb1be commit fb4f0ef

6 files changed

Lines changed: 279 additions & 81 deletions

File tree

python/paddle/nn/clip.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def clip_by_norm(x, max_norm, name=None):
6868
[0.50000000, 0.50000000]])
6969
"""
7070

71-
if in_dynamic_mode():
71+
if in_dynamic_or_pir_mode():
7272
return _C_ops.clip_by_norm(x, max_norm)
7373

7474
helper = LayerHelper("clip_by_norm", **locals())
@@ -528,8 +528,7 @@ def __init__(self, clip_norm):
528528
def __str__(self):
529529
return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm
530530

531-
@imperative_base.no_grad()
532-
def _dygraph_clip(self, params_grads):
531+
def _clip_gradients(self, params_grads):
533532
params_and_grads = []
534533
for p, g in params_grads:
535534
if g is None:
@@ -541,6 +540,13 @@ def _dygraph_clip(self, params_grads):
541540
params_and_grads.append((p, new_grad))
542541
return params_and_grads
543542

543+
@imperative_base.no_grad()
544+
def _dygraph_clip(self, params_grads):
545+
return self._clip_gradients(params_grads)
546+
547+
def _pir_clip(self, params_grads):
548+
return self._clip_gradients(params_grads)
549+
544550
def _static_clip(self, params_grads):
545551
params_and_grads = []
546552
with framework.name_scope('gradient_clip'):

python/paddle/nn/functional/vision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def channel_shuffle(x, groups, data_format="NCHW", name=None):
519519
f"But recevie Attr(data_format): {data_format} "
520520
)
521521

522-
if in_dygraph_mode():
522+
if in_dynamic_or_pir_mode():
523523
return _C_ops.channel_shuffle(x, groups, data_format)
524524

525525
helper = LayerHelper("channel_shuffle", **locals())

python/paddle/pir/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,15 @@ def create_parameter(
275275
**kwargs,
276276
):
277277
regularizer = None
278+
need_clip = None
278279
if 'initializer' not in kwargs:
279280
raise ValueError(
280281
"initializer is None, if you want to create parameter, please pass its initializer."
281282
)
282283
if 'regularizer' in kwargs:
283284
regularizer = kwargs['regularizer']
285+
if 'need_clip' in kwargs:
286+
need_clip = kwargs['need_clip']
284287
if dtype is not None:
285288
if not isinstance(dtype, DataType):
286289
dtype = convert_np_dtype_to_dtype_(dtype)
@@ -308,6 +311,7 @@ def create_parameter(
308311
param.persistable = True
309312

310313
param.regularizer = regularizer
314+
param.need_clip = need_clip
311315
return param
312316

313317

test/legacy_test/test_channel_shuffle.py

Lines changed: 110 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020
import paddle
2121
import paddle.nn.functional as F
22-
from paddle import base
2322
from paddle.base import core
23+
from paddle.pir_utils import test_with_pir_api
2424

2525

2626
def channel_shuffle_np(x, groups, data_format="NCHW"):
@@ -71,10 +71,10 @@ def init_data_format(self):
7171
self.format = "NCHW"
7272

7373
def test_check_output(self):
74-
self.check_output()
74+
self.check_output(check_pir=True)
7575

7676
def test_check_grad(self):
77-
self.check_grad(['X'], 'Out')
77+
self.check_grad(['X'], 'Out', check_pir=True)
7878

7979

8080
class TestChannelLast(TestChannelShuffleOp):
@@ -84,84 +84,122 @@ def init_data_format(self):
8484

8585
class TestChannelShuffleAPI(unittest.TestCase):
8686
def setUp(self):
87-
self.x_1_np = np.random.random([2, 9, 4, 4]).astype("float64")
8887
self.x_2_np = np.random.random([2, 4, 4, 9]).astype("float64")
89-
self.out_1_np = channel_shuffle_np(self.x_1_np, 3)
9088
self.out_2_np = channel_shuffle_np(self.x_2_np, 3, "NHWC")
89+
self.x_1_np = np.random.random([2, 9, 4, 4]).astype("float64")
90+
self.out_1_np = channel_shuffle_np(self.x_1_np, 3)
9191

92+
@test_with_pir_api
9293
def test_static_graph_functional(self):
93-
for use_cuda in (
94-
[False, True] if core.is_compiled_with_cuda() else [False]
94+
with paddle.static.program_guard(
95+
paddle.static.Program(), paddle.static.Program()
9596
):
96-
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
97-
98-
paddle.enable_static()
99-
x_1 = paddle.static.data(
100-
name="x", shape=[2, 9, 4, 4], dtype="float64"
101-
)
102-
x_2 = paddle.static.data(
103-
name="x2", shape=[2, 4, 4, 9], dtype="float64"
104-
)
105-
out_1 = F.channel_shuffle(x_1, 3)
106-
out_2 = F.channel_shuffle(x_2, 3, "NHWC")
107-
108-
exe = paddle.static.Executor(place=place)
109-
res_1 = exe.run(
110-
base.default_main_program(),
111-
feed={"x": self.x_1_np},
112-
fetch_list=out_1,
113-
use_prune=True,
114-
)
115-
116-
res_2 = exe.run(
117-
base.default_main_program(),
118-
feed={"x2": self.x_2_np},
119-
fetch_list=out_2,
120-
use_prune=True,
121-
)
97+
for use_cuda in (
98+
[False, True] if core.is_compiled_with_cuda() else [False]
99+
):
100+
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
101+
102+
paddle.enable_static()
103+
x_1 = paddle.static.data(
104+
name="x", shape=[2, 9, 4, 4], dtype="float64"
105+
)
106+
out_1 = F.channel_shuffle(x_1, 3)
107+
108+
exe = paddle.static.Executor(place=place)
109+
res_1 = exe.run(
110+
paddle.static.default_main_program(),
111+
feed={"x": self.x_1_np},
112+
fetch_list=out_1,
113+
use_prune=True,
114+
)
122115

123-
np.testing.assert_allclose(res_1[0], self.out_1_np)
124-
np.testing.assert_allclose(res_2[0], self.out_2_np)
116+
np.testing.assert_allclose(res_1[0], self.out_1_np)
125117

126118
# same test between layer and functional in this op.
119+
@test_with_pir_api
127120
def test_static_graph_layer(self):
128-
for use_cuda in (
129-
[False, True] if core.is_compiled_with_cuda() else [False]
121+
with paddle.static.program_guard(
122+
paddle.static.Program(), paddle.static.Program()
130123
):
131-
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
124+
for use_cuda in (
125+
[False, True] if core.is_compiled_with_cuda() else [False]
126+
):
127+
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
128+
129+
paddle.enable_static()
130+
x_1 = paddle.static.data(
131+
name="x", shape=[2, 9, 4, 4], dtype="float64"
132+
)
133+
# init instance
134+
ps_1 = paddle.nn.ChannelShuffle(3)
135+
out_1 = ps_1(x_1)
136+
out_1_np = channel_shuffle_np(self.x_1_np, 3)
137+
138+
exe = paddle.static.Executor(place=place)
139+
res_1 = exe.run(
140+
paddle.static.default_main_program(),
141+
feed={"x": self.x_1_np},
142+
fetch_list=out_1,
143+
use_prune=True,
144+
)
132145

133-
paddle.enable_static()
134-
x_1 = paddle.static.data(
135-
name="x", shape=[2, 9, 4, 4], dtype="float64"
136-
)
137-
x_2 = paddle.static.data(
138-
name="x2", shape=[2, 4, 4, 9], dtype="float64"
139-
)
140-
# init instance
141-
ps_1 = paddle.nn.ChannelShuffle(3)
142-
ps_2 = paddle.nn.ChannelShuffle(3, "NHWC")
143-
out_1 = ps_1(x_1)
144-
out_2 = ps_2(x_2)
145-
out_1_np = channel_shuffle_np(self.x_1_np, 3)
146-
out_2_np = channel_shuffle_np(self.x_2_np, 3, "NHWC")
147-
148-
exe = paddle.static.Executor(place=place)
149-
res_1 = exe.run(
150-
base.default_main_program(),
151-
feed={"x": self.x_1_np},
152-
fetch_list=out_1,
153-
use_prune=True,
154-
)
146+
np.testing.assert_allclose(res_1[0], out_1_np)
155147

156-
res_2 = exe.run(
157-
base.default_main_program(),
158-
feed={"x2": self.x_2_np},
159-
fetch_list=out_2,
160-
use_prune=True,
161-
)
148+
@test_with_pir_api
149+
def test_static_graph_functional_new(self):
150+
with paddle.static.program_guard(
151+
paddle.static.Program(), paddle.static.Program()
152+
):
153+
for use_cuda in (
154+
[False, True] if core.is_compiled_with_cuda() else [False]
155+
):
156+
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
157+
158+
paddle.enable_static()
159+
x_2 = paddle.static.data(
160+
name="x2", shape=[2, 4, 4, 9], dtype="float64"
161+
)
162+
out_2 = F.channel_shuffle(x_2, 3, "NHWC")
163+
164+
exe = paddle.static.Executor(place=place)
165+
res_2 = exe.run(
166+
paddle.static.default_main_program(),
167+
feed={"x2": self.x_2_np},
168+
fetch_list=out_2,
169+
use_prune=True,
170+
)
171+
172+
np.testing.assert_allclose(res_2[0], self.out_2_np)
173+
174+
@test_with_pir_api
175+
def test_static_graph_layer_new(self):
176+
with paddle.static.program_guard(
177+
paddle.static.Program(), paddle.static.Program()
178+
):
179+
for use_cuda in (
180+
[False, True] if core.is_compiled_with_cuda() else [False]
181+
):
182+
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
183+
184+
paddle.enable_static()
185+
x_2 = paddle.static.data(
186+
name="x2", shape=[2, 4, 4, 9], dtype="float64"
187+
)
188+
# init instance
189+
ps_2 = paddle.nn.ChannelShuffle(3, "NHWC")
190+
out_2 = ps_2(x_2)
191+
out_2_np = channel_shuffle_np(self.x_2_np, 3, "NHWC")
192+
193+
exe = paddle.static.Executor(place=place)
194+
195+
res_2 = exe.run(
196+
paddle.static.default_main_program(),
197+
feed={"x2": self.x_2_np},
198+
fetch_list=out_2,
199+
use_prune=True,
200+
)
162201

163-
np.testing.assert_allclose(res_1[0], out_1_np)
164-
np.testing.assert_allclose(res_2[0], out_2_np)
202+
np.testing.assert_allclose(res_2[0], out_2_np)
165203

166204
def run_dygraph(self, groups, data_format):
167205
n, c, h, w = 2, 9, 4, 4
@@ -209,6 +247,7 @@ def test_dygraph2(self):
209247

210248

211249
class TestChannelShuffleError(unittest.TestCase):
250+
@test_with_pir_api
212251
def test_error_functional(self):
213252
def error_input():
214253
with paddle.base.dygraph.guard():
@@ -240,6 +279,7 @@ def error_data_format():
240279

241280
self.assertRaises(ValueError, error_data_format)
242281

282+
@test_with_pir_api
243283
def test_error_layer(self):
244284
def error_input_layer():
245285
with paddle.base.dygraph.guard():
@@ -308,15 +348,11 @@ def init_data_format(self):
308348

309349
def test_check_output(self):
310350
place = core.CUDAPlace(0)
311-
self.check_output_with_place(place)
351+
self.check_output_with_place(place, check_pir=True)
312352

313353
def test_check_grad(self):
314354
place = core.CUDAPlace(0)
315-
self.check_grad_with_place(
316-
place,
317-
['X'],
318-
'Out',
319-
)
355+
self.check_grad_with_place(place, ['X'], 'Out', check_pir=True)
320356

321357

322358
if __name__ == '__main__':

test/legacy_test/test_clip_by_norm_op.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import paddle
2222
from paddle.base import core
2323
from paddle.nn import clip
24+
from paddle.pir_utils import test_with_pir_api
2425

2526

2627
class TestClipByNormOp(OpTest):
@@ -45,7 +46,7 @@ def setUp(self):
4546
self.outputs = {'Out': output}
4647

4748
def test_check_output(self):
48-
self.check_output()
49+
self.check_output(check_pir=True)
4950

5051
def initTestCase(self):
5152
self.shape = (100,)
@@ -81,7 +82,7 @@ def test_check_output(self):
8182
if core.is_compiled_with_cuda():
8283
place = core.CUDAPlace(0)
8384
if core.is_float16_supported(place):
84-
self.check_output_with_place(place, atol=0.001)
85+
self.check_output_with_place(place, atol=0.001, check_pir=True)
8586

8687

8788
class TestClipByNormOpFp16Case1(TestClipByNormOpFp16):
@@ -133,7 +134,7 @@ def setUp(self):
133134
self.place = core.CUDAPlace(0)
134135

135136
def test_check_output(self):
136-
self.check_output_with_place(self.place)
137+
self.check_output_with_place(self.place, check_pir=True)
137138

138139
def initTestCase(self):
139140
self.shape = (100,)
@@ -186,6 +187,7 @@ def check_with_place(self, place):
186187
equal_nan=False,
187188
)
188189

190+
@test_with_pir_api
189191
def test_clip_by_norm_with_selected_ros(self):
190192
places = [core.CPUPlace()]
191193
if core.is_compiled_with_cuda():

0 commit comments

Comments
 (0)