From 85ded8ae5827f4600d7c5b29bf946f5455051e79 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 8 Jan 2018 15:48:47 -0800 Subject: [PATCH 1/5] initial commit --- paddle/framework/framework.proto | 6 +++++- paddle/framework/program_desc.cc | 34 ++++++++++++++++++++++++++++++++ paddle/framework/program_desc.h | 14 +++++++++++++ paddle/pybind/protobuf.cc | 18 +++++++++++++++++ python/paddle/v2/fluid/io.py | 2 ++ 5 files changed, 73 insertions(+), 1 deletion(-) diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index ea69b87e2ac7dc..f17c831c2d4fed 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -143,4 +143,8 @@ message BlockDesc { // Please refer to // https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md // for more details. -message ProgramDesc { repeated BlockDesc blocks = 1; } +message ProgramDesc { + repeated BlockDesc blocks = 1; + repeated string feed_var_names = 2; + repeated string fetch_var_names = 3; +} diff --git a/paddle/framework/program_desc.cc b/paddle/framework/program_desc.cc index b5d9e5e385c1ba..84dad1f897ccee 100644 --- a/paddle/framework/program_desc.cc +++ b/paddle/framework/program_desc.cc @@ -42,6 +42,8 @@ ProgramDesc::ProgramDesc() { ProgramDesc::ProgramDesc(const ProgramDesc &o) { desc_ = o.desc_; + // feed_var_names_ = o.feed_var_names_; + // fetch_var_names_ = o.fetch_var_names_; for (int i = 0; i < desc_.blocks_size(); ++i) { auto *block = desc_.mutable_blocks(i); @@ -51,6 +53,14 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { desc_ = desc; + /* + for (int i = 0; i < desc_.feed_var_names_size(); i++) { + feed_var_names_.push_back(desc_.feed_var_names(i)); + } + for (int i = 0; i < desc_.fetch_var_names_size(); i++) { + fetch_var_names_.push_back(desc_.fetch_var_names(i)); + } + */ for (auto &block_desc : *desc_.mutable_blocks()) { blocks_.emplace_back(new BlockDesc(this, &block_desc)); } @@ -59,10 +69,34 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { ProgramDesc::ProgramDesc(const std::string &binary_str) { PADDLE_ENFORCE(desc_.ParseFromString(binary_str), "Fail to parse program_desc from binary string."); + /* + for (int i = 0; i < desc_.feed_var_names_size(); i++) { + feed_var_names_.push_back(desc_.feed_var_names(i)); + } + for (int i = 0; i < desc_.fetch_var_names_size(); i++) { + fetch_var_names_.push_back(desc_.fetch_var_names(i)); + } + */ for (auto &block_desc : *desc_.mutable_blocks()) { blocks_.emplace_back(new BlockDesc(this, &block_desc)); } } +/* +void ProgramDesc::ClearFeedVarNames() { + desc_.clear_feed_var_names(); +} + +void ProgramDesc::ClearFetchVarNames() { + desc_.clear_fetch_var_names(); +} +void AppendFeedVarName(const std::string &var_name) { + desc_.add_feed_var_names(var_name); +} + +void AppendFetchVarName(const std::string &var_name) { + desc_.add_fetch_var_names(var_name); +} +*/ } // namespace framework } // namespace paddle diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h index 15a962bb696d61..00f64f8aa08d3a 100644 --- a/paddle/framework/program_desc.h +++ b/paddle/framework/program_desc.h @@ -45,10 +45,24 @@ class ProgramDesc { proto::ProgramDesc *Proto(); + /* + void ClearFeedVarNames(); + + void ClearFetchVarNames(); + + void AppendFeedVarName(const std::string &var_name); + + void AppendFetchVarName(const std::string &var_name); + */ + private: proto::ProgramDesc desc_; std::vector> blocks_; + + // std::vector feed_var_names_; + + // std::vector fetch_var_names_; }; } // namespace framework } // namespace paddle diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 4f959481537d29..6d1c0353a84586 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -148,6 +148,24 @@ void BindProgramDesc(py::module &m) { PADDLE_ENFORCE(desc->ParseFromString(data), "Fail to parse ProgramDesc from string. This could " "be a bug of Paddle."); + }) + .def("assign_feed_var_names", + [](ProgramDesc &program_desc, + const std::vector &var_names) { + proto::ProgramDesc *desc = program_desc.Proto(); + desc->clear_feed_var_names(); + for (auto var_name : var_names) { + desc->add_feed_var_names(var_name); + } + }) + .def("assign_fetch_var_names", + [](ProgramDesc &program_desc, + const std::vector &var_names) { + proto::ProgramDesc *desc = program_desc.Proto(); + desc->clear_fetch_var_names(); + for (auto var_name : var_names) { + desc->add_fetch_var_names(var_name); + } }); } diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index c63567601accd8..3807b1cbfc5be6 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -214,6 +214,8 @@ def save_inference_model(dirname, # Save only programDesc of inference_program in binary format # in another file: __model__.dat + inference_program.desc.assign_feed_var_names(feeded_var_names) + inference_program.desc.assign_fetch_var_names(fetch_var_names) with open(model_file_name + ".dat", "wb") as fp: fp.write(inference_program.desc.serialize_to_string()) From edf8eaeaedba30dd2536361bff1ca588aa652c4e Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 8 Jan 2018 16:50:44 -0800 Subject: [PATCH 2/5] modify load inference model function --- paddle/framework/program_desc.cc | 15 +++++++++++++++ paddle/framework/program_desc.h | 4 ++++ paddle/inference/example.cc | 17 ++++++++++++----- paddle/inference/inference.cc | 20 ++++++++++++-------- 4 files changed, 43 insertions(+), 13 deletions(-) diff --git a/paddle/framework/program_desc.cc b/paddle/framework/program_desc.cc index 84dad1f897ccee..f46164035b8c47 100644 --- a/paddle/framework/program_desc.cc +++ b/paddle/framework/program_desc.cc @@ -81,6 +81,21 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { blocks_.emplace_back(new BlockDesc(this, &block_desc)); } } + +void ProgramDesc::GetFeedVarNames(std::vector &var_names) { + var_names.clear(); + for (int i = 0; i < desc_.feed_var_names_size(); i++) { + feed_var_names_.push_back(desc_.feed_var_names(i)); + } +} + +void ProgramDesc::GetFetchVarNames(std::vector &var_names) { + var_names.clear(); + for (int i = 0; i < desc_.fetch_var_names_size(); i++) { + fetch_var_names_.push_back(desc_.fetch_var_names(i)); + } +} + /* void ProgramDesc::ClearFeedVarNames() { desc_.clear_feed_var_names(); diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h index 00f64f8aa08d3a..9d729ecf82eba9 100644 --- a/paddle/framework/program_desc.h +++ b/paddle/framework/program_desc.h @@ -45,6 +45,10 @@ class ProgramDesc { proto::ProgramDesc *Proto(); + void GetFeedVarNames(std::vector &var_names); + + void GetFetchVarNames(std::vector &var_names); + /* void ClearFeedVarNames(); diff --git a/paddle/inference/example.cc b/paddle/inference/example.cc index 9711b20e6fb409..c794797ddc7220 100644 --- a/paddle/inference/example.cc +++ b/paddle/inference/example.cc @@ -23,6 +23,7 @@ DEFINE_string(fetch_var_names, "", "Names of fetching variables"); int main(int argc, char** argv) { google::ParseCommandLineFlags(&argc, &argv, true); + /* if (FLAGS_dirname.empty() || FLAGS_feed_var_names.empty() || FLAGS_fetch_var_names.empty()) { // Example: @@ -34,17 +35,23 @@ int main(int argc, char** argv) { << std::endl; exit(1); } + */ + if (FLAGS_dirname.empty()) { + std::cout << "Usage: ./example --dirname=path/to/your/model" << std::endl; + } std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl; - std::cout << "FLAGS_feed_var_names: " << FLAGS_feed_var_names << std::endl; - std::cout << "FLAGS_fetch_var_names: " << FLAGS_fetch_var_names << std::endl; + // std::cout << "FLAGS_feed_var_names: " << FLAGS_feed_var_names << + // std::endl; + // std::cout << "FLAGS_fetch_var_names: " << FLAGS_fetch_var_names << + // std::endl; std::string dirname = FLAGS_dirname; - std::vector feed_var_names = {FLAGS_feed_var_names}; - std::vector fetch_var_names = {FLAGS_fetch_var_names}; + // std::vector feed_var_names = {FLAGS_feed_var_names}; + // std::vector fetch_var_names = {FLAGS_fetch_var_names}; paddle::InferenceEngine* engine = new paddle::InferenceEngine(); - engine->LoadInferenceModel(dirname, feed_var_names, fetch_var_names); + engine->LoadInferenceModel(dirname); paddle::framework::LoDTensor input; srand(time(0)); diff --git a/paddle/inference/inference.cc b/paddle/inference/inference.cc index 49e39358e81bbe..f0eab5be629896 100644 --- a/paddle/inference/inference.cc +++ b/paddle/inference/inference.cc @@ -25,10 +25,7 @@ limitations under the License. */ namespace paddle { -void InferenceEngine::LoadInferenceModel( - const std::string& dirname, - const std::vector& feed_var_names, - const std::vector& fetch_var_names) { +void InferenceEngine::LoadInferenceModel(const std::string& dirname) { #ifdef PADDLE_USE_PTOOLS std::string model_filename = dirname + "/__model__"; LOG(INFO) << "Using PicklingTools, loading model from " << model_filename; @@ -49,14 +46,21 @@ void InferenceEngine::LoadInferenceModel( inputfs.read(&program_desc_str[0], program_desc_str.size()); inputfs.close(); #endif + LOG(INFO) << "feed size: " << feed_var_names_.size(); + LOG(INFO) << "fetch size: " << fetch_var_names_.size(); program_ = new framework::ProgramDesc(program_desc_str); + program_->GetFeedVarNames(feed_var_names_); + program_->GetFetchVarNames(fetch_var_names_); + + LOG(INFO) << "feed size: " << feed_var_names_.size(); + LOG(INFO) << "fetch size: " << fetch_var_names_.size(); + GenerateLoadProgram(dirname); - if (feed_var_names.empty() || fetch_var_names.empty()) { - LOG(FATAL) << "Please specify the feed_var_names and fetch_var_names."; + if (feed_var_names_.empty() || fetch_var_names_.empty()) { + LOG(FATAL) << "Please specify the feed_var_names and fetch_var_names when " + "saving inference models."; } - feed_var_names_ = feed_var_names; - fetch_var_names_ = fetch_var_names; PrependFeedOp(); AppendFetchOp(); } From c456c9c149a4d385f362971e738011b35c98ba12 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 8 Jan 2018 18:00:02 -0800 Subject: [PATCH 3/5] clean up code --- paddle/framework/program_desc.cc | 41 +++----------------------------- paddle/framework/program_desc.h | 14 ----------- paddle/inference/CMakeLists.txt | 21 ---------------- paddle/inference/example.cc | 23 +++--------------- paddle/inference/inference.cc | 16 +------------ paddle/inference/inference.h | 4 +--- 6 files changed, 8 insertions(+), 111 deletions(-) diff --git a/paddle/framework/program_desc.cc b/paddle/framework/program_desc.cc index f46164035b8c47..2c71b5382b1aff 100644 --- a/paddle/framework/program_desc.cc +++ b/paddle/framework/program_desc.cc @@ -42,9 +42,6 @@ ProgramDesc::ProgramDesc() { ProgramDesc::ProgramDesc(const ProgramDesc &o) { desc_ = o.desc_; - // feed_var_names_ = o.feed_var_names_; - // fetch_var_names_ = o.fetch_var_names_; - for (int i = 0; i < desc_.blocks_size(); ++i) { auto *block = desc_.mutable_blocks(i); blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this)); @@ -53,14 +50,6 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { desc_ = desc; - /* - for (int i = 0; i < desc_.feed_var_names_size(); i++) { - feed_var_names_.push_back(desc_.feed_var_names(i)); - } - for (int i = 0; i < desc_.fetch_var_names_size(); i++) { - fetch_var_names_.push_back(desc_.fetch_var_names(i)); - } - */ for (auto &block_desc : *desc_.mutable_blocks()) { blocks_.emplace_back(new BlockDesc(this, &block_desc)); } @@ -69,14 +58,7 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { ProgramDesc::ProgramDesc(const std::string &binary_str) { PADDLE_ENFORCE(desc_.ParseFromString(binary_str), "Fail to parse program_desc from binary string."); - /* - for (int i = 0; i < desc_.feed_var_names_size(); i++) { - feed_var_names_.push_back(desc_.feed_var_names(i)); - } - for (int i = 0; i < desc_.fetch_var_names_size(); i++) { - fetch_var_names_.push_back(desc_.fetch_var_names(i)); - } - */ + for (auto &block_desc : *desc_.mutable_blocks()) { blocks_.emplace_back(new BlockDesc(this, &block_desc)); } @@ -85,33 +67,16 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { void ProgramDesc::GetFeedVarNames(std::vector &var_names) { var_names.clear(); for (int i = 0; i < desc_.feed_var_names_size(); i++) { - feed_var_names_.push_back(desc_.feed_var_names(i)); + var_names.push_back(desc_.feed_var_names(i)); } } void ProgramDesc::GetFetchVarNames(std::vector &var_names) { var_names.clear(); for (int i = 0; i < desc_.fetch_var_names_size(); i++) { - fetch_var_names_.push_back(desc_.fetch_var_names(i)); + var_names.push_back(desc_.fetch_var_names(i)); } } -/* -void ProgramDesc::ClearFeedVarNames() { - desc_.clear_feed_var_names(); -} - -void ProgramDesc::ClearFetchVarNames() { - desc_.clear_fetch_var_names(); -} - -void AppendFeedVarName(const std::string &var_name) { - desc_.add_feed_var_names(var_name); -} - -void AppendFetchVarName(const std::string &var_name) { - desc_.add_fetch_var_names(var_name); -} -*/ } // namespace framework } // namespace paddle diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h index 9d729ecf82eba9..5146d2da05c5a0 100644 --- a/paddle/framework/program_desc.h +++ b/paddle/framework/program_desc.h @@ -49,24 +49,10 @@ class ProgramDesc { void GetFetchVarNames(std::vector &var_names); - /* - void ClearFeedVarNames(); - - void ClearFetchVarNames(); - - void AppendFeedVarName(const std::string &var_name); - - void AppendFetchVarName(const std::string &var_name); - */ - private: proto::ProgramDesc desc_; std::vector> blocks_; - - // std::vector feed_var_names_; - - // std::vector fetch_var_names_; }; } // namespace framework } // namespace paddle diff --git a/paddle/inference/CMakeLists.txt b/paddle/inference/CMakeLists.txt index 8437b2b21942ea..02ca8a45a851d2 100644 --- a/paddle/inference/CMakeLists.txt +++ b/paddle/inference/CMakeLists.txt @@ -8,27 +8,6 @@ cc_library(paddle_fluid_api # Merge all modules into a simgle static library cc_library(paddle_fluid DEPS paddle_fluid_api ${FLUID_CORE_MODULES}) -# ptools -# just for testing, we may need to change the storing format for inference_model -# and move the dependent of pickle. -# download from http://www.picklingtools.com/ -# build in the C++ sub-directory, using command -# make -f Makefile.Linux libptools.so -set(PTOOLS_LIB) -set(PTOOLS_ROOT $ENV{PTOOLS_ROOT} CACHE PATH "Folder contains PicklingTools") -find_path(PTOOLS_INC_DIR chooseser.h PATHS ${PTOOLS_ROOT}/C++) -find_library(PTOOLS_SHARED_LIB NAMES ptools PATHS ${PTOOLS_ROOT}/C++) -if(PTOOLS_INC_DIR AND PTOOLS_SHARED_LIB) - add_definitions(-DPADDLE_USE_PTOOLS) - set(PTOOLS_LIB ptools) - message(STATUS "Found PicklingTools: ${PTOOLS_SHARED_LIB}") - add_library(${PTOOLS_LIB} SHARED IMPORTED GLOBAL) - set_property(TARGET ${PTOOLS_LIB} PROPERTY IMPORTED_LOCATION ${PTOOLS_SHARED_LIB}) - include_directories(${PTOOLS_ROOT}/C++) - include_directories(${PTOOLS_ROOT}/C++/opencontainers_1_8_5/include) - add_definitions(-DOC_NEW_STYLE_INCLUDES) # used in ptools -endif() - add_executable(example example.cc) if(APPLE) set(OPTIONAL_LINK_FLAGS) diff --git a/paddle/inference/example.cc b/paddle/inference/example.cc index c794797ddc7220..ee875817d5fe91 100644 --- a/paddle/inference/example.cc +++ b/paddle/inference/example.cc @@ -18,37 +18,20 @@ limitations under the License. */ #include "paddle/inference/inference.h" DEFINE_string(dirname, "", "Directory of the inference model."); -DEFINE_string(feed_var_names, "", "Names of feeding variables"); -DEFINE_string(fetch_var_names, "", "Names of fetching variables"); int main(int argc, char** argv) { google::ParseCommandLineFlags(&argc, &argv, true); - /* - if (FLAGS_dirname.empty() || FLAGS_feed_var_names.empty() || - FLAGS_fetch_var_names.empty()) { + + if (FLAGS_dirname.empty()) { // Example: // ./example --dirname=recognize_digits_mlp.inference.model - // --feed_var_names="x" - // --fetch_var_names="fc_2.tmp_2" - std::cout << "Usage: ./example --dirname=path/to/your/model " - "--feed_var_names=x --fetch_var_names=y" - << std::endl; - exit(1); - } - */ - if (FLAGS_dirname.empty()) { std::cout << "Usage: ./example --dirname=path/to/your/model" << std::endl; + exit(1); } std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl; - // std::cout << "FLAGS_feed_var_names: " << FLAGS_feed_var_names << - // std::endl; - // std::cout << "FLAGS_fetch_var_names: " << FLAGS_fetch_var_names << - // std::endl; std::string dirname = FLAGS_dirname; - // std::vector feed_var_names = {FLAGS_feed_var_names}; - // std::vector fetch_var_names = {FLAGS_fetch_var_names}; paddle::InferenceEngine* engine = new paddle::InferenceEngine(); engine->LoadInferenceModel(dirname); diff --git a/paddle/inference/inference.cc b/paddle/inference/inference.cc index f0eab5be629896..212a9178e1a210 100644 --- a/paddle/inference/inference.cc +++ b/paddle/inference/inference.cc @@ -26,15 +26,6 @@ limitations under the License. */ namespace paddle { void InferenceEngine::LoadInferenceModel(const std::string& dirname) { -#ifdef PADDLE_USE_PTOOLS - std::string model_filename = dirname + "/__model__"; - LOG(INFO) << "Using PicklingTools, loading model from " << model_filename; - Val v; - LoadValFromFile(model_filename.c_str(), v, SERIALIZE_P0); - std::string program_desc_str = v["program_desc_str"]; - LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); -// PicklingTools cannot parse the vector of strings correctly. -#else std::string model_filename = dirname + "/__model__.dat"; LOG(INFO) << "loading model from " << model_filename; std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); @@ -45,16 +36,11 @@ void InferenceEngine::LoadInferenceModel(const std::string& dirname) { LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); inputfs.read(&program_desc_str[0], program_desc_str.size()); inputfs.close(); -#endif - LOG(INFO) << "feed size: " << feed_var_names_.size(); - LOG(INFO) << "fetch size: " << fetch_var_names_.size(); + program_ = new framework::ProgramDesc(program_desc_str); program_->GetFeedVarNames(feed_var_names_); program_->GetFetchVarNames(fetch_var_names_); - LOG(INFO) << "feed size: " << feed_var_names_.size(); - LOG(INFO) << "fetch size: " << fetch_var_names_.size(); - GenerateLoadProgram(dirname); if (feed_var_names_.empty() || fetch_var_names_.empty()) { diff --git a/paddle/inference/inference.h b/paddle/inference/inference.h index a3f3ef4b440036..26f259824b945e 100644 --- a/paddle/inference/inference.h +++ b/paddle/inference/inference.h @@ -28,9 +28,7 @@ class InferenceEngine { delete load_program_; } - void LoadInferenceModel(const std::string& dirname, - const std::vector& feed_var_names, - const std::vector& fetch_var_names); + void LoadInferenceModel(const std::string& dirname); void Execute(const std::vector& feeds, std::vector& fetchs); From 9da3d721fd67ab4f84cc39d59563830b36ed8cbf Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 8 Jan 2018 19:59:31 -0800 Subject: [PATCH 4/5] fix io.py --- paddle/inference/inference.cc | 2 +- paddle/pybind/protobuf.cc | 19 ++++++++++++++++- python/paddle/v2/fluid/io.py | 39 +++++++++++++++++++++-------------- 3 files changed, 42 insertions(+), 18 deletions(-) diff --git a/paddle/inference/inference.cc b/paddle/inference/inference.cc index 212a9178e1a210..bd1ce4dc8ab2ca 100644 --- a/paddle/inference/inference.cc +++ b/paddle/inference/inference.cc @@ -26,7 +26,7 @@ limitations under the License. */ namespace paddle { void InferenceEngine::LoadInferenceModel(const std::string& dirname) { - std::string model_filename = dirname + "/__model__.dat"; + std::string model_filename = dirname + "/__model__"; LOG(INFO) << "loading model from " << model_filename; std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); std::string program_desc_str; diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 6d1c0353a84586..5925b7751ed080 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -166,7 +166,24 @@ void BindProgramDesc(py::module &m) { for (auto var_name : var_names) { desc->add_fetch_var_names(var_name); } - }); + }) + .def("get_feed_var_names", + [](ProgramDesc &program_desc) { + proto::ProgramDesc *desc = program_desc.Proto(); + std::vector retv; + for (int i = 0; i < desc->feed_var_names_size(); ++i) { + retv.push_back(desc->feed_var_names(i)); + } + return retv; + }) + .def("get_fetch_var_names", [](ProgramDesc &program_desc) { + proto::ProgramDesc *desc = program_desc.Proto(); + std::vector retv; + for (int i = 0; i < desc->fetch_var_names_size(); ++i) { + retv.push_back(desc->fetch_var_names(i)); + } + return retv; + }); } void BindBlockDesc(py::module &m) { diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index 3807b1cbfc5be6..57f9a701a8815c 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -204,20 +204,20 @@ def save_inference_model(dirname, inference_program = pruned_program.inference_optimize() fetch_var_names = [v.name for v in target_vars] - model_file_name = dirname + "/__model__" - with open(model_file_name, "w") as f: - pickle.dump({ - "program_desc_str": inference_program.desc.serialize_to_string(), - "feed_var_names": feeded_var_names, - "fetch_var_names": fetch_var_names - }, f, -1) - - # Save only programDesc of inference_program in binary format - # in another file: __model__.dat + #model_file_name = dirname + "/__model__" + #with open(model_file_name, "w") as f: + # pickle.dump({ + # "program_desc_str": inference_program.desc.serialize_to_string(), + # "feed_var_names": feeded_var_names, + # "fetch_var_names": fetch_var_names + # }, f, -1) + + # Save the ProgramDesc of inference_program in binary format inference_program.desc.assign_feed_var_names(feeded_var_names) inference_program.desc.assign_fetch_var_names(fetch_var_names) - with open(model_file_name + ".dat", "wb") as fp: - fp.write(inference_program.desc.serialize_to_string()) + model_file_name = dirname + "/__model__" + with open(model_file_name, "wb") as f: + f.write(inference_program.desc.serialize_to_string()) save_params(executor, dirname, main_program) @@ -256,11 +256,18 @@ def load_inference_model(dirname, executor): raise ValueError("There is no directory named '%s'", dirname) model_file_name = dirname + "/__model__" - model = pickle.load(open(model_file_name, "r")) - program_desc_str = model["program_desc_str"] - feed_var_names = model["feed_var_names"] - fetch_var_names = model["fetch_var_names"] + with open(model_file_name, "rb") as f: + program_desc_str = f.read() + program = Program.parse_from_string(program_desc_str) + feed_var_names = program.desc.get_feed_var_names() + fetch_var_names = program.desc.get_fetch_var_names() + + #model = pickle.load(open(model_file_name, "r")) + #program_desc_str = model["program_desc_str"] + #feed_var_names = model["feed_var_names"] + #fetch_var_names = model["fetch_var_names"] + #program = Program.parse_from_string(program_desc_str) load_persistables_if_exist(executor, dirname, program) fetch_vars = [program.global_block().var(name) for name in fetch_var_names] From 875e22bdb3a5aee784d2efb72ff76eddfb8536d6 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 8 Jan 2018 20:29:45 -0800 Subject: [PATCH 5/5] clean code --- python/paddle/v2/fluid/io.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index 57f9a701a8815c..6757abd5e21b65 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -204,17 +204,9 @@ def save_inference_model(dirname, inference_program = pruned_program.inference_optimize() fetch_var_names = [v.name for v in target_vars] - #model_file_name = dirname + "/__model__" - #with open(model_file_name, "w") as f: - # pickle.dump({ - # "program_desc_str": inference_program.desc.serialize_to_string(), - # "feed_var_names": feeded_var_names, - # "fetch_var_names": fetch_var_names - # }, f, -1) - - # Save the ProgramDesc of inference_program in binary format inference_program.desc.assign_feed_var_names(feeded_var_names) inference_program.desc.assign_fetch_var_names(fetch_var_names) + model_file_name = dirname + "/__model__" with open(model_file_name, "wb") as f: f.write(inference_program.desc.serialize_to_string()) @@ -263,11 +255,6 @@ def load_inference_model(dirname, executor): feed_var_names = program.desc.get_feed_var_names() fetch_var_names = program.desc.get_fetch_var_names() - #model = pickle.load(open(model_file_name, "r")) - #program_desc_str = model["program_desc_str"] - #feed_var_names = model["feed_var_names"] - #fetch_var_names = model["fetch_var_names"] - #program = Program.parse_from_string(program_desc_str) load_persistables_if_exist(executor, dirname, program) fetch_vars = [program.global_block().var(name) for name in fetch_var_names]