Skip to content

Commit b28cc73

Browse files
authored
fix static error in summary (#35303)
1 parent 25871e0 commit b28cc73

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

python/paddle/hapi/model_summary.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def forward(self, inputs):
147147
input_size = []
148148
for key in input.keys():
149149
input_size.append(tuple(input[key].shape))
150+
elif isinstance(input, paddle.fluid.framework.Variable):
151+
input_size = tuple(input.shape)
150152
else:
151153
raise ValueError(
152154
"Input is not tensor, list, tuple and dict, unable to determine input_size, please input input_size."

python/paddle/tests/test_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,12 @@ def _get_param_from_state_dict(state_dict):
662662
np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0)
663663

664664
def test_summary_input(self):
665+
paddle.enable_static()
666+
mymodel = MyModel()
667+
input_data = paddle.rand([1, 20])
668+
paddle.summary(mymodel, input=input_data)
669+
paddle.disable_static()
670+
665671
rnn = paddle.nn.SimpleRNN(16, 32, 2, direction='bidirectional')
666672
input_data = paddle.rand([4, 23, 16])
667673
paddle.summary(rnn, input=input_data)

0 commit comments

Comments
 (0)