Skip to content

Conversation

@wangna11BD
Copy link
Contributor

@wangna11BD wangna11BD commented Jul 14, 2021

PR types

Function optimization

PR changes

APIs

Describe

Add parameter of input in model.summary
If forward (self, x, y, z), it can use input_shape; if forward(self, x_list) or forward(self, x_dict), the input of dict or list can only be input.
Examples:
.. code-block:: python

        import paddle
        import paddle.nn as nn

        class LeNet(nn.Layer):
            def __init__(self, num_classes=10):
                super(LeNet, self).__init__()
                self.num_classes = num_classes
                self.features = nn.Sequential(
                    nn.Conv2D(
                        1, 6, 3, stride=1, padding=1),
                    nn.ReLU(),
                    nn.MaxPool2D(2, 2),
                    nn.Conv2D(
                        6, 16, 5, stride=1, padding=0),
                    nn.ReLU(),
                    nn.MaxPool2D(2, 2))

                if num_classes > 0:
                    self.fc = nn.Sequential(
                        nn.Linear(400, 120),
                        nn.Linear(120, 84),
                        nn.Linear(
                            84, 10))

            def forward(self, inputs):
                x = self.features(inputs)

                if self.num_classes > 0:
                    x = paddle.flatten(x, 1)
                    x = self.fc(x)
                return x

        lenet = LeNet()

        params_info = paddle.summary(lenet, (1, 1, 28, 28))
        print(params_info)

        # list input demo
        class LeNetListInput(LeNet):

            def forward(self, inputs):
                x = self.features(inputs[0])

                if self.num_classes > 0:
                    x = paddle.flatten(x, 1)
                    x = self.fc(x + inputs[1])
                return x
        
        lenet_list_input = LeNetListInput()
        input_data = [paddle.rand([1, 1, 28, 28]), paddle.rand([1, 400])]
        params_info = paddle.summary(lenet_list_input, input=input_data)
        print(params_info)

        # dict input demo
        class LeNetDictInput(LeNet):

            def forward(self, inputs):
                x = self.features(inputs['x1'])

                if self.num_classes > 0:
                    x = paddle.flatten(x, 1)
                    x = self.fc(x + inputs['x2'])
                return x

        lenet_dict_input = LeNetDictInput()
        input_data = {'x1': paddle.rand([1, 1, 28, 28]),
                      'x2': paddle.rand([1, 400])}
        params_info = paddle.summary(lenet_dict_input, input=input_data)
        print(params_info)

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

dingjiaweiww
dingjiaweiww previously approved these changes Jul 15, 2021
have multiple input, input_size must be a list which contain
every input's shape. Note that input_size only dim of
batch_size can be None or -1.
input (dict|list|): the input tensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input can be of any type. And need to explain, if given, input_size and dtype will be ignored.

Copy link
Contributor Author

@wangna11BD wangna11BD Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

lenet = LeNet()
params_info = paddle.summary(lenet, (1, 1, 28, 28))
input_data = [paddle.ones([1, 1, 28, 28])]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not delete the original examples. Add new examples at the back, and add a example of dict and list input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

elif isinstance(input, list):
_input = input
else:
raise Exception('Input must be list or dict.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input can be any type as long as the network can run through

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

LielinJiang
LielinJiang previously approved these changes Jul 20, 2021
dingjiaweiww
dingjiaweiww previously approved these changes Jul 20, 2021
@wangna11BD wangna11BD dismissed stale reviews from dingjiaweiww and LielinJiang via 1819928 July 21, 2021 02:20
LielinJiang
LielinJiang previously approved these changes Jul 27, 2021
dingjiaweiww
dingjiaweiww previously approved these changes Jul 27, 2021
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@LielinJiang LielinJiang merged commit 40bd7a7 into PaddlePaddle:develop Jul 29, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants