1414
1515#include " paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h"
1616
17+ #include < algorithm>
1718#include < memory>
1819#include < string>
1920#include < unordered_set>
@@ -75,6 +76,12 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
7576 any_op2_desc->Flush ();
7677 auto dequant_type = quant_dequant_op->Op ()->Type ();
7778 auto quantized_op_type = any_op2_desc->Type ();
79+ // get weight tensor
80+ auto * weight_tensor =
81+ scope->GetVar (quant_dequant_op_x->Name ())->GetMutable <LoDTensor>();
82+ auto w_dims = weight_tensor->dims ();
83+ float * quantized_weight_data =
84+ weight_tensor->mutable_data <float >(platform::CPUPlace ());
7885
7986 // Get weight scale
8087 if (dequant_type == " fake_channel_wise_quantize_dequantize_abs_max" ) {
@@ -90,26 +97,64 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
9097 paddle::platform::is_cpu_place (channel_scale_tensor.place ()),
9198 platform::errors::InvalidArgument (
9299 " Channel scale tensor's place should be CPU." ));
93- const float * channel_scale_data = channel_scale_tensor.data <float >();
94- for (int i = 0 ; i < channel_scale_tensor.numel (); i++) {
95- weight_scale.push_back (range / channel_scale_data[i]);
100+ // compute the channel wise abs max of the weight tensor
101+ int quant_axis =
102+ BOOST_GET_CONST (int , quant_dequant_op->Op ()->GetAttr (" quant_axis" ));
103+
104+ PADDLE_ENFORCE_EQ (quant_axis == 0 || quant_axis == 1 , true ,
105+ platform::errors::InvalidArgument (
106+ " 'quant_axis' should be 0 or 1, but "
107+ " the received is %d" ,
108+ quant_axis));
109+
110+ const int64_t channel = w_dims[quant_axis];
111+ weight_scale.resize (channel, 0 );
112+ if (quant_axis == 0 ) {
113+ const int64_t channel_size = weight_tensor->numel () / channel;
114+ for (int64_t i = 0 ; i < channel; i++) {
115+ auto * start = quantized_weight_data + i * channel_size;
116+ for (int64_t j = 0 ; j < channel_size; j++) {
117+ weight_scale[i] = std::max (std::abs (start[j]), weight_scale[i]);
118+ }
119+ }
120+ } else if (quant_axis == 1 ) {
121+ const int64_t step_i = weight_tensor->numel () / w_dims[0 ];
122+ const int64_t step_j = weight_tensor->numel () / (w_dims[0 ] * w_dims[1 ]);
123+ for (int64_t i = 0 ; i < w_dims[0 ]; i++) {
124+ for (int64_t j = 0 ; j < w_dims[1 ]; j++) {
125+ auto * start = quantized_weight_data + i * step_i + j * step_j;
126+ float abs_max = 0 ;
127+ for (int64_t k = 0 ; k < step_j; k++) {
128+ abs_max = std::max (std::abs (start[k]), abs_max);
129+ }
130+ weight_scale[j] = std::max (weight_scale[j], abs_max);
131+ }
132+ }
133+ }
134+ for (int i = 0 ; i < channel; i++) {
135+ PADDLE_ENFORCE_NE (weight_scale[i], 0 ,
136+ platform::errors::InvalidArgument (
137+ " Weight scale should be nonzero, but get zero." ));
138+ weight_scale[i] = range / weight_scale[i];
96139 }
97140 } else {
98141 auto scale_name = quant_dequant_op_outscale->Name ();
99- const LoDTensor& scale_tensor =
100- scope->GetVar (scale_name)->Get <LoDTensor>();
101- const float * scale_data = scale_tensor.data <float >();
102- weight_scale.push_back ((range * range) / scale_data[0 ] / range);
142+ // compute the abs max of the weight tensor
143+ float abs_max_weight = 0 .;
144+ for (int j = 0 ; j < weight_tensor->numel (); j++) {
145+ abs_max_weight =
146+ std::max (abs_max_weight, std::abs (quantized_weight_data[j]));
147+ }
148+ PADDLE_ENFORCE_NE (abs_max_weight, 0 ,
149+ platform::errors::InvalidArgument (
150+ " Weight scale should be nonzero, but get zero" ));
151+ weight_scale.push_back ((range * range) / abs_max_weight / range);
103152 }
104153
105154 nodes2rm.insert (quant_dequant_op_outscale);
155+
106156 // perform quantize dequantize operations
107- auto * weight_tensor =
108- scope->GetVar (quant_dequant_op_x->Name ())->GetMutable <LoDTensor>();
109- auto w_dims = weight_tensor->dims ();
110- float * quantized_weight_data =
111- weight_tensor->mutable_data <float >(platform::CPUPlace ());
112- // If quantized op is fc, weight scale size = 1;
157+ // If quantized op is not channel wise, weight scale size = 1;
113158 // If quantized op is conv2d, weight scale size = weight dims[0]
114159 // If quantized op is conv2d_transpose, weight scale size = weight dims[1]
115160 if (dequant_type == " fake_quantize_dequantize_abs_max" ) {
@@ -119,9 +164,6 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
119164 " %s op weight dequantized by [fake_quantize_dequantize_max_abs] "
120165 " requires weight scale size = 1, but got %d." ,
121166 quantized_op_type, weight_scale.size ()));
122- PADDLE_ENFORCE_NE (weight_scale[0 ], 0 ,
123- platform::errors::InvalidArgument (
124- " Weight scale should be nonzero, but get zero" ));
125167 for (int j = 0 ; j < weight_tensor->numel (); j++) {
126168 // quantized
127169 quantized_weight_data[j] = quantized_weight_data[j] * weight_scale[0 ];
0 commit comments