Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,18 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo
DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle)

if(WITH_NGRAPH)
set(NGRAPH_BS_DEPS ngraph)
else()
set(NGRAPH_BS_DEPS)
endif()

cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
lock_free_optimize_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass)
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
${NGRAPH_BS_DEPS})
22 changes: 22 additions & 0 deletions paddle/fluid/framework/details/build_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"

DECLARE_bool(use_mkldnn);
DECLARE_bool(use_ngraph);

namespace paddle {
namespace framework {
Expand All @@ -53,6 +54,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
"sequential_execution_pass");
AppendPassWithCheck(strategy_.sync_batch_norm_, "sync_batch_norm_pass");

AppendPassToUseNgraph("ngraph_subgraph_pass");

AppendOpFusePasses();
AppendPrintGraphPass("graph_viz_pass", "_fused_graph");

Expand Down Expand Up @@ -220,6 +223,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
#endif
}

void AppendPassToUseNgraph(const std::string &pass_name) {
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kAllReduce) {
LOG(WARNING) << "Currently ngraph_subgraph_pass works under AllReduce,"
"please set FLAGS_use_ngraph=false.";
} else {
AppendPass(pass_name);
}
}
#else
PADDLE_ENFORCE_NE(FLAGS_use_ngraph, true,
"Please compile with NGRAPH first to use NGRAPH");
#endif
}

private:
BuildStrategy strategy_;
};
Expand Down Expand Up @@ -360,3 +379,6 @@ USE_PASS(runtime_context_cache_pass);
#ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass);
#endif
#ifdef PADDLE_WITH_NGRAPH
USE_PASS(ngraph_subgraph_pass);
#endif
2 changes: 1 addition & 1 deletion paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ limitations under the License. */

#ifdef PADDLE_WITH_NGRAPH
#include "paddle/fluid/operators/ngraph/ngraph_engine.h"
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
#endif

DECLARE_bool(benchmark);
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");

namespace paddle {
namespace framework {
Expand Down
116 changes: 72 additions & 44 deletions paddle/fluid/framework/ir/ngraph_subgraph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
return engine_key;
}

void NgraphSubgraphPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE(graph);
void NgraphSubgraphPass::ApplyImpl(Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(graph);
FusePassBase::Init("ngraph_subgraph_pass", graph);

std::unordered_set<Node *> nodes2delete;
Expand All @@ -66,15 +66,13 @@ void NgraphSubgraphPass::ApplyImpl(ir::Graph *graph) const {
if (node->IsOp() && !ANAT::Agent(node).subgraph()->empty()) {
OpDesc *op_desc = node->Op();
op_desc->SetType("ngraph_engine");
for (auto it = ANAT::Agent(node).subgraph()->begin();
it != ANAT::Agent(node).subgraph()->end(); ++it) {
}

CreateNgraphEngineOp(node, graph);

std::unordered_set<const Node *> nodes2remove(
ANAT::Agent(node).subgraph()->begin(),
ANAT::Agent(node).subgraph()->end());

GraphSafeRemoveNodes(graph, nodes2remove);
}
}
Expand All @@ -85,70 +83,100 @@ void NgraphSubgraphPass::ApplyImpl(ir::Graph *graph) const {
nodes2remove.insert(node);
}
}

framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
std::vector<ir::Node *> nodes = ir::TopologySortOperations(*graph);
// std::vector<ir::Node *> nodes = ir::TopologySortOperations(*graph);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or you can remove in next PR

}

void NgraphSubgraphPass::CreateNgraphEngineOp(framework::ir::Node *node,
Graph *graph) const {
auto *op_desc = node->Op();
bool IsValid(std::string name) {
return name.find(Node::kControlDepVarName) == std::string::npos;
}

void UpdateNgraphIO(Node *node, Graph *graph,
std::vector<std::string> *input_names,
std::vector<std::string> *output_names) {
bool is_test = true, has_fetch = false;
for (Node *node : graph->Nodes()) {
if (node->IsOp() && node->Name().find("_grad") != std::string::npos) {
is_test = false;
}
if (node->IsVar() && node->Var()) {
for (auto out : node->outputs) {
if (out->Name() == "fetch") has_fetch = true;
}
}
}
if (is_test && has_fetch) {
for (auto *x : node->inputs) {
(*input_names).emplace_back(x->Name());
}
for (auto *x : node->outputs) {
(*output_names).emplace_back(x->Name());
}
return;
}

auto &subgraph = *ANAT::Agent(node).subgraph();
PADDLE_ENFORCE(!subgraph.empty());
std::unordered_set<std::string> inputs;
std::unordered_set<std::string> outputs;
for (auto *node : subgraph) {
for (auto in : node->inputs) {
auto name = in->Name();
if (!IsValid(name)) continue;
if (!outputs.count(name) && !inputs.count(name)) {
(*input_names).emplace_back(name);
inputs.insert(name);
}
}
for (auto out : node->outputs) {
auto name = out->Name();
if (!IsValid(name)) continue;
outputs.insert(name);
(*output_names).emplace_back(name);
}
}
}

framework::ProgramDesc *program_desc =
Get<framework::ProgramDesc *>("program");
const framework::BlockDesc &main_block =
program_desc->Block(framework::kRootBlockIndex);
framework::BlockDesc *new_block = program_desc->AppendBlock(main_block);
void NgraphSubgraphPass::CreateNgraphEngineOp(Node *node, Graph *graph) const {
auto &subgraph = *ANAT::Agent(node).subgraph();
PADDLE_ENFORCE_NE(subgraph.empty(), true, "subgraph cannot be empty");

framework::proto::BlockDesc block_proto;
framework::BlockDesc block_desc(nullptr, &block_proto);
block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0);
for (auto *node : subgraph) {
auto *new_block_op = new_block->AppendOp();
auto *op = block_desc.AppendOp();
*new_block_op->Proto() = *node->Op()->Proto();
*op->Proto() = *node->Op()->Proto();
}

std::set<std::string> input_names;
std::set<std::string> input_names_with_id;
for (auto *x : node->inputs) {
input_names.insert(x->Name());
input_names_with_id.insert(x->Name() + std::to_string(x->id()));
}
op_desc->SetInput(
"Xs", std::vector<std::string>(input_names.begin(), input_names.end()));

std::set<std::string> output_names;
std::set<std::string> output_names_with_id;

for (auto *x : node->outputs) {
output_names.insert(x->Name());
output_names_with_id.insert(x->Name() + std::to_string(x->id()));
}
op_desc->SetOutput(
"Ys", std::vector<std::string>(output_names.begin(), output_names.end()));
auto *vars = block_desc.Proto()->mutable_vars();
for (framework::ir::Node *node : graph->Nodes()) {
for (Node *node : graph->Nodes()) {
if (node->IsVar() && node->Var()) {
*vars->Add() = *node->Var()->Proto();
}
}
PADDLE_ENFORCE_NE(block_desc.Proto()->vars().empty(), true,
"the block has no var-desc");

PADDLE_ENFORCE(!block_desc.Proto()->vars().empty(),
"the block has no var-desc");

op_desc->SetType("ngraph_engine");
std::vector<std::string> input_names;
std::vector<std::string> output_names;
UpdateNgraphIO(node, graph, &input_names, &output_names);
auto *op_desc = node->Op();
op_desc->SetInput(
"Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
op_desc->SetOutput(
"Ys", std::vector<std::string>(output_names.begin(), output_names.end()));

int sgs = subgraph.size();
std::string engine_key = GenerateEngineKey(
input_names_with_id, output_names_with_id, std::to_string(sgs));
std::string subgraph_str = block_desc.Proto()->SerializeAsString();
std::string engine_key =
std::to_string(std::hash<std::string>()(subgraph_str));
std::vector<int> interval{0, sgs};
op_desc->SetType("ngraph_engine");
op_desc->SetAttr("interval", interval);
op_desc->SetAttr("graph", block_desc.Proto()->SerializeAsString());
op_desc->SetAttr("graph", subgraph_str);
op_desc->SetAttr("engine_key", engine_key);
op_desc->SetAttr("op_role", 0);
}

} // namespace ir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"

DECLARE_bool(use_ngraph);

namespace paddle {
namespace inference {
namespace analysis {
Expand Down Expand Up @@ -398,6 +400,11 @@ void RemoveIntermediateOutputInSubgraph(const std::vector<Node *> &subgraph,
}
}

// In use for ngraph subgraph pass for parallel executor,
// this will remove all nodes, bypass this and let ngraph
// subgraph pass to process outputs
if (FLAGS_use_ngraph && valid_output.size() == 0) return;

outputs->assign(valid_output.begin(), valid_output.end());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from paddle.fluid.tests.unittests.simple_nets import simple_fc_net
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import compiler
import numpy as np
import unittest
import os
import sys
import math


class TestPallelExecutorNgraph(unittest.TestCase):
def check_network_convergence(self, build_strategy=None):
os.environ['CPU_NUM'] = str(2)
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = simple_fc_net()
test_program = main.clone(for_test=True)

opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)

batch_size = 32
image = np.random.normal(size=(batch_size, 784)).astype('float32')
label = np.random.randint(0, 10, (batch_size, 1), dtype="int64")

place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
feed_dict = {'image': image, 'label': label}

train_cp = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
test_cp = compiler.CompiledProgram(test_program).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
share_vars_from=train_cp)

for i in range(5):
_ = exe.run(train_cp, fetch_list=[loss.name], feed=feed_dict)
test_loss, = exe.run(test_cp,
fetch_list=[loss.name],
feed=feed_dict)
train_loss = exe.run(train_cp,
fetch_list=[loss.name],
feed=feed_dict)

avg_test_loss_val = np.array(test_loss).mean()
if math.isnan(float(avg_test_loss_val)):
sys.exit("got NaN loss, testing failed.")

avg_train_loss_val = np.array(train_loss).mean()
if math.isnan(float(avg_train_loss_val)):
sys.exit("got NaN loss, training failed.")

self.assertTrue(
np.allclose(
train_loss, test_loss, atol=1e-8),
"Train loss: " + str(train_loss) + "\n Test loss:" +
str(test_loss))

def test_parallel_testing(self):
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
self.check_network_convergence(build_strategy=build_strategy)


if __name__ == '__main__':
unittest.main()