@@ -132,6 +132,9 @@ def test_check_output(self):
132132 self .check_output_with_place (self .place )
133133
134134 def test_check_grad (self ):
135+ if self .dtype == np .float16 :
136+ return
137+
135138 self .check_grad_with_place (self .place , ['X' ], 'Out' )
136139
137140
@@ -177,43 +180,45 @@ def test_check_grad(self):
177180
178181
179182# Situation 4: input x is float16
180- # don't support grad check for float16
181- class TestExpandV2OpInteger (OpTest ):
183+ # skip grad check for float16
184+ class TestExpandV2OpFloat (OpTest ):
182185 def setUp (self ):
183186 self .set_npu ()
184187 self .place = paddle .NPUPlace (0 )
185188 self .op_type = "expand_v2"
186189 self .dtype = np .float16
187- self .ori_shape = (2 , 4 , 5 )
190+ self .ori_shape = (2 , 4 , 20 )
188191 self .inputs = {'X' : np .random .random (self .ori_shape ).astype (self .dtype )}
189- self .attrs = {'shape' : [2 , 4 , 5 ]}
192+ self .attrs = {'shape' : [2 , 4 , 20 ]}
190193 output = np .tile (self .inputs ['X' ], (1 , 1 , 1 ))
191194 self .outputs = {'Out' : output }
192195
193196 def set_npu (self ):
194197 self .__class__ .use_npu = True
198+ self .__class__ .no_need_check_grad = True
195199
196200 def test_check_output (self ):
197201 self .check_output_with_place (self .place )
198202
199203
200204# Situation 5: input x is int32
201- # ReduceSumD CANN Op doesn't support grad check for int32
205+ # skip grad check for int32
202206class TestExpandV2OpInteger (OpTest ):
203207 def setUp (self ):
204208 self .set_npu ()
205209 self .place = paddle .NPUPlace (0 )
206210 self .op_type = "expand_v2"
207211 self .inputs = {
208212 'X' : np .random .randint (
209- 10 , size = (2 , 4 , 5 )).astype ("int32" )
213+ 10 , size = (2 , 4 , 20 )).astype ("int32" )
210214 }
211- self .attrs = {'shape' : [2 , 4 , 5 ]}
215+ self .attrs = {'shape' : [2 , 4 , 20 ]}
212216 output = np .tile (self .inputs ['X' ], (1 , 1 , 1 ))
213217 self .outputs = {'Out' : output }
214218
215219 def set_npu (self ):
216220 self .__class__ .use_npu = True
221+ self .__class__ .no_need_check_grad = True
217222
218223 def test_check_output (self ):
219224 self .check_output_with_place (self .place )
0 commit comments