-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add a simple C++ inference example for fluid #7097
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
9b3f2c3
cd7d0f8
42a0603
5b3cf4e
c7bd777
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| set(FLUID_CORE_MODULES | ||
| backward proto_desc paddle_memory executor prune init ${GLOB_OP_LIB}) | ||
|
|
||
| cc_library(paddle_fluid_api | ||
| SRCS inference.cc | ||
| DEPS ${FLUID_CORE_MODULES}) | ||
|
|
||
| # Merge all modules into a simgle static library | ||
| cc_library(paddle_fluid DEPS paddle_fluid_api ${FLUID_CORE_MODULES}) | ||
|
|
||
| # ptools | ||
| # just for testing, we may need to change the storing format for inference_model | ||
| # and move the dependent of pickle. | ||
| # download from http://www.picklingtools.com/ | ||
| # build in the C++ sub-directory, using command | ||
| # make -f Makefile.Linux libptools.so | ||
| set(PTOOLS_LIB) | ||
| set(PTOOLS_ROOT $ENV{PTOOLS_ROOT} CACHE PATH "Folder contains PicklingTools") | ||
| find_path(PTOOLS_INC_DIR chooseser.h PATHS ${PTOOLS_ROOT}/C++) | ||
| find_library(PTOOLS_SHARED_LIB NAMES ptools PATHS ${PTOOLS_ROOT}/C++) | ||
| if(PTOOLS_INC_DIR AND PTOOLS_SHARED_LIB) | ||
| add_definitions(-DPADDLE_USE_PTOOLS) | ||
| set(PTOOLS_LIB ptools) | ||
| message(STATUS "Found PicklingTools: ${PTOOLS_SHARED_LIB}") | ||
| add_library(${PTOOLS_LIB} SHARED IMPORTED GLOBAL) | ||
| set_property(TARGET ${PTOOLS_LIB} PROPERTY IMPORTED_LOCATION ${PTOOLS_SHARED_LIB}) | ||
| include_directories(${PTOOLS_ROOT}/C++) | ||
| include_directories(${PTOOLS_ROOT}/C++/opencontainers_1_8_5/include) | ||
| add_definitions(-DOC_NEW_STYLE_INCLUDES) # used in ptools | ||
| endif() | ||
|
|
||
| add_executable(example example.cc) | ||
| target_link_libraries(example | ||
| -Wl,--start-group -Wl,--whole-archive paddle_fluid | ||
| -Wl,--no-whole-archive -Wl,--end-group | ||
| ${PTOOLS_LIB}) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include <time.h> | ||
| #include <iostream> | ||
| #include "paddle/inference/inference.h" | ||
|
|
||
| int main(int argc, char* argv[]) { | ||
| std::string dirname = | ||
| "/home/work/liuyiqun/PaddlePaddle/Paddle/paddle/inference/" | ||
| "recognize_digits_mlp.inference.model"; | ||
| std::vector<std::string> feed_var_names = {"x"}; | ||
| std::vector<std::string> fetch_var_names = {"fc_2.tmp_2"}; | ||
| paddle::InferenceEngine* desc = new paddle::InferenceEngine(); | ||
| desc->LoadInferenceModel(dirname, feed_var_names, fetch_var_names); | ||
|
|
||
| paddle::framework::LoDTensor input; | ||
| srand(time(0)); | ||
| float* input_ptr = | ||
| input.mutable_data<float>({1, 784}, paddle::platform::CPUPlace()); | ||
| for (int i = 0; i < 784; ++i) { | ||
| input_ptr[i] = rand() / (static_cast<float>(RAND_MAX)); | ||
| } | ||
|
|
||
| std::vector<paddle::framework::LoDTensor> feeds; | ||
| feeds.push_back(input); | ||
| std::vector<paddle::framework::LoDTensor> fetchs; | ||
| desc->Execute(feeds, fetchs); | ||
|
|
||
| for (size_t i = 0; i < fetchs.size(); ++i) { | ||
| auto dims_i = fetchs[i].dims(); | ||
| std::cout << "dims_i:"; | ||
| for (int j = 0; j < dims_i.size(); ++j) { | ||
| std::cout << " " << dims_i[j]; | ||
| } | ||
| std::cout << std::endl; | ||
| std::cout << "result:"; | ||
| float* output_ptr = fetchs[i].data<float>(); | ||
| for (int j = 0; j < paddle::framework::product(dims_i); ++j) { | ||
| std::cout << " " << output_ptr[j]; | ||
| } | ||
| std::cout << std::endl; | ||
| } | ||
| return 0; | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,203 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "inference.h" | ||
| #include <fstream> | ||
| #include "paddle/framework/executor.h" | ||
| #include "paddle/framework/feed_fetch_method.h" | ||
| #include "paddle/framework/init.h" | ||
| #include "paddle/framework/scope.h" | ||
|
|
||
| #ifdef PADDLE_USE_PTOOLS | ||
| #include "chooseser.h" | ||
| #endif | ||
|
|
||
| namespace paddle { | ||
|
|
||
| void InferenceEngine::LoadInferenceModel( | ||
| const std::string& dirname, | ||
| const std::vector<std::string>& feed_var_names, | ||
| const std::vector<std::string>& fetch_var_names) { | ||
| #ifdef PADDLE_USE_PTOOLS | ||
| std::string model_filename = dirname + "/__model__"; | ||
| LOG(INFO) << "Using PicklingTools, loading model from " << model_filename; | ||
|
||
| Val v; | ||
| LoadValFromFile(model_filename.c_str(), v, SERIALIZE_P0); | ||
| std::string program_desc_str = v["program_desc_str"]; | ||
| LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); | ||
| // PicklingTools cannot parse the vector of strings correctly. | ||
| #else | ||
| // program_desc_str | ||
| // the inference.model is stored by following python codes: | ||
| // inference_program = fluid.io.get_inference_program(predict) | ||
| // model_filename = "recognize_digits_mlp.inference.model/inference.model" | ||
| // with open(model_filename, "w") as f: | ||
| // program_str = inference_program.desc.serialize_to_string() | ||
| // f.write(struct.pack('q', len(program_str))) | ||
| // f.write(program_str) | ||
| std::string model_filename = dirname + "/inference.model"; | ||
| LOG(INFO) << "loading model from " << model_filename; | ||
| std::ifstream fs(model_filename, std::ios_base::binary); | ||
| int64_t size = 0; | ||
| fs.read(reinterpret_cast<char*>(&size), sizeof(int64_t)); | ||
| LOG(INFO) << "program_desc_str's size: " << size; | ||
| std::string program_desc_str; | ||
| program_desc_str.resize(size); | ||
| fs.read(&program_desc_str[0], size); | ||
| #endif | ||
| program_ = new framework::ProgramDesc(program_desc_str); | ||
| GenerateLoadProgram(dirname); | ||
|
|
||
| if (feed_var_names.empty() || fetch_var_names.empty()) { | ||
| LOG(FATAL) << "Please specify the feed_var_names and fetch_var_names."; | ||
| } | ||
| feed_var_names_ = feed_var_names; | ||
| fetch_var_names_ = fetch_var_names; | ||
| PrependFeedOp(); | ||
| AppendFetchOp(); | ||
| } | ||
|
|
||
| bool InferenceEngine::IsParameter(const framework::VarDesc* var) { | ||
| if (var->Persistable()) { | ||
|
||
| // There are many unreachable variables in the program | ||
| for (size_t i = 0; i < program_->Size(); ++i) { | ||
| const framework::BlockDesc& block = program_->Block(i); | ||
| for (auto* op : block.AllOps()) { | ||
| for (auto input_argument_name : op->InputArgumentNames()) { | ||
| if (input_argument_name == var->Name()) { | ||
| return true; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| return false; | ||
| } | ||
|
|
||
| void InferenceEngine::GenerateLoadProgram(const std::string& dirname) { | ||
| framework::BlockDesc* global_block = program_->MutableBlock(0); | ||
|
|
||
| load_program_ = new framework::ProgramDesc(); | ||
| framework::BlockDesc* load_block = load_program_->MutableBlock(0); | ||
| for (auto* var : global_block->AllVars()) { | ||
| if (IsParameter(var)) { | ||
| LOG(INFO) << "parameter's name: " << var->Name(); | ||
|
|
||
| // framework::VarDesc new_var = *var; | ||
|
||
| framework::VarDesc* new_var = load_block->Var(var->Name()); | ||
| new_var->SetShape(var->Shape()); | ||
| new_var->SetDataType(var->GetDataType()); | ||
| new_var->SetType(var->GetType()); | ||
| new_var->SetLoDLevel(var->GetLoDLevel()); | ||
| new_var->SetPersistable(true); | ||
|
|
||
| // append_op | ||
| framework::OpDesc* op = load_block->AppendOp(); | ||
| op->SetType("load"); | ||
| op->SetOutput("Out", {new_var->Name()}); | ||
| op->SetAttr("file_path", {dirname + "/" + new_var->Name()}); | ||
| op->CheckAttrs(); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void InferenceEngine::PrependFeedOp() { | ||
| if (!program_) { | ||
| LOG(FATAL) << "Please initialize the program_ first."; | ||
| } | ||
|
|
||
| framework::BlockDesc* global_block = program_->MutableBlock(0); | ||
|
|
||
| // create_var | ||
| framework::VarDesc* feed_var = global_block->Var("feed"); | ||
| feed_var->SetType(framework::proto::VarDesc::FEED_MINIBATCH); | ||
| feed_var->SetPersistable(true); | ||
|
|
||
| // prepend feed_op | ||
| for (size_t i = 0; i < feed_var_names_.size(); ++i) { | ||
| std::string var_name = feed_var_names_[i]; | ||
| LOG(INFO) << "feed var's name: " << var_name; | ||
|
|
||
| // prepend_op | ||
| framework::OpDesc* op = global_block->PrependOp(); | ||
| op->SetType("feed"); | ||
| op->SetInput("X", {"feed"}); | ||
| op->SetOutput("Out", {var_name}); | ||
| op->SetAttr("col", {static_cast<int>(i)}); | ||
| op->CheckAttrs(); | ||
| } | ||
| } | ||
|
|
||
| void InferenceEngine::AppendFetchOp() { | ||
| if (!program_) { | ||
| LOG(FATAL) << "Please initialize the program_ first."; | ||
| } | ||
|
|
||
| framework::BlockDesc* global_block = program_->MutableBlock(0); | ||
|
|
||
| // create_var | ||
| framework::VarDesc* fetch_var = global_block->Var("fetch"); | ||
| fetch_var->SetType(framework::proto::VarDesc::FETCH_LIST); | ||
| fetch_var->SetPersistable(true); | ||
|
|
||
| // append fetch_op | ||
| for (size_t i = 0; i < fetch_var_names_.size(); ++i) { | ||
| std::string var_name = fetch_var_names_[i]; | ||
| LOG(INFO) << "fetch var's name: " << var_name; | ||
|
|
||
| // append_op | ||
| framework::OpDesc* op = global_block->AppendOp(); | ||
| op->SetType("fetch"); | ||
| op->SetInput("X", {var_name}); | ||
| op->SetOutput("Out", {"fetch"}); | ||
| op->SetAttr("col", {static_cast<int>(i)}); | ||
| op->CheckAttrs(); | ||
| } | ||
| } | ||
|
|
||
| void InferenceEngine::Execute(const std::vector<framework::LoDTensor>& feeds, | ||
| std::vector<framework::LoDTensor>& fetchs) { | ||
| if (!program_ || !load_program_) { | ||
| LOG(FATAL) << "Please initialize the program_ and load_program_ first."; | ||
| } | ||
|
|
||
| if (feeds.size() < feed_var_names_.size()) { | ||
| LOG(FATAL) << "Please feed " << feed_var_names_.size() << " input Tensors."; | ||
| } | ||
|
|
||
| auto* place = new platform::CPUPlace(); | ||
| framework::InitDevices({"CPU"}); | ||
| framework::Executor* executor = new framework::Executor(*place); | ||
| framework::Scope* scope = new framework::Scope(); | ||
|
|
||
| executor->Run(*load_program_, scope, 0, true, true); | ||
|
|
||
| // set_feed_variable | ||
| for (size_t i = 0; i < feed_var_names_.size(); ++i) { | ||
| framework::SetFeedVariable(scope, feeds[i], "feed", i); | ||
| } | ||
|
|
||
| executor->Run(*program_, scope, 0, true, true); | ||
|
|
||
| // get_fetch_variable | ||
| fetchs.resize(fetch_var_names_.size()); | ||
| for (size_t i = 0; i < fetch_var_names_.size(); ++i) { | ||
| fetchs[i] = framework::GetFetchVariable(*scope, "fetch", i); | ||
| } | ||
|
|
||
| delete place; | ||
| delete scope; | ||
| delete executor; | ||
| } | ||
| } // namespace paddle | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not use absolute/personal path
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you can upload this model, and download it by verifying MD5.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
I think the current version of model is not stable and is not suitable to upload for testing.