@@ -195,32 +195,73 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
195195 auto * weight_tensor =
196196 scope->Var (quantized_op_weight_node->Name ())->GetMutable <LoDTensor>();
197197 auto w_dims = weight_tensor->dims ();
198+ float * quantized_weight_data =
199+ weight_tensor->mutable_data <float >(platform::CPUPlace ());
198200 // If quantized op is fc, weight scale size = 1;
199201 // If quantized op is conv2d, weight scale size = weight dims[0]
200202 // If quantized op is conv2d_transpose, weight scale size = weight dims[1]
201- bool valid_scale_size =
202- (weight_scale.size () == 1 ||
203- weight_scale.size () == static_cast <size_t >(w_dims[0 ]) ||
204- weight_scale.size () == static_cast <size_t >(w_dims[1 ]));
205- PADDLE_ENFORCE_EQ (
206- valid_scale_size, true ,
207- platform::errors::InvalidArgument (
208- " TRT int8 quant: invalid scale size(%d)." , weight_scale.size ()));
209- float * quantized_weight_data =
210- weight_tensor->mutable_data <float >(platform::CPUPlace ());
211- for (int j = 0 ; j < weight_tensor->numel (); j++) {
212- if (weight_scale.size () == 1 ) {
213- quantized_weight_data[j] *= weight_scale[0 ];
214- } else {
215- if (quantized_op_type == " conv2d_transpose" ) {
216- int inner_size = w_dims[2 ] * w_dims[3 ];
217- quantized_weight_data[j] *=
218- weight_scale[(j / inner_size) % w_dims[1 ]];
219- } else {
220- int inner_size = w_dims[1 ] * w_dims[2 ] * w_dims[3 ];
221- quantized_weight_data[j] *= weight_scale[j / inner_size];
203+ if (quantized_op_type == " mul" || quantized_op_type == " fc" ) {
204+ if (dequant_type == " fake_dequantize_max_abs" ) {
205+ PADDLE_ENFORCE_EQ (
206+ weight_scale.size (), 1 ,
207+ platform::errors::InvalidArgument (
208+ " mul op weight dequantized by [fake_dequantize_max_abs] "
209+ " requires weight scale size = 1, but got %d." ,
210+ weight_scale.size ()));
211+ for (int j = 0 ; j < weight_tensor->numel (); j++) {
212+ quantized_weight_data[j] *= weight_scale[0 ];
222213 }
223214 }
215+ if (dequant_type == " fake_channel_wise_dequantize_max_abs" ) {
216+ PADDLE_ENFORCE_EQ (
217+ weight_scale.size (), static_cast <size_t >(w_dims[1 ]),
218+ platform::errors::InvalidArgument (
219+ " mul op weight dequantized by "
220+ " [fake_channel_wise_dequantize_max_abs] requires weight scale "
221+ " size = 2nd dim of mul's weight, which is %d, but got %d." ,
222+ static_cast <size_t >(w_dims[1 ]), weight_scale.size ()));
223+ for (int j = 0 ; j < weight_tensor->numel (); j++) {
224+ quantized_weight_data[j] *= weight_scale[j % w_dims[1 ]];
225+ }
226+ }
227+ } else if (quantized_op_type == " conv2d" ||
228+ quantized_op_type == " depthwise_conv2d" ) {
229+ PADDLE_ENFORCE_EQ (
230+ dequant_type, " fake_channel_wise_dequantize_max_abs" ,
231+ platform::errors::InvalidArgument (" conv2d op must be dequantized by "
232+ " [fake_channel_wise_dequantize_max_"
233+ " abs], but got %s" ,
234+ dequant_type));
235+ PADDLE_ENFORCE_EQ (
236+ weight_scale.size (), static_cast <size_t >(w_dims[0 ]),
237+ platform::errors::InvalidArgument (
238+ " conv2d op requires weight scale size = channel size of the "
239+ " weight, which is %d, but got %d." ,
240+ static_cast <size_t >(w_dims[0 ]), weight_scale.size ()));
241+ for (int j = 0 ; j < weight_tensor->numel (); j++) {
242+ int inner_size = w_dims[1 ] * w_dims[2 ] * w_dims[3 ];
243+ quantized_weight_data[j] *= weight_scale[j / inner_size];
244+ }
245+ } else if (quantized_op_type == " conv2d_transpose" ) {
246+ PADDLE_ENFORCE_EQ (
247+ dequant_type, " fake_channel_wise_dequantize_max_abs" ,
248+ platform::errors::InvalidArgument (
249+ " conv2d_transpose must be dequantized by "
250+ " [fake_channel_wise_dequantize_max_abs], but got %s" ,
251+ dequant_type));
252+ PADDLE_ENFORCE_EQ (
253+ weight_scale.size (), static_cast <size_t >(w_dims[1 ]),
254+ platform::errors::InvalidArgument (
255+ " conv2d_transpose op requires weight scale size = channel size "
256+ " of the weight, which is %d, but got %d." ,
257+ static_cast <size_t >(w_dims[1 ]), weight_scale.size ()));
258+ for (int j = 0 ; j < weight_tensor->numel (); j++) {
259+ int inner_size = w_dims[2 ] * w_dims[3 ];
260+ quantized_weight_data[j] *= weight_scale[(j / inner_size) % w_dims[1 ]];
261+ }
262+ } else {
263+ PADDLE_THROW (platform::errors::InvalidArgument (
264+ " Unsupported quantized op type: %s" , quantized_op_type));
224265 }
225266
226267 // create new op_desc
@@ -285,6 +326,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
285326
286327REGISTER_PASS (quant_conv2d_dequant_fuse_pass,
287328 paddle::framework::ir::QuantDequantFusePass);
329+ REGISTER_PASS_CAPABILITY (quant_conv2d_dequant_fuse_pass);
288330
289331REGISTER_PASS_CAPABILITY (tensorrt_subgraph_pass)
290332 .AddCombination(
0 commit comments