@@ -1162,7 +1162,10 @@ def setUp(self):
11621162 def test_check_grad (self ):
11631163 if self .dtype == np .float16 :
11641164 return
1165- self .check_grad (['X' ], 'Out' )
1165+ self .check_grad (['X' ], 'Out' , check_pir = True )
1166+
1167+ def test_check_output (self ):
1168+ self .check_output (check_pir = True )
11661169
11671170
11681171class TestTanhshrink_ZeroDim (TestTanhshrink ):
@@ -1181,6 +1184,7 @@ def setUp(self):
11811184 else paddle .CPUPlace ()
11821185 )
11831186
1187+ @test_with_pir_api
11841188 def test_static_api (self ):
11851189 with static_guard ():
11861190 with paddle .static .program_guard (paddle .static .Program ()):
@@ -4317,7 +4321,10 @@ def init_shape(self):
43174321 def test_check_grad (self ):
43184322 if self .dtype == np .float16 :
43194323 return
4320- self .check_grad (['X' ], 'Out' )
4324+ self .check_grad (['X' ], 'Out' , check_pir = True )
4325+
4326+ def test_check_output (self ):
4327+ self .check_output (check_pir = True )
43214328
43224329
43234330class TestThresholdedRelu_ZeroDim (TestThresholdedRelu ):
@@ -4338,6 +4345,7 @@ def setUp(self):
43384345 else paddle .CPUPlace ()
43394346 )
43404347
4348+ @test_with_pir_api
43414349 def test_static_api (self ):
43424350 with static_guard ():
43434351 with paddle .static .program_guard (paddle .static .Program ()):
@@ -4805,7 +4813,7 @@ def test_check_grad(self):
48054813create_test_act_fp16_class (
48064814 TestTanh , check_prim = True , check_prim_pir = True , enable_cinn = True
48074815)
4808- create_test_act_fp16_class (TestTanhshrink )
4816+ create_test_act_fp16_class (TestTanhshrink , check_pir = True )
48094817create_test_act_fp16_class (TestHardShrink , check_pir = True )
48104818create_test_act_fp16_class (TestSoftshrink , check_pir = True )
48114819create_test_act_fp16_class (
@@ -4980,7 +4988,7 @@ def test_check_grad(self):
49804988create_test_act_bf16_class (TestSilu , check_prim = True , check_prim_pir = True )
49814989create_test_act_bf16_class (TestLogSigmoid )
49824990create_test_act_bf16_class (TestTanh , check_prim = True , check_prim_pir = True )
4983- create_test_act_bf16_class (TestTanhshrink )
4991+ create_test_act_bf16_class (TestTanhshrink , check_pir = True )
49844992create_test_act_bf16_class (TestHardShrink , check_pir = True )
49854993create_test_act_bf16_class (TestSoftshrink , check_pir = True )
49864994create_test_act_bf16_class (
0 commit comments