@@ -40,12 +40,14 @@ def test_output(self):
4040
4141 def test_errors (self ):
4242 paddle .enable_static ()
43- with paddle .static .program_guard (
44- paddle .static .Program (), paddle .static .Program ()
45- ):
46- x = paddle .static .data (name = 'x' , shape = [- 1 , 2 ], dtype = 'int32' )
47- y = paddle .static .data (name = 'y' , shape = [- 1 , 2 ], dtype = 'int32' )
48- a = paddle .static .data (name = 'a' , shape = [- 1 , 2 ], dtype = 'int16' )
43+ with program_guard (Program (), Program ()):
44+ x = paddle .static .data (name = 'x' , shape = [- 1 , 2 ], dtype = typename )
45+ y = paddle .static .data (name = 'y' , shape = [- 1 , 2 ], dtype = typename )
46+ error_dtype = 'int16' if typename != 'int16' else 'int32'
47+ a = paddle .static .data (
48+ name = 'a' , shape = [- 1 , 2 ], dtype = error_dtype
49+ )
50+
4951 op = eval ("paddle.%s" % self .op_type )
5052 self .assertRaises (TypeError , op , x = x , y = a )
5153 self .assertRaises (TypeError , op , x = a , y = y )
@@ -70,6 +72,8 @@ def test_errors(self):
7072 create_test_class ('equal' , _type_name , lambda _a , _b : _a == _b , True )
7173 create_test_class ('not_equal' , _type_name , lambda _a , _b : _a != _b , True )
7274
75+ create_test_class ('less_than' , 'int8' , lambda _a , _b : _a < _b )
76+
7377
7478def create_paddle_case (op_type , callback ):
7579 class PaddleCls (unittest .TestCase ):
@@ -512,6 +516,31 @@ def test_check_output(self):
512516create_bf16_case ('not_equal' , lambda _a , _b : _a != _b , True )
513517
514518
519+ # add int8 tests
520+ def create_int8_case (op_type , callback , check_pir = False ):
521+ class TestCompareOpInt8Op (op_test .OpTest ):
522+ def setUp (self ):
523+ self .op_type = op_type
524+ self .dtype = np .int8
525+ self .python_api = eval ("paddle." + op_type )
526+
527+ x = np .random .randint (- 128 , 127 , size = [5 , 5 ]).astype (np .int8 )
528+ y = np .random .randint (- 128 , 127 , size = [5 , 5 ]).astype (np .int8 )
529+ real_result = callback (x , y )
530+ self .inputs = {'X' : x , 'Y' : y }
531+ self .outputs = {'Out' : real_result }
532+
533+ def test_check_output (self ):
534+ self .check_output (check_cinn = True , check_pir = check_pir )
535+
536+ cls_name = f"Int8TestCase_{ op_type } "
537+ TestCompareOpInt8Op .__name__ = cls_name
538+ globals ()[cls_name ] = TestCompareOpInt8Op
539+
540+
541+ create_int8_case ('less_than' , lambda _a , _b : _a < _b )
542+
543+
515544class TestCompareOpError (unittest .TestCase ):
516545 def test_errors (self ):
517546 paddle .enable_static ()
0 commit comments