Skip to content

Commit a65895c

Browse files
committed
update test_expand_v2_op_npu.py
1 parent 32097e8 commit a65895c

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

python/paddle/fluid/tests/unittests/npu/test_expand_v2_op_npu.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def test_check_output(self):
132132
self.check_output_with_place(self.place)
133133

134134
def test_check_grad(self):
135+
if self.dtype == np.float16:
136+
return
137+
135138
self.check_grad_with_place(self.place, ['X'], 'Out')
136139

137140

@@ -177,43 +180,45 @@ def test_check_grad(self):
177180

178181

179182
# Situation 4: input x is float16
180-
# don't support grad check for float16
181-
class TestExpandV2OpInteger(OpTest):
183+
# skip grad check for float16
184+
class TestExpandV2OpFloat(OpTest):
182185
def setUp(self):
183186
self.set_npu()
184187
self.place = paddle.NPUPlace(0)
185188
self.op_type = "expand_v2"
186189
self.dtype = np.float16
187-
self.ori_shape = (2, 4, 5)
190+
self.ori_shape = (2, 4, 20)
188191
self.inputs = {'X': np.random.random(self.ori_shape).astype(self.dtype)}
189-
self.attrs = {'shape': [2, 4, 5]}
192+
self.attrs = {'shape': [2, 4, 20]}
190193
output = np.tile(self.inputs['X'], (1, 1, 1))
191194
self.outputs = {'Out': output}
192195

193196
def set_npu(self):
194197
self.__class__.use_npu = True
198+
self.__class__.no_need_check_grad = True
195199

196200
def test_check_output(self):
197201
self.check_output_with_place(self.place)
198202

199203

200204
# Situation 5: input x is int32
201-
# ReduceSumD CANN Op doesn't support grad check for int32
205+
# skip grad check for int32
202206
class TestExpandV2OpInteger(OpTest):
203207
def setUp(self):
204208
self.set_npu()
205209
self.place = paddle.NPUPlace(0)
206210
self.op_type = "expand_v2"
207211
self.inputs = {
208212
'X': np.random.randint(
209-
10, size=(2, 4, 5)).astype("int32")
213+
10, size=(2, 4, 20)).astype("int32")
210214
}
211-
self.attrs = {'shape': [2, 4, 5]}
215+
self.attrs = {'shape': [2, 4, 20]}
212216
output = np.tile(self.inputs['X'], (1, 1, 1))
213217
self.outputs = {'Out': output}
214218

215219
def set_npu(self):
216220
self.__class__.use_npu = True
221+
self.__class__.no_need_check_grad = True
217222

218223
def test_check_output(self):
219224
self.check_output_with_place(self.place)

0 commit comments

Comments
 (0)