From a1c7d6cdd7038cb550a5d4bd3636c49fc456d779 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Sat, 23 Mar 2019 18:00:36 +0800 Subject: [PATCH 1/4] test fix fetch bar place for ce --- .../fluid/framework/details/multi_devices_graph_pass.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 125dbf746c3880..2e9b6b73793da0 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -826,7 +826,11 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { } sharded_var_device_.emplace(send_param_grad[1], op_dev_id); } - } else if (node->Op()->Type() == "recv") { + } else if (node->Op()->Type() == "recv" || + node->Op()->Type() == "fetch_barrier") { + // **IMPORTANT** fetch barrier's output should have same place with recv, + // see + // distribute_transpiler.py std::vector output_var_names; for (ir::Node *n : node->outputs) { output_var_names.push_back(n->Name()); @@ -845,7 +849,7 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { sharded_var_device_.emplace(varname, op_dev_id); } } else { - // send_barrier, fetch_barrier will run on place 0; + // send_barrier will run on place 0; op_dev_id = 0; } From 4595808734acaff98e037cf0a5f1d74a93628d89 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Sun, 24 Mar 2019 16:11:14 +0800 Subject: [PATCH 2/4] fix ps mode dist train in develop test=develop --- paddle/fluid/framework/details/CMakeLists.txt | 3 +- .../details/fetch_barrier_op_handle.cc | 66 +++++++++++++++++++ .../details/fetch_barrier_op_handle.h | 60 +++++++++++++++++ .../details/multi_devices_graph_pass.cc | 28 +++++--- .../fluid/framework/details/op_handle_base.cc | 4 +- 5 files changed, 148 insertions(+), 13 deletions(-) create mode 100644 paddle/fluid/framework/details/fetch_barrier_op_handle.cc create mode 100644 paddle/fluid/framework/details/fetch_barrier_op_handle.h diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 7a371af510b805..77e94e998c4db1 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -5,6 +5,7 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) +cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper) cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper) @@ -72,7 +73,7 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass) cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle - scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle fused_broadcast_op_handle) + scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle fused_broadcast_op_handle) cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle) diff --git a/paddle/fluid/framework/details/fetch_barrier_op_handle.cc b/paddle/fluid/framework/details/fetch_barrier_op_handle.cc new file mode 100644 index 00000000000000..019ecfbb610285 --- /dev/null +++ b/paddle/fluid/framework/details/fetch_barrier_op_handle.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/fetch_barrier_op_handle.h" + +#include + +namespace paddle { +namespace framework { +namespace details { +FetchBarrierOpHandle::FetchBarrierOpHandle( + ir::Node *node, const std::vector &local_scopes, + const std::vector &places) + // fetch_barrier op always run on place0, but output on all places. + : OpHandleBase(node), + op_(framework::OpRegistry::CreateOp(*node->Op())), + local_scopes_(local_scopes), + places_(places), + run_scope_(local_scopes[0]), + place_(places[0]) { + for (auto &p : places) { + this->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); + } +} + +bool FetchBarrierOpHandle::IsMultiDeviceTransfer() { + // override IsMultiDeviceTransfer to return true + return true; +} + +void FetchBarrierOpHandle::RunImpl() { + WaitInputVarGenerated(place_); + + auto run_func = [this]() { + op_->Run(*run_scope_->FindVar(kLocalExecScopeName)->Get(), place_); + }; + + if (is_lock_and_record_event_free_) { + run_func(); + } else { + this->RunAndRecordEvent(run_func); + } +} + +bool FetchBarrierOpHandle::NeedWait(VarHandleBase *in_var) { + bool need_wait = + in_var && in_var->GeneratedOp() && + in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_.at(place_); + return need_wait; +} + +std::string FetchBarrierOpHandle::Name() const { return op_->Type(); } +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/fetch_barrier_op_handle.h b/paddle/fluid/framework/details/fetch_barrier_op_handle.h new file mode 100644 index 00000000000000..1678c71784635c --- /dev/null +++ b/paddle/fluid/framework/details/fetch_barrier_op_handle.h @@ -0,0 +1,60 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace framework { +namespace details { + +// **NOTE**: fetch_barrier op is special it outputs all recved variables on +// all places if there are multiple places, must init with +// multiple dev_ctxes_ !!!! + +struct FetchBarrierOpHandle : public OpHandleBase { + public: + FetchBarrierOpHandle(ir::Node *node, const std::vector &local_scopes, + const std::vector &places); + + bool IsMultiDeviceTransfer() override; + + std::string Name() const override; + + protected: + void RunImpl() override; + + bool NeedWait(VarHandleBase *in_var) override; + + private: + std::unique_ptr op_; + std::vector local_scopes_; + std::vector places_; + Scope *run_scope_; + platform::Place place_; + + bool is_lock_and_record_event_free_{false}; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 2e9b6b73793da0..a6f03ec98417b8 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -23,6 +23,7 @@ #include "paddle/fluid/framework/details/all_reduce_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/fetch_barrier_op_handle.h" #include "paddle/fluid/framework/details/fused_broadcast_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" @@ -354,9 +355,12 @@ void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result, result->Get(kGraphVars).at(src_dev_id).at(p_name).back(); op_handle->AddInput(in); + // **NOTE** bcast op should run on src dev and outputs on all devs. + auto &p = places_[src_dev_id]; + SetCommunicationContext(op_handle, p); + for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; - SetCommunicationContext(op_handle, p); auto &vars = result->Get(kGraphVars).at(i).at(p_name); auto *out_var = new VarHandle( result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(), @@ -826,11 +830,7 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { } sharded_var_device_.emplace(send_param_grad[1], op_dev_id); } - } else if (node->Op()->Type() == "recv" || - node->Op()->Type() == "fetch_barrier") { - // **IMPORTANT** fetch barrier's output should have same place with recv, - // see - // distribute_transpiler.py + } else if (node->Op()->Type() == "recv") { std::vector output_var_names; for (ir::Node *n : node->outputs) { output_var_names.push_back(n->Name()); @@ -849,15 +849,23 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { sharded_var_device_.emplace(varname, op_dev_id); } } else { - // send_barrier will run on place 0; + // send_barrier, fetch_barrier will run on place 0; op_dev_id = 0; } PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", node->Op()->Type()); - result->Get(kGraphOps).emplace_back(new RPCOpHandle( - result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], - node->Op()->Type(), places_[op_dev_id])); + + // Create fetch_barrier op handle to enable output on all devices. + // **NOTE** fetch_barrier should output variables list same as recv op does. + if (node->Op()->Type() == "fetch_barrier") { + result->Get(kGraphOps).emplace_back(new FetchBarrierOpHandle( + result->CreateOpNode(node->Op()), local_scopes_, places_)); + } else { + result->Get(kGraphOps).emplace_back(new RPCOpHandle( + result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], + node->Op()->Type(), places_[op_dev_id])); + } if (node->Op()->Type() == "send") { CreateOpHandleIOs(result, node, op_dev_id); diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index 158da6f606f3f5..413b14961631b3 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -55,7 +55,7 @@ void OpHandleBase::Run(bool use_cuda) { if (out_var_handle) { int dev_id = boost::get(out_var_handle->place()).device; - out_var_handle->SetGenerateEvent(events_[dev_id]); + out_var_handle->SetGenerateEvent(events_.at(dev_id)); } } } else { @@ -71,7 +71,7 @@ void OpHandleBase::Run(bool use_cuda) { "The place of input(%s) is not consistent with the " "place of current op(%s).", out_var_handle->Name(), Name()); - out_var_handle->SetGenerateEvent(events_[dev_id]); + out_var_handle->SetGenerateEvent(events_.at(dev_id)); } } } From c13f39616958d95565059f1075c835b6ed843479 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Sun, 24 Mar 2019 16:31:07 +0800 Subject: [PATCH 3/4] fix style check test=develop --- paddle/fluid/framework/details/fetch_barrier_op_handle.h | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/framework/details/fetch_barrier_op_handle.h b/paddle/fluid/framework/details/fetch_barrier_op_handle.h index 1678c71784635c..b4d12785e0345c 100644 --- a/paddle/fluid/framework/details/fetch_barrier_op_handle.h +++ b/paddle/fluid/framework/details/fetch_barrier_op_handle.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include From f36a43b54802bfb167f64c2ee1c7b033abf85734 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 25 Mar 2019 10:19:20 +0800 Subject: [PATCH 4/4] update test=develop --- paddle/fluid/framework/details/multi_devices_graph_pass.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index a6f03ec98417b8..253cf5b4a8221a 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -355,12 +355,9 @@ void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result, result->Get(kGraphVars).at(src_dev_id).at(p_name).back(); op_handle->AddInput(in); - // **NOTE** bcast op should run on src dev and outputs on all devs. - auto &p = places_[src_dev_id]; - SetCommunicationContext(op_handle, p); - for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; + SetCommunicationContext(op_handle, p); auto &vars = result->Get(kGraphVars).at(i).at(p_name); auto *out_var = new VarHandle( result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(),