Skip to content

Commit 4ce62de

Browse files
bukejiyupull[bot]
authored andcommitted
[PIR] PassTest support run with executor and check accuracy (#60136)
* update pass test * update * update update update pass printer and fix fix code style * code style
1 parent b490a2f commit 4ce62de

12 files changed

Lines changed: 440 additions & 482 deletions

paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
namespace {
2525

2626
int getSMVersion() {
27-
int sm_version = 80;
28-
#if defined(PADDLE_WITH_CUDA)
27+
int sm_version = -1;
28+
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_CUTLASS)
2929
sm_version = paddle::platform::GetGPUComputeCapability(
3030
paddle::platform::GetCurrentDeviceId());
3131
#else

test/ir/pir/fused_pass/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ file(
33
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
44
"test_*.py")
55
string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}")
6+
if(NOT WITH_CUTLASS)
7+
set(CUTLASS_TEST_CASES test_fused_weight_only_linear_pass)
8+
list(REMOVE_ITEM TEST_INTERP_CASES ${CUTLASS_TEST_CASES})
9+
endif()
610

711
foreach(target ${TEST_INTERP_CASES})
812
py_test_modules(${target} MODULES ${target})

test/ir/pir/fused_pass/pass_test.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import abc
1616
import unittest
1717

18+
import numpy as np
19+
1820
import paddle
1921
from paddle import pir
2022

@@ -27,7 +29,8 @@ def setUpClass(self):
2729
self.valid_op_map = {}
2830
self.pass_list = []
2931
self.pir_program = None
30-
self.place_runtime = "cpu"
32+
self.places = []
33+
self.skip_accuracy_verification = False
3134

3235
def run_pir_pass(self, program):
3336
if not isinstance(self.pass_list, list):
@@ -36,7 +39,6 @@ def run_pir_pass(self, program):
3639
pm = pir.PassManager(opt_level=4)
3740
for pass_name in self.pass_list:
3841
pm.add_pass(pass_name)
39-
4042
pm.run(program)
4143
return program
4244

@@ -56,34 +58,62 @@ def check_fused_ops(self, program):
5658
),
5759
)
5860

59-
@abc.abstractmethod
60-
def is_program_valid(self, program=None):
61-
"""
62-
judge the effectiveness of the pir program
63-
"""
64-
raise NotImplementedError
65-
6661
@abc.abstractmethod
6762
def sample_program(self):
6863
"""
6964
Generate all pir grogram
7065
"""
7166
raise NotImplementedError
7267

73-
def check_pass_correct(self, atol=1e-5):
68+
def run_program(self, executor, startup_program, main_program):
69+
with paddle.pir_utils.IrGuard():
70+
with paddle.static.program_guard(startup_program, main_program):
71+
fetches = executor.run(
72+
main_program,
73+
feed=self.feeds,
74+
fetch_list=self.fetch_list,
75+
)
76+
return fetches
77+
78+
def compare_accuracy(
79+
self, baseline_data, actual_data, atol=1e-5, rtol=1e-5
80+
):
7481
self.assertTrue(
75-
self.place_runtime == "cpu" or self.place_runtime == "gpu",
76-
"The place param must be either GPU or CPU ",
82+
len(baseline_data) == len(actual_data),
83+
f"The output baseline_data are not equal, the baseline output_data is {len(baseline_data)}, but got {len(actual_data)}",
7784
)
78-
if self.place_runtime == "cpu":
79-
executor = paddle.static.Executor(paddle.base.CPUPlace())
80-
elif self.place_runtime == "gpu":
81-
executor = paddle.static.Executor(paddle.base.CUDAPlace(0))
85+
for i in range(len(baseline_data)):
86+
self.assertEqual(
87+
baseline_data[i].shape,
88+
actual_data[i].shape,
89+
f"The output shapes are not equal, the baseline shape is {baseline_data[i].shape}, but got {actual_data[i].shape}",
90+
)
91+
np.testing.assert_allclose(
92+
baseline_data[i], actual_data[i], atol=atol, rtol=rtol
93+
)
8294

83-
for program, need_translate_to_pir in self.sample_program():
84-
if need_translate_to_pir:
85-
program = pir.translate_to_pir(program.desc)
86-
if not self.is_program_valid(program):
87-
continue
88-
program = self.run_pir_pass(program)
89-
self.check_fused_ops(program)
95+
def check_pass_correct(self, atol=1e-5, rtol=1e-5):
96+
for place in self.places:
97+
for program, need_translate_to_pir in self.sample_program():
98+
main_program = program[0]
99+
startup_program = program[1]
100+
if need_translate_to_pir:
101+
main_program = pir.translate_to_pir(main_program.desc)
102+
with paddle.pir_utils.IrGuard():
103+
with paddle.static.program_guard(
104+
main_program, startup_program
105+
):
106+
executor = paddle.static.Executor(place)
107+
executor.run(startup_program)
108+
baseline_fetch = self.run_program(
109+
executor, startup_program, main_program
110+
)
111+
main_program = self.run_pir_pass(main_program)
112+
self.check_fused_ops(main_program)
113+
actual_fetch = self.run_program(
114+
executor, startup_program, main_program
115+
)
116+
if self.skip_accuracy_verification is False:
117+
self.compare_accuracy(
118+
baseline_fetch, actual_fetch, atol, rtol
119+
)

test/ir/pir/fused_pass/test_conv2d_add_act_fuse_pass.py

Lines changed: 46 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,11 @@
1818
from pass_test import PassTest
1919

2020
import paddle
21+
from paddle.base import core
2122

2223
paddle.enable_static()
2324

2425

25-
@unittest.skipIf(
26-
not paddle.base.core.is_compiled_with_cuda(),
27-
"core is not complied with CUDA",
28-
)
2926
class TestConv2dAddActFusePattern(PassTest):
3027
r"""
3128
x_var f_var
@@ -47,10 +44,10 @@ def is_program_valid(self, program):
4744
return True
4845

4946
def build_ir_progam(self):
50-
pir_program = None
5147
with paddle.pir_utils.IrGuard():
52-
pir_program = paddle.static.Program()
53-
with paddle.pir.core.program_guard(pir_program):
48+
main_prog = paddle.static.Program()
49+
start_prog = paddle.static.Program()
50+
with paddle.pir.core.program_guard(main_prog, start_prog):
5451
x = paddle.static.data(
5552
name='x', shape=[3, 1, 28, 28], dtype='float32'
5653
)
@@ -67,23 +64,26 @@ def build_ir_progam(self):
6764
)
6865
act_op = paddle.nn.ReLU()
6966
out = act_op(paddle.add(conv2d(x), y))
70-
71-
self.pass_list = ['conv2d_add_act_fuse_pass']
72-
self.feeds = {
73-
"x": np.random.random((3, 32, 28, 28)).astype("float32"),
74-
"y": np.random.random((3, 32, 28, 28)).astype("float32"),
75-
}
76-
self.fetch_list = [out]
77-
self.valid_op_map = {
78-
"pd_op.add": 0,
79-
"pd_op.relu": 0,
80-
"pd_op.conv2d": 0,
81-
"pd_op.fused_conv2d_add_act": 1,
82-
}
83-
return pir_program
67+
out = paddle.assign(out)
68+
self.pass_list = ['conv2d_add_act_fuse_pass']
69+
self.feeds = {
70+
"x": np.random.random((3, 1, 28, 28)).astype("float32"),
71+
"y": np.random.random((3, 32, 28, 28)).astype("float32"),
72+
}
73+
self.fetch_list = [out]
74+
self.valid_op_map = {
75+
"pd_op.add": 0,
76+
"pd_op.relu": 0,
77+
"pd_op.conv2d": 0,
78+
"pd_op.fused_conv2d_add_act": 1,
79+
}
80+
return [main_prog, start_prog]
8481

8582
def setUp(self):
86-
self.place_runtime = "gpu"
83+
if core.is_compiled_with_cuda():
84+
self.places.append(paddle.CUDAPlace(0))
85+
# todo(bukejiyu): This pass will support accuracy verification in the future
86+
self.skip_accuracy_verification = True
8787

8888
def sample_program(self):
8989
yield self.build_ir_progam(), False
@@ -92,15 +92,6 @@ def test_check_output(self):
9292
self.check_pass_correct()
9393

9494

95-
class TestConv2dAddActFusePatternWithCpu(TestConv2dAddActFusePattern):
96-
def setUp(self):
97-
self.place_runtime = "cpu"
98-
99-
100-
@unittest.skipIf(
101-
not paddle.base.core.is_compiled_with_cuda(),
102-
"core is not complied with CUDA",
103-
)
10495
class TestConv2dAdd2ActFusePattern(PassTest):
10596
r"""
10697
x_var f_var(persistable)
@@ -124,10 +115,10 @@ def is_program_valid(self, program):
124115
return True
125116

126117
def build_ir_progam(self):
127-
pir_program = None
128118
with paddle.pir_utils.IrGuard():
129-
pir_program = paddle.static.Program()
130-
with paddle.pir.core.program_guard(pir_program):
119+
main_prog = paddle.static.Program()
120+
start_prog = paddle.static.Program()
121+
with paddle.pir.core.program_guard(main_prog, start_prog):
131122
x = paddle.static.data(
132123
name='x', shape=[3, 1, 28, 28], dtype='float32'
133124
)
@@ -149,22 +140,29 @@ def build_ir_progam(self):
149140
out = act_op(
150141
paddle.add(residual_data, paddle.add(conv2d(x), y))
151142
)
152-
self.pass_list = ['conv2d_add_act_fuse_pass']
153-
self.feeds = {
154-
"x": np.random.random((3, 32, 28, 28)).astype("float32"),
155-
"y": np.random.random((3, 32, 28, 28)).astype("float32"),
156-
}
157-
self.fetch_list = [out]
158-
self.valid_op_map = {
159-
"pd_op.add": 0,
160-
"pd_op.relu": 0,
161-
"pd_op.conv2d": 0,
162-
"pd_op.fused_conv2d_add_act": 1,
163-
}
164-
return pir_program
143+
out = paddle.assign(out)
144+
self.pass_list = ['conv2d_add_act_fuse_pass']
145+
self.feeds = {
146+
"x": np.random.random((3, 1, 28, 28)).astype("float32"),
147+
"y": np.random.random((3, 32, 28, 28)).astype("float32"),
148+
"residual_data": np.random.random((3, 32, 28, 28)).astype(
149+
"float32"
150+
),
151+
}
152+
self.fetch_list = [out]
153+
self.valid_op_map = {
154+
"pd_op.add": 0,
155+
"pd_op.relu": 0,
156+
"pd_op.conv2d": 0,
157+
"pd_op.fused_conv2d_add_act": 1,
158+
}
159+
return [main_prog, start_prog]
165160

166161
def setUp(self):
167-
self.place_runtime = "gpu"
162+
if core.is_compiled_with_cuda():
163+
self.places.append(paddle.CUDAPlace(0))
164+
# todo(bukejiyu): This pass will support accuracy verification in the future
165+
self.skip_accuracy_verification = True
168166

169167
def sample_program(self):
170168
yield self.build_ir_progam(), False
@@ -173,10 +171,5 @@ def test_check_output(self):
173171
self.check_pass_correct()
174172

175173

176-
class TestConv2dAdd2ActFusePatternWithCpu(TestConv2dAdd2ActFusePattern):
177-
def setUp(self):
178-
self.place_runtime = "cpu"
179-
180-
181174
if __name__ == "__main__":
182175
unittest.main()

test/ir/pir/fused_pass/test_conv2d_add_fuse_pass.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,11 @@
1818
from pass_test import PassTest
1919

2020
import paddle
21+
from paddle.base import core
2122

2223
paddle.enable_static()
2324

2425

25-
@unittest.skipIf(
26-
not paddle.base.core.is_compiled_with_cuda(),
27-
"core is not complied with CUDA",
28-
)
2926
class TestConv2dAddFusePass(PassTest):
3027
r"""
3128
x_var f_var
@@ -39,10 +36,10 @@ def is_program_valid(self, program=None):
3936
return True
4037

4138
def build_ir_progam(self):
42-
pir_program = None
4339
with paddle.pir_utils.IrGuard():
44-
pir_program = paddle.static.Program()
45-
with paddle.pir.core.program_guard(pir_program):
40+
main_prog = paddle.static.Program()
41+
start_prog = paddle.static.Program()
42+
with paddle.pir.core.program_guard(main_prog, start_prog):
4643
x = paddle.static.data(
4744
name='x', shape=[3, 1, 28, 28], dtype='float32'
4845
)
@@ -53,39 +50,37 @@ def build_ir_progam(self):
5350
in_channels=1,
5451
out_channels=32,
5552
kernel_size=3,
56-
padding=1,
53+
padding="SAME",
5754
data_format='NCHW',
5855
bias_attr=False,
5956
)
6057
out = paddle.add(conv2d(x), y)
61-
62-
self.pass_list = ['conv2d_add_fuse_pass']
63-
self.feeds = {
64-
"x": np.random.random((3, 1, 28, 28)).astype("float32"),
65-
"y": np.random.random((3, 32, 28, 28)).astype("float32"),
66-
}
67-
self.fetch_list = [out]
68-
self.valid_op_map = {
69-
"pd_op.fused_conv2d_add_act": 1,
70-
"pd_op.conv2d": 0,
71-
"pd_op.add": 0,
72-
}
73-
return pir_program
58+
out = paddle.assign(out)
59+
self.pass_list = ['conv2d_add_fuse_pass']
60+
self.feeds = {
61+
"x": np.random.random((3, 1, 28, 28)).astype("float32"),
62+
"y": np.random.random((3, 32, 28, 28)).astype("float32"),
63+
}
64+
self.fetch_list = [out]
65+
self.valid_op_map = {
66+
"pd_op.fused_conv2d_add_act": 1,
67+
"pd_op.conv2d": 0,
68+
"pd_op.add": 0,
69+
}
70+
return [main_prog, start_prog]
7471

7572
def sample_program(self):
7673
yield self.build_ir_progam(), False
7774

7875
def setUp(self):
79-
self.place_runtime = "gpu"
76+
if core.is_compiled_with_cuda():
77+
self.places.append(paddle.CUDAPlace(0))
78+
# todo(bukejiyu): This pass will support accuracy verification in the future
79+
self.skip_accuracy_verification = True
8080

8181
def test_check_output(self):
8282
self.check_pass_correct()
8383

8484

85-
class TestConv2dAddFusePassWtihCpu(TestConv2dAddFusePass):
86-
def setUp(self):
87-
self.place_runtime = "cpu"
88-
89-
9085
if __name__ == "__main__":
9186
unittest.main()

0 commit comments

Comments
 (0)