Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry)

cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
Expand Down Expand Up @@ -72,7 +73,7 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)

cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)

cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)

Expand Down
66 changes: 66 additions & 0 deletions paddle/fluid/framework/details/fetch_barrier_op_handle.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/fetch_barrier_op_handle.h"

#include <string>

namespace paddle {
namespace framework {
namespace details {
FetchBarrierOpHandle::FetchBarrierOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
// fetch_barrier op always run on place0, but output on all places.
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())),
local_scopes_(local_scopes),
places_(places),
run_scope_(local_scopes[0]),
place_(places[0]) {
for (auto &p : places) {
this->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p));
}
}

bool FetchBarrierOpHandle::IsMultiDeviceTransfer() {
// override IsMultiDeviceTransfer to return true
return true;
}

void FetchBarrierOpHandle::RunImpl() {
WaitInputVarGenerated(place_);

auto run_func = [this]() {
op_->Run(*run_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
};

if (is_lock_and_record_event_free_) {
run_func();
} else {
this->RunAndRecordEvent(run_func);
}
}

bool FetchBarrierOpHandle::NeedWait(VarHandleBase *in_var) {
bool need_wait =
in_var && in_var->GeneratedOp() &&
in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_.at(place_);
return need_wait;
}

std::string FetchBarrierOpHandle::Name() const { return op_->Type(); }
} // namespace details
} // namespace framework
} // namespace paddle
61 changes: 61 additions & 0 deletions paddle/fluid/framework/details/fetch_barrier_op_handle.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace framework {
namespace details {

// **NOTE**: fetch_barrier op is special it outputs all recved variables on
// all places if there are multiple places, must init with
// multiple dev_ctxes_ !!!!

struct FetchBarrierOpHandle : public OpHandleBase {
public:
FetchBarrierOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);

bool IsMultiDeviceTransfer() override;

std::string Name() const override;

protected:
void RunImpl() override;

bool NeedWait(VarHandleBase *in_var) override;

private:
std::unique_ptr<OperatorBase> op_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
Scope *run_scope_;
platform::Place place_;

bool is_lock_and_record_event_free_{false};
};

} // namespace details
} // namespace framework
} // namespace paddle
15 changes: 12 additions & 3 deletions paddle/fluid/framework/details/multi_devices_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/fetch_barrier_op_handle.h"
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
Expand Down Expand Up @@ -851,9 +852,17 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {

PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
node->Op()->Type());
result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle(
result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
node->Op()->Type(), places_[op_dev_id]));

// Create fetch_barrier op handle to enable output on all devices.
// **NOTE** fetch_barrier should output variables list same as recv op does.
if (node->Op()->Type() == "fetch_barrier") {
result->Get<GraphOps>(kGraphOps).emplace_back(new FetchBarrierOpHandle(
result->CreateOpNode(node->Op()), local_scopes_, places_));
} else {
result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle(
result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
node->Op()->Type(), places_[op_dev_id]));
}

if (node->Op()->Type() == "send") {
CreateOpHandleIOs(result, node, op_dev_id);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/details/op_handle_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void OpHandleBase::Run(bool use_cuda) {
if (out_var_handle) {
int dev_id =
boost::get<platform::CUDAPlace>(out_var_handle->place()).device;
out_var_handle->SetGenerateEvent(events_[dev_id]);
out_var_handle->SetGenerateEvent(events_.at(dev_id));
}
}
} else {
Expand All @@ -71,7 +71,7 @@ void OpHandleBase::Run(bool use_cuda) {
"The place of input(%s) is not consistent with the "
"place of current op(%s).",
out_var_handle->Name(), Name());
out_var_handle->SetGenerateEvent(events_[dev_id]);
out_var_handle->SetGenerateEvent(events_.at(dev_id));
}
}
}
Expand Down