Skip to content

Commit 0ec982e

Browse files
xiaoguoguo626807SecretXV
authored andcommitted
【pir】add fused_fused_gemm_epilogue_pass test (PaddlePaddle#58750)
* modify * modify * Update paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc * Update paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc * Update paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc * code style * delete flag * modify * add fused_linear test * modify * Update paddle/fluid/framework/new_executor/standalone_executor.cc
1 parent f816716 commit 0ec982e

1 file changed

Lines changed: 120 additions & 0 deletions

File tree

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
19+
import paddle
20+
from paddle.autograd.ir_backward import grad as ir_grad
21+
from paddle.base import core
22+
23+
np.random.seed(2013)
24+
25+
import os
26+
import re
27+
28+
29+
def get_cuda_version():
30+
result = os.popen("nvcc --version").read()
31+
regex = r'release (\S+),'
32+
match = re.search(regex, result)
33+
if match:
34+
num = str(match.group(1))
35+
integer, decimal = num.split('.')
36+
return int(integer) * 1000 + int(float(decimal) * 10)
37+
else:
38+
return -1
39+
40+
41+
@unittest.skipIf(
42+
not core.is_compiled_with_cuda() or get_cuda_version() < 11060,
43+
"core is not complied with CUDA or nvcc version is less than11.6",
44+
)
45+
class TestFusedgemm_epilogueAdd(unittest.TestCase):
46+
def test_fused_gemm_epilogue_add(self):
47+
with paddle.pir_utils.IrGuard():
48+
x_np = np.random.normal(3, 2.5, size=(1024, 1024)).astype(
49+
np.float32
50+
)
51+
y_np = x_np
52+
z_np = np.random.normal(3, 2.5, size=(1024)).astype(np.float32)
53+
main_program = paddle.static.Program()
54+
with paddle.static.program_guard(main_program):
55+
x_ = paddle.static.data(
56+
name="x", shape=[1024, 1024], dtype="float32"
57+
)
58+
y_ = paddle.static.data(
59+
name="y", shape=[1024, 1024], dtype="float32"
60+
)
61+
z_ = paddle.static.data(name="z", shape=[1024], dtype="float32")
62+
x_.stop_gradient = False
63+
y_.stop_gradient = False
64+
z_.stop_gradient = False
65+
x = paddle.assign(x_)
66+
y = paddle.assign(y_)
67+
z = paddle.assign(z_)
68+
res1 = paddle.matmul(x=x, y=y)
69+
res2 = paddle.add(res1, z)
70+
res3 = paddle.assign(res2)
71+
72+
res4, res5, res6 = ir_grad(res3, [x, y, z])
73+
res4_ = paddle.assign(res4)
74+
res5_ = paddle.assign(res5)
75+
res6_ = paddle.assign(res6)
76+
op_names = [op.name() for op in main_program.global_block().ops]
77+
self.assertTrue(
78+
'pd_op.matmul' in op_names and 'pd_op.add' in op_names
79+
)
80+
self.assertTrue(
81+
'pd_op.add_grad' in op_names
82+
and 'pd_op.matmul_grad' in op_names
83+
)
84+
85+
with paddle.static.scope_guard(paddle.static.Scope()):
86+
exe = paddle.base.Executor(paddle.base.CUDAPlace(0))
87+
fetches0 = exe.run(
88+
main_program,
89+
feed={"x": x_np, "y": y_np, "z": z_np},
90+
fetch_list=[res3, res4_, res5_, res6_],
91+
)
92+
# main_program = main_program.clone()
93+
94+
pm = paddle.pir.PassManager()
95+
pm.add_pass(
96+
'fused_gemm_epilogue_pass'
97+
) # apply pass to elimitate dead code
98+
pm.run(main_program)
99+
op_names = [op.name() for op in main_program.global_block().ops]
100+
self.assertTrue(
101+
'pd_op.fused_gemm_epilogue' in op_names
102+
and 'pd_op.fused_gemm_epilogue_grad' in op_names
103+
)
104+
105+
with paddle.static.scope_guard(paddle.static.Scope()):
106+
exe = paddle.base.Executor(paddle.base.CUDAPlace(0))
107+
fetches1 = exe.run(
108+
main_program,
109+
feed={"x": x_np, "y": y_np, "z": z_np},
110+
fetch_list=[res3, res4_, res5_, res6_],
111+
)
112+
113+
np.array_equal(fetches0[0], fetches1[0])
114+
np.array_equal(fetches0[1], fetches1[1])
115+
np.array_equal(fetches0[2], fetches1[2])
116+
np.array_equal(fetches0[3], fetches1[3])
117+
118+
119+
if __name__ == "__main__":
120+
unittest.main()

0 commit comments

Comments
 (0)