Skip to content

Commit d55120d

Browse files
authored
[NPU] Support testing grad of NPU ops in OpTest (#31697)
1 parent e424712 commit d55120d

File tree

4 files changed

+38
-33
lines changed

4 files changed

+38
-33
lines changed

python/paddle/fluid/tests/unittests/npu/test_elementwise_add_op_npu.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -64,28 +64,28 @@ def init_axis(self):
6464
def test_check_output(self):
6565
self.check_output_with_place(self.place, check_dygraph=False)
6666

67-
# TODO(ascendrc): Test grad op after it is implemented.
68-
# def test_check_grad_normal(self):
69-
# self.check_grad_with_place(
70-
# self.place, ['X', 'Y'],
71-
# 'Out',
72-
# max_relative_error=0.006,
73-
# check_dygraph=False)
74-
#
75-
# def test_check_grad_ingore_x(self):
76-
# self.check_grad_with_place(
77-
# self.place, ['Y'],
78-
# 'Out',
79-
# no_grad_set=set("X"),
80-
# max_relative_error=0.006,
81-
# check_dygraph=False)
82-
#
83-
# def test_check_grad_ingore_y(self):
84-
# self.check_grad_with_place(
85-
# self.place, ['X'],
86-
# 'Out',
87-
# no_grad_set=set("Y"),
88-
# max_relative_error=0.006,check_dygraph=False)
67+
def test_check_grad_normal(self):
68+
self.check_grad_with_place(
69+
self.place, ['X', 'Y'],
70+
'Out',
71+
max_relative_error=0.006,
72+
check_dygraph=False)
73+
74+
def test_check_grad_ingore_x(self):
75+
self.check_grad_with_place(
76+
self.place, ['Y'],
77+
'Out',
78+
no_grad_set=set("X"),
79+
max_relative_error=0.006,
80+
check_dygraph=False)
81+
82+
def test_check_grad_ingore_y(self):
83+
self.check_grad_with_place(
84+
self.place, ['X'],
85+
'Out',
86+
no_grad_set=set("Y"),
87+
max_relative_error=0.006,
88+
check_dygraph=False)
8989

9090

9191
@unittest.skipIf(not paddle.is_compiled_with_npu(),
@@ -133,10 +133,6 @@ def test_static(self):
133133
True,
134134
msg="z_value = {}, but expected {}".format(z_value, z_expected))
135135

136-
def test_backward(self):
137-
# TODO(ascendrc): Test backward after add grad npu op implemented.
138-
pass
139-
140136

141137
@unittest.skipIf(not paddle.is_compiled_with_npu(),
142138
"core is not compiled with NPU")

python/paddle/fluid/tests/unittests/npu/test_pow_op_npu.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,9 @@ def init_dtype(self):
5252
def test_check_output(self):
5353
self.check_output_with_place(self.place, check_dygraph=False)
5454

55-
# TODO(ascendrc): Add grad test
56-
# def test_check_grad(self):
57-
# if self.dtype == np.float16:
58-
# return
59-
# self.check_grad(['X'], 'Out')
60-
#
55+
def test_check_grad(self):
56+
self.check_grad_with_place(
57+
self.place, ['X'], 'Out', check_dygraph=False)
6158

6259

6360
@unittest.skipIf(not paddle.is_compiled_with_npu(),

python/paddle/fluid/tests/unittests/npu/test_slice_op_npu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ def set_npu(self):
6262
def test_check_output(self):
6363
self.check_output_with_place(self.place, check_dygraph=False)
6464

65+
def test_check_grad_normal(self):
66+
self.check_grad_with_place(
67+
self.place, ['Input'], 'Out', check_dygraph=False)
68+
6569

6670
@unittest.skipIf(not paddle.is_compiled_with_npu(),
6771
"core is not compiled with NPU")

python/paddle/fluid/tests/unittests/op_test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1416,9 +1416,17 @@ def check_grad_with_place(self,
14161416
if not type(output_names) is list:
14171417
output_names = [output_names]
14181418

1419+
# FIXME: Replace numeric_place with place to calculate numeric_grads.
1420+
# NOTE(liym27): There is an unknown error when call op.run() on NPUPlace, which
1421+
# needs to be fixed.
1422+
if self.__class__.use_npu == True:
1423+
numeric_place = paddle.CPUPlace()
1424+
else:
1425+
numeric_place = place
1426+
14191427
numeric_grads = user_defined_grads or [
14201428
get_numeric_gradient(
1421-
place,
1429+
numeric_place,
14221430
self.scope,
14231431
self.op,
14241432
self.inputs,

0 commit comments

Comments
 (0)