@@ -60,7 +60,7 @@ def nearest_neighbor_interp_mkldnn_np(
6060
6161@skip_check_grad_ci (reason = "Haven not implement interpolate grad kernel." )
6262@OpTestTool .skip_if_not_cpu_bf16 ()
63- class TestNearestInterpV2MKLDNNOp (OpTest ):
63+ class TestNearestInterpV2ONEDNNOp (OpTest ):
6464 def init_test_case (self ):
6565 pass
6666
@@ -154,7 +154,7 @@ def test_check_output(self):
154154 self .check_output (check_dygraph = False , check_pir_onednn = True )
155155
156156
157- class TestNearestInterpOpV2MKLDNNNHWC ( TestNearestInterpV2MKLDNNOp ):
157+ class TestNearestInterpOpV2ONEDNNNHWC ( TestNearestInterpV2ONEDNNOp ):
158158 def init_test_case (self ):
159159 self .input_shape = [3 , 2 , 32 , 16 ]
160160 self .out_h = 27
@@ -163,22 +163,22 @@ def init_test_case(self):
163163 self .data_layout = 'NHWC'
164164
165165
166- class TestNearestNeighborInterpV2MKLDNNCase2 ( TestNearestInterpV2MKLDNNOp ):
166+ class TestNearestNeighborInterpV2ONEDNNCase2 ( TestNearestInterpV2ONEDNNOp ):
167167 def init_test_case (self ):
168168 self .input_shape = [3 , 3 , 9 , 6 ]
169169 self .out_h = 12
170170 self .out_w = 12
171171
172172
173- class TestNearestNeighborInterpV2MKLDNNCase3 ( TestNearestInterpV2MKLDNNOp ):
173+ class TestNearestNeighborInterpV2ONEDNNCase3 ( TestNearestInterpV2ONEDNNOp ):
174174 def init_test_case (self ):
175175 self .input_shape = [1 , 1 , 32 , 64 ]
176176 self .out_h = 64
177177 self .out_w = 128
178178 self .scale = [0.1 , 0.05 ]
179179
180180
181- class TestNearestNeighborInterpV2MKLDNNCase4 ( TestNearestInterpV2MKLDNNOp ):
181+ class TestNearestNeighborInterpV2ONEDNNCase4 ( TestNearestInterpV2ONEDNNOp ):
182182 def init_test_case (self ):
183183 self .input_shape = [1 , 1 , 32 , 64 ]
184184 self .out_h = 64
@@ -187,7 +187,7 @@ def init_test_case(self):
187187 self .out_size = np .array ([65 , 129 ]).astype ("int32" )
188188
189189
190- class TestNearestNeighborInterpV2MKLDNNSame ( TestNearestInterpV2MKLDNNOp ):
190+ class TestNearestNeighborInterpV2ONEDNNSame ( TestNearestInterpV2ONEDNNOp ):
191191 def init_test_case (self ):
192192 self .input_shape = [2 , 3 , 32 , 64 ]
193193 self .out_h = 32
@@ -220,12 +220,12 @@ def init_data_type(self):
220220 globals ()[TestUint8Case .__name__ ] = TestUint8Case
221221
222222
223- create_test_class (TestNearestInterpV2MKLDNNOp )
224- create_test_class (TestNearestInterpOpV2MKLDNNNHWC )
225- create_test_class (TestNearestNeighborInterpV2MKLDNNCase2 )
226- create_test_class (TestNearestNeighborInterpV2MKLDNNCase3 )
227- create_test_class (TestNearestNeighborInterpV2MKLDNNCase4 )
228- create_test_class (TestNearestNeighborInterpV2MKLDNNSame )
223+ create_test_class (TestNearestInterpV2ONEDNNOp )
224+ create_test_class (TestNearestInterpOpV2ONEDNNNHWC )
225+ create_test_class (TestNearestNeighborInterpV2ONEDNNCase2 )
226+ create_test_class (TestNearestNeighborInterpV2ONEDNNCase3 )
227+ create_test_class (TestNearestNeighborInterpV2ONEDNNCase4 )
228+ create_test_class (TestNearestNeighborInterpV2ONEDNNSame )
229229
230230if __name__ == "__main__" :
231231 from paddle import enable_static
0 commit comments