Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions paddle/fluid/framework/details/all_reduce_deps_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ namespace paddle {
namespace framework {
namespace details {

static constexpr char kAllOpDescs[] = "all_op_descs";

VarHandle* GetValidInput(const OpHandleBase* a) {
for (auto p : a->Inputs()) {
VarHandle* b = dynamic_cast<VarHandle*>(p);
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/details/build_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ namespace details {
static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
// Should fix the allreduce op order if scheduling
// them in multiple threads or processes to avoid hang.
// NOTE: ParallelGraph would execute this pass on each graph, so
// don't need to append it here.
return (!strategy.enable_sequential_execution_ &&
strategy.num_trainers_ > 1) ||
strategy.enable_parallel_graph_;
strategy.num_trainers_ > 1) &&
!strategy.enable_parallel_graph_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parallel_graph do not need to add deps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PG needs to execute this pass on each graph.

}

class ParallelExecutorPassBuilder : public ir::PassBuilder {
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/framework/details/memory_optimize_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ namespace paddle {
namespace framework {
namespace details {

constexpr char kAllOpDescs[] = "all_op_descs";

std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);

// NOTE(dzh): A ordered set for node reuse in memory optimize.
Expand Down
32 changes: 22 additions & 10 deletions paddle/fluid/framework/details/multi_devices_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,28 +392,40 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,

void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
ir::Graph *result, const std::string &og) const {
OpHandleBase *op_handle = nullptr;

auto append_allreduce_op = [&](
const std::vector<Scope *> &scopes,
const std::vector<platform::Place> &places) -> OpHandleBase * {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places, nccl_ctxs_));
#else
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_));
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places));
#endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
return result->Get<GraphOps>(kGraphOps).back();
};

if (!strategy_.enable_parallel_graph_)
op_handle = append_allreduce_op(local_scopes_, places_);

for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
if (strategy_.enable_parallel_graph_) {
op_handle = append_allreduce_op({local_scopes_[i]}, {places_[i]});
}

SetCommunicationContext(op_handle, places_[i]);
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad);

auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), i, og, p);
vars.size(), i, og, places_[i]);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/framework/details/multi_devices_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ namespace details {
// map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the
// `std::vector<VarHandle*>` is the version of varaibles.
typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle*>>>
typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle *>>>
GraphVars;
const char kGraphVars[] = "vars";

// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<VarHandleBase*> GraphDepVars;
typedef std::unordered_set<VarHandleBase *> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";

} // namespace details
} // namespace framework
} // namespace paddle
3 changes: 3 additions & 0 deletions paddle/fluid/framework/details/op_handle_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class OpHandleBase {
auto it = dev_ctxes_.find(place);
return it != dev_ctxes_.end() ? it->second : nullptr;
}
const std::map<platform::Place, platform::DeviceContext *> &DeviceContext() {
return dev_ctxes_;
}

void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) {
dev_ctxes_[place] = ctx_;
Expand Down
76 changes: 73 additions & 3 deletions paddle/fluid/framework/details/parallel_ssa_graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,92 @@
// limitations under the License.

#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph_helper.h"

namespace paddle {
namespace framework {
namespace details {

std::vector<std::unique_ptr<ir::Graph>>
ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(
std::unique_ptr<ir::Graph> &&graph) {
std::vector<std::unique_ptr<ir::Graph>> graphs;
graphs.reserve(places_.size());
for (size_t i = 0; i < places_.size(); ++i) {
ProgramDesc empty;
graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty)));
auto &g = graphs.back();
g->Set(kGraphVars, new GraphVars(1UL));
g->Set(kGraphDepVars, new GraphDepVars);
}
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(*graph);

for (auto &op : op_handles) {
auto &dev_ctx = op->DeviceContext();
auto &p = dev_ctx.begin()->first;
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &dev_dummys = graphs[dev_id]->Get<GraphDepVars>(kGraphDepVars);
graphs[dev_id]->AddNode(graph->RemoveNode(op->Node()).release());

for (auto &var : op->Inputs()) {
auto dummy_ptr = dynamic_cast<DummyVarHandle *>(var);
if (dummy_ptr) {
dev_dummys.insert(var);
if (graph->Nodes().count(var->Node()))
graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release());
}
}
for (auto &var : op->Outputs()) {
auto dummy_ptr = dynamic_cast<DummyVarHandle *>(var);
if (dummy_ptr) {
dev_dummys.insert(var);
if (graph->Nodes().count(var->Node()))
graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release());
}
}
}

for (size_t dev_id = 0; dev_id < places_.size(); ++dev_id) {
auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0];
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id];
for (auto &name_pair : origin_vars) {
dev_vars.emplace(name_pair.first, name_pair.second);
for (auto &version_pair : name_pair.second) {
if (graph->Nodes().count(version_pair->Node())) {
graphs[dev_id]->AddNode(
graph->RemoveNode(version_pair->Node()).release());
}
}
}
}

return graphs;
}

ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> &&graphs)
const framework::ProgramDesc &main_prog, std::unique_ptr<ir::Graph> &&graph)
: strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)),
graphs_(std::move(graphs)) {
main_prog_(main_prog),
// TODO(Yancey1989): Copying graphs is not safely since it deleted the
// attrs.
graphs_(SeparateMultiDevicesGraph(std::move(graph))) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());

auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Erase(details::kAllOpDescs);
seq_allreduce_pass->Set<const std::vector<OpDesc *>>(
details::kAllOpDescs,
new std::vector<OpDesc *>(main_prog_.Block(0).AllOps()));
for (size_t i = 0; i < graphs_.size(); ++i) {
graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i]));
}

// set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
? 1UL
Expand All @@ -37,7 +107,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
<< " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i])));
strategy_, local_scopes_, {places_[i]}, std::move(graphs_.at(i))));
}
}

Expand Down
10 changes: 9 additions & 1 deletion paddle/fluid/framework/details/parallel_ssa_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include <vector>

#include "ThreadPool.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"

namespace paddle {
namespace framework {
Expand All @@ -29,17 +31,23 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> &&graphs);
const framework::ProgramDesc &main_prog,
std::unique_ptr<ir::Graph> &&graph);
~ParallelSSAGraphExecutor() final = default;

const ir::Graph &Graph() const override { return *graphs_[0]; }

FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;

private:
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
std::unique_ptr<ir::Graph> &&graph);

ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_;
framework::ProgramDesc main_prog_;
std::vector<std::unique_ptr<ir::Graph>> graphs_;

std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ void ThreadedSSAGraphExecutor::RunOp(
VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--;
ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << "Signal posted";
VLOG(10) << op << " " << op->Name() << " Signal posted";
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
Expand Down
29 changes: 20 additions & 9 deletions paddle/fluid/framework/ir/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ limitations under the License. */

namespace paddle {
namespace framework {

namespace details {

// This attr is not recommended, because the graph should not dependence
// the program once it is built.
constexpr char kAllOpDescs[] = "all_op_descs";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc here that this is not recommended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

} // namespace details

namespace ir {

/*
Expand Down Expand Up @@ -168,10 +176,13 @@ class Graph {
return ret;
}

void RemoveNode(ir::Node *node) {
std::unique_ptr<ir::Node> RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
node_set_.erase(node);
std::unique_ptr<ir::Node> ret;
ret.reset(nodes_.at(node).release());
nodes_.erase(node);
node_set_.erase(node);
return ret;
}

// NOTE low performance, but simple and secure.
Expand All @@ -184,13 +195,6 @@ class Graph {
return nullptr;
}

void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);

private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);

// This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
Expand All @@ -199,6 +203,13 @@ class Graph {
return node;
}

void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);

private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);

// NOTE: program_ shouldn't be exposed to user.
const ProgramDesc program_;
std::map<std::string, boost::any> attrs_;
Expand Down
Loading