Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 64 additions & 61 deletions paddle/fluid/operators/ngraph/ngraph_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,14 @@ static std::map<ngraph::element::Type, framework::proto::VarType::Type>
{ngraph::element::boolean, framework::proto::VarType::BOOL}};

std::vector<std::string> NgraphEngine::feed_vars = {};
std::vector<std::string> NgraphEngine::fetch_vars = {};
framework::Variable* NgraphEngine::pre_var_ptr = nullptr;
const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr;
bool NgraphEngine::is_training = false;

std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
ngraph::runtime::Backend::create("CPU");
std::weak_ptr<ngraph::runtime::Backend> NgraphEngine::wp_backend_;

std::mutex NgraphEngine::ng_mutex_;

static std::vector<std::vector<int>> NgraphOpIntervals(
std::vector<std::unique_ptr<framework::OperatorBase>>* ops) {
NgraphEngine::feed_vars.clear();
NgraphEngine::fetch_vars.clear();
std::vector<std::vector<int>> intervals;

int size = ops->size();
Expand Down Expand Up @@ -118,11 +114,6 @@ static std::vector<std::vector<int>> NgraphOpIntervals(

int index = right;
while (index < size && ops->at(index)->Type() == framework::kFetchOpType) {
for (auto& var_name_item : ops->at(index)->Inputs()) {
for (auto& var_name : var_name_item.second) {
NgraphEngine::fetch_vars.emplace_back(var_name);
}
}
++index;
}

Expand Down Expand Up @@ -167,16 +158,22 @@ static void SubstituteNgraphOp(
framework::OpRegistry::CreateOp(ng_op_desc));
}

std::string SerializedBlock(const std::vector<framework::OpDesc*>& op_descs) {
std::string SerializedBlock(const framework::BlockDesc& bdesc) {
framework::proto::BlockDesc block_proto;
framework::BlockDesc block_desc(nullptr, &block_proto);
block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0);

for (auto* op_desc : op_descs) {
for (auto& op_desc : bdesc.AllOps()) {
auto* op = block_desc.AppendOp();
*op->Proto() = *op_desc->Proto();
}

auto* vars = block_desc.Proto()->mutable_vars();
for (auto& var_desc : bdesc.AllVars()) {
*vars->Add() = *var_desc->Proto();
}

return block_desc.Proto()->SerializeAsString();
}

Expand Down Expand Up @@ -213,12 +210,12 @@ std::string GenerateEngineKey(const std::vector<std::string>& engine_inputs,
void NgraphEngine::FuseNgraphOps(
const framework::BlockDesc& block_desc,
std::vector<std::unique_ptr<framework::OperatorBase>>* ops) {
NgraphEngine::p_bdesc = &block_desc;
auto intervals = NgraphOpIntervals(ops);
std::string serialized_block = SerializedBlock(block_desc);
std::string engine_key =
GenerateEngineKey(feed_vars, fetch_vars, ops->size());
std::to_string(std::hash<std::string>()(serialized_block));
for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) {
SubstituteNgraphOp(ops, engine_key, "", *it);
SubstituteNgraphOp(ops, engine_key, serialized_block, *it);
}
}

Expand All @@ -232,31 +229,32 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope,
var_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();

std::lock_guard<std::mutex> lock(ng_mutex_);

if (!wp_backend_.lock()) {
try {
VLOG(3) << "ngraph creating CPU backend.";
backend_ = ngraph::runtime::Backend::create("CPU");
} catch (...) {
PADDLE_THROW("Unsupported nGraph backend");
}
wp_backend_ = backend_;
} else {
backend_ = wp_backend_.lock();
}

GetNgFunction(ctx);
}

void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
auto interval = ctx.Attr<std::vector<int>>("interval");
std::string serialized_graph = ctx.Attr<std::string>("graph");

auto input_vars = ctx.Inputs("Xs");
if (!input_vars.empty()) {
feed_vars = input_vars;
var_in_ = input_vars;
}
auto output_vars = ctx.Outputs("Ys");
if (!output_vars.empty()) {
var_out_ = output_vars;
}

framework::proto::BlockDesc block_proto;
if (!serialized_graph.empty()) block_proto.ParseFromString(serialized_graph);
framework::BlockDesc block_desc(nullptr, &block_proto);
if (!serialized_graph.empty()) {
NgraphEngine::p_bdesc = &block_desc;
}

for (auto& var : p_bdesc->AllVars()) {
for (auto& var : block_desc.AllVars()) {
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
var->GetType() == framework::proto::VarType::LOD_TENSOR ||
var->GetType() == framework::proto::VarType::LOD_TENSOR_ARRAY)) {
Expand Down Expand Up @@ -284,10 +282,9 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
}

std::vector<paddle::framework::OpDesc*> ops_desc;
for (auto op_desc : p_bdesc->AllOps()) {
for (auto op_desc : block_desc.AllOps()) {
ops_desc.emplace_back(op_desc);
if (op_desc->Type().find("_grad") != std::string::npos) {
is_training = true;
this->is_test_ = false;
}
}
Expand All @@ -298,8 +295,7 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
framework::OpRegistry::CreateOp(*(ops_desc[idx])));
++idx;
}
while (idx < static_cast<int>(ops_desc.size()) &&
ops_desc.at(idx)->Type() != framework::kFetchOpType) {
while (idx < static_cast<int>(ops_desc.size())) {
auto op_desc = ops_desc.at(idx);
for (auto& var_name_item : op_desc->Inputs()) {
for (auto& var_name : var_name_item.second) {
Expand All @@ -309,9 +305,21 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
++idx;
}

auto input_vars = ctx.Inputs("Xs");
if (!input_vars.empty()) {
feed_vars = input_vars;
var_in_ = input_vars;
}

auto output_vars = ctx.Outputs("Ys");
if (!output_vars.empty()) {
var_out_ = output_vars;
}

if (var_in_.empty() && var_out_.empty()) {
BuildNgIO(ops_desc, interval);
}

for (size_t i = 0; i < var_in_.size(); ++i) {
auto var_name = var_in_[i];
if (persistables_.find(var_name) == persistables_.end()) {
Expand All @@ -324,6 +332,7 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
const std::vector<int>& interval) {
std::unordered_set<std::string> inputs;
std::unordered_set<std::string> outputs;

for (int i = interval[0]; i < interval[1]; ++i) {
auto op = ops_desc[i];
for (auto& var_name_item : op->Inputs()) {
Expand Down Expand Up @@ -359,15 +368,11 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
op->Type());
for (auto& var_name : var_name_item.second) {
if (this->is_test_) {
if (post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end()) {
if (post_op_inputs_.find(var_name) != post_op_inputs_.end()) {
this->var_out_.emplace_back(var_name);
}
} else {
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end() ||
post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
if (post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
persistables_.find(var_name) != persistables_.end()) {
this->var_out_.emplace_back(var_name);
}
Expand Down Expand Up @@ -434,10 +439,14 @@ std::shared_ptr<ngraph::Function> NgraphEngine::BuildNgFunction(
ngraph::ParameterVector func_inputs;

for (auto& vo : var_out_) {
PADDLE_ENFORCE_GT(var_node_map_->count(vo), 0,
"Cannot find vo %s in var_node_map_", vo);
func_outputs.emplace_back(var_node_map_->at(vo));
}

for (auto& vi : var_in_) {
PADDLE_ENFORCE_GT(var_node_map_->count(vi), 0,
"Cannot find vi %s in var_node_map_", vi);
std::shared_ptr<ngraph::op::Parameter> prm =
std::dynamic_pointer_cast<ngraph::op::Parameter>(
var_in_node_map_->at(vi));
Expand All @@ -454,7 +463,8 @@ void NgraphEngine::ClearNgCache() {
auto it = engine_cache.begin();
while (it != engine_cache.end()) {
auto ng_engine = it->second;
backend_->remove_compiled_function(ng_engine.ngraph_handle);
ng_engine.ngraph_backend->remove_compiled_function(ng_engine.ngraph_handle);
ng_engine.ngraph_backend.reset();
++it;
}
engine_cache.clear();
Expand Down Expand Up @@ -497,13 +507,6 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
if (engine_cache.find(func_cache_key_) != engine_cache.end()) {
if (engine_cache[func_cache_key_].persistables.size() == 0) {
ClearNgCache();
} else {
auto var_name = engine_cache[func_cache_key_].persistables.begin();
framework::Variable* var = scope_.FindVar(*var_name);
if (var != pre_var_ptr) {
ClearNgCache();
}
pre_var_ptr = var;
}
}

Expand All @@ -515,6 +518,7 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
for (auto& r : func->get_results()) {
r->set_needs_default_layout(true);
}
engine_cache[func_cache_key_].ngraph_backend = backend_;
engine_cache[func_cache_key_].ngraph_handle = backend_->compile(func);
engine_cache[func_cache_key_].persistables = this->persistables_;
engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_;
Expand All @@ -526,31 +530,32 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {

void NgraphEngine::Run(const framework::Scope& scope,
const platform::Place& place) const {
VLOG(3) << "NgraphEngine Run ...";
std::shared_ptr<ngraph::runtime::Executable> ng_handle;
std::shared_ptr<ngraph::runtime::Backend> ng_backend;
const std::set<std::string>* p_persistables;
const std::vector<size_t>* p_var_in_updates;
const std::vector<std::string>* p_var_in;
const std::vector<std::string>* p_var_out;
bool is_test;

auto& engine_cache = main_engine_cache::fetch();
auto& t_in_cache_ = main_t_in_cache::fetch();

PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(),
"Cannot find cached data to run ngraph function");
PADDLE_ENFORCE_GT(engine_cache.count(func_cache_key_), 0,
"Cannot find cached data to run ngraph function");
ng_handle = engine_cache[func_cache_key_].ngraph_handle;
ng_backend = engine_cache[func_cache_key_].ngraph_backend;
p_persistables = &(engine_cache[func_cache_key_].persistables);
p_var_in_updates = &(engine_cache[func_cache_key_].var_in_updates);
p_var_in = &(engine_cache[func_cache_key_].var_in);
p_var_out = &(engine_cache[func_cache_key_].var_out);
is_test = engine_cache[func_cache_key_].is_test;

std::vector<std::shared_ptr<ngraph::runtime::Tensor>>* p_t_in;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in = {};

auto m_parameters = ng_handle->get_parameters();
auto m_results = ng_handle->get_results();
if (is_test && t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) {
if (is_inference_ && t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) {
p_t_in = &(t_in_cache_[func_cache_key_]);
for (size_t i = 0; i < p_var_in_updates->size(); ++i) {
int index = p_var_in_updates->at(i);
Expand All @@ -562,14 +567,14 @@ void NgraphEngine::Run(const framework::Scope& scope,
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]);
ti = backend_->create_tensor(ng_type, sp, pd_arr);
ti = ng_backend->create_tensor(ng_type, sp, pd_arr);
(*p_t_in)[index] = ti;
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vi);
}
}
} else {
if (is_test) {
if (is_inference_) {
p_t_in = &(t_in_cache_[func_cache_key_]);
} else {
p_t_in = &t_in;
Expand All @@ -584,15 +589,13 @@ void NgraphEngine::Run(const framework::Scope& scope,
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]);
PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()),
"Ensure ngraph tensor layout align with paddle tensor");
ti = backend_->create_tensor(ng_type, sp, pd_arr);
ti = ng_backend->create_tensor(ng_type, sp, pd_arr);
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vi);
}
bool is_persistable =
(p_persistables->find(vi) != p_persistables->end()) ? true : false;
if (!is_training && is_test && is_persistable) {
if (is_inference_ && is_persistable) {
ti->set_stale(false);
}
(*p_t_in).emplace_back(ti);
Expand All @@ -615,7 +618,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
auto ng_type = m_results[i]->get_element_type();
void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]);
std::shared_ptr<ngraph::runtime::Tensor> to =
backend_->create_tensor(ng_type, sp, pd_arr);
ng_backend->create_tensor(ng_type, sp, pd_arr);
t_out.emplace_back(to);
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vo);
Expand Down
21 changes: 13 additions & 8 deletions paddle/fluid/operators/ngraph/ngraph_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include <list>
#include <memory>
#include <mutex> //NOLINT
#include <set>
#include <string>
#include <unordered_map>
Expand All @@ -34,7 +35,8 @@ namespace operators {

// cache engine repetitives
struct EngineCache {
std::shared_ptr<ngraph::runtime::Executable> ngraph_handle;
std::shared_ptr<ngraph::runtime::Executable> ngraph_handle = nullptr;
std::shared_ptr<ngraph::runtime::Backend> ngraph_backend = nullptr;
std::set<std::string> persistables;
std::vector<std::string> var_in;
std::vector<std::string> var_out;
Expand Down Expand Up @@ -127,9 +129,7 @@ class NgraphEngine {

void Run(const framework::Scope& scope, const platform::Place& place) const;

static bool is_training;
static const framework::BlockDesc* p_bdesc;
static std::vector<std::string> feed_vars, fetch_vars;
static std::vector<std::string> feed_vars;

static void FuseNgraphOps(
const framework::BlockDesc& prog,
Expand All @@ -149,19 +149,24 @@ class NgraphEngine {
using main_t_in_cache =
ThCache<std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>;

static framework::Variable* pre_var_ptr;

const framework::Scope& scope_;
const platform::Place& place_;
std::vector<std::shared_ptr<framework::OperatorBase>> fused_ops_;
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::set<std::string> persistables_;
std::unordered_set<std::string> post_op_inputs_;
// it is test for a single run, it can be a validation during training
bool is_test_{true};
// inference only. eg. CAPI inference
bool is_inference_{false};
std::string func_cache_key_;

// use a weak pointer to keep backend_ alive
// to avoid it to be destropyed too earlier
static std::weak_ptr<ngraph::runtime::Backend> wp_backend_;
// use mutex to keep it thread safe
static std::mutex ng_mutex_;
// ngraph backend eg. CPU
static std::shared_ptr<ngraph::runtime::Backend> backend_;
std::shared_ptr<ngraph::runtime::Backend> backend_;
// var_name of inputs
std::vector<std::string> var_in_;
// var_name of outputs from fetch in order
Expand Down