Skip to content

Commit 2360fae

Browse files
[PIR & Inference] Refactor fused_weight_only_linear_pass (#59792)
* refactor: refactor fused_weight_only_linear_pass * refactor: add else case for PADDLE_WITH_CUDA * fix: fix typo * refactor: refactor pass * refactor: support sm 70, 75, 80 and 86 in pass * refactor: refactor pass and test * fix: fix typo * refactor: use xxOp::name() instead of pd_op.xx in pass * refactor: refactor error msg and fix typo * refactor: refactor pass and test * fix: fix typo * refactor: refactor IrDtype
1 parent d9407a5 commit 2360fae

5 files changed

Lines changed: 105 additions & 28 deletions

File tree

paddle/fluid/pir/drr/api/tensor_interface.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,7 @@ bool DtypeInterface::operator==(const DtypeInterface& other) const {
3030
return *dtype_ == *other.dtype_;
3131
}
3232

33+
IrDtype DtypeInterface::dtype() const { return *(this->dtype_); }
34+
3335
} // namespace drr
3436
} // namespace pir

paddle/fluid/pir/drr/api/tensor_interface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class ShapeInterface final {
4242
class DtypeInterface final {
4343
public:
4444
bool operator==(const DtypeInterface& other) const;
45+
IrDtype dtype() const;
4546

4647
private:
4748
explicit DtypeInterface(const IrDtype* dtype) : dtype_(dtype) {}

paddle/fluid/pir/drr/ir_value.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ class IrDtype {
4444

4545
bool operator==(IrDtype other) const { return dtype_ == other.dtype_; }
4646

47+
template <typename T>
48+
bool isa() const {
49+
return dtype_.isa<T>();
50+
}
51+
4752
private:
4853
const pir::Type dtype_;
4954
};

paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h"
16+
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
1617
#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
1718
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
1819
#include "paddle/fluid/platform/place.h"
@@ -22,11 +23,14 @@
2223

2324
namespace {
2425

25-
inline int getSMVersion() {
26+
int getSMVersion() {
2627
int sm_version = 80;
2728
#if defined(PADDLE_WITH_CUDA)
2829
sm_version = paddle::platform::GetGPUComputeCapability(
2930
paddle::platform::GetCurrentDeviceId());
31+
#else
32+
PADDLE_THROW(paddle::platform::errors::Unavailable(
33+
"fused_weight_only_linear_pass needs paddle compiled with CUDA."));
3034
#endif
3135
return sm_version;
3236
}
@@ -40,12 +44,14 @@ class FusedWeightOnlyLinearPattern
4044
//
4145
pir::drr::SourcePattern src = ctx->SourcePattern();
4246
const auto &matmul =
43-
src.Op("pd_op.matmul",
47+
src.Op(paddle::dialect::MatmulOp::name(),
4448
{{"transpose_x", src.Attr("matmul_transpose_x")},
4549
{"transpose_y", src.Attr("matmul_transpose_y")}});
50+
const auto &parameter = src.Op(
51+
pir::ParameterOp::name(), {{"parameter_name", src.Attr("param_name")}});
52+
src.Tensor("w") = parameter();
4653
src.Tensor("matmul_out") = matmul(src.Tensor("x"), src.Tensor("w"));
47-
48-
const auto &add = src.Op("pd_op.add");
54+
const auto &add = src.Op(paddle::dialect::AddOp::name());
4955
src.Tensor("add_out") = add(src.Tensor("matmul_out"), src.Tensor("bias"));
5056

5157
//
@@ -62,6 +68,17 @@ class FusedWeightOnlyLinearPattern
6268
return false;
6369
}
6470

71+
auto w_dims = match_ctx.Tensor("w").Shape();
72+
if (w_dims.at(0) % 64 != 0 || w_dims.at(1) % 16 != 0) return false;
73+
74+
auto w_dtype = match_ctx.Tensor("w").Dtype();
75+
if (!w_dtype.dtype().isa<pir::Float16Type>() &&
76+
!w_dtype.dtype().isa<pir::BFloat16Type>())
77+
return false;
78+
79+
auto x_dims = match_ctx.Tensor("x").Shape();
80+
if (x_dims.at(x_dims.size() - 1) != w_dims.at(1)) return false;
81+
6582
return true;
6683
});
6784
//
@@ -74,15 +91,15 @@ class FusedWeightOnlyLinearPattern
7491
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
7592
return "weight_only_int8";
7693
});
77-
// int arch = getSMVersion();
78-
const auto &weight_quantize_arch_attr =
79-
res.Attr([&](const pir::drr::MatchContext &match_ctx) -> std::any {
80-
return 80;
94+
95+
const auto &arch_attr =
96+
res.Attr([&](const pir::drr::MatchContext &match_ctx) -> int {
97+
return getSMVersion();
8198
});
8299

83-
const auto &weight_quantize = res.Op(
84-
"pd_op.weight_quantize",
85-
{{"algo", weight_only_int8_attr}, {"arch", weight_quantize_arch_attr}});
100+
const auto &weight_quantize =
101+
res.Op(paddle::dialect::WeightQuantizeOp::name(),
102+
{{"algo", weight_only_int8_attr}, {"arch", arch_attr}});
86103
weight_quantize({&res.Tensor("w")},
87104
{&res.Tensor("quanted_weight_tensor"),
88105
&res.Tensor("weight_scale_tensor")});
@@ -92,12 +109,9 @@ class FusedWeightOnlyLinearPattern
92109
return "int8";
93110
});
94111

95-
const auto &weight_only_linear_arch_attr = res.Attr(
96-
[&](const pir::drr::MatchContext &match_ctx) -> int { return 80; });
97112
const auto &weight_only_linear =
98-
res.Op("pd_op.weight_only_linear",
99-
{{"weight_dtype", weight_dtype_attr},
100-
{"arch", weight_only_linear_arch_attr}});
113+
res.Op(paddle::dialect::WeightOnlyLinearOp::name(),
114+
{{"weight_dtype", weight_dtype_attr}, {"arch", arch_attr}});
101115
weight_only_linear({&res.Tensor("x"),
102116
&res.Tensor("quanted_weight_tensor"),
103117
&res.Tensor("bias"),
@@ -119,8 +133,8 @@ class FusedWeightOnlyLinearPass : public pir::PatternRewritePass {
119133

120134
bool CanApplyOn(pir::Operation *op) const override {
121135
int sm_vesion = getSMVersion();
122-
if (sm_vesion != 70 && sm_vesion != 80 && sm_vesion != 86 &&
123-
sm_vesion != 75) {
136+
if (sm_vesion != 70 && sm_vesion != 75 && sm_vesion != 80 &&
137+
sm_vesion != 86) {
124138
return false;
125139
}
126140
return op->num_regions() > 0;

test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import paddle
2121
from paddle.base import core
22+
from paddle.pir.core import create_parameter
2223

2324
np.random.seed(2013)
2425

@@ -56,34 +57,38 @@ def build_ir_progam(self):
5657
x = paddle.static.data(
5758
name='x', shape=[3, 64, 64], dtype=self.dtype
5859
)
59-
w = paddle.static.data(
60-
name="w", shape=[64, 64], dtype=self.dtype
60+
initializer = paddle.nn.initializer.Constant(0.0)
61+
w = create_parameter(
62+
shape=[64, 64], dtype=self.dtype, initializer=initializer
6163
)
6264
bias_ = paddle.static.data(
63-
name="bias", shape=[64], dtype=self.dtype
65+
name="bias",
66+
shape=[64],
67+
dtype=self.dtype,
6468
)
6569
bias = paddle.assign(bias_)
6670
res1 = paddle.matmul(x=x, y=w)
6771
out = paddle.add(res1, bias)
68-
6972
self.pass_list = ['fused_weight_only_linear_pass']
7073
self.feeds = {
7174
"x": np.random.random((3, 64, 64)).astype(self.dtype),
7275
"w": np.random.random((64, 64)).astype(self.dtype),
7376
"bias": np.random.random(64).astype(self.dtype),
7477
}
7578
self.fetch_list = [out]
76-
self.valid_op_map = {
77-
"pd_op.weight_only_linear": 1,
78-
"pd_op.weight_quantize": 1,
79-
"pd_op.matmul": 0,
80-
"pd_op.add": 0,
81-
}
79+
8280
return pir_program
8381

8482
def setUp(self):
8583
self.place_runtime = "gpu"
8684
self.dtype = 'float32'
85+
# weight_quantize need weight's dtype to be fp16 or bf16
86+
self.valid_op_map = {
87+
"pd_op.weight_only_linear": 0,
88+
"pd_op.weight_quantize": 0,
89+
"pd_op.matmul": 1,
90+
"pd_op.add": 1,
91+
}
8792

8893
def sample_program(self):
8994
yield self.build_ir_progam(), False
@@ -96,6 +101,56 @@ class TestFusedWeightOnlyLinearPass_Fp16(TestFusedWeightOnlyLinearPass_Fp32):
96101
def setUp(self):
97102
self.place_runtime = "gpu"
98103
self.dtype = 'float16'
104+
self.valid_op_map = {
105+
"pd_op.weight_only_linear": 1,
106+
"pd_op.weight_quantize": 1,
107+
"pd_op.matmul": 0,
108+
"pd_op.add": 0,
109+
}
110+
111+
112+
class TestFusedWeightOnlyLinearPass_wdim_divisible_by_16(
113+
TestFusedWeightOnlyLinearPass_Fp32
114+
):
115+
def build_ir_progam(self):
116+
pir_program = None
117+
with paddle.pir_utils.IrGuard():
118+
pir_program = paddle.static.Program()
119+
with paddle.pir.core.program_guard(pir_program):
120+
x = paddle.static.data(
121+
name='x', shape=[3, 64, 64], dtype=self.dtype
122+
)
123+
initializer = paddle.nn.initializer.Constant(0.0)
124+
w = create_parameter(
125+
shape=[64, 15], dtype=self.dtype, initializer=initializer
126+
)
127+
bias_ = paddle.static.data(
128+
name="bias",
129+
shape=[15],
130+
dtype=self.dtype,
131+
)
132+
bias = paddle.assign(bias_)
133+
res1 = paddle.matmul(x=x, y=w)
134+
out = paddle.add(res1, bias)
135+
self.pass_list = ['fused_weight_only_linear_pass']
136+
self.feeds = {
137+
"x": np.random.random((3, 64, 64)).astype(self.dtype),
138+
"w": np.random.random((64, 15)).astype(self.dtype),
139+
"bias": np.random.random(15).astype(self.dtype),
140+
}
141+
self.fetch_list = [out]
142+
143+
return pir_program
144+
145+
def setUp(self):
146+
self.place_runtime = "gpu"
147+
self.dtype = 'float16'
148+
self.valid_op_map = {
149+
"pd_op.weight_only_linear": 0,
150+
"pd_op.weight_quantize": 0,
151+
"pd_op.matmul": 1,
152+
"pd_op.add": 1,
153+
}
99154

100155

101156
if __name__ == "__main__":

0 commit comments

Comments
 (0)