Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions python/paddle/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from paddle.fluid.executor import Executor
from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, program_guard
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.log_helper import get_logger
from . import reader
from .reader import *
Expand Down Expand Up @@ -177,6 +178,7 @@ def name_has_fc(var):
# saved in the same file named 'var_file' in the path "./my_paddle_vars".
"""
save_dirname = os.path.normpath(dirname)

if vars is None:
if main_program is None:
main_program = default_main_program()
Expand Down Expand Up @@ -425,7 +427,7 @@ def is_valid(var):
return is_valid

if not isinstance(main_program, Program):
raise ValueError("'main_program' should be an instance of Program.")
raise TypeError("'main_program' should be an instance of Program.")

if not main_program._is_distributed:
raise ValueError(
Expand Down Expand Up @@ -594,6 +596,7 @@ def name_has_fc(var):
# been saved in the same file named 'var_file' in the path "./my_paddle_vars".
"""
load_dirname = os.path.normpath(dirname)

if vars is None:
if main_program is None:
main_program = default_main_program()
Expand All @@ -612,6 +615,7 @@ def name_has_fc(var):

if main_program is None:
main_program = default_main_program()

if not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None")

Expand Down Expand Up @@ -845,7 +849,7 @@ def __load_persistable_vars(executor, dirname, need_load_vars):
executor.run(load_prog)

if not isinstance(main_program, Program):
raise ValueError("'main_program' should be an instance of Program.")
raise TypeError("'main_program' should be an instance of Program.")

if not main_program._is_distributed:
raise ValueError(
Expand Down Expand Up @@ -1009,6 +1013,9 @@ def save_inference_model(dirname,
we save the original program as inference model.",
RuntimeWarning)

elif not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None")

# fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance.
# TODO(Superjomn) add an IR pass to remove 1-scale op.
Expand Down
32 changes: 32 additions & 0 deletions python/paddle/fluid/tests/unittests/test_inference_model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import paddle.fluid.executor as executor
import paddle.fluid.layers as layers
import paddle.fluid.optimizer as optimizer
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.io import save_inference_model, load_inference_model
from paddle.fluid.transpiler import memory_optimize
Expand Down Expand Up @@ -114,5 +115,36 @@ def test_save_inference_model(self):
save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, program)


class TestInstance(unittest.TestCase):
def test_save_inference_model(self):
MODEL_DIR = "./tmp/inference_model3"
init_program = Program()
program = Program()

# fake program without feed/fetch
with program_guard(program, init_program):
x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[1], dtype='float32')

y_predict = layers.fc(input=x, size=1, act=None)

cost = layers.square_error_cost(input=y_predict, label=y)
avg_cost = layers.mean(cost)

place = core.CPUPlace()
exe = executor.Executor(place)
exe.run(init_program, feed={}, fetch_list=[])

# will print warning message

cp_prog = CompiledProgram(program).with_data_parallel(
loss_name=avg_cost.name)

self.assertRaises(TypeError, save_inference_model,
[MODEL_DIR, ["x", "y"], [avg_cost], exe, cp_prog])
self.assertRaises(TypeError, save_inference_model,
[MODEL_DIR, ["x", "y"], [avg_cost], [], cp_prog])


if __name__ == '__main__':
unittest.main()