Skip to content

Commit d682d8e

Browse files
authored
update TraceLayer save_inference_model api doc and dygraph_to_static example code (#3406)
update TraceLayer save_inference_model api doc and the example code of dygraph_to_static guide doc
1 parent 9d09f51 commit d682d8e

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

doc/paddle/api/paddle/fluid/dygraph/jit/TracedLayer_cn.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ TracedLayer只能用于将data independent的动态图模型转换为静态图
5353
print(out_static_graph[0].shape) # (2, 10)
5454
5555
# 将静态图模型保存为预测模型
56-
static_layer.save_inference_model(dirname='./saved_infer_model')
56+
static_layer.save_inference_model(path='./saved_infer_model')
5757
5858
.. py:method:: set_strategy(build_strategy=None, exec_strategy=None)
5959
@@ -93,12 +93,14 @@ TracedLayer只能用于将data independent的动态图模型转换为静态图
9393
static_layer.set_strategy(build_strategy=build_strategy, exec_strategy=exec_strategy)
9494
out_static_graph = static_layer([in_var])
9595
96-
.. py:method:: save_inference_model(dirname, feed=None, fetch=None)
96+
.. py:method:: save_inference_model(path, feed=None, fetch=None)
9797
9898
将TracedLayer保存为用于预测部署的模型。保存的预测模型可被C++预测接口加载。
9999

100+
``path`` 是存储目标的前缀,存储的模型结构 ``Program`` 文件的后缀为 ``.pdmodel``,存储的持久参数变量文件的后缀为 ``.pdiparams``.
101+
100102
参数:
101-
- **dirname** (str) - 预测模型的保存目录
103+
- **path** (str) - 存储模型的路径前缀。格式为 ``dirname/file_prefix`` 或者 ``file_prefix``
102104
- **feed** (list(int), 可选) - 预测模型输入变量的索引。若为None,则TracedLayer的所有输入变量均会作为预测模型的输入。默认值为None。
103105
- **fetch** (list(int), 可选) - 预测模型输出变量的索引。若为None,则TracedLayer的所有输出变量均会作为预测模型的输出。默认值为None。
104106

doc/paddle/guides/04_dygraph_to_static/basic_usage_cn.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ PaddlePaddle主要的动转静方式是基于源代码级别转换的ProgramTran
1111
.. code-block:: python
1212
1313
import paddle
14+
import numpy as np
1415
1516
@paddle.jit.to_static
1617
def func(input_var):
@@ -106,6 +107,7 @@ trace是指在模型运行时记录下其运行过哪些算子。TracedLayer就
106107

107108
.. code-block:: python
108109
110+
paddle.enable_static()
109111
place = paddle.CPUPlace()
110112
exe = paddle.Executor(place)
111113
program, feed_vars, fetch_vars = paddle.static.load_inference_model(save_dirname, exe)

0 commit comments

Comments
 (0)