Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions paddle/fluid/framework/details/all_reduce_deps_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ namespace paddle {
namespace framework {
namespace details {

static constexpr char kAllOpDescs[] = "all_op_descs";

VarHandle* GetValidInput(const OpHandleBase* a) {
for (auto p : a->Inputs()) {
VarHandle* b = dynamic_cast<VarHandle*>(p);
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/details/build_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ namespace details {
static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
// Should fix the allreduce op order if scheduling
// them in multiple threads or processes to avoid hang.
// NOTE: ParallelExecutor would execute this pass on each graph, so
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParallelGraphExecutor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

// don't need to append it here.
return (!strategy.enable_sequential_execution_ &&
strategy.num_trainers_ > 1) ||
strategy.enable_parallel_graph_;
strategy.num_trainers_ > 1) &&
!strategy.enable_parallel_graph_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parallel_graph do not need to add deps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PG needs to execute this pass on each graph.

}

class ParallelExecutorPassBuilder : public ir::PassBuilder {
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/framework/details/memory_optimize_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ namespace paddle {
namespace framework {
namespace details {

constexpr char kAllOpDescs[] = "all_op_descs";

std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);

// NOTE(dzh): A ordered set for node reuse in memory optimize.
Expand Down
36 changes: 22 additions & 14 deletions paddle/fluid/framework/details/multi_devices_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,6 @@ namespace framework {
namespace details {

namespace {
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops";

bool OpHaveRole(const ir::Node &node, const framework::OpRole &role) {
return boost::get<int>(
Expand Down Expand Up @@ -226,7 +221,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
* Only variables should be the leaves of graph.
*/
AddOutputToLeafOps(&result);
result.Erase(kGraphOps);
return graph;
}

Expand Down Expand Up @@ -392,19 +386,33 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,

void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
ir::Graph *result, const std::string &og) const {
OpHandleBase *op_handle = nullptr;

auto append_allreduce_op = [&](
std::vector<Scope *> &scopes,
std::vector<platform::Place> &places) -> OpHandleBase * {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places, nccl_ctxs_));
#else
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_));
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places));
#endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
return result->Get<GraphOps>(kGraphOps).back();
};

if (!strategy_.enable_parallel_graph_)
op_handle = append_allreduce_op(local_scopes_, places_);

for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
auto p = places_[i];
std::vector<Scope *> ss{local_scopes_[i]};
std::vector<platform::Place> ps{p};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put this in if?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

if (strategy_.enable_parallel_graph_)
op_handle = append_allreduce_op(ss, ps);

SetCommunicationContext(op_handle, p);
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty());
Expand Down
11 changes: 9 additions & 2 deletions paddle/fluid/framework/details/multi_devices_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,20 @@ namespace details {
// map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the
// `std::vector<VarHandle*>` is the version of varaibles.
typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle*>>>
typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle *>>>
GraphVars;
const char kGraphVars[] = "vars";

// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<VarHandleBase*> GraphDepVars;
typedef std::unordered_set<VarHandleBase *> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";

// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to expose these? you can get graph ops:

for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


} // namespace details
} // namespace framework
} // namespace paddle
3 changes: 3 additions & 0 deletions paddle/fluid/framework/details/op_handle_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class OpHandleBase {
auto it = dev_ctxes_.find(place);
return it != dev_ctxes_.end() ? it->second : nullptr;
}
const std::map<platform::Place, platform::DeviceContext *> &DeviceContext() {
return dev_ctxes_;
}

void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) {
dev_ctxes_[place] = ctx_;
Expand Down
78 changes: 75 additions & 3 deletions paddle/fluid/framework/details/parallel_ssa_graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,94 @@
// limitations under the License.

#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph_helper.h"

namespace paddle {
namespace framework {
namespace details {

std::vector<std::unique_ptr<ir::Graph>>
ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(
std::unique_ptr<ir::Graph> &&graph) {
std::vector<std::unique_ptr<ir::Graph>> graphs;
graphs.reserve(places_.size());
for (size_t i = 0; i < places_.size(); ++i) {
ProgramDesc empty;
graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty)));
auto &g = graphs.back();
g->Set(kGraphVars, new GraphVars(1UL));
g->Set(kGraphDepVars, new GraphDepVars);
g->Set(kGraphOps, new GraphOps);
}

for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
auto &dev_ctx = op->DeviceContext();
auto &p = dev_ctx.begin()->first;
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &dev_ops = graphs[dev_id]->Get<GraphOps>(kGraphOps);
auto &dev_dummys = graphs[dev_id]->Get<GraphDepVars>(kGraphDepVars);
dev_ops.emplace_back(op);
graphs[dev_id]->AddNode(graph->RemoveNode(op->Node()).release());

for (auto &var : op->Inputs()) {
auto dummy_ptr = dynamic_cast<DummyVarHandle *>(var);
if (dummy_ptr) {
dev_dummys.insert(var);
if (graph->Nodes().count(var->Node()))
graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release());
}
}
for (auto &var : op->Outputs()) {
auto dummy_ptr = dynamic_cast<DummyVarHandle *>(var);
if (dummy_ptr) {
dev_dummys.insert(var);
if (graph->Nodes().count(var->Node()))
graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release());
}
}
}

for (size_t dev_id = 0; dev_id < places_.size(); ++dev_id) {
auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0];
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id];
for (auto &name_pair : origin_vars) {
dev_vars.emplace(name_pair.first, name_pair.second);
for (auto &version_pair : name_pair.second) {
if (graph->Nodes().count(version_pair->Node())) {
graphs[dev_id]->AddNode(
graph->RemoveNode(version_pair->Node()).release());
}
}
}
}

return graphs;
}

ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> &&graphs)
const framework::ProgramDesc &main_prog, std::unique_ptr<ir::Graph> &&graph)
: strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)),
graphs_(std::move(graphs)) {
main_prog_(main_prog),
// TODO(Yancey1989): Copying graphs is not safely since it deleted the
// attrs.
graphs_(SeparateMultiDevicesGraph(std::move(graph))) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());

auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Erase(details::kAllOpDescs);
seq_allreduce_pass->Set<const std::vector<OpDesc *>>(
details::kAllOpDescs,
new std::vector<OpDesc *>(main_prog_.Block(0).AllOps()));
for (size_t i = 0; i < graphs_.size(); ++i) {
graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i]));
}

// set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
? 1UL
Expand All @@ -37,7 +109,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
<< " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i])));
strategy_, local_scopes_, {places_[i]}, std::move(graphs_.at(i))));
}
}

Expand Down
12 changes: 11 additions & 1 deletion paddle/fluid/framework/details/parallel_ssa_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@

#pragma once

#include <fstream>
#include <sstream>
#include <string>
#include <vector>

#include "ThreadPool.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"

namespace paddle {
namespace framework {
Expand All @@ -29,17 +33,23 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> &&graphs);
const framework::ProgramDesc &main_prog,
std::unique_ptr<ir::Graph> &&graph);
~ParallelSSAGraphExecutor() final = default;

const ir::Graph &Graph() const override { return *graphs_[0]; }

FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;

private:
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
std::unique_ptr<ir::Graph> &&graph);

ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_;
framework::ProgramDesc main_prog_;
std::vector<std::unique_ptr<ir::Graph>> graphs_;

std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ void ThreadedSSAGraphExecutor::RunOp(
VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--;
ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << "Signal posted";
VLOG(10) << op << " " << op->Name() << " Signal posted";
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
Expand Down
26 changes: 17 additions & 9 deletions paddle/fluid/framework/ir/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ limitations under the License. */

namespace paddle {
namespace framework {

namespace details {
constexpr char kAllOpDescs[] = "all_op_descs";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc here that this is not recommended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

} // namespace details

namespace ir {

/*
Expand Down Expand Up @@ -168,10 +173,13 @@ class Graph {
return ret;
}

void RemoveNode(ir::Node *node) {
std::unique_ptr<ir::Node> RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
node_set_.erase(node);
std::unique_ptr<ir::Node> ret;
ret.reset(nodes_.at(node).release());
nodes_.erase(node);
node_set_.erase(node);
return ret;
}

// NOTE low performance, but simple and secure.
Expand All @@ -184,13 +192,6 @@ class Graph {
return nullptr;
}

void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);

private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);

// This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
Expand All @@ -199,6 +200,13 @@ class Graph {
return node;
}

void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);

private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);

// NOTE: program_ shouldn't be exposed to user.
const ProgramDesc program_;
std::map<std::string, boost::any> attrs_;
Expand Down
Loading