@@ -31,7 +31,7 @@ int getSMVersion() {
3131 sm_version = paddle::platform::GetGPUComputeCapability (
3232 paddle::platform::GetCurrentDeviceId ());
3333#else
34- PADDLE_THROW (paddle::platform ::errors::Unavailable (
34+ PADDLE_THROW (common ::errors::Unavailable (
3535 " fused_weight_only_linear_pass needs paddle compiled with CUDA." ));
3636#endif
3737 return sm_version;
@@ -41,10 +41,14 @@ class FusedWeightOnlyLinearWithBiasPattern
4141 : public paddle::drr::DrrPatternBase {
4242 private:
4343 bool reverse_add_;
44+ std::string algo_;
45+ int sm_version_;
4446
4547 public:
46- explicit FusedWeightOnlyLinearWithBiasPattern (bool reverse_add)
47- : reverse_add_(reverse_add) {}
48+ FusedWeightOnlyLinearWithBiasPattern (bool reverse_add,
49+ const std::string &algo,
50+ int sm_version)
51+ : reverse_add_(reverse_add), algo_(algo), sm_version_(sm_version) {}
4852
4953 std::string name () const override {
5054 return " FusedWeightOnlyLinearWithBiasPattern" ;
@@ -104,19 +108,49 @@ class FusedWeightOnlyLinearWithBiasPattern
104108 //
105109 paddle::drr::ResultPattern res = src.ResultPattern ();
106110
107- const auto &weight_quantize =
108- res.Op (paddle::dialect::WeightQuantizeOp::name (),
109- {{" algo" , res.StrAttr (" weight_only_int8" )},
110- {" arch" , res.Int32Attr (getSMVersion ())},
111- {" group_size" , res.Int32Attr (-1 )}});
112- weight_quantize ({&res.Tensor (" w" )},
113- {&res.Tensor (" quanted_weight_tensor" ),
114- &res.Tensor (" weight_scale_tensor" )});
111+ if (algo_ == " weight_only_int4" ) {
112+ // TODO(liuyuanle): When the operator weight_quantize supports
113+ // weight_only_int4 on gpu version, delete the memory copy.
114+ const auto &memcpy_d2h =
115+ res.Op (paddle::dialect::MemcpyD2hOp::name (),
116+ {{" dst_place_type" , res.Int32Attr (0 /* cpu*/ )}});
117+ res.Tensor (" w_cpu" ) = memcpy_d2h (res.Tensor (" w" ));
118+ const auto &weight_quantize =
119+ res.Op (paddle::dialect::WeightQuantizeOp::name (),
120+ {{" algo" , res.StrAttr (algo_)},
121+ {" arch" , res.Int32Attr (sm_version_)},
122+ {" group_size" , res.Int32Attr (-1 )}});
123+ weight_quantize ({&res.Tensor (" w_cpu" )},
124+ {&res.Tensor (" quanted_weight_tensor_cpu" ),
125+ &res.Tensor (" weight_scale_tensor_cpu" )});
126+
127+ const auto &memcpy_h2d_1 =
128+ res.Op (paddle::dialect::MemcpyH2dOp::name (),
129+ {{" dst_place_type" , res.Int32Attr (1 /* gpu*/ )}});
130+ res.Tensor (" quanted_weight_tensor" ) =
131+ memcpy_h2d_1 (res.Tensor (" quanted_weight_tensor_cpu" ));
132+ const auto &memcpy_h2d_2 =
133+ res.Op (paddle::dialect::MemcpyH2dOp::name (),
134+ {{" dst_place_type" , res.Int32Attr (1 /* gpu*/ )}});
135+ res.Tensor (" weight_scale_tensor" ) =
136+ memcpy_h2d_2 (res.Tensor (" weight_scale_tensor_cpu" ));
137+ } else {
138+ const auto &weight_quantize =
139+ res.Op (paddle::dialect::WeightQuantizeOp::name (),
140+ {{" algo" , res.StrAttr (algo_)},
141+ {" arch" , res.Int32Attr (sm_version_)},
142+ {" group_size" , res.Int32Attr (-1 )}});
143+
144+ weight_quantize ({&res.Tensor (" w" )},
145+ {&res.Tensor (" quanted_weight_tensor" ),
146+ &res.Tensor (" weight_scale_tensor" )});
147+ }
115148
116149 const auto &weight_only_linear =
117150 res.Op (paddle::dialect::WeightOnlyLinearOp::name (),
118- {{" weight_dtype" , res.StrAttr (" int8" )},
119- {" arch" , res.Int32Attr (getSMVersion ())},
151+ {{" weight_dtype" ,
152+ res.StrAttr (algo_ == " weight_only_int8" ? " int8" : " int4" )},
153+ {" arch" , res.Int32Attr (sm_version_)},
120154 {" group_size" , res.Int32Attr (-1 )}});
121155 weight_only_linear ({&res.Tensor (" x" ),
122156 &res.Tensor (" quanted_weight_tensor" ),
@@ -127,6 +161,14 @@ class FusedWeightOnlyLinearWithBiasPattern
127161};
128162
129163class FusedWeightOnlyLinearNoBiasPattern : public paddle ::drr::DrrPatternBase {
164+ private:
165+ std::string algo_;
166+ int sm_version_;
167+
168+ public:
169+ FusedWeightOnlyLinearNoBiasPattern (const std::string &algo, int sm_version)
170+ : algo_(algo), sm_version_(sm_version) {}
171+
130172 public:
131173 std::string name () const override {
132174 return " FusedWeightOnlyLinearNoBiasPattern" ;
@@ -179,19 +221,48 @@ class FusedWeightOnlyLinearNoBiasPattern : public paddle::drr::DrrPatternBase {
179221 //
180222 paddle::drr::ResultPattern res = src.ResultPattern ();
181223
182- const auto &weight_quantize =
183- res.Op (paddle::dialect::WeightQuantizeOp::name (),
184- {{" algo" , res.StrAttr (" weight_only_int8" )},
185- {" arch" , res.Int32Attr (getSMVersion ())},
186- {" group_size" , res.Int32Attr (-1 )}});
187- weight_quantize ({&res.Tensor (" w" )},
188- {&res.Tensor (" quanted_weight_tensor" ),
189- &res.Tensor (" weight_scale_tensor" )});
190-
224+ if (algo_ == " weight_only_int4" ) {
225+ // TODO(liuyuanle): When the operator weight_quantize supports
226+ // weight_only_int4 on gpu version, delete the memory copy.
227+ const auto &memcpy_d2h =
228+ res.Op (paddle::dialect::MemcpyD2hOp::name (),
229+ {{" dst_place_type" , res.Int32Attr (0 /* cpu*/ )}});
230+ res.Tensor (" w_cpu" ) = memcpy_d2h (res.Tensor (" w" ));
231+ const auto &weight_quantize =
232+ res.Op (paddle::dialect::WeightQuantizeOp::name (),
233+ {{" algo" , res.StrAttr (algo_)},
234+ {" arch" , res.Int32Attr (sm_version_)},
235+ {" group_size" , res.Int32Attr (-1 )}});
236+ weight_quantize ({&res.Tensor (" w_cpu" )},
237+ {&res.Tensor (" quanted_weight_tensor_cpu" ),
238+ &res.Tensor (" weight_scale_tensor_cpu" )});
239+
240+ const auto &memcpy_h2d_1 =
241+ res.Op (paddle::dialect::MemcpyH2dOp::name (),
242+ {{" dst_place_type" , res.Int32Attr (1 /* gpu*/ )}});
243+ res.Tensor (" quanted_weight_tensor" ) =
244+ memcpy_h2d_1 (res.Tensor (" quanted_weight_tensor_cpu" ));
245+ const auto &memcpy_h2d_2 =
246+ res.Op (paddle::dialect::MemcpyH2dOp::name (),
247+ {{" dst_place_type" , res.Int32Attr (1 /* gpu*/ )}});
248+ res.Tensor (" weight_scale_tensor" ) =
249+ memcpy_h2d_2 (res.Tensor (" weight_scale_tensor_cpu" ));
250+ } else {
251+ const auto &weight_quantize =
252+ res.Op (paddle::dialect::WeightQuantizeOp::name (),
253+ {{" algo" , res.StrAttr (algo_)},
254+ {" arch" , res.Int32Attr (sm_version_)},
255+ {" group_size" , res.Int32Attr (-1 )}});
256+
257+ weight_quantize ({&res.Tensor (" w" )},
258+ {&res.Tensor (" quanted_weight_tensor" ),
259+ &res.Tensor (" weight_scale_tensor" )});
260+ }
191261 const auto &weight_only_linear =
192262 res.Op (paddle::dialect::WeightOnlyLinearOp::name (),
193- {{" weight_dtype" , res.StrAttr (" int8" )},
194- {" arch" , res.Int32Attr (getSMVersion ())},
263+ {{" weight_dtype" ,
264+ res.StrAttr (algo_ == " weight_only_int8" ? " int8" : " int4" )},
265+ {" arch" , res.Int32Attr (sm_version_)},
195266 {" group_size" , res.Int32Attr (-1 )}});
196267 weight_only_linear ({&res.Tensor (" x" ),
197268 &res.Tensor (" quanted_weight_tensor" ),
@@ -204,15 +275,28 @@ class FusedWeightOnlyLinearNoBiasPattern : public paddle::drr::DrrPatternBase {
204275class FusedWeightOnlyLinearPass : public pir ::PatternRewritePass {
205276 public:
206277 FusedWeightOnlyLinearPass ()
207- : pir::PatternRewritePass(" fused_weight_only_linear_pass" , 4 ) {}
278+ : pir::PatternRewritePass(" fused_weight_only_linear_pass" , 4 ),
279+ sm_version_ (getSMVersion()) {}
208280
209281 pir::RewritePatternSet InitializePatterns (pir::IrContext *context) override {
282+ std::string algo = " weight_only_int4" ;
283+ if (Has (" weight_only_algo" )) {
284+ algo = Get<std::string>(" weight_only_algo" );
285+ }
286+ PADDLE_ENFORCE_EQ (algo == " weight_only_int8" || algo == " weight_only_int4" ,
287+ true ,
288+ common::errors::InvalidArgument (
289+ " fused_weight_only_linear_pass only support "
290+ " weight_only_int8 or weight_only_int4, but get %s." ,
291+ algo));
292+
210293 pir::RewritePatternSet ps (context);
211- ps.Add (paddle::drr::Create<FusedWeightOnlyLinearWithBiasPattern>(context,
212- true ));
213- ps.Add (paddle::drr::Create<FusedWeightOnlyLinearWithBiasPattern>(context,
214- false ));
215- ps.Add (paddle::drr::Create<FusedWeightOnlyLinearNoBiasPattern>(context));
294+ ps.Add (paddle::drr::Create<FusedWeightOnlyLinearWithBiasPattern>(
295+ context, true , algo, sm_version_));
296+ ps.Add (paddle::drr::Create<FusedWeightOnlyLinearWithBiasPattern>(
297+ context, false , algo, sm_version_));
298+ ps.Add (paddle::drr::Create<FusedWeightOnlyLinearNoBiasPattern>(
299+ context, algo, sm_version_));
216300 return ps;
217301 }
218302
@@ -228,15 +312,15 @@ class FusedWeightOnlyLinearPass : public pir::PatternRewritePass {
228312 }
229313
230314 bool CanApplyOn (pir::Operation *op) const override {
231- int sm_version = getSMVersion ();
232- if (sm_version != 70 && sm_version != 75 && sm_version != 80 &&
233- sm_version != 86 ) {
315+ if (sm_version_ != 70 && sm_version_ != 75 && sm_version_ != 80 &&
316+ sm_version_ != 86 ) {
234317 return false ;
235318 }
236319 return op->num_regions () > 0 ;
237320 }
238321
239322 private:
323+ int sm_version_;
240324 pir::FrozenRewritePatternSet patterns_;
241325};
242326
0 commit comments