Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions paddle/fluid/framework/details/all_reduce_deps_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ namespace paddle {
namespace framework {
namespace details {

static constexpr char kAllOpDescs[] = "all_op_descs";

VarHandle* GetValidInput(const OpHandleBase* a) {
for (auto p : a->Inputs()) {
VarHandle* b = dynamic_cast<VarHandle*>(p);
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/details/all_reduce_deps_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ namespace paddle {
namespace framework {
namespace details {

constexpr char kAllOpDescs[] = "all_op_descs";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this also in memory_optimizer_helper.h, unify them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, moved to graph.h


// TODO(gongwb): overlap allreduce with backward computation.
class AllReduceDepsPass : public ir::Pass {
protected:
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/details/build_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
// Should fix the allreduce op order if scheduling
// them in multiple threads or processes to avoid hang.
return (!strategy.enable_sequential_execution_ &&
strategy.num_trainers_ > 1) ||
strategy.enable_parallel_graph_;
strategy.num_trainers_ > 1) &&
!strategy.enable_parallel_graph_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parallel_graph do not need to add deps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PG needs to execute this pass on each graph.

}

class ParallelExecutorPassBuilder : public ir::PassBuilder {
Expand Down Expand Up @@ -118,7 +118,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
}

// Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass");
auto multi_devices_pass = AppendPass("multi_devices_check_pass");

if (SeqOnlyAllReduceOps(strategy)) {
AppendPass("all_reduce_deps_pass");
Expand Down
37 changes: 23 additions & 14 deletions paddle/fluid/framework/details/multi_devices_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,6 @@ namespace framework {
namespace details {

namespace {
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops";

bool OpHaveRole(const ir::Node &node, const framework::OpRole &role) {
return boost::get<int>(
Expand Down Expand Up @@ -226,7 +221,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
* Only variables should be the leaves of graph.
*/
AddOutputToLeafOps(&result);
result.Erase(kGraphOps);
// result.Erase(kGraphOps);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not removing this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's easy to iter all ops, and I don't know why to delete this attr, I think it's useful to do sth. about the graph.

return graph;
}

Expand Down Expand Up @@ -392,19 +387,33 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,

void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
ir::Graph *result, const std::string &og) const {
OpHandleBase *op_handle = nullptr;

auto append_allreduce_op = [&](
std::vector<Scope *> &scopes,
std::vector<platform::Place> &places) -> OpHandleBase * {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places, nccl_ctxs_));
#else
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_));
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places));
#endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
return result->Get<GraphOps>(kGraphOps).back();
};

if (!strategy_.enable_parallel_graph_)
op_handle = append_allreduce_op(local_scopes_, places_);

for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
auto p = places_[i];
std::vector<Scope *> ss{local_scopes_[i]};
std::vector<platform::Place> ps{p};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put this in if?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

if (strategy_.enable_parallel_graph_)
op_handle = append_allreduce_op(ss, ps);

SetCommunicationContext(op_handle, p);
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty());
Expand Down
11 changes: 9 additions & 2 deletions paddle/fluid/framework/details/multi_devices_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,20 @@ namespace details {
// map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the
// `std::vector<VarHandle*>` is the version of varaibles.
typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle*>>>
typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle *>>>
GraphVars;
const char kGraphVars[] = "vars";

// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<VarHandleBase*> GraphDepVars;
typedef std::unordered_set<VarHandleBase *> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";

// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to expose these? you can get graph ops:

for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


} // namespace details
} // namespace framework
} // namespace paddle
3 changes: 3 additions & 0 deletions paddle/fluid/framework/details/op_handle_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class OpHandleBase {
auto it = dev_ctxes_.find(place);
return it != dev_ctxes_.end() ? it->second : nullptr;
}
const std::map<platform::Place, platform::DeviceContext *> &DeviceContext() {
return dev_ctxes_;
}

void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) {
dev_ctxes_[place] = ctx_;
Expand Down
61 changes: 60 additions & 1 deletion paddle/fluid/framework/details/parallel_ssa_graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,70 @@
// limitations under the License.

#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph_helper.h"

namespace paddle {
namespace framework {
namespace details {

std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in a pass?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's hide this in parallel graph executor for now. Splitting graph is not safely supported now. (add comment). We might want to remove this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is not safe. Copying graph can drop graph attrs.
Maybe we shouldn't copy graph. Instead, we pass selected graph ops and vars to threaded graph executor. (change threaded_graph_executor interface.)

const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> graph) {
std::vector<std::unique_ptr<ir::Graph>> graphs;
graphs.reserve(places.size());
for (size_t i = 0; i < places.size(); ++i) {
ProgramDesc empty;
graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty)));
auto &g = graphs.back();
g->Set(kGraphVars, new GraphVars(1UL));
g->Set(kGraphDepVars, new GraphDepVars);
g->Set(kGraphOps, new GraphOps);
}

for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
auto &dev_ctx = op->DeviceContext();
auto &p = dev_ctx.begin()->first;
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &dev_ops = graphs[dev_id]->Get<GraphOps>(kGraphOps);
auto &dev_dummys = graphs[dev_id]->Get<GraphDepVars>(kGraphDepVars);
dev_ops.emplace_back(op);
graphs[dev_id]->AddNode(graph->RemoveNode(op->Node()).release());

for (auto &var : op->Inputs()) {
auto dummy_ptr = dynamic_cast<DummyVarHandle *>(var);
if (dummy_ptr) {
dev_dummys.insert(var);
if (graph->Nodes().count(var->Node()))
graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release());
}
}
for (auto &var : op->Outputs()) {
auto dummy_ptr = dynamic_cast<DummyVarHandle *>(var);
if (dummy_ptr) {
dev_dummys.insert(var);
if (graph->Nodes().count(var->Node()))
graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release());
}
}
}

for (size_t dev_id = 0; dev_id < places.size(); ++dev_id) {
auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0];
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id];
for (auto &name_pair : origin_vars) {
dev_vars.emplace(name_pair.first, name_pair.second);
for (auto &version_pair : name_pair.second) {
if (graph->Nodes().count(version_pair->Node())) {
graphs[dev_id]->AddNode(
graph->RemoveNode(version_pair->Node()).release());
}
}
}
}

return graphs;
}

ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
Expand All @@ -37,7 +96,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
<< " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i])));
strategy_, local_scopes_, {places_[i]}, std::move(graphs_.at(i))));
}
}

Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/framework/details/parallel_ssa_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,32 @@

#pragma once

#include <fstream>
#include <sstream>
#include <string>
#include <vector>

#include "ThreadPool.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"

namespace paddle {
namespace framework {
namespace details {

std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> graph);

class ParallelSSAGraphExecutor : public SSAGraphExecutor {
public:
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> &&graphs);
~ParallelSSAGraphExecutor() final = default;

const ir::Graph &Graph() const override { return *graphs_[0]; }

FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ void ThreadedSSAGraphExecutor::RunOp(
VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--;
ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << "Signal posted";
VLOG(10) << op << " " << op->Name() << " Signal posted";
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
Expand Down
21 changes: 12 additions & 9 deletions paddle/fluid/framework/ir/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,13 @@ class Graph {
return ret;
}

void RemoveNode(ir::Node *node) {
std::unique_ptr<ir::Node> RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
node_set_.erase(node);
std::unique_ptr<ir::Node> ret;
ret.reset(nodes_.at(node).release());
nodes_.erase(node);
node_set_.erase(node);
return ret;
}

// NOTE low performance, but simple and secure.
Expand All @@ -184,13 +187,6 @@ class Graph {
return nullptr;
}

void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);

private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);

// This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
Expand All @@ -199,6 +195,13 @@ class Graph {
return node;
}

void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);

private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);

// NOTE: program_ shouldn't be exposed to user.
const ProgramDesc program_;
std::map<std::string, boost::any> attrs_;
Expand Down
Loading