Skip to content

Commit 5199527

Browse files
authored
fused_weight_only_linear_pass supoport weight_only_int4 (#63212)
* supoport weight_only_int4 * update * fix * fix
1 parent 0f8cd1f commit 5199527

File tree

6 files changed

+158
-63
lines changed

6 files changed

+158
-63
lines changed

paddle/fluid/pir/drr/src/rewrite_pattern.cc

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -474,16 +474,16 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
474474
}
475475
}
476476

477-
std::vector<std::vector<pir::Operation*>> temp_program;
478-
std::unordered_map<pir::Operation*, size_t> op_2_temp_program_index;
479-
for (auto& op : *rewriter.block()) {
480-
op_2_temp_program_index[&op] = temp_program.size();
481-
temp_program.push_back({&op});
482-
}
483-
484477
// topo order visit result_pattern_graph
485478
GraphTopo graph_topo_visit(&result_pattern_graph);
486479
graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) {
480+
std::vector<std::vector<pir::Operation*>> temp_program;
481+
std::unordered_map<pir::Operation*, size_t> op_2_temp_program_index;
482+
for (auto& op : *rewriter.block()) {
483+
op_2_temp_program_index[&op] = temp_program.size();
484+
temp_program.push_back({&op});
485+
}
486+
487487
// set insert point
488488
size_t max_input_op_index = 0UL;
489489
pir::Operation* max_index_op = nullptr;
@@ -530,11 +530,13 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
530530

531531
pir::Operation* new_op =
532532
CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx);
533-
op_2_temp_program_index[new_op] = max_input_op_index + 1;
534-
if (max_input_op_index + 1 >= temp_program.size()) {
533+
534+
size_t new_max_input_op_index = max_input_op_index + 1;
535+
op_2_temp_program_index[new_op] = new_max_input_op_index;
536+
if (new_max_input_op_index >= temp_program.size()) {
535537
temp_program.push_back({});
536538
}
537-
temp_program[max_input_op_index + 1].push_back(new_op);
539+
temp_program[new_max_input_op_index].push_back(new_op);
538540
});
539541

540542
return res_match_ctx;

paddle/fluid/pir/transforms/general/constant_folding_pass.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,11 @@ class ConstantFoldingPattern : public pir::RewritePattern {
238238
const std::vector<std::pair<pir::Operation*, int32_t>>& use_ops) const {
239239
for (auto [use_op, idx] : use_ops) {
240240
if (use_op->isa<pir::CombineOp>()) {
241-
if (!ReplaceResultByParameterOp(use_op)) return false;
241+
if (!ReplaceResultByParameterOp(use_op)) {
242+
return false;
243+
}
244+
} else if (use_op->isa<paddle::dialect::MemcpyH2dOp>()) {
245+
return false;
242246
} else if (use_op->HasInterface<paddle::dialect::OpYamlInfoInterface>()) {
243247
auto [input_infos, _1, _2, _3, _4] =
244248
use_op->dyn_cast<paddle::dialect::OpYamlInfoInterface>()
@@ -255,6 +259,9 @@ class ConstantFoldingPattern : public pir::RewritePattern {
255259
}
256260

257261
bool ReplaceResultByParameterOp(pir::Operation* op) const {
262+
if (op->isa<paddle::dialect::MemcpyD2hOp>()) {
263+
return false;
264+
}
258265
for (uint32_t i = 0; i < op->num_results(); i++) {
259266
auto use_ops = pir::GetUseOpsForOutput(op, i);
260267
if (!CheckUseOps(use_ops)) return false;

paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.cc

Lines changed: 117 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ int getSMVersion() {
3131
sm_version = paddle::platform::GetGPUComputeCapability(
3232
paddle::platform::GetCurrentDeviceId());
3333
#else
34-
PADDLE_THROW(paddle::platform::errors::Unavailable(
34+
PADDLE_THROW(common::errors::Unavailable(
3535
"fused_weight_only_linear_pass needs paddle compiled with CUDA."));
3636
#endif
3737
return sm_version;
@@ -41,10 +41,14 @@ class FusedWeightOnlyLinearWithBiasPattern
4141
: public paddle::drr::DrrPatternBase {
4242
private:
4343
bool reverse_add_;
44+
std::string algo_;
45+
int sm_version_;
4446

4547
public:
46-
explicit FusedWeightOnlyLinearWithBiasPattern(bool reverse_add)
47-
: reverse_add_(reverse_add) {}
48+
FusedWeightOnlyLinearWithBiasPattern(bool reverse_add,
49+
const std::string &algo,
50+
int sm_version)
51+
: reverse_add_(reverse_add), algo_(algo), sm_version_(sm_version) {}
4852

4953
std::string name() const override {
5054
return "FusedWeightOnlyLinearWithBiasPattern";
@@ -104,19 +108,49 @@ class FusedWeightOnlyLinearWithBiasPattern
104108
//
105109
paddle::drr::ResultPattern res = src.ResultPattern();
106110

107-
const auto &weight_quantize =
108-
res.Op(paddle::dialect::WeightQuantizeOp::name(),
109-
{{"algo", res.StrAttr("weight_only_int8")},
110-
{"arch", res.Int32Attr(getSMVersion())},
111-
{"group_size", res.Int32Attr(-1)}});
112-
weight_quantize({&res.Tensor("w")},
113-
{&res.Tensor("quanted_weight_tensor"),
114-
&res.Tensor("weight_scale_tensor")});
111+
if (algo_ == "weight_only_int4") {
112+
// TODO(liuyuanle): When the operator weight_quantize supports
113+
// weight_only_int4 on gpu version, delete the memory copy.
114+
const auto &memcpy_d2h =
115+
res.Op(paddle::dialect::MemcpyD2hOp::name(),
116+
{{"dst_place_type", res.Int32Attr(0 /*cpu*/)}});
117+
res.Tensor("w_cpu") = memcpy_d2h(res.Tensor("w"));
118+
const auto &weight_quantize =
119+
res.Op(paddle::dialect::WeightQuantizeOp::name(),
120+
{{"algo", res.StrAttr(algo_)},
121+
{"arch", res.Int32Attr(sm_version_)},
122+
{"group_size", res.Int32Attr(-1)}});
123+
weight_quantize({&res.Tensor("w_cpu")},
124+
{&res.Tensor("quanted_weight_tensor_cpu"),
125+
&res.Tensor("weight_scale_tensor_cpu")});
126+
127+
const auto &memcpy_h2d_1 =
128+
res.Op(paddle::dialect::MemcpyH2dOp::name(),
129+
{{"dst_place_type", res.Int32Attr(1 /*gpu*/)}});
130+
res.Tensor("quanted_weight_tensor") =
131+
memcpy_h2d_1(res.Tensor("quanted_weight_tensor_cpu"));
132+
const auto &memcpy_h2d_2 =
133+
res.Op(paddle::dialect::MemcpyH2dOp::name(),
134+
{{"dst_place_type", res.Int32Attr(1 /*gpu*/)}});
135+
res.Tensor("weight_scale_tensor") =
136+
memcpy_h2d_2(res.Tensor("weight_scale_tensor_cpu"));
137+
} else {
138+
const auto &weight_quantize =
139+
res.Op(paddle::dialect::WeightQuantizeOp::name(),
140+
{{"algo", res.StrAttr(algo_)},
141+
{"arch", res.Int32Attr(sm_version_)},
142+
{"group_size", res.Int32Attr(-1)}});
143+
144+
weight_quantize({&res.Tensor("w")},
145+
{&res.Tensor("quanted_weight_tensor"),
146+
&res.Tensor("weight_scale_tensor")});
147+
}
115148

116149
const auto &weight_only_linear =
117150
res.Op(paddle::dialect::WeightOnlyLinearOp::name(),
118-
{{"weight_dtype", res.StrAttr("int8")},
119-
{"arch", res.Int32Attr(getSMVersion())},
151+
{{"weight_dtype",
152+
res.StrAttr(algo_ == "weight_only_int8" ? "int8" : "int4")},
153+
{"arch", res.Int32Attr(sm_version_)},
120154
{"group_size", res.Int32Attr(-1)}});
121155
weight_only_linear({&res.Tensor("x"),
122156
&res.Tensor("quanted_weight_tensor"),
@@ -127,6 +161,14 @@ class FusedWeightOnlyLinearWithBiasPattern
127161
};
128162

129163
class FusedWeightOnlyLinearNoBiasPattern : public paddle::drr::DrrPatternBase {
164+
private:
165+
std::string algo_;
166+
int sm_version_;
167+
168+
public:
169+
FusedWeightOnlyLinearNoBiasPattern(const std::string &algo, int sm_version)
170+
: algo_(algo), sm_version_(sm_version) {}
171+
130172
public:
131173
std::string name() const override {
132174
return "FusedWeightOnlyLinearNoBiasPattern";
@@ -179,19 +221,48 @@ class FusedWeightOnlyLinearNoBiasPattern : public paddle::drr::DrrPatternBase {
179221
//
180222
paddle::drr::ResultPattern res = src.ResultPattern();
181223

182-
const auto &weight_quantize =
183-
res.Op(paddle::dialect::WeightQuantizeOp::name(),
184-
{{"algo", res.StrAttr("weight_only_int8")},
185-
{"arch", res.Int32Attr(getSMVersion())},
186-
{"group_size", res.Int32Attr(-1)}});
187-
weight_quantize({&res.Tensor("w")},
188-
{&res.Tensor("quanted_weight_tensor"),
189-
&res.Tensor("weight_scale_tensor")});
190-
224+
if (algo_ == "weight_only_int4") {
225+
// TODO(liuyuanle): When the operator weight_quantize supports
226+
// weight_only_int4 on gpu version, delete the memory copy.
227+
const auto &memcpy_d2h =
228+
res.Op(paddle::dialect::MemcpyD2hOp::name(),
229+
{{"dst_place_type", res.Int32Attr(0 /*cpu*/)}});
230+
res.Tensor("w_cpu") = memcpy_d2h(res.Tensor("w"));
231+
const auto &weight_quantize =
232+
res.Op(paddle::dialect::WeightQuantizeOp::name(),
233+
{{"algo", res.StrAttr(algo_)},
234+
{"arch", res.Int32Attr(sm_version_)},
235+
{"group_size", res.Int32Attr(-1)}});
236+
weight_quantize({&res.Tensor("w_cpu")},
237+
{&res.Tensor("quanted_weight_tensor_cpu"),
238+
&res.Tensor("weight_scale_tensor_cpu")});
239+
240+
const auto &memcpy_h2d_1 =
241+
res.Op(paddle::dialect::MemcpyH2dOp::name(),
242+
{{"dst_place_type", res.Int32Attr(1 /*gpu*/)}});
243+
res.Tensor("quanted_weight_tensor") =
244+
memcpy_h2d_1(res.Tensor("quanted_weight_tensor_cpu"));
245+
const auto &memcpy_h2d_2 =
246+
res.Op(paddle::dialect::MemcpyH2dOp::name(),
247+
{{"dst_place_type", res.Int32Attr(1 /*gpu*/)}});
248+
res.Tensor("weight_scale_tensor") =
249+
memcpy_h2d_2(res.Tensor("weight_scale_tensor_cpu"));
250+
} else {
251+
const auto &weight_quantize =
252+
res.Op(paddle::dialect::WeightQuantizeOp::name(),
253+
{{"algo", res.StrAttr(algo_)},
254+
{"arch", res.Int32Attr(sm_version_)},
255+
{"group_size", res.Int32Attr(-1)}});
256+
257+
weight_quantize({&res.Tensor("w")},
258+
{&res.Tensor("quanted_weight_tensor"),
259+
&res.Tensor("weight_scale_tensor")});
260+
}
191261
const auto &weight_only_linear =
192262
res.Op(paddle::dialect::WeightOnlyLinearOp::name(),
193-
{{"weight_dtype", res.StrAttr("int8")},
194-
{"arch", res.Int32Attr(getSMVersion())},
263+
{{"weight_dtype",
264+
res.StrAttr(algo_ == "weight_only_int8" ? "int8" : "int4")},
265+
{"arch", res.Int32Attr(sm_version_)},
195266
{"group_size", res.Int32Attr(-1)}});
196267
weight_only_linear({&res.Tensor("x"),
197268
&res.Tensor("quanted_weight_tensor"),
@@ -204,15 +275,28 @@ class FusedWeightOnlyLinearNoBiasPattern : public paddle::drr::DrrPatternBase {
204275
class FusedWeightOnlyLinearPass : public pir::PatternRewritePass {
205276
public:
206277
FusedWeightOnlyLinearPass()
207-
: pir::PatternRewritePass("fused_weight_only_linear_pass", 4) {}
278+
: pir::PatternRewritePass("fused_weight_only_linear_pass", 4),
279+
sm_version_(getSMVersion()) {}
208280

209281
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
282+
std::string algo = "weight_only_int4";
283+
if (Has("weight_only_algo")) {
284+
algo = Get<std::string>("weight_only_algo");
285+
}
286+
PADDLE_ENFORCE_EQ(algo == "weight_only_int8" || algo == "weight_only_int4",
287+
true,
288+
common::errors::InvalidArgument(
289+
"fused_weight_only_linear_pass only support "
290+
"weight_only_int8 or weight_only_int4, but get %s.",
291+
algo));
292+
210293
pir::RewritePatternSet ps(context);
211-
ps.Add(paddle::drr::Create<FusedWeightOnlyLinearWithBiasPattern>(context,
212-
true));
213-
ps.Add(paddle::drr::Create<FusedWeightOnlyLinearWithBiasPattern>(context,
214-
false));
215-
ps.Add(paddle::drr::Create<FusedWeightOnlyLinearNoBiasPattern>(context));
294+
ps.Add(paddle::drr::Create<FusedWeightOnlyLinearWithBiasPattern>(
295+
context, true, algo, sm_version_));
296+
ps.Add(paddle::drr::Create<FusedWeightOnlyLinearWithBiasPattern>(
297+
context, false, algo, sm_version_));
298+
ps.Add(paddle::drr::Create<FusedWeightOnlyLinearNoBiasPattern>(
299+
context, algo, sm_version_));
216300
return ps;
217301
}
218302

@@ -228,15 +312,15 @@ class FusedWeightOnlyLinearPass : public pir::PatternRewritePass {
228312
}
229313

230314
bool CanApplyOn(pir::Operation *op) const override {
231-
int sm_version = getSMVersion();
232-
if (sm_version != 70 && sm_version != 75 && sm_version != 80 &&
233-
sm_version != 86) {
315+
if (sm_version_ != 70 && sm_version_ != 75 && sm_version_ != 80 &&
316+
sm_version_ != 86) {
234317
return false;
235318
}
236319
return op->num_regions() > 0;
237320
}
238321

239322
private:
323+
int sm_version_;
240324
pir::FrozenRewritePatternSet patterns_;
241325
};
242326

paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ static phi::Backend GetKernelBackendByYaml(
754754
auto& backend_info = op_info_parser->OpRuntimeInfo().kernel_key_backend;
755755
phi::Backend kernel_backend = phi::Backend::UNDEFINED;
756756

757-
for (auto slot_name : backend_info) {
757+
for (const auto& slot_name : backend_info) {
758758
auto& input_map = op_info_parser->InputName2Id();
759759

760760
if (input_map.count(slot_name)) {

paddle/phi/api/yaml/ops.yaml

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2559,7 +2559,7 @@
25592559
kernel :
25602560
func : shape {dense -> dense},
25612561
shape_sr {selected_rows -> dense}
2562-
data_transform:
2562+
data_transform :
25632563
skip_transform : input
25642564
interfaces : paddle::dialect::InferSymbolicShapeInterface
25652565

@@ -2620,7 +2620,7 @@
26202620
spmd_rule : ElementwiseUnaryInferSpmd
26212621
kernel :
26222622
func : sin
2623-
inplace: (x -> out)
2623+
inplace : (x -> out)
26242624
backward : sin_grad
26252625
interfaces : paddle::dialect::InferSymbolicShapeInterface
26262626

@@ -2781,10 +2781,10 @@
27812781
- op : swiglu
27822782
args : (Tensor x, Tensor y)
27832783
output : Tensor(out)
2784-
infer_meta:
2785-
func: SwiGLUInferMeta
2786-
spmd_rule: SwiGLUInferSpmd
2787-
kernel:
2784+
infer_meta :
2785+
func : SwiGLUInferMeta
2786+
spmd_rule : SwiGLUInferSpmd
2787+
kernel :
27882788
func : swiglu
27892789
optional : y
27902790
backward: swiglu_grad
@@ -2808,7 +2808,7 @@
28082808
func : UnchangedInferMeta
28092809
kernel :
28102810
func : tan
2811-
inplace: (x -> out)
2811+
inplace : (x -> out)
28122812
backward : tan_grad
28132813
interfaces : paddle::dialect::InferSymbolicShapeInterface
28142814

@@ -3057,9 +3057,9 @@
30573057
func : WarpctcInferMeta
30583058
kernel :
30593059
func : warpctc
3060-
data_type: logits
3061-
optional: logits_length, labels_length
3062-
intermediate: warpctcgrad
3060+
data_type : logits
3061+
optional : logits_length, labels_length
3062+
intermediate : warpctcgrad
30633063
backward : warpctc_grad
30643064

30653065
- op : warprnnt
@@ -3069,8 +3069,8 @@
30693069
func : WarprnntInferMeta
30703070
kernel :
30713071
func : warprnnt
3072-
data_type: input
3073-
intermediate: warprnntgrad
3072+
data_type : input
3073+
intermediate : warprnntgrad
30743074
backward : warprnnt_grad
30753075

30763076
- op : weight_dequantize
@@ -3090,8 +3090,8 @@
30903090
kernel :
30913091
func : weight_only_linear
30923092
data_type : x
3093-
optional: bias
3094-
backward: weight_only_linear_grad
3093+
optional : bias
3094+
backward : weight_only_linear_grad
30953095

30963096
- op : weight_quantize
30973097
args : (Tensor x, str algo = "weight_only_int8", int arch = 80, int group_size = -1)
@@ -3100,7 +3100,8 @@
31003100
func : WeightQuantizeInferMeta
31013101
kernel :
31023102
func : weight_quantize
3103-
data_type: x
3103+
data_type : x
3104+
backend : x
31043105

31053106
- op : weighted_sample_neighbors
31063107
args : (Tensor row, Tensor colptr, Tensor edge_weight, Tensor input_nodes, Tensor eids, int sample_size, bool return_eids)
@@ -3119,7 +3120,7 @@
31193120
spmd_rule: WhereInferSpmd
31203121
kernel :
31213122
func : where
3122-
inplace: (x -> out)
3123+
inplace : (x -> out)
31233124
backward : where_grad
31243125
interfaces : paddle::dialect::InferSymbolicShapeInterface
31253126

0 commit comments

Comments
 (0)