-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Fix save load inference model and remove pickle #7712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix save load inference model and remove pickle #7712
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you please add an example to show how to use load_inference_model to do inference in Python API? You can append the usage codes in https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py.
|
|
||
| void InferenceEngine::LoadInferenceModel(const std::string& dirname) { | ||
| std::string model_filename = dirname + "/__model__.dat"; | ||
| std::string model_filename = dirname + "/__model__"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove line 22 - 24 (I cannot comment on line 22.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| } | ||
| } | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why need to use insert on line 48?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to make the ordering of the feed_var_names_ (also known as feed_target_names) consistent with the python side when using the save_inference_model function.
Say you use code like save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, program) in python.
Then based on our code when prepending feed operators (because we first prepend "x", then "y"), the order of feed ops in block.AllOps will be like [feed("y") feed("x")].
So by doing the insert thing on line 48, we can make the order of vector feed_var_names_ to be ["x", "y"] which is consistent with the python side order. So that we you provide a vector<Tensor> feeds, you know the order is the same as the python save_inference_model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Thanks~
python/paddle/v2/fluid/executor.py
Outdated
| for op in global_block.ops: | ||
| if op.desc.type() == 'feed': | ||
| feed_count += 1 | ||
| assert op.desc.input('X')[0] == feed_var_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to keep all the input variable of feed_op the same name? I'm not sure...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume there will be only one feed_variable in the program Desc, which holds a vector of Tensor.
Each feed operator will take one of the Tensor data from this input feed_variable and copy the data to its output variable. So I think different feed operator should have different output variable names (x, y, etc), but should have the same input variable which is our feed_var_name here.
Could you explain a little more about your point? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I am not sure, whether the input data can be feed correctly if we set the feed_holder_name to different string... and do we need to support this...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to make the feed_holder_name (same for fetch) consistent so that the code is simpler and less error prone.
Even if this feature is later found out to be desirable, we can fix it in future PR. What do you think?
python/paddle/v2/fluid/executor.py
Outdated
| core.set_feed_variable(scope, cur_feed, feed_var.name, idx) | ||
| elif op.desc.type() == 'fetch': | ||
| fetch_count += 1 | ||
| assert op.desc.output('Out')[0] == fetch_var_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to keep all the output variable of fetch_op the same name? I'm not sure...
| type='fetch', | ||
| inputs={'X': [var]}, | ||
| outputs={'Out': [fetch_var]}, | ||
| attrs={'col': i}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is too long. Can you simplify it and avoid some repeated codes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I extract the feed/fetch operator checking code as separate functions and simplify the code a little bit.
kexinzhao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Xreki for the suggestions.
I have added some example code using load_inference_model in test_recognize_digits_mlp.py
|
|
||
| void InferenceEngine::LoadInferenceModel(const std::string& dirname) { | ||
| std::string model_filename = dirname + "/__model__.dat"; | ||
| std::string model_filename = dirname + "/__model__"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| type='fetch', | ||
| inputs={'X': [var]}, | ||
| outputs={'Out': [fetch_var]}, | ||
| attrs={'col': i}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I extract the feed/fetch operator checking code as separate functions and simplify the code a little bit.
| } | ||
| } | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Thanks~
| return ans | ||
|
|
||
|
|
||
| def has_feed_operators(block, feed_targets, feed_holder_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add some comment for this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| return feed_count > 0 | ||
|
|
||
|
|
||
| def has_fetch_operators(block, fetch_targets, fetch_holder_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add some comment for this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
python/paddle/v2/fluid/io.py
Outdated
|
|
||
| def prepend_feed_ops(inference_program, feeded_var_names): | ||
| def prepend_feed_ops(inference_program, | ||
| feeded_var_names, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feeded_var_names -> feed_target_names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
python/paddle/v2/fluid/io.py
Outdated
|
|
||
| def append_fetch_ops(inference_program, fetch_var_names): | ||
| def append_fetch_ops(inference_program, | ||
| fetch_var_names, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fetch_var_names -> fetch_target_names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
python/paddle/v2/fluid/io.py
Outdated
| predicate=_is_presistable_and_exist_) | ||
|
|
||
|
|
||
| def get_feed_targets(program): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_feed_targets -> get_feed_target_names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
python/paddle/v2/fluid/io.py
Outdated
| return feed_targets | ||
|
|
||
|
|
||
| def get_fetch_targets(program): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_fetch_targets -> get_fetch_target_names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| results = exe.run(infer_prog, | ||
| feed={feed_var_names[0]: tensor_x}, | ||
| fetch_list=fetch_vars) | ||
| print(results[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move the added lines to the end of the file (before exit) and add some comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Xreki
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| for (size_t i = 0; i < program_->Size(); ++i) { | ||
| const framework::BlockDesc& block = program_->Block(i); | ||
| for (auto* op : block.AllOps()) { | ||
| if (op->Type() == "feed") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a quick clarification, how do we handle vars whose name is "fetch"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to handle fetch case. Because "fetch" var, although is persistable, is only going to be the output of ops, which will not be a problem for the logic this code:
for (auto input_argument_name : op->InputArgumentNames()) {
| exe.run(fluid.default_startup_program()) | ||
|
|
||
| PASS_NUM = 100 | ||
| PASS_NUM = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should change PASS_NUM to 100 :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
sidgoyal78
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Fix #7221
This pr addresses comments by @Xreki on #7636.
It fixes the following issues:
feedandfetchfeedandfetch_listinput arguments)TODO: In the future PR, we will also create a new executor.run() function in the C++ executor class to mimic how executor.py handle ProgramDesc in this PR.