Skip to content

Commit 8967a66

Browse files
authored
support quantization of conv2d_transpose (#34547)
1 parent 4d88cdb commit 8967a66

File tree

6 files changed

+225
-32
lines changed

6 files changed

+225
-32
lines changed

python/paddle/fluid/contrib/slim/quantization/imperative/qat.py

Lines changed: 76 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,18 @@ class ImperativeQuantAware(object):
4242
Applying quantization aware training (QAT) to the dgraph model.
4343
"""
4444

45-
def __init__(self,
46-
quantizable_layer_type=['Conv2D', 'Linear'],
47-
weight_quantize_type='abs_max',
48-
activation_quantize_type='moving_average_abs_max',
49-
weight_bits=8,
50-
activation_bits=8,
51-
moving_rate=0.9,
52-
weight_preprocess_layer=None,
53-
act_preprocess_layer=None,
54-
weight_quantize_layer=None,
55-
act_quantize_layer=None):
45+
def __init__(
46+
self,
47+
quantizable_layer_type=['Conv2D', 'Linear', 'Conv2DTranspose'],
48+
weight_quantize_type='abs_max',
49+
activation_quantize_type='moving_average_abs_max',
50+
weight_bits=8,
51+
activation_bits=8,
52+
moving_rate=0.9,
53+
weight_preprocess_layer=None,
54+
act_preprocess_layer=None,
55+
weight_quantize_layer=None,
56+
act_quantize_layer=None):
5657
"""
5758
The constructor for ImperativeQuantAware.
5859
@@ -212,9 +213,44 @@ def quantize(self, model):
212213
the out_scale value of outputs would be calculated.
213214
214215
Args:
215-
model(fluid.dygraph.Layer): the model to be quantized.
216+
model(paddle.nn.Layer): the model to be quantized.
216217
Returns:
217218
None
219+
220+
Examples:
221+
.. code-block:: python
222+
223+
import paddle
224+
from paddle.fluid.contrib.slim.quantization \
225+
import ImperativeQuantAware
226+
227+
class ImperativeModel(paddle.nn.Layer):
228+
def __init__(self):
229+
super(ImperativeModel, self).__init__()
230+
# self.linear_0 would skip the quantization.
231+
self.linear_0 = paddle.nn.Linear(784, 400)
232+
self.linear_0.skip_quant = True
233+
234+
# self.linear_1 would not skip the quantization.
235+
self.linear_1 = paddle.nn.Linear(400, 10)
236+
self.linear_1.skip_quant = False
237+
238+
def forward(self, inputs):
239+
x = self.linear_0(inputs)
240+
x = self.linear_1(inputs)
241+
return x
242+
243+
model = ImperativeModel()
244+
imperative_qat = ImperativeQuantAware(
245+
weight_quantize_type='abs_max',
246+
activation_quantize_type='moving_average_abs_max')
247+
248+
# Add the fake quant logical.
249+
# The original model will be rewrite.
250+
#
251+
# There is only one Layer(self.linear1) would be added the
252+
# fake quant logical.
253+
imperative_qat.quantize(model)
218254
"""
219255
assert isinstance(model, dygraph.Layer), \
220256
"The model must be the instance of dygraph.Layer."
@@ -232,17 +268,18 @@ class ImperativeQuantizeInputs(object):
232268
logic both for activation inputs and weight inputs.
233269
"""
234270

235-
def __init__(self,
236-
quantizable_layer_type=['Conv2D', 'Linear'],
237-
weight_quantize_type='abs_max',
238-
activation_quantize_type='moving_average_abs_max',
239-
weight_bits=8,
240-
activation_bits=8,
241-
moving_rate=0.9,
242-
weight_preprocess_layer=None,
243-
act_preprocess_layer=None,
244-
weight_quantize_layer=None,
245-
act_quantize_layer=None):
271+
def __init__(
272+
self,
273+
quantizable_layer_type=['Conv2D', 'Linear', 'Conv2DTranspose'],
274+
weight_quantize_type='abs_max',
275+
activation_quantize_type='moving_average_abs_max',
276+
weight_bits=8,
277+
activation_bits=8,
278+
moving_rate=0.9,
279+
weight_preprocess_layer=None,
280+
act_preprocess_layer=None,
281+
weight_quantize_layer=None,
282+
act_quantize_layer=None):
246283
"""
247284
The constructor for ImperativeQuantizeInputs.
248285
@@ -303,6 +340,18 @@ def __init__(self,
303340
}
304341

305342
def apply(self, model):
343+
"""
344+
Quantize the weights and activations to calculate for specific
345+
layers.
346+
347+
Args:
348+
model(paddle.nn.Layer): The target model which would
349+
calculate the input quantization scale.
350+
351+
Returns:
352+
None
353+
"""
354+
306355
assert isinstance(model, dygraph.Layer), \
307356
"The model must be the instance of dygraph.Layer."
308357

@@ -354,7 +403,7 @@ def apply(self, model):
354403
output scales for specific layers in the dygraph model.
355404
356405
Args:
357-
model(fluid.dygraph.Layer): The target model which would be
406+
model(paddle.nn.Layer): The target model which would be
358407
calculate the output quantization scale.
359408
360409
Returns:
@@ -544,7 +593,9 @@ def _is_skip_quant_op(self, block, in_op):
544593
1. the type of input op should be conv2d, depthwise_conv2d or matmul
545594
2. the previous ops of the input op are not fake_quantize_dequantize ops
546595
"""
547-
target_op_types = ["conv2d", "depthwise_conv2d", "matmul"]
596+
target_op_types = [
597+
"conv2d", "depthwise_conv2d", "matmul", "conv2d_transpose"
598+
]
548599
if in_op.type not in target_op_types:
549600
return False
550601

python/paddle/fluid/contrib/slim/quantization/imperative/utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ..quantization_pass import _get_input_name_index
2525

2626
layer_name_map = {
27+
'Conv2DTranspose': paddle.nn.Conv2DTranspose,
2728
'Conv2D': paddle.nn.Conv2D,
2829
'Linear': paddle.nn.Linear,
2930
'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D,
@@ -46,8 +47,9 @@
4647
}
4748

4849
# Apply fake quant for the inputs of these layers
49-
# TODO (jc): support paddle.nn.Conv2DTranspose
50-
fake_quant_input_layers = [paddle.nn.Conv2D, paddle.nn.Linear]
50+
fake_quant_input_layers = [
51+
paddle.nn.Conv2D, paddle.nn.Linear, paddle.nn.Conv2DTranspose
52+
]
5153

5254
# Apply fake quant for the output of these layers
5355
# TODO(jc): fix the problem of adding duplicate fake_quant ops
@@ -65,7 +67,8 @@
6567
]
6668

6769
fake_quant_wrap_layers = [
68-
quant_layers.QuantizedConv2D, quant_layers.QuantizedLinear
70+
quant_layers.QuantizedConv2D, quant_layers.QuantizedLinear,
71+
quant_layers.QuantizedConv2DTranspose
6972
]
7073

7174
# The weight format of these layers is Cin * Cout * H * W
@@ -84,9 +87,9 @@
8487

8588

8689
def load_variable_data(scope, var_name):
87-
'''
90+
"""
8891
Load variable value from scope
89-
'''
92+
"""
9093
var_node = scope.find_var(var_name)
9194
assert var_node is not None, \
9295
"Can not find " + var_name + " in the scope."
@@ -120,6 +123,12 @@ def find_parent_layer_and_sub_name(model, name):
120123
the sub_name of the layer.
121124
For example, if name is 'block_1/convbn_1/conv_1', the parent layer is
122125
'block_1/convbn_1' and the sub_name is `conv_1`.
126+
Args:
127+
model(paddle.nn.Layer): the model to be quantized.
128+
name(string): the name of a layer
129+
130+
Returns:
131+
parent_layer, subname
123132
"""
124133
assert isinstance(model, paddle.nn.Layer), \
125134
"The model must be the instance of paddle.nn.Layer."

python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
from paddle.fluid.optimizer import AdamOptimizer
2929
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
3030
from paddle.fluid.dygraph.container import Sequential
31-
from paddle.nn import Linear, Conv2D, Softmax
31+
from paddle.nn import Linear, Conv2D, Softmax, Conv2DTranspose
3232
from paddle.fluid.log_helper import get_logger
3333
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
34-
from paddle.nn.quant.quant_layers import QuantizedConv2D
34+
from paddle.nn.quant.quant_layers import QuantizedConv2D, QuantizedConv2DTranspose
3535

3636
from imperative_test_utils import fix_model_dict, ImperativeLenet
3737

@@ -75,6 +75,12 @@ def test_qat(self):
7575
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
7676
quant_conv1(fluid.dygraph.to_variable(data))
7777

78+
conv_transpose = Conv2DTranspose(4, 6, (3, 3))
79+
quant_conv_transpose = QuantizedConv2DTranspose(conv_transpose)
80+
x_var = paddle.uniform(
81+
(2, 4, 8, 8), dtype='float32', min=-1.0, max=1.0)
82+
quant_conv_transpose(x_var)
83+
7884
seed = 1
7985
np.random.seed(seed)
8086
fluid.default_main_program().random_seed = seed

python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from paddle.fluid.dygraph import Conv2D
2929
from paddle.fluid.dygraph import Pool2D
3030
from paddle.fluid.dygraph import Linear
31+
from paddle.nn.quant.quant_layers import QuantizedConv2DTranspose
3132
from paddle.fluid.log_helper import get_logger
3233

3334
os.environ["CPU_NUM"] = "1"
@@ -100,6 +101,19 @@ def dequantize(x, lower_bound, delta, interval):
100101
return x
101102

102103

104+
class ModelForConv2dT(nn.Layer):
105+
def __init__(self, num_classes=10):
106+
super(ModelForConv2dT, self).__init__()
107+
self.features = nn.Conv2DTranspose(4, 6, (3, 3))
108+
self.fc = Linear(input_dim=600, output_dim=num_classes)
109+
110+
def forward(self, inputs):
111+
x = self.features(inputs)
112+
x = paddle.flatten(x, 1)
113+
x = self.fc(x)
114+
return x
115+
116+
103117
class ImperativeLenet(paddle.nn.Layer):
104118
def __init__(self, num_classes=10, classifier_activation='softmax'):
105119
super(ImperativeLenet, self).__init__()
@@ -168,6 +182,11 @@ def test_quant_aware_training(self):
168182
imperative_qat.quantize(lenet)
169183
adam = Adam(learning_rate=0.001, parameters=lenet.parameters())
170184
dynamic_loss_rec = []
185+
#for CI coverage
186+
conv_transpose = ModelForConv2dT()
187+
imperative_qat.quantize(conv_transpose)
188+
x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
189+
conv_transpose(x_var)
171190

172191
def train(model):
173192
adam = Adam(learning_rate=0.001, parameters=model.parameters())

python/paddle/nn/quant/quant_layers.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
'FakeQuantMovingAverageAbsMax',
3232
'FakeQuantChannelWiseAbsMax',
3333
'QuantizedConv2D',
34+
'QuantizedConv2DTranspose',
3435
'QuantizedLinear',
3536
'MovingAverageAbsMaxScale',
3637
'MAOutputScaleLayer',
@@ -481,6 +482,112 @@ def forward(self, input):
481482
data_format=self._data_format)
482483

483484

485+
class QuantizedConv2DTranspose(layers.Layer):
486+
"""
487+
The computational logic of QuantizedConv2DTranspose is the same with Conv2DTranspose.
488+
The only difference is that its inputs are all fake quantized.
489+
490+
Examples:
491+
.. code-block:: python
492+
import paddle
493+
import paddle.nn as nn
494+
from paddle.nn.quant.quant_layers import QuantizedConv2DTranspose
495+
x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
496+
conv = nn.Conv2DTranspose(4, 6, (3, 3))
497+
conv_quantized = QuantizedConv2DTranspose(conv)
498+
y_quantized = conv_quantized(x_var)
499+
y_var = conv(x_var)
500+
y_quantized_np = y_quantized.numpy()
501+
y_np = y_var.numpy()
502+
print(y_np.shape, y_quantized_np.shape)
503+
# (2, 6, 10, 10), (2, 6, 10, 10)
504+
"""
505+
506+
def __init__(self,
507+
layer,
508+
weight_bits=8,
509+
activation_bits=8,
510+
moving_rate=0.9,
511+
weight_quantize_type='abs_max',
512+
activation_quantize_type='abs_max',
513+
weight_pre_layer=None,
514+
act_pre_layer=None,
515+
weight_quant_layer=None,
516+
act_quant_layer=None):
517+
r"""
518+
Constructor.
519+
520+
The arguments are the same as ImperativeQuantAware.
521+
"""
522+
super(QuantizedConv2DTranspose, self).__init__()
523+
# For Conv2DTranspose
524+
self._groups = getattr(layer, '_groups')
525+
self._stride = getattr(layer, '_stride')
526+
self._padding = getattr(layer, '_padding')
527+
self._output_padding = getattr(layer, 'output_padding')
528+
self._dilation = getattr(layer, '_dilation')
529+
self._data_format = getattr(layer, '_data_format')
530+
self.weight = getattr(layer, 'weight')
531+
self.bias = getattr(layer, 'bias')
532+
# For FakeQuant
533+
self._conv2d_transpose_quant_axis = 1
534+
if weight_quant_layer is not None:
535+
self._fake_quant_weight = weight_quant_layer()
536+
else:
537+
self._fake_quant_weight = _get_fake_quant_type(
538+
weight_quantize_type,
539+
name=self.weight.name,
540+
moving_rate=moving_rate,
541+
quant_bits=weight_bits,
542+
dtype=self._dtype,
543+
quant_on_weight=True,
544+
channel_num=self.weight.shape[
545+
self._conv2d_transpose_quant_axis],
546+
quant_axis=self._conv2d_transpose_quant_axis)
547+
if act_quant_layer is not None:
548+
self._fake_quant_input = act_quant_layer()
549+
else:
550+
self._fake_quant_input = _get_fake_quant_type(
551+
activation_quantize_type,
552+
name=layer.full_name(),
553+
moving_rate=moving_rate,
554+
quant_bits=activation_bits,
555+
dtype=self._dtype,
556+
quant_on_weight=False)
557+
558+
self._act_preprocess = act_pre_layer(
559+
) if act_pre_layer is not None else None
560+
self._weight_preprocess = weight_pre_layer(
561+
) if weight_pre_layer is not None else None
562+
563+
def forward(self, input, output_size=None):
564+
if self._act_preprocess is not None:
565+
input = self._act_preprocess(input)
566+
quant_input = self._fake_quant_input(input)
567+
568+
weight = self.weight
569+
if self._weight_preprocess is not None:
570+
weight = self._weight_preprocess(self.weight)
571+
quant_weight = self._fake_quant_weight(weight)
572+
573+
if output_size is None:
574+
output_padding = self._output_padding
575+
else:
576+
output_padding = 0
577+
578+
return F.conv2d_transpose(
579+
quant_input,
580+
quant_weight,
581+
bias=self.bias,
582+
padding=self._padding,
583+
output_padding=output_padding,
584+
stride=self._stride,
585+
dilation=self._dilation,
586+
groups=self._groups,
587+
output_size=output_size,
588+
data_format=self._data_format)
589+
590+
484591
class QuantizedLinear(layers.Layer):
485592
"""
486593
The computational logic of QuantizedLinear is the same with Linear.

tools/sampcd_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ def get_filenames(full_test=False):
440440
'''
441441
global whl_error
442442
import paddle
443+
import paddle.fluid.contrib.slim.quantization
443444
whl_error = []
444445
if full_test:
445446
get_full_api_from_pr_spec()

0 commit comments

Comments
 (0)