-
Notifications
You must be signed in to change notification settings - Fork 5.9k
overlap rpc op memcpy in distributed training #11221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
93401c9
6d69ae0
82d741c
cb38615
e533a4b
15913d9
23433de
d5a88b9
4444e79
6d752ba
f52d78d
3d875b6
7d1b146
7e6518e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -161,38 +161,75 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( | |
| auto send_vars = FindDistTrainSendVars(program); | ||
| auto recv_vars = FindDistTrainRecvVars(program); | ||
|
|
||
| std::vector<std::unordered_set<std::string>> var_name_on_devices; | ||
| std::vector<std::unordered_set<std::string>> bcast_var_name_set; | ||
| var_name_on_devices.resize(places_.size()); | ||
| bcast_var_name_set.resize(places_.size()); | ||
|
|
||
| size_t cur_device_id = 0; | ||
| std::vector<int64_t> balance_grads(places_.size(), 0); | ||
|
|
||
| auto get_appropriate_dev = [&](std::string &g_name) -> size_t { | ||
| auto var_desc = all_vars.at(g_name); | ||
| PADDLE_ENFORCE_NOT_NULL(var_desc); | ||
| auto dim = framework::make_ddim(var_desc->GetShape()); | ||
| int64_t numel = framework::product(dim); | ||
| PADDLE_ENFORCE_GE(numel, 0); | ||
| auto get_appropriate_dev = [&](std::vector<std::string> var_names) -> size_t { | ||
| int64_t numel_all = 0; | ||
| for (auto var_name : var_names) { | ||
| auto var_desc = all_vars.at(var_name); | ||
| PADDLE_ENFORCE_NOT_NULL(var_desc); | ||
| auto dim = framework::make_ddim(var_desc->GetShape()); | ||
| int64_t numel = framework::product(dim); | ||
| PADDLE_ENFORCE_GT(numel, 0); | ||
| numel_all += numel; | ||
| } | ||
|
|
||
| auto smallest = | ||
| std::min_element(std::begin(balance_grads), std::end(balance_grads)); | ||
| size_t dev_id = | ||
| static_cast<size_t>(std::distance(std::begin(balance_grads), smallest)); | ||
| balance_grads[dev_id] += numel; | ||
| balance_grads[dev_id] += numel_all; | ||
| return dev_id; | ||
| }; | ||
|
|
||
| bool is_forwarding = true; | ||
|
|
||
| for (auto *op : program.Block(0).AllOps()) { | ||
| if (boost::get<int>( | ||
| op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == | ||
| static_cast<int>(OpRole::kRPC)) { | ||
| // append rpc op if program is distributed trainer main program. | ||
| // always use the first device | ||
| CreateRPCOp(&result, *op); | ||
| if (op->Type() == "send_vars") { | ||
| int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]); | ||
| if (op_dev_id == -1) { | ||
| op_dev_id = get_appropriate_dev(op->InputArgumentNames()); | ||
| for (auto &varname : op->InputArgumentNames()) { | ||
| var_name_on_devices_.emplace(varname, op_dev_id); | ||
| } | ||
| } | ||
| CreateRPCOp(&result, *op, op_dev_id); | ||
| } else if (op->Type() == "recv") { | ||
|
||
| int op_dev_id = get_appropriate_dev(op->OutputArgumentNames()); | ||
| for (auto &varname : op->OutputArgumentNames()) { | ||
| var_name_on_devices_.emplace(varname, op_dev_id); | ||
| } | ||
| CreateRPCOp(&result, *op, op_dev_id); | ||
| } else { | ||
| // send_barrier and fetch_barrier op would run on device 0 | ||
| CreateRPCOp(&result, *op, 0); | ||
| } | ||
| } else if (IsDistTrainOp(*op, send_vars, recv_vars)) { | ||
| CreateDistTrainOp(&result, *op); | ||
| if (op->Type() == "split_byref") { | ||
| int op_dev_id = get_appropriate_dev(op->OutputArgumentNames()); | ||
| for (auto &varname : op->OutputArgumentNames()) { | ||
| var_name_on_devices_.emplace(varname, op_dev_id); | ||
| } | ||
| CreateDistTrainOp(&result, *op, op_dev_id); | ||
| } else if (op->Type() == "concat") { | ||
| int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]); | ||
| PADDLE_ENFORCE(op_dev_id != -1, | ||
| "can not find right place to concatenate received var."); | ||
| CreateDistTrainOp(&result, *op, op_dev_id); | ||
| } else { | ||
| PADDLE_ENFORCE( | ||
| "the distribute training related op should be in [split_byref, " | ||
| "concat]."); | ||
| } | ||
| } else if (IsScaleLossOp(*op)) { | ||
| // user can customize loss@grad if not use_default_grad_scale_ | ||
| if (strategy_.gradient_scale_ != | ||
|
|
@@ -201,13 +238,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( | |
| } | ||
| is_forwarding = false; | ||
| } else { | ||
| int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); | ||
| int op_dev_id = GetOpDeviceID(*op); | ||
| if (op_dev_id == -1) { // var on all device | ||
| CreateComputationalOps(&result, *op, places_.size()); | ||
| } else { | ||
| CreateComputationalOp(&result, *op, op_dev_id); | ||
| for (auto &var_name : op->OutputArgumentNames()) { | ||
| var_name_on_devices[op_dev_id].emplace(var_name); | ||
| var_name_on_devices_.emplace(var_name, op_dev_id); | ||
| } | ||
| } | ||
| if (!is_forwarding && places_.size() > 1) { | ||
|
|
@@ -230,9 +267,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( | |
|
|
||
| switch (strategy_.reduce_) { | ||
| case BuildStrategy::ReduceStrategy::kReduce: | ||
| cur_device_id = get_appropriate_dev(g_name); | ||
| cur_device_id = get_appropriate_dev({g_name}); | ||
| CreateReduceOp(&result, g_name, cur_device_id); | ||
| var_name_on_devices[cur_device_id].emplace(g_name); | ||
| var_name_on_devices_.emplace(g_name, cur_device_id); | ||
| bcast_var_name_set[cur_device_id].emplace(p_name); | ||
| break; | ||
| case BuildStrategy::ReduceStrategy::kAllReduce: | ||
|
|
@@ -363,24 +400,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( | |
| return is_pg_once; | ||
| } | ||
|
|
||
| int MultiDevSSAGraphBuilder::GetOpDeviceID( | ||
| const std::vector<std::unordered_set<std::string>> &var_name_on_devices, | ||
| const OpDesc &op) const { | ||
| int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const { | ||
| if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { | ||
| return -1; | ||
| } | ||
|
|
||
| int var_dev_id = -1; | ||
| for (auto &var_name : op.InputArgumentNames()) { | ||
| if (var_dev_id != -1) break; | ||
| for (size_t i = 0; i < var_name_on_devices.size(); ++i) { | ||
| if (var_name_on_devices[i].count(var_name)) { | ||
| var_dev_id = static_cast<int>(i); | ||
| break; | ||
| } | ||
| for (auto &varname : op.InputArgumentNames()) { | ||
| int dev_id = GetVarDeviceID(varname); | ||
| if (dev_id != -1) { | ||
| return dev_id; | ||
| } | ||
| } | ||
| return var_dev_id; | ||
| return -1; | ||
| } | ||
|
|
||
| int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { | ||
| auto got = var_name_on_devices_.find(varname); | ||
| return got == var_name_on_devices_.end() ? -1 : got->second; | ||
| } | ||
|
|
||
| void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { | ||
|
|
@@ -462,17 +498,18 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, | |
| } | ||
|
|
||
| void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, | ||
| const OpDesc &op) const { | ||
| CreateComputationalOp(result, op, 0); | ||
| const OpDesc &op, | ||
| int place_id) const { | ||
| CreateComputationalOp(result, op, place_id); | ||
| if (op.Type() == "concat") { | ||
| ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); | ||
| } | ||
| } | ||
|
|
||
| void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, | ||
| const OpDesc &op) const { | ||
| result->ops_.emplace_back( | ||
| new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0])); | ||
| void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op, | ||
| int device_id) const { | ||
| result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[device_id], | ||
| op.Type(), places_[device_id])); | ||
|
|
||
| if (op.Type() == "send_barrier") { | ||
| ConnectOp(result, result->ops_.back().get(), "send"); | ||
|
|
@@ -490,7 +527,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, | |
|
|
||
| // TODO(Yancey1989): schedule rpc op on different place may | ||
| // increate throughput | ||
| CreateOpHandleIOs(result, op, 0); | ||
| CreateOpHandleIOs(result, op, device_id); | ||
| } | ||
|
|
||
| bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,10 +47,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { | |
| #endif | ||
|
|
||
| std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; | ||
| int GetVarDeviceID(const std::string &varname) const; | ||
|
|
||
| private: | ||
| void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, | ||
| size_t place_id) const; | ||
| size_t device_id) const; | ||
|
|
||
| private: | ||
| std::string loss_var_name_; | ||
|
|
@@ -64,8 +65,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { | |
|
|
||
| bool IsScaleLossOp(const OpDesc &op) const; | ||
|
|
||
| void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; | ||
| void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const; | ||
| void CreateRPCOp(SSAGraph *result, const OpDesc &op, int place_id) const; | ||
| void CreateDistTrainOp(SSAGraph *result, const OpDesc &op, | ||
| int place_id) const; | ||
|
|
||
| /** | ||
| * Is this operator as the end-point operator before/after send operator. | ||
|
|
@@ -96,9 +98,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { | |
| const std::string &og, | ||
| std::unordered_set<std::string> *og_has_been_broadcast) const; | ||
|
|
||
| int GetOpDeviceID( | ||
| const std::vector<std::unordered_set<std::string>> &var_name_on_devices, | ||
| const OpDesc &op) const; | ||
| int GetOpDeviceID(const OpDesc &op) const; | ||
|
|
||
| void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; | ||
|
|
||
|
|
@@ -111,6 +111,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { | |
|
|
||
| private: | ||
| BuildStrategy strategy_; | ||
| mutable std::unordered_map<std::string, int> var_name_on_devices_; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not use unordered_map to record the var_name on devices, because the same var_name may be on different devices.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May not, this does not record all variables, only used for Reduce strategy and distributed training. For the Reduce strategy, we schedule Reduce Op on the different device and record the gradient variable name in For the distributed training, the same as Reduce strategy, we schedule |
||
|
|
||
| void SetCommunicationContext(OpHandleBase *op_handle, | ||
| const platform::Place &p) const; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -110,7 +110,6 @@ ParallelExecutor::ParallelExecutor( | |
|
|
||
| // Step 3. Convert main_program to SSA form and dependency graph. Also, insert | ||
| // ncclOp | ||
|
|
||
| details::SSAGraphBuilderFactory builder_factory( | ||
| member_->places_, loss_var_name, params, member_->local_scopes_, | ||
| build_strategy); | ||
|
|
@@ -122,9 +121,10 @@ ParallelExecutor::ParallelExecutor( | |
| #endif | ||
| } | ||
|
|
||
| builder_ = std::move(builder_factory.Create()); | ||
| member_->executor_.reset(new details::ThreadedSSAGraphExecutor( | ||
| exec_strategy, member_->local_scopes_, places, | ||
| builder_factory.Create()->Build(main_program))); | ||
| builder_->Build(main_program))); | ||
|
|
||
| member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( | ||
| exec_strategy, member_->local_scopes_, std::move(var_infos), | ||
|
|
@@ -160,8 +160,15 @@ void ParallelExecutor::BCastParamsToGPUs( | |
| buffer = t->mutable_data(place, main_tensor.type()); | ||
| } | ||
| auto &nccl_ctx = member_->nccl_ctxs_->at(place); | ||
| platform::dynload::ncclBcast(buffer, numel, data_type, 0, | ||
| nccl_ctx.comm_, nccl_ctx.stream()); | ||
|
|
||
| if (builder_.get() != nullptr && builder_->GetVarDeviceID(var) != -1) { | ||
|
||
| int place_id = builder_->GetVarDeviceID(var); | ||
| platform::dynload::ncclBcast(buffer, numel, data_type, place_id, | ||
| nccl_ctx.comm_, nccl_ctx.stream()); | ||
| } else { | ||
| platform::dynload::ncclBcast(buffer, numel, data_type, 0, | ||
| nccl_ctx.comm_, nccl_ctx.stream()); | ||
| } | ||
| } | ||
| member_->nccl_ctxs_->WaitAll(); | ||
| #else | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Yancey1989 as we discussed, one concern, the order when calling get_appropriate_dev must be the same to reduce and split_op or the device id for the variable may be different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, done.