Skip to content

Commit 58a4f9d

Browse files
author
chengduozh
committed
add fuse_grad_op
1 parent 5a8bd82 commit 58a4f9d

5 files changed

Lines changed: 274 additions & 0 deletions

File tree

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_he
1010
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
1111
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
1212

13+
cc_library(fuse_gradient_space_pass SRCS fuse_gradient_space_pass.cc DEPS graph graph_helper)
14+
1315
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
1416

1517
if(WITH_DISTRIBUTE)
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include "paddle/fluid/framework/details/fuse_gradient_space_pass.h"
15+
#include <algorithm>
16+
#include <fstream>
17+
#include <string>
18+
#include <utility>
19+
#include <vector>
20+
21+
#include "paddle/fluid/framework/ir/graph_helper.h"
22+
#include "paddle/fluid/framework/ir/node.h"
23+
#include "paddle/fluid/framework/op_info.h"
24+
#include "paddle/fluid/framework/op_registry.h"
25+
26+
namespace paddle {
27+
namespace framework {
28+
namespace details {
29+
30+
std::unique_ptr<ir::Graph> FuseGradientSpacePass::ApplyImpl(
31+
std::unique_ptr<ir::Graph> graph) const {
32+
ir::Graph& result = *graph;
33+
34+
result.Set(kParamsAndGrads, new ParamsAndGrads);
35+
std::unordered_map<std::string, ir::Node*> vars;
36+
std::unordered_map<std::string, ir::Node*> ops;
37+
// Get parameters and gradients
38+
for (ir::Node* node : graph->Nodes()) {
39+
if (node->IsVar()) {
40+
if (node->Var()) {
41+
auto var_name = node->Var()->Name();
42+
PADDLE_ENFORCE_EQ(vars.count(var_name), static_cast<size_t>(0));
43+
vars.emplace(var_name, node);
44+
}
45+
} else {
46+
try {
47+
bool is_bk_op =
48+
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
49+
OpProtoAndCheckerMaker::OpRoleAttrName())) &
50+
static_cast<int>(OpRole::kBackward));
51+
if (!is_bk_op) continue;
52+
53+
// Currently, we assume that once gradient is generated, it can be
54+
// broadcast, and each gradient is only broadcast once.
55+
auto backward_vars =
56+
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
57+
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
58+
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
59+
60+
for (size_t i = 0; i < backward_vars.size(); i += 2) {
61+
result.Get<ParamsAndGrads>(kParamsAndGrads)
62+
.emplace(backward_vars[i] /*param*/,
63+
backward_vars[i + 1] /*grad*/);
64+
ops.emplace(backward_vars[i + 1], node);
65+
}
66+
} catch (boost::bad_get e) {
67+
}
68+
}
69+
}
70+
71+
std::vector<std::string> grads_name;
72+
// Set Gradients as Persistable
73+
proto::VarType::Type fuse_space_type = static_cast<proto::VarType::Type>(0);
74+
auto& params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
75+
for (auto& p_g : params_grads) {
76+
auto iter = vars.find(p_g.second);
77+
PADDLE_ENFORCE(iter != vars.end());
78+
// Set Persistable
79+
iter->second->Var()->SetPersistable(true);
80+
81+
// The Gradient can be SeletedRows and LoDTensor
82+
bool valid_type =
83+
(iter->second->Var()->GetType() == proto::VarType::LOD_TENSOR);
84+
PADDLE_ENFORCE(valid_type);
85+
// Get Dtype
86+
auto dtype = iter->second->Var()->GetDataType();
87+
if (fuse_space_type == static_cast<proto::VarType::Type>(0)) {
88+
fuse_space_type = dtype;
89+
PADDLE_ENFORCE_NE(dtype, static_cast<proto::VarType::Type>(0));
90+
}
91+
PADDLE_ENFORCE_EQ(dtype, fuse_space_type);
92+
93+
grads_name.emplace_back(p_g.second);
94+
}
95+
96+
OpDesc desc;
97+
desc.SetType("alloc_space_for_vars");
98+
desc.SetInput("Input", grads_name);
99+
desc.SetOutput("Output", grads_name);
100+
101+
auto alloc_space_node = result.CreateOpNode(&desc);
102+
// Need Insert alloc_space_node's input
103+
104+
// Insert alloc_space_node's output
105+
for (auto& op : ops) {
106+
auto ctl_node = result.CreateControlDepVar();
107+
alloc_space_node->outputs.emplace_back(ctl_node);
108+
ctl_node->inputs.emplace_back(alloc_space_node);
109+
op.second->inputs.emplace_back(ctl_node);
110+
ctl_node->outputs.emplace_back(op.second);
111+
}
112+
return std::move(graph);
113+
}
114+
115+
} // namespace details
116+
} // namespace framework
117+
} // namespace paddle
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include <utility>
19+
#include <vector>
20+
21+
#include "paddle/fluid/framework/details/build_strategy.h"
22+
#include "paddle/fluid/framework/details/multi_devices_helper.h"
23+
#include "paddle/fluid/framework/ir/graph.h"
24+
25+
namespace paddle {
26+
namespace framework {
27+
namespace details {
28+
29+
class FuseGradientSpacePass : public ir::Pass {
30+
protected:
31+
std::unique_ptr<ir::Graph> ApplyImpl(
32+
std::unique_ptr<ir::Graph> graph) const override;
33+
34+
void GetParamsAndGrads(const ir::Graph &graph) const;
35+
/*
36+
void SetGradientAsPersistable(
37+
ir::Graph &graph,
38+
const std::unordered_map<std::string, std::string> params_grads) const;
39+
40+
ir::Node CreateFusedSpaceOp(
41+
const std::unordered_map<std::string, std::string> params_grads) const;
42+
43+
ir::Node CreateGetVarSpaceOp(
44+
const std::unordered_map<std::string, std::string> params_grads) const;
45+
46+
private:
47+
std::unordered_map<std::string, std::string> params_grads;
48+
*/
49+
};
50+
51+
} // namespace details
52+
} // namespace framework
53+
} // namespace paddle

paddle/fluid/framework/details/multi_devices_helper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ const char kGraphVars[] = "vars";
4343
// aux variables to represent dependency. Useful to resolve data hazard.
4444
typedef std::unordered_set<VarHandleBase*> GraphDepVars;
4545
const char kGraphDepVars[] = "dep_vars";
46+
47+
typedef std::unordered_map<std::string, std::string> ParamsAndGrads;
48+
const char kParamsAndGrads[] = "params_grads";
4649
} // namespace details
4750
} // namespace framework
4851
} // namespace paddle
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <vector>
16+
#include "paddle/fluid/framework/lod_tensor_array.h"
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/framework/operator.h"
19+
#include "paddle/fluid/framework/var_type.h"
20+
#include "paddle/fluid/operators/detail/safe_ref.h"
21+
22+
namespace paddle {
23+
namespace operators {
24+
25+
class AllocSpaceForVarsOpOp : public framework::OperatorBase {
26+
public:
27+
AllocSpaceForVarsOpOp(const std::string &type,
28+
const framework::VariableNameMap &inputs,
29+
const framework::VariableNameMap &outputs,
30+
const framework::AttributeMap &attrs)
31+
: framework::OperatorBase(type, inputs, outputs, attrs) {}
32+
33+
private:
34+
void RunImpl(const framework::Scope &scope,
35+
const platform::Place &dev_place) const override {
36+
auto &in_var_names = Inputs("Input");
37+
PADDLE_ENFORCE_GT(in_var_names.size(), 0);
38+
size_t mem_size = 0;
39+
framework::proto::VarType::Type fuse_space_type =
40+
static_cast<framework::proto::VarType::Type>(0);
41+
for (auto &name : in_var_names) {
42+
auto var = scope.FindVar(name);
43+
PADDLE_ENFORCE_NOT_NULL(var);
44+
// Only support LoDTensor,
45+
bool valid_var = var->IsType<framework::LoDTensor>();
46+
PADDLE_ENFORCE(valid_var, "");
47+
auto tensor = var->Get<framework::LoDTensor>();
48+
auto dtype = tensor.type();
49+
if (fuse_space_type == static_cast<framework::proto::VarType::Type>(0)) {
50+
fuse_space_type = dtype;
51+
PADDLE_ENFORCE_NE(dtype,
52+
static_cast<framework::proto::VarType::Type>(0));
53+
}
54+
PADDLE_ENFORCE_EQ(dtype, fuse_space_type);
55+
auto size = tensor.numel();
56+
PADDLE_ENFORCE_GT(size, 0);
57+
mem_size += size;
58+
}
59+
auto out_var_names = Outputs("Input");
60+
61+
PADDLE_ENFORCE_EQ(in_var_names.size(), out_var_names.size());
62+
auto out_tensor =
63+
scope.FindVar(out_var_names[0])->GetMutable<framework::LoDTensor>();
64+
auto origin_dim = out_tensor->dims();
65+
auto offset = framework::product(origin_dim);
66+
67+
out_tensor->Resize(framework::make_ddim({static_cast<int64_t>(mem_size)}));
68+
out_tensor->mutable_data(dev_place, fuse_space_type);
69+
70+
for (size_t i = 1; i < out_var_names.size(); ++i) {
71+
PADDLE_ENFORCE_EQ(in_var_names[i], out_var_names[i]);
72+
auto out_t =
73+
scope.FindVar(out_var_names[i])->GetMutable<framework::LoDTensor>();
74+
auto &origin_dim = out_t->dims();
75+
int64_t len = out_t->numel();
76+
out_t->ShareDataWith(out_tensor->Slice(offset, offset + len));
77+
offset += len;
78+
out_t->Resize(origin_dim);
79+
}
80+
out_tensor->Resize(origin_dim);
81+
}
82+
};
83+
84+
class AllocSpaceForVarsOpOpMaker : public framework::OpProtoAndCheckerMaker {
85+
public:
86+
void Make() override {
87+
AddInput("Input", "A set of variables.").AsDuplicable();
88+
AddOutput("Output", "A set of variables.").AsDuplicable();
89+
AddComment(R"DOC(
90+
)DOC");
91+
}
92+
};
93+
94+
} // namespace operators
95+
} // namespace paddle
96+
97+
REGISTER_OPERATOR(alloc_space_for_vars,
98+
paddle::operators::AllocSpaceForVarsOpOp,
99+
paddle::operators::AllocSpaceForVarsOpOpMaker);

0 commit comments

Comments
 (0)