-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Fix save load inference model and remove pickle #7712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
0e8bf52
52f07d4
888e48b
0580519
55fdc82
8fb74d0
73cdd63
834d8ad
2458893
3cce026
0bda221
459d9cb
90889d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,7 +26,7 @@ limitations under the License. */ | |
| namespace paddle { | ||
|
|
||
| void InferenceEngine::LoadInferenceModel(const std::string& dirname) { | ||
| std::string model_filename = dirname + "/__model__.dat"; | ||
| std::string model_filename = dirname + "/__model__"; | ||
| LOG(INFO) << "loading model from " << model_filename; | ||
| std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); | ||
| std::string program_desc_str; | ||
|
|
@@ -52,39 +52,15 @@ void InferenceEngine::LoadInferenceModel(const std::string& dirname) { | |
| } | ||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why need to use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to make the ordering of the Say you use code like Then based on our code when prepending feed operators (because we first prepend "x", then "y"), the order of feed ops in block.AllOps will be like [feed("y") feed("x")]. So by doing the insert thing on line 48, we can make the order of vector
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Thanks~ |
||
| void InferenceEngine::LoadInferenceModel( | ||
| const std::string& dirname, | ||
| const std::vector<std::string>& feed_var_names, | ||
| const std::vector<std::string>& fetch_var_names) { | ||
| std::string model_filename = dirname + "/__model__.dat"; | ||
| LOG(INFO) << "loading model from " << model_filename; | ||
| std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); | ||
| std::string program_desc_str; | ||
| inputfs.seekg(0, std::ios::end); | ||
| program_desc_str.resize(inputfs.tellg()); | ||
| inputfs.seekg(0, std::ios::beg); | ||
| LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); | ||
| inputfs.read(&program_desc_str[0], program_desc_str.size()); | ||
| inputfs.close(); | ||
|
|
||
| program_ = new framework::ProgramDesc(program_desc_str); | ||
| GenerateLoadProgram(dirname); | ||
|
|
||
| if (feed_var_names.empty() || fetch_var_names.empty()) { | ||
| LOG(FATAL) << "Please specify the feed_var_names and fetch_var_names."; | ||
| } | ||
| feed_var_names_ = feed_var_names; | ||
| fetch_var_names_ = fetch_var_names; | ||
| PrependFeedOp(); | ||
| AppendFetchOp(); | ||
| } | ||
|
|
||
| bool InferenceEngine::IsParameter(const framework::VarDesc* var) { | ||
| if (var->Persistable() && var->Name() != "feed" && var->Name() != "fetch") { | ||
| if (var->Persistable()) { | ||
| // There are many unreachable variables in the program | ||
| for (size_t i = 0; i < program_->Size(); ++i) { | ||
| const framework::BlockDesc& block = program_->Block(i); | ||
| for (auto* op : block.AllOps()) { | ||
| if (op->Type() == "feed") { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a quick clarification, how do we handle vars whose name is "fetch"?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to handle fetch case. Because "fetch" var, although is persistable, is only going to be the output of ops, which will not be a problem for the logic this code: |
||
| continue; | ||
| } | ||
| for (auto input_argument_name : op->InputArgumentNames()) { | ||
| if (input_argument_name == var->Name()) { | ||
| return true; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -146,33 +146,76 @@ def run(self, | |
|
|
||
| program = program.clone() | ||
| global_block = program.global_block() | ||
| feed_var = global_block.create_var( | ||
| name=feed_var_name, | ||
| type=core.VarDesc.VarType.FEED_MINIBATCH, | ||
| persistable=True) | ||
|
|
||
| for i, name in enumerate(feed): | ||
| out = global_block.var(name) | ||
| global_block.prepend_op( | ||
| 'feed', | ||
| inputs={'X': [feed_var]}, | ||
| outputs={'Out': [out]}, | ||
| attrs={'col': i}) | ||
| cur_feed = feed[name] | ||
| if not isinstance(cur_feed, core.LoDTensor): | ||
| cur_feed = self.aslodtensor(cur_feed) | ||
| core.set_feed_variable(scope, cur_feed, feed_var.name, i) | ||
|
|
||
| fetch_var = global_block.create_var( | ||
| name=fetch_var_name, | ||
| type=core.VarDesc.VarType.FETCH_LIST, | ||
| persistable=True) | ||
| for i, var in enumerate(fetch_list): | ||
| global_block.append_op( | ||
| type='fetch', | ||
| inputs={'X': [var]}, | ||
| outputs={'Out': [fetch_var]}, | ||
| attrs={'col': i}) | ||
|
|
||
| if feed_var_name in global_block.vars: | ||
| feed_var = global_block.var(feed_var_name) | ||
| else: | ||
| feed_var = global_block.create_var( | ||
| name=feed_var_name, | ||
| type=core.VarDesc.VarType.FEED_MINIBATCH, | ||
| persistable=True) | ||
|
|
||
| if fetch_var_name in global_block.vars: | ||
| fetch_var = global_block.var(fetch_var_name) | ||
| else: | ||
| fetch_var = global_block.create_var( | ||
| name=fetch_var_name, | ||
| type=core.VarDesc.VarType.FETCH_LIST, | ||
| persistable=True) | ||
|
|
||
| feed_count = 0 | ||
| fetch_count = 0 | ||
| for op in global_block.ops: | ||
| if op.desc.type() == 'feed': | ||
| feed_count += 1 | ||
| assert op.desc.input('X')[0] == feed_var_name | ||
|
||
| name = op.desc.output('Out')[0] | ||
| if name not in feed: | ||
| raise Exception("feed does not have {} variable".format( | ||
| name)) | ||
| cur_feed = feed[name] | ||
| if not isinstance(cur_feed, core.LoDTensor): | ||
| cur_feed = self.aslodtensor(cur_feed) | ||
| idx = op.desc.attr('col') | ||
| core.set_feed_variable(scope, cur_feed, feed_var.name, idx) | ||
| elif op.desc.type() == 'fetch': | ||
| fetch_count += 1 | ||
| assert op.desc.output('Out')[0] == fetch_var_name | ||
|
||
| name = op.desc.input('X')[0] | ||
| if name not in [var.desc.name() for var in fetch_list]: | ||
| raise Exception( | ||
| "fetch_list does not have {} variable".format(name)) | ||
| idx = op.desc.attr('col') | ||
| assert name == fetch_list[idx].desc.name() | ||
|
|
||
| if feed_count > 0 and feed_count != len(feed): | ||
| raise Exception( | ||
| "Feed operators in program desc does not match 'feed'") | ||
|
|
||
| if fetch_count > 0 and fetch_count != len(fetch_list): | ||
| raise Exception( | ||
| "Fetch operators in program desc does not match 'fetch_list'") | ||
|
|
||
| if feed_count == 0: | ||
| for i, name in enumerate(feed): | ||
| out = global_block.var(name) | ||
| global_block.prepend_op( | ||
| type='feed', | ||
| inputs={'X': [feed_var]}, | ||
| outputs={'Out': [out]}, | ||
| attrs={'col': i}) | ||
| cur_feed = feed[name] | ||
| if not isinstance(cur_feed, core.LoDTensor): | ||
| cur_feed = self.aslodtensor(cur_feed) | ||
| core.set_feed_variable(scope, cur_feed, feed_var.name, i) | ||
|
|
||
| if fetch_count == 0: | ||
| for i, var in enumerate(fetch_list): | ||
| global_block.append_op( | ||
| type='fetch', | ||
| inputs={'X': [var]}, | ||
| outputs={'Out': [fetch_var]}, | ||
| attrs={'col': i}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is too long. Can you simplify it and avoid some repeated codes?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. I extract the feed/fetch operator checking code as separate functions and simplify the code a little bit. |
||
|
|
||
| self.executor.run(program.desc, scope, 0, True, True) | ||
| outs = [ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -192,10 +192,14 @@ def get_inference_program(target_vars, main_program=None): | |
| return inference_program | ||
|
|
||
|
|
||
| def prepend_feed_ops(inference_program, feeded_var_names): | ||
| def prepend_feed_ops(inference_program, | ||
| feeded_var_names, | ||
|
||
| feed_holder_name='feed'): | ||
| global_block = inference_program.global_block() | ||
| feed_var = global_block.create_var( | ||
| name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True) | ||
| name=feed_holder_name, | ||
| type=core.VarDesc.VarType.FEED_MINIBATCH, | ||
| persistable=True) | ||
|
|
||
| for i, name in enumerate(feeded_var_names): | ||
| out = global_block.var(name) | ||
|
|
@@ -206,10 +210,14 @@ def prepend_feed_ops(inference_program, feeded_var_names): | |
| attrs={'col': i}) | ||
|
|
||
|
|
||
| def append_fetch_ops(inference_program, fetch_var_names): | ||
| def append_fetch_ops(inference_program, | ||
| fetch_var_names, | ||
|
||
| fetch_holder_name='fetch'): | ||
| global_block = inference_program.global_block() | ||
| fetch_var = global_block.create_var( | ||
| name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True) | ||
| name=fetch_holder_name, | ||
| type=core.VarDesc.VarType.FETCH_LIST, | ||
| persistable=True) | ||
|
|
||
| for i, name in enumerate(fetch_var_names): | ||
| global_block.append_op( | ||
|
|
@@ -261,21 +269,12 @@ def save_inference_model(dirname, | |
| inference_program = pruned_program.inference_optimize() | ||
| fetch_var_names = [v.name for v in target_vars] | ||
|
|
||
| model_file_name = dirname + "/__model__" | ||
| with open(model_file_name, "w") as f: | ||
| pickle.dump({ | ||
| "program_desc_str": inference_program.desc.serialize_to_string(), | ||
| "feed_var_names": feeded_var_names, | ||
| "fetch_var_names": fetch_var_names | ||
| }, f, -1) | ||
|
|
||
| prepend_feed_ops(inference_program, feeded_var_names) | ||
| append_fetch_ops(inference_program, fetch_var_names) | ||
|
|
||
| # Save only programDesc of inference_program in binary format | ||
| # in another file: __model__.dat | ||
| with open(model_file_name + ".dat", "wb") as fp: | ||
| fp.write(inference_program.desc.serialize_to_string()) | ||
| model_file_name = dirname + "/__model__" | ||
| with open(model_file_name, "wb") as f: | ||
| f.write(inference_program.desc.serialize_to_string()) | ||
|
|
||
| save_params(executor, dirname, main_program) | ||
|
|
||
|
|
@@ -298,6 +297,24 @@ def _is_presistable_and_exist_(var): | |
| predicate=_is_presistable_and_exist_) | ||
|
|
||
|
|
||
| def get_feed_targets(program): | ||
|
||
| feed_targets = [] | ||
| global_block = program.global_block() | ||
| for op in global_block.ops: | ||
| if op.desc.type() == 'feed': | ||
| feed_targets.insert(0, op.desc.output('Out')[0]) | ||
| return feed_targets | ||
|
|
||
|
|
||
| def get_fetch_targets(program): | ||
|
||
| fetch_targets = [] | ||
| global_block = program.global_block() | ||
| for op in global_block.ops: | ||
| if op.desc.type() == 'fetch': | ||
| fetch_targets.append(op.desc.input('X')[0]) | ||
| return fetch_targets | ||
|
|
||
|
|
||
| def load_inference_model(dirname, executor): | ||
| """ | ||
| Load inference model from a directory | ||
|
|
@@ -314,12 +331,14 @@ def load_inference_model(dirname, executor): | |
| raise ValueError("There is no directory named '%s'", dirname) | ||
|
|
||
| model_file_name = dirname + "/__model__" | ||
| model = pickle.load(open(model_file_name, "r")) | ||
| program_desc_str = model["program_desc_str"] | ||
| feed_var_names = model["feed_var_names"] | ||
| fetch_var_names = model["fetch_var_names"] | ||
| with open(model_file_name, "rb") as f: | ||
| program_desc_str = f.read() | ||
|
|
||
| program = Program.parse_from_string(program_desc_str) | ||
| load_persistables_if_exist(executor, dirname, program) | ||
|
|
||
| feed_var_names = get_feed_targets(program) | ||
| fetch_var_names = get_fetch_targets(program) | ||
| fetch_vars = [program.global_block().var(name) for name in fetch_var_names] | ||
|
|
||
| return [program, feed_var_names, fetch_vars] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
line 22 - 24(I cannot comment online 22.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done