Skip to content

Commit 01f5210

Browse files
authored
Add QuantizedMatmul in QAT (#47997)
1 parent 94fe929 commit 01f5210

File tree

3 files changed

+298
-0
lines changed

3 files changed

+298
-0
lines changed

python/paddle/fluid/contrib/slim/tests/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,8 @@ if(WIN32)
253253
list(REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2)
254254
list(REMOVE_ITEM TEST_OPS test_imperative_qat_amp)
255255
list(REMOVE_ITEM TEST_OPS test_imperative_qat_lsq)
256+
list(REMOVE_ITEM TEST_OPS test_imperative_qat_matmul)
257+
256258
endif()
257259

258260
if(LINUX AND WITH_MKLDNN)
@@ -507,6 +509,7 @@ if(WIN32)
507509
test_imperative_qat_channelwise
508510
test_imperative_qat
509511
test_imperative_qat_lsq
512+
test_imperative_qat_matmul
510513
test_imperative_out_scale
511514
test_graph)
512515
list(REMOVE_ITEM TEST_OPS ${SINGLE_CARD_TEST_OPS})
@@ -547,6 +550,7 @@ set_tests_properties(test_imperative_qat_fuse PROPERTIES TIMEOUT 200)
547550
set_tests_properties(test_imperative_out_scale PROPERTIES TIMEOUT 200)
548551
set_tests_properties(test_imperative_qat_user_defined PROPERTIES TIMEOUT 200)
549552
set_tests_properties(test_imperative_qat_lsq PROPERTIES TIMEOUT 300)
553+
set_tests_properties(test_imperative_qat_matmul PROPERTIES TIMEOUT 300)
550554

551555
if(LINUX AND WITH_MKLDNN)
552556
set_tests_properties(test_quant2_int8_mobilenetv1_mkldnn PROPERTIES TIMEOUT
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
2+
#
3+
# licensed under the apache license, version 2.0 (the "license");
4+
# you may not use this file except in compliance with the license.
5+
# you may obtain a copy of the license at
6+
#
7+
# http://www.apache.org/licenses/license-2.0
8+
#
9+
# unless required by applicable law or agreed to in writing, software
10+
# distributed under the license is distributed on an "as is" basis,
11+
# without warranties or conditions of any kind, either express or implied.
12+
# see the license for the specific language governing permissions and
13+
# limitations under the license.
14+
15+
import os
16+
import numpy as np
17+
import random
18+
import time
19+
import tempfile
20+
import unittest
21+
import logging
22+
23+
import paddle
24+
import paddle.fluid as fluid
25+
from paddle.fluid import core
26+
from paddle.fluid.optimizer import (
27+
SGDOptimizer,
28+
AdamOptimizer,
29+
MomentumOptimizer,
30+
)
31+
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
32+
from paddle.nn import Sequential
33+
from paddle.nn import ReLU, ReLU6, LeakyReLU, Sigmoid, Softmax, PReLU
34+
from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D
35+
from paddle.fluid.log_helper import get_logger
36+
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
37+
from paddle.nn.quant.quant_layers import (
38+
QuantizedConv2D,
39+
QuantizedMatmul,
40+
)
41+
from paddle.fluid.framework import _test_eager_guard
42+
from imperative_test_utils import fix_model_dict
43+
44+
paddle.enable_static()
45+
46+
os.environ["CPU_NUM"] = "1"
47+
if core.is_compiled_with_cuda():
48+
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
49+
50+
_logger = get_logger(
51+
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
52+
)
53+
54+
55+
class ImperativeLenet(fluid.dygraph.Layer):
56+
def __init__(self, num_classes=10):
57+
super().__init__()
58+
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
59+
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
60+
fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
61+
fc_w2_attr = fluid.ParamAttr(name="fc_w_2")
62+
fc_w3_attr = fluid.ParamAttr(name="fc_w_3")
63+
conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2")
64+
fc_b1_attr = fluid.ParamAttr(name="fc_b_1")
65+
fc_b2_attr = fluid.ParamAttr(name="fc_b_2")
66+
fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
67+
self.features = Sequential(
68+
Conv2D(
69+
in_channels=1,
70+
out_channels=6,
71+
kernel_size=3,
72+
stride=1,
73+
padding=1,
74+
weight_attr=conv2d_w1_attr,
75+
bias_attr=False,
76+
),
77+
BatchNorm2D(6),
78+
ReLU(),
79+
MaxPool2D(kernel_size=2, stride=2),
80+
Conv2D(
81+
in_channels=6,
82+
out_channels=16,
83+
kernel_size=5,
84+
stride=1,
85+
padding=0,
86+
weight_attr=conv2d_w2_attr,
87+
bias_attr=conv2d_b2_attr,
88+
),
89+
BatchNorm2D(16),
90+
PReLU(),
91+
MaxPool2D(kernel_size=2, stride=2),
92+
)
93+
self.matmul = QuantizedMatmul()
94+
self.fc = Sequential(
95+
Linear(
96+
in_features=400,
97+
out_features=120,
98+
weight_attr=fc_w1_attr,
99+
bias_attr=fc_b1_attr,
100+
),
101+
LeakyReLU(),
102+
Linear(
103+
in_features=120,
104+
out_features=84,
105+
weight_attr=fc_w2_attr,
106+
bias_attr=fc_b2_attr,
107+
),
108+
Sigmoid(),
109+
Linear(
110+
in_features=84,
111+
out_features=num_classes,
112+
weight_attr=fc_w3_attr,
113+
bias_attr=fc_b3_attr,
114+
),
115+
Softmax(),
116+
)
117+
118+
def forward(self, inputs):
119+
inputs = self.features(inputs)
120+
inputs = self.matmul(inputs, inputs, transpose_y=True)
121+
inputs = paddle.flatten(inputs, 1)
122+
x = self.fc(inputs)
123+
return x
124+
125+
126+
class TestImperativeQatMatmul(unittest.TestCase):
127+
def set_vars(self):
128+
self.weight_quantize_type = 'abs_max'
129+
self.activation_quantize_type = 'moving_average_abs_max'
130+
self.onnx_format = True
131+
self.fuse_conv_bn = False
132+
133+
def func_qat(self):
134+
self.set_vars()
135+
136+
imperative_qat = ImperativeQuantAware(
137+
weight_quantize_type=self.weight_quantize_type,
138+
activation_quantize_type=self.activation_quantize_type,
139+
fuse_conv_bn=self.fuse_conv_bn,
140+
)
141+
142+
seed = 100
143+
np.random.seed(seed)
144+
fluid.default_main_program().random_seed = seed
145+
fluid.default_startup_program().random_seed = seed
146+
paddle.disable_static()
147+
lenet = ImperativeLenet()
148+
lenet = fix_model_dict(lenet)
149+
imperative_qat.quantize(lenet)
150+
151+
optimizer = MomentumOptimizer(
152+
learning_rate=0.1, parameter_list=lenet.parameters(), momentum=0.9
153+
)
154+
155+
train_reader = paddle.batch(
156+
paddle.dataset.mnist.train(), batch_size=64, drop_last=True
157+
)
158+
test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=32)
159+
epoch_num = 1
160+
for epoch in range(epoch_num):
161+
lenet.train()
162+
for batch_id, data in enumerate(train_reader()):
163+
x_data = np.array(
164+
[x[0].reshape(1, 28, 28) for x in data]
165+
).astype('float32')
166+
y_data = (
167+
np.array([x[1] for x in data])
168+
.astype('int64')
169+
.reshape(-1, 1)
170+
)
171+
172+
img = fluid.dygraph.to_variable(x_data)
173+
label = fluid.dygraph.to_variable(y_data)
174+
out = lenet(img)
175+
acc = paddle.static.accuracy(out, label)
176+
loss = fluid.layers.cross_entropy(out, label)
177+
avg_loss = paddle.mean(loss)
178+
179+
avg_loss.backward()
180+
optimizer.minimize(avg_loss)
181+
lenet.clear_gradients()
182+
183+
if batch_id % 100 == 0:
184+
_logger.info(
185+
"Train | At epoch {} step {}: loss = {:}, acc= {:}".format(
186+
epoch, batch_id, avg_loss.numpy(), acc.numpy()
187+
)
188+
)
189+
190+
lenet.eval()
191+
eval_acc_top1_list = []
192+
with paddle.no_grad():
193+
for batch_id, data in enumerate(test_reader()):
194+
195+
x_data = np.array(
196+
[x[0].reshape(1, 28, 28) for x in data]
197+
).astype('float32')
198+
y_data = (
199+
np.array([x[1] for x in data])
200+
.astype('int64')
201+
.reshape(-1, 1)
202+
)
203+
img = fluid.dygraph.to_variable(x_data)
204+
label = fluid.dygraph.to_variable(y_data)
205+
206+
out = lenet(img)
207+
acc_top1 = paddle.static.accuracy(
208+
input=out, label=label, k=1
209+
)
210+
acc_top5 = paddle.static.accuracy(
211+
input=out, label=label, k=5
212+
)
213+
214+
if batch_id % 100 == 0:
215+
eval_acc_top1_list.append(float(acc_top1.numpy()))
216+
_logger.info(
217+
"Test | At epoch {} step {}: acc1 = {:}, acc5 = {:}".format(
218+
epoch,
219+
batch_id,
220+
acc_top1.numpy(),
221+
acc_top5.numpy(),
222+
)
223+
)
224+
225+
# check eval acc
226+
eval_acc_top1 = sum(eval_acc_top1_list) / len(eval_acc_top1_list)
227+
print('eval_acc_top1', eval_acc_top1)
228+
229+
def test_qat(self):
230+
self.func_qat()
231+
232+
233+
if __name__ == '__main__':
234+
unittest.main()

python/paddle/nn/quant/quant_layers.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
'QuantStub',
4040
'QuantizedRowParallelLinear',
4141
'QuantizedColumnParallelLinear',
42+
'QuantizedMatmul',
4243
]
4344

4445
_logger = get_logger(
@@ -999,6 +1000,65 @@ def forward(self, input):
9991000
return output
10001001

10011002

1003+
class QuantizedMatmul(Layer):
1004+
"""
1005+
The computational logic of QuantizedMatmul is the same with Matmul.
1006+
The only difference is that its inputs are all fake quantized.
1007+
"""
1008+
1009+
def __init__(
1010+
self,
1011+
layer=None,
1012+
weight_bits=8,
1013+
activation_bits=8,
1014+
moving_rate=0.9,
1015+
weight_quantize_type='abs_max',
1016+
activation_quantize_type='abs_max',
1017+
weight_pre_layer=None,
1018+
act_pre_layer=None,
1019+
weight_quant_layer=None,
1020+
act_quant_layer=None,
1021+
):
1022+
super().__init__()
1023+
1024+
# For FakeQuant
1025+
if act_quant_layer is not None:
1026+
self._fake_quant_x = act_quant_layer()
1027+
self._fake_quant_y = act_quant_layer()
1028+
else:
1029+
self._fake_quant_x = _get_fake_quant_type(
1030+
activation_quantize_type,
1031+
moving_rate=moving_rate,
1032+
quant_bits=activation_bits,
1033+
quant_on_weight=False,
1034+
)
1035+
self._fake_quant_y = _get_fake_quant_type(
1036+
activation_quantize_type,
1037+
moving_rate=moving_rate,
1038+
quant_bits=activation_bits,
1039+
quant_on_weight=False,
1040+
)
1041+
1042+
self._act_preprocess_x = (
1043+
act_pre_layer() if act_pre_layer is not None else None
1044+
)
1045+
self._act_preprocess_y = (
1046+
act_pre_layer() if act_pre_layer is not None else None
1047+
)
1048+
1049+
def forward(self, x, y, transpose_x=False, transpose_y=False, name=None):
1050+
if self._act_preprocess_x is not None:
1051+
x = self._act_preprocess_x(x)
1052+
quant_x = self._fake_quant_x(x)
1053+
1054+
if self._act_preprocess_y is not None:
1055+
y = self._act_preprocess_y(y)
1056+
quant_y = self._fake_quant_y(y)
1057+
1058+
out = paddle.matmul(quant_x, quant_y, transpose_x, transpose_y, name)
1059+
return out
1060+
1061+
10021062
class MAOutputScaleLayer(Layer):
10031063
"""
10041064
Add MovingAverageMaxScale layer to the behind of the input layer.

0 commit comments

Comments
 (0)