Skip to content

Commit 40bd7a7

Browse files
authored
add parameter of input in model.summary (#34165)
* add input option in model.summary
1 parent d3dae0c commit 40bd7a7

File tree

3 files changed

+105
-12
lines changed

3 files changed

+105
-12
lines changed

python/paddle/hapi/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2145,7 +2145,7 @@ def summary(self, input_size=None, dtype=None):
21452145
_input_size = input_size
21462146
else:
21472147
_input_size = self._inputs
2148-
return summary(self.network, _input_size, dtype)
2148+
return summary(self.network, _input_size, dtypes=dtype)
21492149

21502150
def _verify_spec(self, specs, shapes=None, dtypes=None, is_input=False):
21512151
out_specs = []

python/paddle/hapi/model_summary.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
__all__ = []
2626

2727

28-
def summary(net, input_size, dtypes=None):
28+
def summary(net, input_size=None, dtypes=None, input=None):
2929
"""Prints a string summary of the network.
3030
3131
Args:
@@ -34,8 +34,10 @@ def summary(net, input_size, dtypes=None):
3434
have one input, input_size can be tuple or InputSpec. if model
3535
have multiple input, input_size must be a list which contain
3636
every input's shape. Note that input_size only dim of
37-
batch_size can be None or -1.
37+
batch_size can be None or -1. Default: None. Note that
38+
input_size and input cannot be None at the same time.
3839
dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None.
40+
input: the input tensor. if input is given, input_size and dtype will be ignored, Default: None.
3941
4042
Returns:
4143
Dict: a summary of the network including total params and total trainable params.
@@ -94,10 +96,62 @@ def forward(self, inputs, y):
9496
lenet_multi_input = LeNetMultiInput()
9597
9698
params_info = paddle.summary(lenet_multi_input, [(1, 1, 28, 28), (1, 400)],
97-
['float32', 'float32'])
99+
dtypes=['float32', 'float32'])
100+
print(params_info)
101+
102+
# list input demo
103+
class LeNetListInput(LeNet):
104+
105+
def forward(self, inputs):
106+
x = self.features(inputs[0])
107+
108+
if self.num_classes > 0:
109+
x = paddle.flatten(x, 1)
110+
x = self.fc(x + inputs[1])
111+
return x
112+
113+
lenet_list_input = LeNetListInput()
114+
input_data = [paddle.rand([1, 1, 28, 28]), paddle.rand([1, 400])]
115+
params_info = paddle.summary(lenet_list_input, input=input_data)
116+
print(params_info)
117+
118+
# dict input demo
119+
class LeNetDictInput(LeNet):
120+
121+
def forward(self, inputs):
122+
x = self.features(inputs['x1'])
123+
124+
if self.num_classes > 0:
125+
x = paddle.flatten(x, 1)
126+
x = self.fc(x + inputs['x2'])
127+
return x
128+
129+
lenet_dict_input = LeNetDictInput()
130+
input_data = {'x1': paddle.rand([1, 1, 28, 28]),
131+
'x2': paddle.rand([1, 400])}
132+
params_info = paddle.summary(lenet_dict_input, input=input_data)
98133
print(params_info)
99134
100135
"""
136+
if input_size is None and input is None:
137+
raise ValueError("input_size and input cannot be None at the same time")
138+
139+
if input_size is None and input is not None:
140+
if paddle.is_tensor(input):
141+
input_size = tuple(input.shape)
142+
elif isinstance(input, (list, tuple)):
143+
input_size = []
144+
for x in input:
145+
input_size.append(tuple(x.shape))
146+
elif isinstance(input, dict):
147+
input_size = []
148+
for key in input.keys():
149+
input_size.append(tuple(input[key].shape))
150+
else:
151+
raise ValueError(
152+
"Input is not tensor, list, tuple and dict, unable to determine input_size, please input input_size."
153+
)
154+
101155
if isinstance(input_size, InputSpec):
102156
_input_size = tuple(input_size.shape)
103157
elif isinstance(input_size, list):
@@ -163,7 +217,8 @@ def _check_input(input_size):
163217
return [_check_input(i) for i in input_size]
164218

165219
_input_size = _check_input(_input_size)
166-
result, params_info = summary_string(net, _input_size, dtypes)
220+
221+
result, params_info = summary_string(net, _input_size, dtypes, input)
167222
print(result)
168223

169224
if in_train_mode:
@@ -173,7 +228,7 @@ def _check_input(input_size):
173228

174229

175230
@paddle.no_grad()
176-
def summary_string(model, input_size, dtypes=None):
231+
def summary_string(model, input_size=None, dtypes=None, input=None):
177232
def _all_is_numper(items):
178233
for item in items:
179234
if not isinstance(item, numbers.Number):
@@ -280,17 +335,18 @@ def build_input(input_size, dtypes):
280335
build_input(i, dtype) for i, dtype in zip(input_size, dtypes)
281336
]
282337

283-
x = build_input(input_size, dtypes)
284-
285338
# create properties
286339
summary = OrderedDict()
287340
hooks = []
288-
289341
# register hook
290342
model.apply(register_hook)
291-
292-
# make a forward pass
293-
model(*x)
343+
if input is not None:
344+
x = input
345+
model(x)
346+
else:
347+
x = build_input(input_size, dtypes)
348+
# make a forward pass
349+
model(*x)
294350

295351
# remove these hooks
296352
for h in hooks:

python/paddle/tests/test_model.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,27 @@ def forward(self, inputs):
6868
return x
6969

7070

71+
class LeNetListInput(LeNetDygraph):
72+
def forward(self, inputs):
73+
x = inputs[0]
74+
x = self.features(x)
75+
76+
if self.num_classes > 0:
77+
x = paddle.flatten(x, 1)
78+
x = self.fc(x + inputs[1])
79+
return x
80+
81+
82+
class LeNetDictInput(LeNetDygraph):
83+
def forward(self, inputs):
84+
x = self.features(inputs['x1'])
85+
86+
if self.num_classes > 0:
87+
x = paddle.flatten(x, 1)
88+
x = self.fc(x + inputs['x2'])
89+
return x
90+
91+
7192
class MnistDataset(MNIST):
7293
def __init__(self, mode, return_label=True, sample_num=None):
7394
super(MnistDataset, self).__init__(mode=mode)
@@ -615,6 +636,22 @@ def _get_param_from_state_dict(state_dict):
615636
gt_params = _get_param_from_state_dict(rnn.state_dict())
616637
np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0)
617638

639+
def test_summary_input(self):
640+
rnn = paddle.nn.SimpleRNN(16, 32, 2, direction='bidirectional')
641+
input_data = paddle.rand([4, 23, 16])
642+
paddle.summary(rnn, input=input_data)
643+
644+
lenet_List_input = LeNetListInput()
645+
input_data = [paddle.rand([1, 1, 28, 28]), paddle.rand([1, 400])]
646+
paddle.summary(lenet_List_input, input=input_data)
647+
648+
lenet_dict_input = LeNetDictInput()
649+
input_data = {
650+
'x1': paddle.rand([1, 1, 28, 28]),
651+
'x2': paddle.rand([1, 400])
652+
}
653+
paddle.summary(lenet_dict_input, input=input_data)
654+
618655
def test_summary_dtype(self):
619656
input_shape = (3, 1)
620657
net = paddle.nn.Embedding(10, 3, sparse=True)

0 commit comments

Comments
 (0)