@@ -146,5 +146,42 @@ def test_check_output(self):
146146 self .check_pass_correct ()
147147
148148
149+ class TestPlacementSlicePass (PassTest ):
150+ def is_program_valid (self , program = None ):
151+ return True
152+
153+ def build_ir_program (self ):
154+ with paddle .pir_utils .IrGuard ():
155+ main_prog = paddle .static .Program ()
156+ start_prog = paddle .static .Program ()
157+ with paddle .pir .core .program_guard (main_prog , start_prog ):
158+ x = paddle .static .data (
159+ name = 'x' , shape = [2 , 5 , 5 , 5 ], dtype = 'float32'
160+ )
161+ out_1 = x [0 , :, :, :]
162+ out_2 = x [0 , :, :, :]
163+ out_1 = paddle .assign (out_1 )
164+ out_2 = paddle .assign (out_2 )
165+ self .pass_attr_list = [{'onednn_placement_pass' : {}}]
166+ self .feeds = {
167+ "x" : np .random .random ((2 , 5 , 5 , 5 )).astype ("float32" ),
168+ }
169+ self .fetch_list = [out_1 , out_2 ]
170+ self .valid_op_map = {
171+ "onednn_op.slice" : 2 ,
172+ "pd_op.slice" : 0 ,
173+ }
174+ return [main_prog , start_prog ]
175+
176+ def sample_program (self ):
177+ yield self .build_ir_program (), False
178+
179+ def setUp (self ):
180+ self .places .append (paddle .CPUPlace ())
181+
182+ def test_check_output (self ):
183+ self .check_pass_correct ()
184+
185+
149186if __name__ == "__main__" :
150187 unittest .main ()
0 commit comments