Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 5 additions & 29 deletions paddle/inference/inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ limitations under the License. */
namespace paddle {

void InferenceEngine::LoadInferenceModel(const std::string& dirname) {
std::string model_filename = dirname + "/__model__.dat";
std::string model_filename = dirname + "/__model__";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove line 22 - 24 (I cannot comment on line 22.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

LOG(INFO) << "loading model from " << model_filename;
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
std::string program_desc_str;
Expand All @@ -52,39 +52,15 @@ void InferenceEngine::LoadInferenceModel(const std::string& dirname) {
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need to use insert on line 48?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to make the ordering of the feed_var_names_ (also known as feed_target_names) consistent with the python side when using the save_inference_model function.

Say you use code like save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, program) in python.

Then based on our code when prepending feed operators (because we first prepend "x", then "y"), the order of feed ops in block.AllOps will be like [feed("y") feed("x")].

So by doing the insert thing on line 48, we can make the order of vector feed_var_names_ to be ["x", "y"] which is consistent with the python side order. So that we you provide a vector<Tensor> feeds, you know the order is the same as the python save_inference_model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks~

void InferenceEngine::LoadInferenceModel(
const std::string& dirname,
const std::vector<std::string>& feed_var_names,
const std::vector<std::string>& fetch_var_names) {
std::string model_filename = dirname + "/__model__.dat";
LOG(INFO) << "loading model from " << model_filename;
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
std::string program_desc_str;
inputfs.seekg(0, std::ios::end);
program_desc_str.resize(inputfs.tellg());
inputfs.seekg(0, std::ios::beg);
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
inputfs.read(&program_desc_str[0], program_desc_str.size());
inputfs.close();

program_ = new framework::ProgramDesc(program_desc_str);
GenerateLoadProgram(dirname);

if (feed_var_names.empty() || fetch_var_names.empty()) {
LOG(FATAL) << "Please specify the feed_var_names and fetch_var_names.";
}
feed_var_names_ = feed_var_names;
fetch_var_names_ = fetch_var_names;
PrependFeedOp();
AppendFetchOp();
}

bool InferenceEngine::IsParameter(const framework::VarDesc* var) {
if (var->Persistable() && var->Name() != "feed" && var->Name() != "fetch") {
if (var->Persistable()) {
// There are many unreachable variables in the program
for (size_t i = 0; i < program_->Size(); ++i) {
const framework::BlockDesc& block = program_->Block(i);
for (auto* op : block.AllOps()) {
if (op->Type() == "feed") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a quick clarification, how do we handle vars whose name is "fetch"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to handle fetch case. Because "fetch" var, although is persistable, is only going to be the output of ops, which will not be a problem for the logic this code:

for (auto input_argument_name : op->InputArgumentNames()) {

continue;
}
for (auto input_argument_name : op->InputArgumentNames()) {
if (input_argument_name == var->Name()) {
return true;
Expand Down
3 changes: 0 additions & 3 deletions paddle/inference/inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ class InferenceEngine {
}

void LoadInferenceModel(const std::string& dirname);
void LoadInferenceModel(const std::string& dirname,
const std::vector<std::string>& feed_var_names,
const std::vector<std::string>& fetch_var_names);
void Execute(const std::vector<framework::LoDTensor>& feeds,
std::vector<framework::LoDTensor>& fetchs);

Expand Down
97 changes: 70 additions & 27 deletions python/paddle/v2/fluid/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,33 +146,76 @@ def run(self,

program = program.clone()
global_block = program.global_block()
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)

for i, name in enumerate(feed):
out = global_block.var(name)
global_block.prepend_op(
'feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
cur_feed = feed[name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
core.set_feed_variable(scope, cur_feed, feed_var.name, i)

fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)
for i, var in enumerate(fetch_list):
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})

if feed_var_name in global_block.vars:
feed_var = global_block.var(feed_var_name)
else:
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)

if fetch_var_name in global_block.vars:
fetch_var = global_block.var(fetch_var_name)
else:
fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)

feed_count = 0
fetch_count = 0
for op in global_block.ops:
if op.desc.type() == 'feed':
feed_count += 1
assert op.desc.input('X')[0] == feed_var_name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to keep all the input variable of feed_op the same name? I'm not sure...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume there will be only one feed_variable in the program Desc, which holds a vector of Tensor.
Each feed operator will take one of the Tensor data from this input feed_variable and copy the data to its output variable. So I think different feed operator should have different output variable names (x, y, etc), but should have the same input variable which is our feed_var_name here.

Could you explain a little more about your point? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I am not sure, whether the input data can be feed correctly if we set the feed_holder_name to different string... and do we need to support this...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to make the feed_holder_name (same for fetch) consistent so that the code is simpler and less error prone.

Even if this feature is later found out to be desirable, we can fix it in future PR. What do you think?

name = op.desc.output('Out')[0]
if name not in feed:
raise Exception("feed does not have {} variable".format(
name))
cur_feed = feed[name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
idx = op.desc.attr('col')
core.set_feed_variable(scope, cur_feed, feed_var.name, idx)
elif op.desc.type() == 'fetch':
fetch_count += 1
assert op.desc.output('Out')[0] == fetch_var_name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to keep all the output variable of fetch_op the same name? I'm not sure...

name = op.desc.input('X')[0]
if name not in [var.desc.name() for var in fetch_list]:
raise Exception(
"fetch_list does not have {} variable".format(name))
idx = op.desc.attr('col')
assert name == fetch_list[idx].desc.name()

if feed_count > 0 and feed_count != len(feed):
raise Exception(
"Feed operators in program desc does not match 'feed'")

if fetch_count > 0 and fetch_count != len(fetch_list):
raise Exception(
"Fetch operators in program desc does not match 'fetch_list'")

if feed_count == 0:
for i, name in enumerate(feed):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
cur_feed = feed[name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
core.set_feed_variable(scope, cur_feed, feed_var.name, i)

if fetch_count == 0:
for i, var in enumerate(fetch_list):
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is too long. Can you simplify it and avoid some repeated codes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I extract the feed/fetch operator checking code as separate functions and simplify the code a little bit.


self.executor.run(program.desc, scope, 0, True, True)
outs = [
Expand Down
59 changes: 39 additions & 20 deletions python/paddle/v2/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,14 @@ def get_inference_program(target_vars, main_program=None):
return inference_program


def prepend_feed_ops(inference_program, feeded_var_names):
def prepend_feed_ops(inference_program,
feeded_var_names,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feeded_var_names -> feed_target_names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

feed_holder_name='feed'):
global_block = inference_program.global_block()
feed_var = global_block.create_var(
name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True)
name=feed_holder_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)

for i, name in enumerate(feeded_var_names):
out = global_block.var(name)
Expand All @@ -206,10 +210,14 @@ def prepend_feed_ops(inference_program, feeded_var_names):
attrs={'col': i})


def append_fetch_ops(inference_program, fetch_var_names):
def append_fetch_ops(inference_program,
fetch_var_names,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fetch_var_names -> fetch_target_names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

fetch_holder_name='fetch'):
global_block = inference_program.global_block()
fetch_var = global_block.create_var(
name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True)
name=fetch_holder_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)

for i, name in enumerate(fetch_var_names):
global_block.append_op(
Expand Down Expand Up @@ -261,21 +269,12 @@ def save_inference_model(dirname,
inference_program = pruned_program.inference_optimize()
fetch_var_names = [v.name for v in target_vars]

model_file_name = dirname + "/__model__"
with open(model_file_name, "w") as f:
pickle.dump({
"program_desc_str": inference_program.desc.serialize_to_string(),
"feed_var_names": feeded_var_names,
"fetch_var_names": fetch_var_names
}, f, -1)

prepend_feed_ops(inference_program, feeded_var_names)
append_fetch_ops(inference_program, fetch_var_names)

# Save only programDesc of inference_program in binary format
# in another file: __model__.dat
with open(model_file_name + ".dat", "wb") as fp:
fp.write(inference_program.desc.serialize_to_string())
model_file_name = dirname + "/__model__"
with open(model_file_name, "wb") as f:
f.write(inference_program.desc.serialize_to_string())

save_params(executor, dirname, main_program)

Expand All @@ -298,6 +297,24 @@ def _is_presistable_and_exist_(var):
predicate=_is_presistable_and_exist_)


def get_feed_targets(program):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_feed_targets -> get_feed_target_names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

feed_targets = []
global_block = program.global_block()
for op in global_block.ops:
if op.desc.type() == 'feed':
feed_targets.insert(0, op.desc.output('Out')[0])
return feed_targets


def get_fetch_targets(program):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_fetch_targets -> get_fetch_target_names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

fetch_targets = []
global_block = program.global_block()
for op in global_block.ops:
if op.desc.type() == 'fetch':
fetch_targets.append(op.desc.input('X')[0])
return fetch_targets


def load_inference_model(dirname, executor):
"""
Load inference model from a directory
Expand All @@ -314,12 +331,14 @@ def load_inference_model(dirname, executor):
raise ValueError("There is no directory named '%s'", dirname)

model_file_name = dirname + "/__model__"
model = pickle.load(open(model_file_name, "r"))
program_desc_str = model["program_desc_str"]
feed_var_names = model["feed_var_names"]
fetch_var_names = model["fetch_var_names"]
with open(model_file_name, "rb") as f:
program_desc_str = f.read()

program = Program.parse_from_string(program_desc_str)
load_persistables_if_exist(executor, dirname, program)

feed_var_names = get_feed_targets(program)
fetch_var_names = get_fetch_targets(program)
fetch_vars = [program.global_block().var(name) for name in fetch_var_names]

return [program, feed_var_names, fetch_vars]
Expand Down