Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/hapi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2145,7 +2145,7 @@ def summary(self, input_size=None, dtype=None):
_input_size = input_size
else:
_input_size = self._inputs
return summary(self.network, _input_size, dtype)
return summary(self.network, _input_size, dtypes=dtype)

def _verify_spec(self, specs, shapes=None, dtypes=None, is_input=False):
out_specs = []
Expand Down
78 changes: 67 additions & 11 deletions python/paddle/hapi/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
__all__ = []


def summary(net, input_size, dtypes=None):
def summary(net, input_size=None, dtypes=None, input=None):
"""Prints a string summary of the network.

Args:
Expand All @@ -34,8 +34,10 @@ def summary(net, input_size, dtypes=None):
have one input, input_size can be tuple or InputSpec. if model
have multiple input, input_size must be a list which contain
every input's shape. Note that input_size only dim of
batch_size can be None or -1.
batch_size can be None or -1. Default: None. Note that
input_size and input cannot be None at the same time.
dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None.
input: the input tensor. if input is given, input_size and dtype will be ignored, Default: None.

Returns:
Dict: a summary of the network including total params and total trainable params.
Expand Down Expand Up @@ -94,10 +96,62 @@ def forward(self, inputs, y):
lenet_multi_input = LeNetMultiInput()

params_info = paddle.summary(lenet_multi_input, [(1, 1, 28, 28), (1, 400)],
['float32', 'float32'])
dtypes=['float32', 'float32'])
print(params_info)

# list input demo
class LeNetListInput(LeNet):

def forward(self, inputs):
x = self.features(inputs[0])

if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.fc(x + inputs[1])
return x

lenet_list_input = LeNetListInput()
input_data = [paddle.rand([1, 1, 28, 28]), paddle.rand([1, 400])]
params_info = paddle.summary(lenet_list_input, input=input_data)
print(params_info)

# dict input demo
class LeNetDictInput(LeNet):

def forward(self, inputs):
x = self.features(inputs['x1'])

if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.fc(x + inputs['x2'])
return x

lenet_dict_input = LeNetDictInput()
input_data = {'x1': paddle.rand([1, 1, 28, 28]),
'x2': paddle.rand([1, 400])}
params_info = paddle.summary(lenet_dict_input, input=input_data)
print(params_info)

"""
if input_size is None and input is None:
raise ValueError("input_size and input cannot be None at the same time")

if input_size is None and input is not None:
if paddle.is_tensor(input):
input_size = tuple(input.shape)
elif isinstance(input, (list, tuple)):
input_size = []
for x in input:
input_size.append(tuple(x.shape))
elif isinstance(input, dict):
input_size = []
for key in input.keys():
input_size.append(tuple(input[key].shape))
else:
raise ValueError(
"Input is not tensor, list, tuple and dict, unable to determine input_size, please input input_size."
)

if isinstance(input_size, InputSpec):
_input_size = tuple(input_size.shape)
elif isinstance(input_size, list):
Expand Down Expand Up @@ -163,7 +217,8 @@ def _check_input(input_size):
return [_check_input(i) for i in input_size]

_input_size = _check_input(_input_size)
result, params_info = summary_string(net, _input_size, dtypes)

result, params_info = summary_string(net, _input_size, dtypes, input)
print(result)

if in_train_mode:
Expand All @@ -173,7 +228,7 @@ def _check_input(input_size):


@paddle.no_grad()
def summary_string(model, input_size, dtypes=None):
def summary_string(model, input_size=None, dtypes=None, input=None):
def _all_is_numper(items):
for item in items:
if not isinstance(item, numbers.Number):
Expand Down Expand Up @@ -280,17 +335,18 @@ def build_input(input_size, dtypes):
build_input(i, dtype) for i, dtype in zip(input_size, dtypes)
]

x = build_input(input_size, dtypes)

# create properties
summary = OrderedDict()
hooks = []

# register hook
model.apply(register_hook)

# make a forward pass
model(*x)
if input is not None:
x = input
model(x)
else:
x = build_input(input_size, dtypes)
# make a forward pass
model(*x)

# remove these hooks
for h in hooks:
Expand Down
37 changes: 37 additions & 0 deletions python/paddle/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,27 @@ def forward(self, inputs):
return x


class LeNetListInput(LeNetDygraph):
def forward(self, inputs):
x = inputs[0]
x = self.features(x)

if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.fc(x + inputs[1])
return x


class LeNetDictInput(LeNetDygraph):
def forward(self, inputs):
x = self.features(inputs['x1'])

if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.fc(x + inputs['x2'])
return x


class MnistDataset(MNIST):
def __init__(self, mode, return_label=True, sample_num=None):
super(MnistDataset, self).__init__(mode=mode)
Expand Down Expand Up @@ -615,6 +636,22 @@ def _get_param_from_state_dict(state_dict):
gt_params = _get_param_from_state_dict(rnn.state_dict())
np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0)

def test_summary_input(self):
rnn = paddle.nn.SimpleRNN(16, 32, 2, direction='bidirectional')
input_data = paddle.rand([4, 23, 16])
paddle.summary(rnn, input=input_data)

lenet_List_input = LeNetListInput()
input_data = [paddle.rand([1, 1, 28, 28]), paddle.rand([1, 400])]
paddle.summary(lenet_List_input, input=input_data)

lenet_dict_input = LeNetDictInput()
input_data = {
'x1': paddle.rand([1, 1, 28, 28]),
'x2': paddle.rand([1, 400])
}
paddle.summary(lenet_dict_input, input=input_data)

def test_summary_dtype(self):
input_shape = (3, 1)
net = paddle.nn.Embedding(10, 3, sparse=True)
Expand Down