From 2b16b90c6df637cc30fa30fb825881ff24ab351e Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Thu, 21 Mar 2024 19:11:22 +0800 Subject: [PATCH] support feed of list --- python/paddle/distributed/auto_parallel/api.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 3ae564b9c4d343..6a17f32e85a5eb 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -1927,7 +1927,21 @@ def __call__(self, *args): if self._mode == "eval": if self._engine._loss is None: raise ValueError("Please set loss function before evaluation.") - feeds = self._make_feeds(list(args)) + + feed_list = [] + for feed_item in list(args): + if isinstance(feed_item, (list, tuple)): + feed_list += list(feed_item) + elif isinstance(feed_item, paddle.Tensor): + feed_list += [feed_item] + elif isinstance(feed_item, core.LoDTensor): + feed_list += [feed_item] + else: + raise TypeError( + f"The inputs of DistModel should be list or tensor, but got {type(feed_item)}" + ) + + feeds = self._make_feeds(feed_list) outs = self._engine.run(feeds) if self._mode == "predict":