Skip to content

Commit 43f43c6

Browse files
authored
support ce with quant (PaddlePaddle#932)
1 parent 203ae65 commit 43f43c6

File tree

3 files changed

+65
-12
lines changed

3 files changed

+65
-12
lines changed

demo/dygraph/quant/train.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import functools
2525
import math
2626
import time
27+
import random
2728
import numpy as np
2829
from paddle.distributed import ParallelEnv
2930
from paddle.static import load_program_state
@@ -53,6 +54,7 @@
5354
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
5455
add_arg('ls_epsilon', float, 0.0, "Label smooth epsilon.")
5556
add_arg('use_pact', bool, False, "Whether to use PACT method.")
57+
add_arg('ce_test', bool, False, "Whether to CE test.")
5658
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
5759
add_arg('num_epochs', int, 1, "The number of total epochs.")
5860
add_arg('total_images', int, 1281167, "The number of total training images.")
@@ -88,6 +90,17 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False):
8890

8991

9092
def compress(args):
93+
num_workers = 4
94+
shuffle = True
95+
if args.ce_test:
96+
# set seed
97+
seed = 111
98+
paddle.seed(seed)
99+
np.random.seed(seed)
100+
random.seed(seed)
101+
num_workers = 0
102+
shuffle = False
103+
91104
if args.data == "cifar10":
92105
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
93106
train_dataset = paddle.vision.datasets.Cifar10(
@@ -172,13 +185,16 @@ def compress(args):
172185
net = paddle.DataParallel(net)
173186

174187
train_batch_sampler = paddle.io.DistributedBatchSampler(
175-
train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
188+
train_dataset,
189+
batch_size=args.batch_size,
190+
shuffle=shuffle,
191+
drop_last=True)
176192
train_loader = paddle.io.DataLoader(
177193
train_dataset,
178194
batch_sampler=train_batch_sampler,
179195
places=place,
180196
return_list=True,
181-
num_workers=4)
197+
num_workers=num_workers)
182198

183199
valid_loader = paddle.io.DataLoader(
184200
val_dataset,
@@ -187,7 +203,7 @@ def compress(args):
187203
shuffle=False,
188204
drop_last=False,
189205
return_list=True,
190-
num_workers=4)
206+
num_workers=num_workers)
191207

192208
@paddle.no_grad()
193209
def test(epoch, net):

demo/quant/pact_quant_aware/train.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import math
88
import time
99
import numpy as np
10+
import random
1011
from collections import defaultdict
1112

1213
sys.path.append(os.path.dirname("__file__"))
@@ -64,6 +65,7 @@
6465
"Whether to use PACT or not.")
6566
add_arg('analysis', bool, False,
6667
"Whether analysis variables distribution.")
68+
add_arg('ce_test', bool, False, "Whether to CE test.")
6769

6870
# yapf: enable
6971

@@ -108,6 +110,16 @@ def create_optimizer(args):
108110

109111

110112
def compress(args):
113+
num_workers = 4
114+
shuffle = True
115+
if args.ce_test:
116+
# set seed
117+
seed = 111
118+
paddle.seed(seed)
119+
np.random.seed(seed)
120+
random.seed(seed)
121+
num_workers = 0
122+
shuffle = False
111123

112124
if args.data == "mnist":
113125
train_dataset = paddle.vision.datasets.MNIST(mode='train')
@@ -160,8 +172,8 @@ def compress(args):
160172
return_list=False,
161173
batch_size=args.batch_size,
162174
use_shared_memory=True,
163-
shuffle=True,
164-
num_workers=4)
175+
shuffle=shuffle,
176+
num_workers=num_workers)
165177

166178
valid_loader = paddle.io.DataLoader(
167179
val_dataset,

demo/quant/quant_aware/train.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
import functools
77
import math
88
import time
9+
import random
910
import numpy as np
1011
import paddle.fluid as fluid
1112
sys.path[0] = os.path.join(
1213
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
1314
from paddleslim.common import get_logger
1415
from paddleslim.analysis import flops
1516
from paddleslim.quant import quant_aware, convert
17+
import paddle.vision.transforms as T
1618
import models
1719
from utility import add_arguments, print_arguments
1820

@@ -35,9 +37,10 @@
3537
add_arg('total_images', int, 1281167, "The number of total training images.")
3638
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
3739
add_arg('config_file', str, None, "The config file for compression with yaml format.")
38-
add_arg('data', str, "imagenet", "Which data to use. 'mnist' or 'imagenet'")
40+
add_arg('data', str, "imagenet", "Which data to use. 'mnist', 'cifar10' or 'imagenet'")
3941
add_arg('log_period', int, 10, "Log period in batches.")
4042
add_arg('checkpoint_dir', str, "output", "checkpoint save dir")
43+
add_arg('ce_test', bool, False, "Whether to CE test.")
4144
# yapf: enable
4245

4346
model_list = [m for m in dir(models) if "__" not in m]
@@ -81,6 +84,17 @@ def create_optimizer(args):
8184

8285

8386
def compress(args):
87+
num_workers = 4
88+
shuffle = True
89+
if args.ce_test:
90+
# set seed
91+
seed = 111
92+
paddle.seed(seed)
93+
np.random.seed(seed)
94+
random.seed(seed)
95+
num_workers = 0
96+
shuffle = False
97+
8498
############################################################################################################
8599
# 1. quantization configs
86100
############################################################################################################
@@ -105,11 +119,21 @@ def compress(args):
105119
'moving_rate': 0.9,
106120
}
107121

122+
pretrain = True
108123
if args.data == "mnist":
109124
train_dataset = paddle.vision.datasets.MNIST(mode='train')
110125
val_dataset = paddle.vision.datasets.MNIST(mode='test')
111126
class_dim = 10
112127
image_shape = "1,28,28"
128+
elif args.data == "cifar10":
129+
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
130+
train_dataset = paddle.vision.datasets.Cifar10(
131+
mode="train", backend="cv2", transform=transform)
132+
val_dataset = paddle.vision.datasets.Cifar10(
133+
mode="test", backend="cv2", transform=transform)
134+
class_dim = 10
135+
image_shape = "3, 32, 32"
136+
pretrain = False
113137
elif args.data == "imagenet":
114138
import imagenet_reader as reader
115139
train_dataset = reader.ImageNetDataset(mode='train')
@@ -153,11 +177,12 @@ def compress(args):
153177
exe = paddle.static.Executor(place)
154178
exe.run(paddle.static.default_startup_program())
155179

156-
assert os.path.exists(
157-
args.pretrained_model), "pretrained_model doesn't exist"
180+
if pretrain:
181+
assert os.path.exists(
182+
args.pretrained_model), "pretrained_model doesn't exist"
158183

159-
if args.pretrained_model:
160-
paddle.static.load(train_prog, args.pretrained_model, exe)
184+
if args.pretrained_model:
185+
paddle.static.load(train_prog, args.pretrained_model, exe)
161186

162187
places = paddle.static.cuda_places(
163188
) if args.use_gpu else paddle.static.cpu_places()
@@ -170,8 +195,8 @@ def compress(args):
170195
batch_size=args.batch_size,
171196
return_list=False,
172197
use_shared_memory=True,
173-
shuffle=True,
174-
num_workers=4)
198+
shuffle=shuffle,
199+
num_workers=num_workers)
175200
valid_loader = paddle.io.DataLoader(
176201
val_dataset,
177202
places=place,

0 commit comments

Comments
 (0)