Skip to content

Commit 57fc341

Browse files
author
chengduozh
committed
create fuse_all_reduce_op_pass
test=develop
1 parent fcb9c81 commit 57fc341

File tree

8 files changed

+213
-135
lines changed

8 files changed

+213
-135
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
7272

7373
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
7474
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle
75-
data_balance_op_handle fused_broadcast_op_handle fused_all_reduce_op_handle)
75+
data_balance_op_handle fused_broadcast_op_handle)
76+
77+
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
7678

7779
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass)
7880
if (WITH_GPU)
@@ -102,4 +104,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
102104
multi_devices_graph_print_pass multi_devices_graph_check_pass
103105
fuse_elewise_add_act_pass multi_batch_merge_pass
104106
fuse_relu_depthwise_conv_pass
105-
memory_optimize_pass lock_free_optimize_pass alloc_continuous_space_for_grad_pass fuse_adam_op_pass fuse_sgd_op_pass)
107+
memory_optimize_pass lock_free_optimize_pass alloc_continuous_space_for_grad_pass fuse_adam_op_pass fuse_sgd_op_pass fuse_all_reduce_op_pass)

paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ namespace paddle {
2626
namespace framework {
2727
namespace details {
2828

29-
class AllocContinuousSpaceForGrad : public ir::Pass {
29+
class AllocContinuousSpaceForGradPass : public ir::Pass {
3030
protected:
3131
std::unique_ptr<ir::Graph> ApplyImpl(
3232
std::unique_ptr<ir::Graph> graph) const override {
3333
ir::Graph& result = *graph;
3434
if (result.Has(kParamsAndGrads)) {
35-
VLOG(10) << kParamsAndGrads << " are reset.";
35+
VLOG(10) << kParamsAndGrads << " is reset.";
3636
result.Erase(kParamsAndGrads);
3737
}
3838
result.Set(kParamsAndGrads, new ParamsAndGrads);
@@ -161,4 +161,4 @@ class AllocContinuousSpaceForGrad : public ir::Pass {
161161
} // namespace paddle
162162

163163
REGISTER_PASS(alloc_continuous_space_for_grad_pass,
164-
paddle::framework::details::AllocContinuousSpaceForGrad);
164+
paddle::framework::details::AllocContinuousSpaceForGradPass);

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
8686
}
8787

8888
if (strategy.fuse_all_optimizer_ops_) {
89-
if (!strategy.fuse_all_reduce_ops_) {
90-
VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
91-
AppendPass("alloc_continuous_space_for_grad_pass");
92-
}
89+
VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
90+
AppendPass("alloc_continuous_space_for_grad_pass");
9391
VLOG(10) << "Add fuse_adam_op_pass";
9492
AppendPass("fuse_adam_op_pass");
9593
VLOG(10) << "Add fuse_sgd_op_pass";
@@ -127,6 +125,13 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
127125

128126
AppendMultiDevPass(strategy);
129127

128+
if (strategy.fuse_all_reduce_ops_) {
129+
PADDLE_ENFORCE(strategy.reduce_ ==
130+
BuildStrategy::ReduceStrategy::kAllReduce);
131+
VLOG(10) << "Add fuse_all_reduce_op_pass";
132+
AppendPass("fuse_all_reduce_op_pass");
133+
}
134+
130135
// Add a graph print pass to record a graph with device info.
131136
if (!strategy_.debug_graphviz_path_.empty()) {
132137
auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
@@ -160,12 +165,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
160165
VLOG(10) << "Add dist_multi_devices_pass";
161166
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
162167
} else {
163-
if (strategy.fuse_all_reduce_ops_) {
164-
VLOG(10) << "Add fused_all_reduce_mode_multi_devices_pass";
165-
multi_devices_pass =
166-
AppendPass("fused_all_reduce_mode_multi_devices_pass").get();
167-
} else if (strategy.reduce_ ==
168-
BuildStrategy::ReduceStrategy::kAllReduce) {
168+
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
169169
VLOG(10) << "Add all_reduce_mode_multi_devices_pass";
170170
multi_devices_pass =
171171
AppendPass("all_reduce_mode_multi_devices_pass").get();
@@ -227,8 +227,19 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
227227

228228
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
229229
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
230-
pass->Erase("nccl_ctxs");
231-
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
230+
pass->Erase(kNCCLCtxs);
231+
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
232+
#endif
233+
} else if (pass->Type() == "fuse_all_reduce_op_pass") {
234+
pass->Erase(kPlaces);
235+
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
236+
pass->Erase(kLocalScopes);
237+
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
238+
&local_scopes);
239+
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
240+
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
241+
pass->Erase(kNCCLCtxs);
242+
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
232243
#endif
233244
} else if (pass->Type() == "memory_optimize_pass") {
234245
if (graph->Has(kAllOpDescs)) {
@@ -300,3 +311,4 @@ USE_PASS(alloc_continuous_space_for_grad_pass);
300311
USE_PASS(graph_to_program_pass);
301312
USE_PASS(fuse_adam_op_pass);
302313
USE_PASS(fuse_sgd_op_pass);
314+
USE_PASS(fuse_all_reduce_op_pass);
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <algorithm>
16+
#include <string>
17+
#include <vector>
18+
19+
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
20+
#include "paddle/fluid/framework/details/container_cast.h"
21+
#include "paddle/fluid/framework/details/fused_all_reduce_op_handle.h"
22+
#include "paddle/fluid/framework/details/multi_devices_helper.h"
23+
#include "paddle/fluid/framework/ir/graph_helper.h"
24+
25+
namespace paddle {
26+
namespace framework {
27+
namespace details {
28+
29+
class FuseAllReduceOpPass : public ir::Pass {
30+
protected:
31+
std::unique_ptr<ir::Graph> ApplyImpl(
32+
std::unique_ptr<ir::Graph> graph) const override {
33+
ir::Graph &result = *graph;
34+
35+
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
36+
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
37+
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
38+
auto *nccl_ctxs = &Get<platform::NCCLContextMap>(kNCCLCtxs);
39+
#endif
40+
41+
std::unordered_set<std::string> grads;
42+
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
43+
size_t num_of_all_reduce = params_grads.size();
44+
grads.reserve(num_of_all_reduce);
45+
for (auto p_g : params_grads) {
46+
grads.insert(p_g.second);
47+
}
48+
49+
// find all reduce op
50+
// the gradient doesn't have sparse type
51+
//
52+
size_t num_place = places.size();
53+
std::vector<std::string> all_reduce_grads;
54+
std::vector<ir::Node *> all_reduce_ops;
55+
all_reduce_ops.reserve(grads.size());
56+
all_reduce_grads.reserve(grads.size());
57+
for (auto &node : result.Nodes()) {
58+
if (node->IsOp()) {
59+
PADDLE_ENFORCE(node->IsWrappedBy<OpHandleBase>());
60+
auto *all_reduce_op_handle =
61+
dynamic_cast<AllReduceOpHandle *>(&node->Wrapper<OpHandleBase>());
62+
if (all_reduce_op_handle) {
63+
auto inputs = DynamicCast<VarHandle>(all_reduce_op_handle->Inputs());
64+
PADDLE_ENFORCE_EQ(all_reduce_op_handle->NoDummyInputSize(),
65+
num_place);
66+
// TODO(zcd): The inputs' name should be the same.
67+
68+
PADDLE_ENFORCE_NE(grads.count(inputs.at(0)->name()), 0);
69+
all_reduce_ops.emplace_back(node);
70+
all_reduce_grads.emplace_back(inputs.at(0)->name());
71+
}
72+
}
73+
}
74+
VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size();
75+
if (all_reduce_ops.size() == 0) {
76+
return std::move(graph);
77+
}
78+
79+
PADDLE_ENFORCE_EQ(all_reduce_ops.size(), grads.size());
80+
VLOG(10) << "Insert fused_all_reduce";
81+
82+
std::vector<VarHandleBase *> inputs;
83+
std::vector<VarHandleBase *> outputs;
84+
for (auto &op : all_reduce_ops) {
85+
auto &op_handle = op->Wrapper<OpHandleBase>();
86+
inputs.insert(inputs.end(), op_handle.Inputs().begin(),
87+
op_handle.Inputs().end());
88+
// Remove output
89+
std::for_each(op_handle.Inputs().begin(), op_handle.Inputs().end(),
90+
[&op_handle](VarHandleBase *var_handle) {
91+
var_handle->RemoveOutput(&op_handle, op_handle.Node());
92+
});
93+
94+
outputs.insert(outputs.end(), op_handle.Outputs().begin(),
95+
op_handle.Outputs().end());
96+
// Remove Input
97+
std::for_each(
98+
op_handle.Outputs().begin(), op_handle.Outputs().end(),
99+
[](VarHandleBase *var_handle) { var_handle->ClearGeneratedOp(); });
100+
101+
result.RemoveNode(op_handle.Node());
102+
}
103+
104+
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
105+
CreateFusedAllReduceOp(inputs, outputs, num_of_all_reduce, places,
106+
local_scopes, nccl_ctxs, &result);
107+
#else
108+
CreateFusedAllReduceOp(inputs, outputs, num_of_all_reduce, places,
109+
local_scopes, &result);
110+
#endif
111+
112+
return std::move(graph);
113+
}
114+
115+
private:
116+
void CreateFusedAllReduceOp(const std::vector<VarHandleBase *> &inputs,
117+
const std::vector<VarHandleBase *> &outputs,
118+
const size_t num_of_all_reduce,
119+
const std::vector<platform::Place> &places,
120+
const std::vector<Scope *> &local_scopes,
121+
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
122+
const platform::NCCLContextMap *nccl_ctxs,
123+
#endif
124+
ir::Graph *result) const {
125+
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
126+
auto *op_handle = new FusedAllReduceOpHandle(
127+
result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
128+
local_scopes, places, num_of_all_reduce, nccl_ctxs);
129+
#else
130+
auto *op_handle = new FusedAllReduceOpHandle(
131+
result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
132+
local_scopes, places, num_of_all_reduce);
133+
#endif
134+
135+
for (auto in : inputs) {
136+
op_handle->AddInput(in);
137+
}
138+
139+
for (auto out : outputs) {
140+
op_handle->AddOutput(out);
141+
}
142+
143+
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
144+
if (!nccl_ctxs) {
145+
SetCommunicationContext(places, op_handle);
146+
}
147+
#else
148+
SetCommunicationContext(places, op_handle);
149+
#endif
150+
}
151+
152+
void SetCommunicationContext(const std::vector<platform::Place> &places,
153+
FusedAllReduceOpHandle *op_handle) const {
154+
for (size_t i = 0; i < places.size(); ++i) {
155+
op_handle->SetDeviceContext(
156+
places[i], platform::DeviceContextPool::Instance().Get(places[i]));
157+
}
158+
}
159+
};
160+
161+
} // namespace details
162+
} // namespace framework
163+
} // namespace paddle
164+
165+
REGISTER_PASS(fuse_all_reduce_op_pass,
166+
paddle::framework::details::FuseAllReduceOpPass);

paddle/fluid/framework/details/fuse_optimizer_op_pass.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ std::unique_ptr<ir::Graph> FuseOptimizerOpPass::ApplyImpl(
4444
return std::move(graph);
4545
}
4646

47+
if (result.Has(kFusedOptType)) {
48+
VLOG(10)
49+
<< "Currently only support fusing one type optimizer op. Has fused "
50+
<< result.Get<FusedOptType>(kFusedOptType);
51+
return std::move(graph);
52+
} else {
53+
result.Set(kFusedOptType, new FusedOptType);
54+
}
55+
result.Get<FusedOptType>(kFusedOptType) = fuse_op_type;
56+
4757
// Step 2: Insert fused_var_name to FusedVars, and the FusedVars need be
4858
// initialized in scopes before execution.
4959
if (!result.Has(kFusedVars)) {

0 commit comments

Comments
 (0)