Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 71 additions & 34 deletions paddle/fluid/framework/details/multi_devices_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,38 +161,75 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program);

std::vector<std::unordered_set<std::string>> var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
var_name_on_devices.resize(places_.size());
bcast_var_name_set.resize(places_.size());

size_t cur_device_id = 0;
std::vector<int64_t> balance_grads(places_.size(), 0);

auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
auto var_desc = all_vars.at(g_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GE(numel, 0);
auto get_appropriate_dev = [&](std::vector<std::string> var_names) -> size_t {
int64_t numel_all = 0;
for (auto var_name : var_names) {
auto var_desc = all_vars.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GT(numel, 0);
numel_all += numel;
}

auto smallest =
std::min_element(std::begin(balance_grads), std::end(balance_grads));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
balance_grads[dev_id] += numel;
balance_grads[dev_id] += numel_all;
return dev_id;
};

bool is_forwarding = true;

for (auto *op : program.Block(0).AllOps()) {
if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) {
// append rpc op if program is distributed trainer main program.
// always use the first device
CreateRPCOp(&result, *op);
if (op->Type() == "send_vars") {
int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]);
if (op_dev_id == -1) {
op_dev_id = get_appropriate_dev(op->InputArgumentNames());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yancey1989 as we discussed, one concern, the order when calling get_appropriate_dev must be the same to reduce and split_op or the device id for the variable may be different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done.

for (auto &varname : op->InputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
}
CreateRPCOp(&result, *op, op_dev_id);
} else if (op->Type() == "recv") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will only 1 device perform broadcast in Reduce mode? So recv should be done on that device before broadcast? Perhaps take a look at get_appropriate_dev? I'm not quite sure the details

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take a look at get_appropriate_dev and find the relationship with this PR.

int op_dev_id = get_appropriate_dev(op->OutputArgumentNames());
for (auto &varname : op->OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
CreateRPCOp(&result, *op, op_dev_id);
} else {
// send_barrier and fetch_barrier op would run on device 0
CreateRPCOp(&result, *op, 0);
}
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
CreateDistTrainOp(&result, *op);
if (op->Type() == "split_byref") {
int op_dev_id = get_appropriate_dev(op->OutputArgumentNames());
for (auto &varname : op->OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
CreateDistTrainOp(&result, *op, op_dev_id);
} else if (op->Type() == "concat") {
int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]);
PADDLE_ENFORCE(op_dev_id != -1,
"can not find right place to concatenate received var.");
CreateDistTrainOp(&result, *op, op_dev_id);
} else {
PADDLE_ENFORCE(
"the distribute training related op should be in [split_byref, "
"concat].");
}
} else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ !=
Expand All @@ -201,13 +238,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
is_forwarding = false;
} else {
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
int op_dev_id = GetOpDeviceID(*op);
if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size());
} else {
CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices[op_dev_id].emplace(var_name);
var_name_on_devices_.emplace(var_name, op_dev_id);
}
}
if (!is_forwarding && places_.size() > 1) {
Expand All @@ -230,9 +267,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(

switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = get_appropriate_dev(g_name);
cur_device_id = get_appropriate_dev({g_name});
CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices[cur_device_id].emplace(g_name);
var_name_on_devices_.emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name);
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
Expand Down Expand Up @@ -363,24 +400,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once;
}

int MultiDevSSAGraphBuilder::GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const {
int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
}

int var_dev_id = -1;
for (auto &var_name : op.InputArgumentNames()) {
if (var_dev_id != -1) break;
for (size_t i = 0; i < var_name_on_devices.size(); ++i) {
if (var_name_on_devices[i].count(var_name)) {
var_dev_id = static_cast<int>(i);
break;
}
for (auto &varname : op.InputArgumentNames()) {
int dev_id = GetVarDeviceID(varname);
if (dev_id != -1) {
return dev_id;
}
}
return var_dev_id;
return -1;
}

int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
auto got = var_name_on_devices_.find(varname);
return got == var_name_on_devices_.end() ? -1 : got->second;
}

void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
Expand Down Expand Up @@ -462,17 +498,18 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
}

void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const {
CreateComputationalOp(result, op, 0);
const OpDesc &op,
int place_id) const {
CreateComputationalOp(result, op, place_id);
if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
}
}

void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const {
result->ops_.emplace_back(
new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0]));
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op,
int device_id) const {
result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[device_id],
op.Type(), places_[device_id]));

if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send");
Expand All @@ -490,7 +527,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,

// TODO(Yancey1989): schedule rpc op on different place may
// increate throughput
CreateOpHandleIOs(result, op, 0);
CreateOpHandleIOs(result, op, device_id);
}

bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
Expand Down
13 changes: 7 additions & 6 deletions paddle/fluid/framework/details/multi_devices_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#endif

std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
int GetVarDeviceID(const std::string &varname) const;

private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
size_t place_id) const;
size_t device_id) const;

private:
std::string loss_var_name_;
Expand All @@ -64,8 +65,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {

bool IsScaleLossOp(const OpDesc &op) const;

void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op, int place_id) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op,
int place_id) const;

/**
* Is this operator as the end-point operator before/after send operator.
Expand Down Expand Up @@ -96,9 +98,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const;

int GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const;
int GetOpDeviceID(const OpDesc &op) const;

void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;

Expand All @@ -111,6 +111,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {

private:
BuildStrategy strategy_;
mutable std::unordered_map<std::string, int> var_name_on_devices_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use unordered_map to record the var_name on devices, because the same var_name may be on different devices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May not, this does not record all variables, only used for Reduce strategy and distributed training.

For the Reduce strategy, we schedule Reduce Op on the different device and record the gradient variable name in var_name_on_devices_ , so it would only appear on only one device.

For the distributed training, the same as Reduce strategy, we schedule send_op and recv_op on the different device, the variable name would not appear on the different device also.


void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/details/ssa_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class SSAGraphBuilder {
SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const { return -1; }

DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);

Expand Down
15 changes: 11 additions & 4 deletions paddle/fluid/framework/parallel_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ ParallelExecutor::ParallelExecutor(

// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp

details::SSAGraphBuilderFactory builder_factory(
member_->places_, loss_var_name, params, member_->local_scopes_,
build_strategy);
Expand All @@ -122,9 +121,10 @@ ParallelExecutor::ParallelExecutor(
#endif
}

builder_ = std::move(builder_factory.Create());
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places,
builder_factory.Create()->Build(main_program)));
builder_->Build(main_program)));

member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos),
Expand Down Expand Up @@ -160,8 +160,15 @@ void ParallelExecutor::BCastParamsToGPUs(
buffer = t->mutable_data(place, main_tensor.type());
}
auto &nccl_ctx = member_->nccl_ctxs_->at(place);
platform::dynload::ncclBcast(buffer, numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());

if (builder_.get() != nullptr && builder_->GetVarDeviceID(var) != -1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

builder_.get() != nullptr -> builder_

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that GetVarDeviceID should probably be a method of a built graph or executor. This avoids making builder_ a private member. But I guess it's ok to leave it as TOOD for now.

int place_id = builder_->GetVarDeviceID(var);
platform::dynload::ncclBcast(buffer, numel, data_type, place_id,
nccl_ctx.comm_, nccl_ctx.stream());
} else {
platform::dynload::ncclBcast(buffer, numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
}
}
member_->nccl_ctxs_->WaitAll();
#else
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/parallel_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace framework {

Expand Down Expand Up @@ -68,6 +70,7 @@ class ParallelExecutor {

private:
ParallelExecutorPrivate *member_;
std::unique_ptr<details::SSAGraphBuilder> builder_;
};

} // namespace framework
Expand Down