Skip to content

Commit 122d5d9

Browse files
committed
add test case for new attr
1 parent d041d02 commit 122d5d9

1 file changed

Lines changed: 37 additions & 0 deletions

File tree

test/ir/pir/fused_pass/onednn/test_onednn_placement.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,5 +146,42 @@ def test_check_output(self):
146146
self.check_pass_correct()
147147

148148

149+
class TestPlacementSlicePass(PassTest):
150+
def is_program_valid(self, program=None):
151+
return True
152+
153+
def build_ir_program(self):
154+
with paddle.pir_utils.IrGuard():
155+
main_prog = paddle.static.Program()
156+
start_prog = paddle.static.Program()
157+
with paddle.pir.core.program_guard(main_prog, start_prog):
158+
x = paddle.static.data(
159+
name='x', shape=[2, 5, 5, 5], dtype='float32'
160+
)
161+
out_1 = x[0, :, :, :]
162+
out_2 = x[0, :, :, :]
163+
out_1 = paddle.assign(out_1)
164+
out_2 = paddle.assign(out_2)
165+
self.pass_attr_list = [{'onednn_placement_pass': {}}]
166+
self.feeds = {
167+
"x": np.random.random((2, 5, 5, 5)).astype("float32"),
168+
}
169+
self.fetch_list = [out_1, out_2]
170+
self.valid_op_map = {
171+
"onednn_op.slice": 2,
172+
"pd_op.slice": 0,
173+
}
174+
return [main_prog, start_prog]
175+
176+
def sample_program(self):
177+
yield self.build_ir_program(), False
178+
179+
def setUp(self):
180+
self.places.append(paddle.CPUPlace())
181+
182+
def test_check_output(self):
183+
self.check_pass_correct()
184+
185+
149186
if __name__ == "__main__":
150187
unittest.main()

0 commit comments

Comments
 (0)