1818__all__ = ["flops" ]
1919
2020
21- def flops (program , detail = False ):
21+ def flops (program , only_conv = True , detail = False ):
2222 """
2323 Get FLOPS of target graph.
2424 Args:
2525 program(Program): The program used to calculate FLOPS.
26+ only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
27+ default: True.
28+ detail(bool): Whether to return detail of each convolution layer.
29+
30+ Return:
31+ If `detail` is true, then return a tuple in format `(FLOPs, details)`, otherwise it will just return `FlOPs`
32+ FLOPs(int): The FLOPs of target network.
33+ details(dict): The key is the parameter name of convlution layer and the value is the FLOPs of each convolution layer.
2634 """
2735 graph = GraphWrapper (program )
28- return _graph_flops (graph , detail = detail )
36+ return _graph_flops (graph , only_conv = only_conv , detail = detail )
2937
3038
3139def _graph_flops (graph , only_conv = False , detail = False ):
@@ -44,7 +52,7 @@ def _graph_flops(graph, only_conv=False, detail=False):
4452 with_bias = 1
4553 else :
4654 with_bias = 0
47- op_flops = 2 * h_out * w_out * c_out * (kernel_ops + with_bias )
55+ op_flops = h_out * w_out * c_out * (kernel_ops + with_bias )
4856 flops += op_flops
4957 params2flops [op .inputs ("Filter" )[0 ].name ()] = op_flops
5058 elif op .type () == 'pool2d' and not only_conv :
@@ -53,14 +61,17 @@ def _graph_flops(graph, only_conv=False, detail=False):
5361 k_size = op .attr ("ksize" )
5462 flops += h_out * w_out * c_out * (k_size [0 ]** 2 )
5563
56- elif op .type () == 'mul' and not only_conv :
64+ elif op .type () == 'mul' :
5765 x_shape = list (op .inputs ("X" )[0 ].shape ())
5866 y_shape = op .inputs ("Y" )[0 ].shape ()
5967 if x_shape [0 ] == - 1 :
6068 x_shape [0 ] = 1
61- flops += 2 * x_shape [0 ] * x_shape [1 ] * y_shape [1 ]
6269
63- elif op .type () in ['relu' , 'sigmoid' , 'batch_norm' ] and not only_conv :
70+ op_flops = x_shape [0 ] * x_shape [1 ] * y_shape [1 ]
71+ flops += op_flops
72+ params2flops [op .inputs ("Y" )[0 ].name ()] = op_flops
73+
74+ elif op .type () in ['relu' , 'sigmoid' , 'batch_norm' , 'relu6' ] and not only_conv :
6475 input_shape = list (op .inputs ("X" )[0 ].shape ())
6576 if input_shape [0 ] == - 1 :
6677 input_shape [0 ] = 1
0 commit comments