-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Fix save load inference model and remove pickle #7712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0e8bf52
52f07d4
888e48b
0580519
55fdc82
8fb74d0
73cdd63
834d8ad
2458893
3cce026
0bda221
459d9cb
90889d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,14 +19,10 @@ limitations under the License. */ | |
| #include "paddle/framework/init.h" | ||
| #include "paddle/framework/scope.h" | ||
|
|
||
| #ifdef PADDLE_USE_PTOOLS | ||
| #include "chooseser.h" | ||
| #endif | ||
|
|
||
| namespace paddle { | ||
|
|
||
| void InferenceEngine::LoadInferenceModel(const std::string& dirname) { | ||
| std::string model_filename = dirname + "/__model__.dat"; | ||
| std::string model_filename = dirname + "/__model__"; | ||
| LOG(INFO) << "loading model from " << model_filename; | ||
| std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); | ||
| std::string program_desc_str; | ||
|
|
@@ -52,39 +48,15 @@ void InferenceEngine::LoadInferenceModel(const std::string& dirname) { | |
| } | ||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why need to use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to make the ordering of the Say you use code like Then based on our code when prepending feed operators (because we first prepend "x", then "y"), the order of feed ops in block.AllOps will be like [feed("y") feed("x")]. So by doing the insert thing on line 48, we can make the order of vector
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Thanks~ |
||
| void InferenceEngine::LoadInferenceModel( | ||
| const std::string& dirname, | ||
| const std::vector<std::string>& feed_var_names, | ||
| const std::vector<std::string>& fetch_var_names) { | ||
| std::string model_filename = dirname + "/__model__.dat"; | ||
| LOG(INFO) << "loading model from " << model_filename; | ||
| std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); | ||
| std::string program_desc_str; | ||
| inputfs.seekg(0, std::ios::end); | ||
| program_desc_str.resize(inputfs.tellg()); | ||
| inputfs.seekg(0, std::ios::beg); | ||
| LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); | ||
| inputfs.read(&program_desc_str[0], program_desc_str.size()); | ||
| inputfs.close(); | ||
|
|
||
| program_ = new framework::ProgramDesc(program_desc_str); | ||
| GenerateLoadProgram(dirname); | ||
|
|
||
| if (feed_var_names.empty() || fetch_var_names.empty()) { | ||
| LOG(FATAL) << "Please specify the feed_var_names and fetch_var_names."; | ||
| } | ||
| feed_var_names_ = feed_var_names; | ||
| fetch_var_names_ = fetch_var_names; | ||
| PrependFeedOp(); | ||
| AppendFetchOp(); | ||
| } | ||
|
|
||
| bool InferenceEngine::IsParameter(const framework::VarDesc* var) { | ||
| if (var->Persistable() && var->Name() != "feed" && var->Name() != "fetch") { | ||
| if (var->Persistable()) { | ||
| // There are many unreachable variables in the program | ||
| for (size_t i = 0; i < program_->Size(); ++i) { | ||
| const framework::BlockDesc& block = program_->Block(i); | ||
| for (auto* op : block.AllOps()) { | ||
| if (op->Type() == "feed") { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a quick clarification, how do we handle vars whose name is "fetch"?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to handle fetch case. Because "fetch" var, although is persistable, is only going to be the output of ops, which will not be a problem for the logic this code: |
||
| continue; | ||
| } | ||
| for (auto input_argument_name : op->InputArgumentNames()) { | ||
| if (input_argument_name == var->Name()) { | ||
| return true; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,6 +67,84 @@ def as_numpy(tensor): | |
| return ans | ||
|
|
||
|
|
||
| def has_feed_operators(block, feed_targets, feed_holder_name): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add some comment for this function?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| """ Check whether the block already has feed operators. | ||
|
|
||
| Return false if the block does not have any feed operators. | ||
| If some feed operators have been prepended to the block, check that | ||
| the info contained in these feed operators matches the feed_targets | ||
| and feed_holder_name. Raise exception when any mismatch is found. | ||
| Return true when the block has feed operators with matching info. | ||
|
|
||
| Args: | ||
| block: a block instance (typically global block of a program) | ||
| feed_targets: a dictionary of {feed_target_name: feed_target_data} | ||
| feed_holder_name: the name of the variable that holds the data of | ||
| all feed targets. The type of this feed_holder variable is | ||
| FEED_MINIBATCH, which is essentially vector<LoDTensor>. | ||
|
|
||
| Returns: | ||
| A boolean value that indicates whether a block has feed operators | ||
| that match the info contained in feed_targets and feed_holder_name. | ||
| """ | ||
|
|
||
| feed_count = 0 | ||
| for op in block.ops: | ||
| if op.desc.type() == 'feed': | ||
| feed_count += 1 | ||
| assert op.desc.input('X')[0] == feed_holder_name | ||
| feed_target_name = op.desc.output('Out')[0] | ||
| if feed_target_name not in feed_targets: | ||
| raise Exception("'feed_targets' does not have {} variable". | ||
| format(feed_target_name)) | ||
| else: | ||
| break | ||
| if feed_count > 0 and feed_count != len(feed_targets): | ||
| raise Exception( | ||
| "Feed operators in program desc do not match 'feed_targets'") | ||
| return feed_count > 0 | ||
|
|
||
|
|
||
| def has_fetch_operators(block, fetch_targets, fetch_holder_name): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add some comment for this function?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| """ Check whether the block already has fetch operators. | ||
|
|
||
| Return false if the block does not have any fetch operators. | ||
| If some fetch operators have been appended to the block, check that | ||
| the info contained in these fetch operators matches the fetch_targets | ||
| and fetch_holder_name. Raise exception when any mismatch is found. | ||
| Return true when the block has fetch operators with matching info. | ||
|
|
||
| Args: | ||
| block: a block instance (typically global block of a program) | ||
| fetch_targets: a dictionary of {fetch_target_name: fetch_target_data} | ||
| fetch_holder_name: the name of the variable that holds the data of | ||
| all fetch targets. The type of this fetch_holder variable is | ||
| FETCH_LIST, which is essentially vector<LoDTensor>. | ||
|
|
||
| Return: | ||
| A boolean value that indicates whether a block has fetch operators | ||
| that match the info contained in fetch_targets and fetch_holder_name. | ||
| """ | ||
|
|
||
| fetch_count = 0 | ||
| for op in block.ops: | ||
| if op.desc.type() == 'fetch': | ||
| fetch_count += 1 | ||
| assert op.desc.output('Out')[0] == fetch_holder_name | ||
| fetch_target_name = op.desc.input('X')[0] | ||
| if fetch_target_name not in [ | ||
| var.desc.name() for var in fetch_targets | ||
| ]: | ||
| raise Exception("'fetch_targets' does not have {} variable". | ||
| format(fetch_target_name)) | ||
| idx = op.desc.attr('col') | ||
| assert fetch_target_name == fetch_targets[idx].desc.name() | ||
| if fetch_count > 0 and fetch_count != len(fetch_targets): | ||
| raise Exception( | ||
| "Fetch operators in program desc do not match 'fetch_targets'") | ||
| return fetch_count > 0 | ||
|
|
||
|
|
||
| class Executor(object): | ||
| def __init__(self, places): | ||
| if not isinstance(places, list) and not isinstance(places, tuple): | ||
|
|
@@ -146,33 +224,50 @@ def run(self, | |
|
|
||
| program = program.clone() | ||
| global_block = program.global_block() | ||
| feed_var = global_block.create_var( | ||
| name=feed_var_name, | ||
| type=core.VarDesc.VarType.FEED_MINIBATCH, | ||
| persistable=True) | ||
|
|
||
| for i, name in enumerate(feed): | ||
| out = global_block.var(name) | ||
| global_block.prepend_op( | ||
| 'feed', | ||
| inputs={'X': [feed_var]}, | ||
| outputs={'Out': [out]}, | ||
| attrs={'col': i}) | ||
| cur_feed = feed[name] | ||
| if not isinstance(cur_feed, core.LoDTensor): | ||
| cur_feed = self.aslodtensor(cur_feed) | ||
| core.set_feed_variable(scope, cur_feed, feed_var.name, i) | ||
|
|
||
| fetch_var = global_block.create_var( | ||
| name=fetch_var_name, | ||
| type=core.VarDesc.VarType.FETCH_LIST, | ||
| persistable=True) | ||
| for i, var in enumerate(fetch_list): | ||
| global_block.append_op( | ||
| type='fetch', | ||
| inputs={'X': [var]}, | ||
| outputs={'Out': [fetch_var]}, | ||
| attrs={'col': i}) | ||
|
|
||
| if feed_var_name in global_block.vars: | ||
| feed_var = global_block.var(feed_var_name) | ||
| else: | ||
| feed_var = global_block.create_var( | ||
| name=feed_var_name, | ||
| type=core.VarDesc.VarType.FEED_MINIBATCH, | ||
| persistable=True) | ||
|
|
||
| if fetch_var_name in global_block.vars: | ||
| fetch_var = global_block.var(fetch_var_name) | ||
| else: | ||
| fetch_var = global_block.create_var( | ||
| name=fetch_var_name, | ||
| type=core.VarDesc.VarType.FETCH_LIST, | ||
| persistable=True) | ||
|
|
||
| if not has_feed_operators(global_block, feed, feed_var_name): | ||
| for i, name in enumerate(feed): | ||
| out = global_block.var(name) | ||
| global_block.prepend_op( | ||
| type='feed', | ||
| inputs={'X': [feed_var]}, | ||
| outputs={'Out': [out]}, | ||
| attrs={'col': i}) | ||
|
|
||
| for op in global_block.ops: | ||
| if op.desc.type() == 'feed': | ||
| feed_target_name = op.desc.output('Out')[0] | ||
| cur_feed = feed[feed_target_name] | ||
| if not isinstance(cur_feed, core.LoDTensor): | ||
| cur_feed = self.aslodtensor(cur_feed) | ||
| idx = op.desc.attr('col') | ||
| core.set_feed_variable(scope, cur_feed, feed_var_name, idx) | ||
| else: | ||
| break | ||
|
|
||
| if not has_fetch_operators(global_block, fetch_list, fetch_var_name): | ||
| for i, var in enumerate(fetch_list): | ||
| global_block.append_op( | ||
| type='fetch', | ||
| inputs={'X': [var]}, | ||
| outputs={'Out': [fetch_var]}, | ||
| attrs={'col': i}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is too long. Can you simplify it and avoid some repeated codes?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. I extract the feed/fetch operator checking code as separate functions and simplify the code a little bit. |
||
|
|
||
| self.executor.run(program.desc, scope, 0, True, True) | ||
| outs = [ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
line 22 - 24(I cannot comment online 22.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done