-
Notifications
You must be signed in to change notification settings - Fork 6k
Remove Input requirement in dygraph for Model #27557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4c5dead
47e375a
5d2fad7
f7f4063
87a1b53
9048b14
6396fce
67bac48
8dcc03a
cc71c7b
7afd7d4
707a421
c941c36
7d8d791
fc0943c
3e56bfb
16139dc
5ecba2a
7ab5cb6
b151fba
9e9132d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -200,6 +200,15 @@ def _init_context(): | |
| return strategy | ||
|
|
||
|
|
||
| def _update_input_shapes(inputs): | ||
| shapes = None | ||
| if isinstance(inputs, list): | ||
| shapes = [list(input.shape) for input in inputs] | ||
| elif isinstance(inputs, dict): | ||
| shapes = [list(inputs[name].shape) for name in inputs] | ||
| return shapes | ||
|
|
||
|
|
||
| class StaticGraphAdapter(object): | ||
| """ | ||
| Model traning/inference with a static graph. | ||
|
|
@@ -598,6 +607,7 @@ def __init__(self, model): | |
| 'test_batch': 0 | ||
| } | ||
|
|
||
| self._input_shapes = None | ||
| if self._nranks > 1: | ||
| stradegy = fluid.dygraph.parallel.ParallelStrategy() | ||
| stradegy.nranks = ParallelEnv().nranks | ||
|
|
@@ -622,6 +632,7 @@ def train_batch(self, inputs, labels=None): | |
| self.model.network.train() | ||
| self.mode = 'train' | ||
| inputs = to_list(inputs) | ||
| self._input_shapes = _update_input_shapes(inputs) | ||
| labels = labels or [] | ||
| labels = [to_variable(l) for l in to_list(labels)] | ||
|
|
||
|
|
@@ -656,6 +667,7 @@ def eval_batch(self, inputs, labels=None): | |
| self.model.network.eval() | ||
| self.mode = 'eval' | ||
| inputs = to_list(inputs) | ||
| self._input_shapes = _update_input_shapes(inputs) | ||
| labels = labels or [] | ||
| labels = [to_variable(l) for l in to_list(labels)] | ||
|
|
||
|
|
@@ -704,6 +716,7 @@ def test_batch(self, inputs): | |
| self.model.network.eval() | ||
| self.mode = 'test' | ||
| inputs = [to_variable(x) for x in to_list(inputs)] | ||
| self._input_shapes = _update_input_shapes(inputs) | ||
| outputs = self.model.network.forward(*inputs) | ||
| if self._nranks > 1 and isinstance(self.model._place, fluid.CUDAPlace): | ||
| outputs = [_all_gather(o, self._nranks) for o in to_list(outputs)] | ||
|
|
@@ -778,7 +791,7 @@ def load(self, param_state_pairs, optim_state): | |
|
|
||
| if not hasattr(self.model._optimizer, 'set_state_dict'): | ||
| warnings.warn( | ||
| "paddle.fluid.optimizer is deprecated in API 2.0, please use paddle.optimizer instead" | ||
| "paddle.fluid.optimizer is deprecated in API 2.0, please use paddle.optimizer instead." | ||
| ) | ||
| self.model._optimizer.set_dict(converted_state) | ||
| else: | ||
|
|
@@ -792,14 +805,15 @@ class Model(object): | |
| switched by `paddle.disable_static()`. The usage is as follows. | ||
| But note, the switching between dynamic and static should be before | ||
| instantiating a Model. The input description, i.e, paddle.static.InputSpec, | ||
| must be required. | ||
| must be required for static graph. | ||
|
|
||
| Args: | ||
| network (paddle.nn.Layer): The network is an instance of | ||
| paddle.nn.Layer. | ||
| inputs (InputSpec|list|dict|None): `inputs`, entry points of network, | ||
| could be a InputSpec instance, or lits of InputSpec instances, | ||
| or dict ({name: InputSpec}), and it couldn't be None. | ||
| or dict ({name: InputSpec}), and it couldn't be None in static | ||
| graph. | ||
| labels (InputSpec|list|None): `labels`, entry points of network, | ||
| could be a InputSpec instnace or lits of InputSpec instances, | ||
| or None. For static graph, if labels is required in loss, | ||
|
|
@@ -844,14 +858,18 @@ def __init__(self, network, inputs=None, labels=None): | |
| self._loss = None | ||
| self._loss_weights = None | ||
| self._optimizer = None | ||
| self._optimizer = None | ||
| self._input_shapes = None | ||
| self._is_shape_inferred = False | ||
| self._test_dataloader = None | ||
|
|
||
| if not isinstance(inputs, (list, dict, Input)): | ||
| raise TypeError( | ||
| "'inputs' must be list or dict in static graph mode") | ||
| if not in_dygraph_mode(): | ||
| if not isinstance(inputs, (list, dict, Input)): | ||
| raise TypeError( | ||
| "'inputs' must be list or dict, and couldn't be None.") | ||
| elif inputs: | ||
| self._input_shapes = _update_input_shapes(inputs) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如上个comment,
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改,感谢~ |
||
| self._inputs = self._verify_spec(inputs, True) | ||
| self._inputs = self._verify_spec(inputs, is_input=True) | ||
| self._labels = self._verify_spec(labels) | ||
|
|
||
| # init backend | ||
|
|
@@ -902,7 +920,12 @@ def train_batch(self, inputs, labels=None): | |
| loss = model.train_batch([data], [label]) | ||
| print(loss) | ||
| """ | ||
| return self._adapter.train_batch(inputs, labels) | ||
| loss = self._adapter.train_batch(inputs, labels) | ||
| if fluid.in_dygraph_mode() and self._input_shapes is None: | ||
| self._input_shapes = self._adapter._input_shapes | ||
| self._is_shape_inferred = True | ||
| self._inputs = self._verify_spec(None, self._input_shapes, True) | ||
| return loss | ||
|
|
||
| def eval_batch(self, inputs, labels=None): | ||
| """ | ||
|
|
@@ -947,7 +970,12 @@ def eval_batch(self, inputs, labels=None): | |
| loss = model.eval_batch([data], [label]) | ||
| print(loss) | ||
| """ | ||
| return self._adapter.eval_batch(inputs, labels) | ||
| loss = self._adapter.eval_batch(inputs, labels) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感谢~已修正 |
||
| if fluid.in_dygraph_mode() and self._input_shapes is None: | ||
| self._input_shapes = self._adapter._input_shapes | ||
| self._is_shape_inferred = True | ||
| self._inputs = self._verify_spec(None, self._input_shapes, True) | ||
| return loss | ||
|
|
||
| def test_batch(self, inputs): | ||
| """ | ||
|
|
@@ -987,7 +1015,12 @@ def test_batch(self, inputs): | |
| out = model.test_batch([data]) | ||
| print(out) | ||
| """ | ||
| return self._adapter.test_batch(inputs) | ||
| loss = self._adapter.test_batch(inputs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改,感谢 |
||
| if fluid.in_dygraph_mode() and self._input_shapes is None: | ||
| self._input_shapes = self._adapter._input_shapes | ||
| self._is_shape_inferred = True | ||
| self._inputs = self._verify_spec(None, self._input_shapes, True) | ||
| return loss | ||
|
|
||
| def save(self, path, training=True): | ||
| """ | ||
|
|
@@ -1677,6 +1710,14 @@ def get_inout_spec(all_vars, return_name=False): | |
| if fluid.in_dygraph_mode(): | ||
| with fluid.framework._dygraph_guard(None): | ||
| layer = self.network | ||
| if self._input_shapes is None: # No provided or inferred | ||
| raise RuntimeError( | ||
| "Saving inference model needs 'inputs' or running before saving. Please specify 'inputs' in Model initialization or input training zqqdata and perform a training for shape derivation." | ||
| ) | ||
| if self._is_shape_inferred: | ||
| warnings.warn( | ||
| "'inputs' was not specified when Model initialization, so the input shape to be saved will be the shape derived from the user's actual inputs. The input shape to be saved is %s. For saving correct input shapes, please provide 'inputs' for Model initialization." | ||
| % self._input_shapes) | ||
| layer.forward = paddle.jit.to_static( | ||
| layer.forward, input_spec=self._inputs) | ||
|
|
||
|
|
@@ -1775,6 +1816,7 @@ def _run_one_epoch(self, data_loader, callbacks, mode, logs={}): | |
| data = flatten(data) | ||
| # LoDTensor.shape is callable, where LoDTensor comes from | ||
| # DataLoader in static graph | ||
|
|
||
| batch_size = data[0].shape()[0] if callable(data[ | ||
| 0].shape) else data[0].shape[0] | ||
|
|
||
|
|
@@ -1864,10 +1906,26 @@ def summary(self, input_size=None, dtype=None): | |
| _input_size = self._inputs | ||
| return summary(self.network, _input_size, dtype) | ||
|
|
||
| def _verify_spec(self, specs, is_input=False): | ||
| def _verify_spec(self, specs, shapes=None, is_input=False): | ||
| out_specs = [] | ||
|
|
||
| if isinstance(specs, dict): | ||
| if specs is None: | ||
| # Note(Aurelius84): If not specific specs of `Input`, using argument names of `forward` function | ||
| # to generate `Input`. But how can we know the actual shape of each input tensor? | ||
|
|
||
| if is_input: | ||
| arg_names = extract_args(self.network.forward)[1:] | ||
| if shapes is not None and fluid.in_dygraph_mode(): | ||
| out_specs = [ | ||
| Input( | ||
| name=n, shape=shapes[i]) | ||
| for i, n in enumerate(arg_names) | ||
| ] | ||
| else: | ||
| out_specs = [Input(name=n, shape=[None]) for n in arg_names] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| else: | ||
| out_specs = to_list(specs) | ||
| elif isinstance(specs, dict): | ||
| assert is_input == False | ||
| out_specs = [specs[n] \ | ||
| for n in extract_args(self.network.forward) if n != 'self'] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议一个更达义的变量名替换
self._shapes。这里如果成员变量inputs初始化后是不变的话,是否不需要一个额外的shape变量。只需要一个成员函数解析一下self._inputs即可?如果保留这个shape变量的话,建议修改下变量名。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
动、静态图下在
Model初始化时都需要对self._inputs进行初始化,因为目前train_batch,eval_batch等也需要用到self._inputs。因此需要在动态图下用户没提供inputs时用self._input_shapes记录下在运行模型推导出的输入shape,以便能通过此次更新后的self._verify_spec根据shape获取一个可传递给paddle.to_static的较为合理的self._inputs