2525__all__ = []
2626
2727
28- def summary (net , input_size , dtypes = None ):
28+ def summary (net , input_size = None , dtypes = None , input = None ):
2929 """Prints a string summary of the network.
3030
3131 Args:
@@ -34,8 +34,10 @@ def summary(net, input_size, dtypes=None):
3434 have one input, input_size can be tuple or InputSpec. if model
3535 have multiple input, input_size must be a list which contain
3636 every input's shape. Note that input_size only dim of
37- batch_size can be None or -1.
37+ batch_size can be None or -1. Default: None. Note that
38+ input_size and input cannot be None at the same time.
3839 dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None.
40+ input: the input tensor. if input is given, input_size and dtype will be ignored, Default: None.
3941
4042 Returns:
4143 Dict: a summary of the network including total params and total trainable params.
@@ -94,10 +96,62 @@ def forward(self, inputs, y):
9496 lenet_multi_input = LeNetMultiInput()
9597
9698 params_info = paddle.summary(lenet_multi_input, [(1, 1, 28, 28), (1, 400)],
97- ['float32', 'float32'])
99+ dtypes=['float32', 'float32'])
100+ print(params_info)
101+
102+ # list input demo
103+ class LeNetListInput(LeNet):
104+
105+ def forward(self, inputs):
106+ x = self.features(inputs[0])
107+
108+ if self.num_classes > 0:
109+ x = paddle.flatten(x, 1)
110+ x = self.fc(x + inputs[1])
111+ return x
112+
113+ lenet_list_input = LeNetListInput()
114+ input_data = [paddle.rand([1, 1, 28, 28]), paddle.rand([1, 400])]
115+ params_info = paddle.summary(lenet_list_input, input=input_data)
116+ print(params_info)
117+
118+ # dict input demo
119+ class LeNetDictInput(LeNet):
120+
121+ def forward(self, inputs):
122+ x = self.features(inputs['x1'])
123+
124+ if self.num_classes > 0:
125+ x = paddle.flatten(x, 1)
126+ x = self.fc(x + inputs['x2'])
127+ return x
128+
129+ lenet_dict_input = LeNetDictInput()
130+ input_data = {'x1': paddle.rand([1, 1, 28, 28]),
131+ 'x2': paddle.rand([1, 400])}
132+ params_info = paddle.summary(lenet_dict_input, input=input_data)
98133 print(params_info)
99134
100135 """
136+ if input_size is None and input is None :
137+ raise ValueError ("input_size and input cannot be None at the same time" )
138+
139+ if input_size is None and input is not None :
140+ if paddle .is_tensor (input ):
141+ input_size = tuple (input .shape )
142+ elif isinstance (input , (list , tuple )):
143+ input_size = []
144+ for x in input :
145+ input_size .append (tuple (x .shape ))
146+ elif isinstance (input , dict ):
147+ input_size = []
148+ for key in input .keys ():
149+ input_size .append (tuple (input [key ].shape ))
150+ else :
151+ raise ValueError (
152+ "Input is not tensor, list, tuple and dict, unable to determine input_size, please input input_size."
153+ )
154+
101155 if isinstance (input_size , InputSpec ):
102156 _input_size = tuple (input_size .shape )
103157 elif isinstance (input_size , list ):
@@ -163,7 +217,8 @@ def _check_input(input_size):
163217 return [_check_input (i ) for i in input_size ]
164218
165219 _input_size = _check_input (_input_size )
166- result , params_info = summary_string (net , _input_size , dtypes )
220+
221+ result , params_info = summary_string (net , _input_size , dtypes , input )
167222 print (result )
168223
169224 if in_train_mode :
@@ -173,7 +228,7 @@ def _check_input(input_size):
173228
174229
175230@paddle .no_grad ()
176- def summary_string (model , input_size , dtypes = None ):
231+ def summary_string (model , input_size = None , dtypes = None , input = None ):
177232 def _all_is_numper (items ):
178233 for item in items :
179234 if not isinstance (item , numbers .Number ):
@@ -280,17 +335,18 @@ def build_input(input_size, dtypes):
280335 build_input (i , dtype ) for i , dtype in zip (input_size , dtypes )
281336 ]
282337
283- x = build_input (input_size , dtypes )
284-
285338 # create properties
286339 summary = OrderedDict ()
287340 hooks = []
288-
289341 # register hook
290342 model .apply (register_hook )
291-
292- # make a forward pass
293- model (* x )
343+ if input is not None :
344+ x = input
345+ model (x )
346+ else :
347+ x = build_input (input_size , dtypes )
348+ # make a forward pass
349+ model (* x )
294350
295351 # remove these hooks
296352 for h in hooks :
0 commit comments