-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfull_model_eval.py
More file actions
62 lines (52 loc) · 1.78 KB
/
full_model_eval.py
File metadata and controls
62 lines (52 loc) · 1.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import sys, os
notebook_path = os.getcwd() # Get the current working directory
parent_directory = os.path.dirname(notebook_path) # Get the parent directory
sys.path.append(parent_directory)
from models.resnet import (
resnet18, resnet18_qint8, resnet18_dc,
resnet34, resnet34_qint8, resnet34_dc,
)
from models.vggnet import (
vgg16_float, vgg16_qint8, vgg16_dc,
)
from models.utils import eager_quantize_model,fx_quantize_model,MeasureExecutionTime
from models.conv import fuse_module
# Set engine QNNPACK
torch.backends.quantized.engine = 'qnnpack'
W_bits, A_bits = 3,3
batchsize = 16
x = torch.rand(batchsize,3,224,224)
# Evaluate latency on VGG16
model =vgg16_float()
q_model = vgg16_qint8()
dc_model = vgg16_dc(W_bits, A_bits)
print(f"Evaluate latency on VGG16 with batchsize of {batchsize}:")
with MeasureExecutionTime(measure_name="Float"):
out= model(x)
with MeasureExecutionTime(measure_name="Qint8"):
out= q_model(x)
with MeasureExecutionTime(measure_name="HIPACK"):
out= dc_model(x)
# Evaluate latency on ResNet18
model =resnet18()
q_model = resnet18_qint8()
dc_model = resnet18_dc(W_bits, A_bits)
print(f"Evaluate latency on ResNet18 with batchsize of {batchsize}:")
with MeasureExecutionTime(measure_name="Float"):
out= model(x)
with MeasureExecutionTime(measure_name="Qint8"):
out= q_model(x)
with MeasureExecutionTime(measure_name="HIPACK"):
out= dc_model(x)
# Evaluate latency on ResNet34
model =resnet34()
q_model = resnet34_qint8()
dc_model = resnet34_dc(W_bits, A_bits)
print(f"Evaluate latency on ResNet34 with batchsize of {batchsize}:")
with MeasureExecutionTime(measure_name="Float"):
out= model(x)
with MeasureExecutionTime(measure_name="Qint8"):
out= q_model(x)
with MeasureExecutionTime(measure_name="HIPACK"):
out= dc_model(x)