Skip to content

Commit cb28428

Browse files
committed
Replace LoDTensor in elementwise_mul_op, pad_op and recurrent_op_utils.
1 parent 30a58b5 commit cb28428

File tree

5 files changed

+51
-39
lines changed

5 files changed

+51
-39
lines changed

paddle/framework/operator.cc

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,7 @@ void OperatorBase::GenerateTemporaryNames() {
189189
template <>
190190
const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const {
191191
auto* var = InputVar(name);
192-
if (var == nullptr) return nullptr;
193-
if (var->IsType<LoDTensor>()) {
194-
return &var->Get<LoDTensor>();
195-
}
196-
PADDLE_ENFORCE(var->IsType<Tensor>(),
197-
"The Input(%s) must be LoDTensor or Tensor.");
198-
return &var->Get<Tensor>();
192+
return var == nullptr ? nullptr : GetTensorFromVar(var);
199193
}
200194

201195
template <>
@@ -204,22 +198,19 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
204198
auto names = op().Inputs(name);
205199
std::vector<const Tensor*> res;
206200
res.reserve(names.size());
207-
std::transform(
208-
names.begin(), names.end(), std::back_inserter(res),
209-
[&](const std::string& sub_name) { return Input<Tensor>(sub_name); });
201+
std::transform(names.begin(), names.end(), std::back_inserter(res),
202+
[&](const std::string& sub_name) {
203+
auto var = scope_.FindVar(sub_name);
204+
return var == nullptr ? nullptr : GetTensorFromVar(var);
205+
});
210206
return res;
211207
}
212208

213209
template <>
214210
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
215211
auto* var = OutputVar(name);
216212
if (var == nullptr) return nullptr;
217-
if (var->IsType<LoDTensor>()) {
218-
return const_cast<LoDTensor*>(&var->Get<LoDTensor>());
219-
}
220-
PADDLE_ENFORCE(var->IsType<Tensor>(),
221-
"The Input(%s) must be LoDTensor or Tensor.");
222-
return const_cast<Tensor*>(&var->Get<Tensor>());
213+
return GetTensorFromVar(var);
223214
}
224215

225216
template <>
@@ -228,9 +219,11 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
228219
auto names = op().Outputs(name);
229220
std::vector<Tensor*> res;
230221
res.reserve(names.size());
231-
std::transform(
232-
names.begin(), names.end(), std::back_inserter(res),
233-
[&](const std::string& sub_name) { return Output<Tensor>(sub_name); });
222+
std::transform(names.begin(), names.end(), std::back_inserter(res),
223+
[&](const std::string& sub_name) {
224+
auto var = scope().FindVar(sub_name);
225+
return var == nullptr ? nullptr : GetTensorFromVar(var);
226+
});
234227
return res;
235228
}
236229

paddle/framework/operator.h

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,11 @@ class InferShapeContext {
306306
auto names = op_.Inputs(name);
307307
std::vector<const T*> res;
308308
res.reserve(names.size());
309-
std::transform(
310-
names.begin(), names.end(), std::back_inserter(res),
311-
[&](const std::string& sub_name) { return Input<T>(sub_name); });
309+
std::transform(names.begin(), names.end(), std::back_inserter(res),
310+
[&](const std::string& sub_name) {
311+
auto var = scope_.FindVar(sub_name);
312+
return var == nullptr ? nullptr : &var->Get<T>();
313+
});
312314
return res;
313315
}
314316

@@ -317,12 +319,23 @@ class InferShapeContext {
317319
auto names = op_.Outputs(name);
318320
std::vector<T*> res;
319321
res.reserve(names.size());
320-
std::transform(
321-
names.begin(), names.end(), std::back_inserter(res),
322-
[&](const std::string& sub_name) { return Output<T>(sub_name); });
322+
std::transform(names.begin(), names.end(), std::back_inserter(res),
323+
[&](const std::string& sub_name) {
324+
auto var = scope_.FindVar(sub_name);
325+
return var == nullptr ? nullptr : var->GetMutable<T>();
326+
});
323327
return res;
324328
}
325329

330+
Tensor* GetTensorFromVar(const Variable* var) const {
331+
if (var->IsType<LoDTensor>()) {
332+
return const_cast<LoDTensor*>(&var->Get<LoDTensor>());
333+
}
334+
PADDLE_ENFORCE(var->IsType<Tensor>(),
335+
"The Input(%s) must be LoDTensor or Tensor.");
336+
return const_cast<Tensor*>(&var->Get<Tensor>());
337+
}
338+
326339
private:
327340
const OperatorBase& op_;
328341
const Scope& scope_;

paddle/operators/elementwise_mul_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class ElementWiseMulOp : public framework::OperatorWithKernel {
3131
auto y_dim = ctx.Input<Tensor>("Y")->dims();
3232
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
3333
"Rank of first input must >= rank of second input.")
34-
ctx.Output<Tensor>("Out")->Resize(x_dim);
34+
ctx.Output<framework::Tensor>("Out")->Resize(x_dim);
3535
}
3636
};
3737

@@ -80,8 +80,8 @@ class ElementWiseMulOpGrad : public framework::OperatorWithKernel {
8080
auto x_dims = ctx.Input<Tensor>("X")->dims();
8181
auto y_dims = ctx.Input<Tensor>("Y")->dims();
8282
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
83-
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
84-
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
83+
auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
84+
auto *y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
8585

8686
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
8787
"Rank of first input must >= rank of second input.")

paddle/operators/pad_op.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class PadOp : public framework::OperatorWithKernel {
3434
for (int i = 0; i < x_dim.size(); ++i) {
3535
out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
3636
}
37-
ctx.Output<Tensor>("Out")->Resize(framework::make_ddim(out_dims));
37+
ctx.Output<framework::LoDTensor>("Out")->Resize(
38+
framework::make_ddim(out_dims));
3839
}
3940
};
4041

@@ -95,9 +96,9 @@ class PadOpGrad : public framework::OperatorWithKernel {
9596
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
9697
"Input(Out@GRAD) should not be null");
9798
auto x_dims = ctx.Input<Tensor>("X")->dims();
98-
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
99-
if (x_grad != nullptr) {
100-
x_grad->Resize(x_dims);
99+
auto *x_g = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
100+
if (x_g != nullptr) {
101+
x_g->Resize(x_dims);
101102
}
102103
}
103104
};

paddle/operators/rnn/recurrent_op_utils.cc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace rnn {
2121
namespace f = paddle::framework;
2222

2323
using Tensor = framework::Tensor;
24+
using LoDTensor = framework::LoDTensor;
2425

2526
void SegmentInputs(const std::vector<Scope*>& step_scopes,
2627
const std::vector<Link>& inlinks, const size_t seq_len,
@@ -31,7 +32,7 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
3132
PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.",
3233
inlinks[i].external);
3334

34-
Tensor* input = input_var->GetMutable<Tensor>();
35+
LoDTensor* input = input_var->GetMutable<LoDTensor>();
3536
f::DDim dims = input->dims();
3637
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
3738
"all the inlinks must have same length");
@@ -40,6 +41,8 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
4041
Tensor* step_input =
4142
step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
4243
if (!infer_shape_mode) {
44+
// The input of operators of each step is Tensor here.
45+
// Maybe need to modify Slice function.
4346
*step_input = input->Slice<float>(j, j + 1);
4447
}
4548
step_input->Resize(step_dims);
@@ -54,21 +57,23 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
5457
auto output_var = step_scopes[0]->FindVar(outlinks[i].external);
5558
PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.",
5659
outlinks[i].external);
57-
Tensor* output = output_var->GetMutable<Tensor>();
60+
LoDTensor* output = output_var->GetMutable<LoDTensor>();
5861

5962
if (infer_shape_mode) {
6063
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal);
6164
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
6265
outlinks[i].internal);
63-
f::DDim step_dims = step_scope_var->template GetMutable<Tensor>()->dims();
66+
f::DDim step_dims =
67+
step_scope_var->template GetMutable<LoDTensor>()->dims();
6468
std::vector<int64_t> dims_vec = vectorize(step_dims);
6569
dims_vec.insert(dims_vec.begin(), seq_len);
6670
output->Resize(f::make_ddim(dims_vec));
6771
} else {
6872
output->mutable_data<float>(platform::CPUPlace());
6973
for (size_t j = 0; j < seq_len; j++) {
70-
Tensor* step_output =
71-
step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable<Tensor>();
74+
LoDTensor* step_output = step_scopes[j]
75+
->FindVar(outlinks[i].internal)
76+
->GetMutable<LoDTensor>();
7277
// TODO(luotao02) data type and platform::DeviceContext() should set
7378
// correctly
7479
(output->Slice<float>(j, j + 1))
@@ -94,8 +99,8 @@ void LinkMemories(const std::vector<Scope*>& scopes,
9499
auto scope = scopes[step_id];
95100
auto linked_scope = scopes[step_id + offset];
96101
for (auto& attr : memories) {
97-
auto mem = scope->FindVar(attr.pre_var)->GetMutable<Tensor>();
98-
auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<Tensor>();
102+
auto mem = scope->FindVar(attr.pre_var)->GetMutable<LoDTensor>();
103+
auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<LoDTensor>();
99104
if (infer_shape_mode) {
100105
mem->Resize(linked_mem->dims());
101106
} else {

0 commit comments

Comments
 (0)