-
Notifications
You must be signed in to change notification settings - Fork 6k
Remove Input requirement in dygraph for Model #27557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
4c5dead
47e375a
5d2fad7
f7f4063
87a1b53
9048b14
6396fce
67bac48
8dcc03a
cc71c7b
7afd7d4
707a421
c941c36
7d8d791
fc0943c
3e56bfb
16139dc
5ecba2a
7ab5cb6
b151fba
9e9132d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -598,6 +598,7 @@ def __init__(self, model): | |
| 'test_batch': 0 | ||
| } | ||
|
|
||
| self._shapes = None | ||
| if self._nranks > 1: | ||
| stradegy = fluid.dygraph.parallel.ParallelStrategy() | ||
| stradegy.nranks = ParallelEnv().nranks | ||
|
|
@@ -622,6 +623,7 @@ def train_batch(self, inputs, labels=None): | |
| self.model.network.train() | ||
| self.mode = 'train' | ||
| inputs = to_list(inputs) | ||
| self._shapes = [list(input.shape) for input in inputs] | ||
LiuChiachi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| labels = labels or [] | ||
| labels = [to_variable(l) for l in to_list(labels)] | ||
|
|
||
|
|
@@ -656,6 +658,7 @@ def eval_batch(self, inputs, labels=None): | |
| self.model.network.eval() | ||
| self.mode = 'eval' | ||
| inputs = to_list(inputs) | ||
| self._shapes = [list(input.shape) for input in inputs] | ||
LiuChiachi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| labels = labels or [] | ||
| labels = [to_variable(l) for l in to_list(labels)] | ||
|
|
||
|
|
@@ -704,6 +707,7 @@ def test_batch(self, inputs): | |
| self.model.network.eval() | ||
| self.mode = 'test' | ||
| inputs = [to_variable(x) for x in to_list(inputs)] | ||
| self._shapes = [list(input.shape) for input in inputs] | ||
| outputs = self.model.network.forward(*inputs) | ||
| if self._nranks > 1 and isinstance(self.model._place, fluid.CUDAPlace): | ||
| outputs = [_all_gather(o, self._nranks) for o in to_list(outputs)] | ||
|
|
@@ -778,7 +782,7 @@ def load(self, param_state_pairs, optim_state): | |
|
|
||
| if not hasattr(self.model._optimizer, 'set_state_dict'): | ||
| warnings.warn( | ||
| "paddle.fluid.optimizer is deprecated in API 2.0, please use paddle.optimizer instead" | ||
| "paddle.fluid.optimizer is deprecated in API 2.0, please use paddle.optimizer instead." | ||
| ) | ||
| self.model._optimizer.set_dict(converted_state) | ||
| else: | ||
|
|
@@ -792,14 +796,15 @@ class Model(object): | |
| switched by `paddle.disable_static()`. The usage is as follows. | ||
| But note, the switching between dynamic and static should be before | ||
| instantiating a Model. The input description, i.e, paddle.static.InputSpec, | ||
| must be required. | ||
| must be required for static graph. | ||
|
|
||
| Args: | ||
| network (paddle.nn.Layer): The network is an instance of | ||
| paddle.nn.Layer. | ||
| inputs (InputSpec|list|dict|None): `inputs`, entry points of network, | ||
| could be a InputSpec instance, or lits of InputSpec instances, | ||
| or dict ({name: InputSpec}), and it couldn't be None. | ||
| or dict ({name: InputSpec}), and it couldn't be None in static | ||
| graph. | ||
| labels (InputSpec|list|None): `labels`, entry points of network, | ||
| could be a InputSpec instnace or lits of InputSpec instances, | ||
| or None. For static graph, if labels is required in loss, | ||
|
|
@@ -844,14 +849,21 @@ def __init__(self, network, inputs=None, labels=None): | |
| self._loss = None | ||
| self._loss_weights = None | ||
| self._optimizer = None | ||
| self._optimizer = None | ||
| self._shapes = None | ||
| self._is_shape_inferred = False | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议一个更达义的变量名替换 如果保留这个shape变量的话,建议修改下变量名。
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 动、静态图下在 |
||
| self._test_dataloader = None | ||
|
|
||
| if not isinstance(inputs, (list, dict, Input)): | ||
| raise TypeError( | ||
| "'inputs' must be list or dict in static graph mode") | ||
|
|
||
| self._inputs = self._verify_spec(inputs, True) | ||
| if not in_dygraph_mode(): | ||
| if not isinstance(inputs, (list, dict, Input)): | ||
| raise TypeError( | ||
| "'inputs' must be list or dict, and couldn't be None.") | ||
| elif inputs: | ||
| if isinstance(inputs, list): | ||
| self._shapes = [list(input.shape) for input in inputs] | ||
| elif isinstance(inputs, dict): | ||
| self._shapes = [list(inputs[name].shape) for name in inputs] | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如上个comment,
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改,感谢~ |
||
| self._inputs = self._verify_spec(inputs, is_input=True) | ||
| self._labels = self._verify_spec(labels) | ||
|
|
||
| # init backend | ||
|
|
@@ -902,7 +914,12 @@ def train_batch(self, inputs, labels=None): | |
| loss = model.train_batch([data], [label]) | ||
| print(loss) | ||
| """ | ||
| return self._adapter.train_batch(inputs, labels) | ||
| loss = self._adapter.train_batch(inputs, labels) | ||
| if fluid.in_dygraph_mode() and self._shapes is None: | ||
| self._shapes = self._adapter._shapes | ||
| self._is_shape_inferred = True | ||
| self._inputs = self._verify_spec(None, self._shapes, True) | ||
|
||
| return loss | ||
|
|
||
| def eval_batch(self, inputs, labels=None): | ||
| """ | ||
|
|
@@ -947,7 +964,12 @@ def eval_batch(self, inputs, labels=None): | |
| loss = model.eval_batch([data], [label]) | ||
| print(loss) | ||
| """ | ||
| return self._adapter.eval_batch(inputs, labels) | ||
| loss = self._adapter.eval_batch(inputs, labels) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感谢~已修正 |
||
| if fluid.in_dygraph_mode() and self._shapes is None: | ||
| self._shapes = self._adapter._shapes | ||
| self._is_shape_inferred = True | ||
| self._inputs = self._verify_spec(None, self._shapes, True) | ||
| return loss | ||
|
|
||
| def test_batch(self, inputs): | ||
| """ | ||
|
|
@@ -987,7 +1009,12 @@ def test_batch(self, inputs): | |
| out = model.test_batch([data]) | ||
| print(out) | ||
| """ | ||
| return self._adapter.test_batch(inputs) | ||
| loss = self._adapter.test_batch(inputs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改,感谢 |
||
| if fluid.in_dygraph_mode() and self._shapes is None: | ||
| self._shapes = self._adapter._shapes | ||
| self._is_shape_inferred = True | ||
| self._inputs = self._verify_spec(None, self._shapes, True) | ||
| return loss | ||
|
|
||
| def save(self, path, training=True): | ||
| """ | ||
|
|
@@ -1677,6 +1704,14 @@ def get_inout_spec(all_vars, return_name=False): | |
| if fluid.in_dygraph_mode(): | ||
| with fluid.framework._dygraph_guard(None): | ||
| layer = self.network | ||
| if self._shapes is None: # No provided or inferred | ||
| raise RuntimeError( | ||
| "Saving inference model needs `inputs` or running before saving." | ||
|
||
| ) | ||
| if self._is_shape_inferred: | ||
| warnings.warn( | ||
| 'Saving actual input shapes only if `inputs` is provided, otherwise variable input dimension is immutable.' | ||
|
||
| ) | ||
| layer.forward = paddle.jit.to_static( | ||
| layer.forward, input_spec=self._inputs) | ||
|
|
||
|
|
@@ -1775,6 +1810,7 @@ def _run_one_epoch(self, data_loader, callbacks, mode, logs={}): | |
| data = flatten(data) | ||
| # LoDTensor.shape is callable, where LoDTensor comes from | ||
| # DataLoader in static graph | ||
|
|
||
| batch_size = data[0].shape()[0] if callable(data[ | ||
| 0].shape) else data[0].shape[0] | ||
|
|
||
|
|
@@ -1864,10 +1900,26 @@ def summary(self, input_size=None, dtype=None): | |
| _input_size = self._inputs | ||
| return summary(self.network, _input_size, dtype) | ||
|
|
||
| def _verify_spec(self, specs, is_input=False): | ||
| def _verify_spec(self, specs, shapes=None, is_input=False): | ||
| out_specs = [] | ||
|
|
||
| if isinstance(specs, dict): | ||
| if specs is None: | ||
| # Note(Aurelius84): If not specific specs of `Input`, using argument names of `forward` function | ||
| # to generate `Input`. But how can we know the actual shape of each input tensor? | ||
|
|
||
| if is_input: | ||
| arg_names = extract_args(self.network.forward)[1:] | ||
| if shapes is not None and fluid.in_dygraph_mode(): | ||
| out_specs = [ | ||
| Input( | ||
| name=n, shape=shapes[i]) | ||
| for i, n in enumerate(arg_names) | ||
| ] | ||
| else: | ||
| out_specs = [Input(name=n, shape=[None]) for n in arg_names] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| else: | ||
| out_specs = to_list(specs) | ||
| elif isinstance(specs, dict): | ||
| assert is_input == False | ||
| out_specs = [specs[n] \ | ||
| for n in extract_args(self.network.forward) if n != 'self'] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议使用一个更达义的变量名,比如
self._input_shapesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢指出,已修改