Skip to content

Commit 3ea1286

Browse files
Fix FLOPs API (PaddlePaddle#11)
1 parent e34a150 commit 3ea1286

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

paddleslim/analysis/flops.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,22 @@
1818
__all__ = ["flops"]
1919

2020

21-
def flops(program, detail=False):
21+
def flops(program, only_conv=True, detail=False):
2222
"""
2323
Get FLOPS of target graph.
2424
Args:
2525
program(Program): The program used to calculate FLOPS.
26+
only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
27+
default: True.
28+
detail(bool): Whether to return detail of each convolution layer.
29+
30+
Return:
31+
If `detail` is true, then return a tuple in format `(FLOPs, details)`, otherwise it will just return `FlOPs`
32+
FLOPs(int): The FLOPs of target network.
33+
details(dict): The key is the parameter name of convlution layer and the value is the FLOPs of each convolution layer.
2634
"""
2735
graph = GraphWrapper(program)
28-
return _graph_flops(graph, detail=detail)
36+
return _graph_flops(graph, only_conv=only_conv, detail=detail)
2937

3038

3139
def _graph_flops(graph, only_conv=False, detail=False):
@@ -44,7 +52,7 @@ def _graph_flops(graph, only_conv=False, detail=False):
4452
with_bias = 1
4553
else:
4654
with_bias = 0
47-
op_flops = 2 * h_out * w_out * c_out * (kernel_ops + with_bias)
55+
op_flops = h_out * w_out * c_out * (kernel_ops + with_bias)
4856
flops += op_flops
4957
params2flops[op.inputs("Filter")[0].name()] = op_flops
5058
elif op.type() == 'pool2d' and not only_conv:
@@ -53,14 +61,17 @@ def _graph_flops(graph, only_conv=False, detail=False):
5361
k_size = op.attr("ksize")
5462
flops += h_out * w_out * c_out * (k_size[0]**2)
5563

56-
elif op.type() == 'mul' and not only_conv:
64+
elif op.type() == 'mul':
5765
x_shape = list(op.inputs("X")[0].shape())
5866
y_shape = op.inputs("Y")[0].shape()
5967
if x_shape[0] == -1:
6068
x_shape[0] = 1
61-
flops += 2 * x_shape[0] * x_shape[1] * y_shape[1]
6269

63-
elif op.type() in ['relu', 'sigmoid', 'batch_norm'] and not only_conv:
70+
op_flops = x_shape[0] * x_shape[1] * y_shape[1]
71+
flops += op_flops
72+
params2flops[op.inputs("Y")[0].name()] = op_flops
73+
74+
elif op.type() in ['relu', 'sigmoid', 'batch_norm', 'relu6'] and not only_conv:
6475
input_shape = list(op.inputs("X")[0].shape())
6576
if input_shape[0] == -1:
6677
input_shape[0] = 1

0 commit comments

Comments
 (0)