Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions paddle/framework/lod_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,15 @@ bool operator==(const LoD& a, const LoD& b);
* LoDTensor (Level of details Tensor)
* see https://en.wikipedia.org/wiki/Level_of_details for reference.
*/
class LoDTensor {
class LoDTensor : public Tensor {
public:
LoDTensor() {}
LoDTensor(const LoD& lod, Tensor* t) : lod_(lod), tensor_(t) {}

void set_lod(const LoD& lod) { lod_ = lod; }

void set_tensor(Tensor* tensor) { tensor_ = tensor; }
explicit LoDTensor(const LoD& lod) : lod_(lod) {}

Tensor& tensor() { return *tensor_; }
void set_lod(const LoD& lod) { lod_ = lod; }

LoD lod() { return lod_; }
LoD lod() const { return lod_; }

/*
* Get a element from LoD.
Expand Down Expand Up @@ -104,7 +101,6 @@ class LoDTensor {

private:
LoD lod_;
Tensor* tensor_; // not owned
};
} // namespace framework
} // namespace paddle
45 changes: 20 additions & 25 deletions paddle/framework/lod_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,69 +36,64 @@ class LoDTensorTester : public ::testing::Test {

ASSERT_EQ(lod.size(), 3UL);

tensor.Resize({20 /*batch size*/, 128 /*dim*/});
lod_tensor_.Resize({20 /*batch size*/, 128 /*dim*/});
// malloc memory
tensor.mutable_data<float>(place);
lod_tensor_.mutable_data<float>(place);

lod_tensor.set_lod(lod);
lod_tensor.set_tensor(&tensor);
lod_tensor_.set_lod(lod);
}

protected:
platform::CPUPlace place;
Tensor tensor;
LoDTensor lod_tensor;
LoDTensor lod_tensor_;
};

TEST_F(LoDTensorTester, NumLevels) { ASSERT_EQ(lod_tensor.NumLevels(), 3UL); }
TEST_F(LoDTensorTester, NumLevels) { ASSERT_EQ(lod_tensor_.NumLevels(), 3UL); }

TEST_F(LoDTensorTester, NumElements) {
ASSERT_EQ(lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(lod_tensor.NumElements(2), 8UL);
ASSERT_EQ(lod_tensor_.NumElements(0), 2UL);
ASSERT_EQ(lod_tensor_.NumElements(1), 4UL);
ASSERT_EQ(lod_tensor_.NumElements(2), 8UL);
}

TEST_F(LoDTensorTester, SliceLevels) {
// slice 1 level
for (size_t level = 0; level < 3UL; ++level) {
LoDTensor new_lod_tensor = lod_tensor;
LoDTensor new_lod_tensor = lod_tensor_;
new_lod_tensor.SliceLevels(level, level + 1);
ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level));
ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor().data<float>());
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor_.NumElements(level));
ASSERT_EQ(new_lod_tensor.data<float>(), lod_tensor_.data<float>());
}
// slice 2 level
for (size_t level = 0; level < 2UL; ++level) {
LoDTensor new_lod_tensor = lod_tensor;
LoDTensor new_lod_tensor = lod_tensor_;
new_lod_tensor.SliceLevels(level, level + 2);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level));
ASSERT_EQ(new_lod_tensor.NumElements(1), lod_tensor.NumElements(level + 1));
ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor().data<float>());
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor_.NumElements(level));
ASSERT_EQ(new_lod_tensor.NumElements(1),
lod_tensor_.NumElements(level + 1));
ASSERT_EQ(new_lod_tensor.data<float>(), lod_tensor_.data<float>());
}
}

TEST_F(LoDTensorTester, SliceInLevel) {
size_t level = 0;
LoDTensor new_lod_tensor = lod_tensor;
LoDTensor new_lod_tensor = lod_tensor_;
new_lod_tensor.SliceInLevel(level, 0, 2);
EXPECT_EQ(new_lod_tensor.NumLevels(), 3UL);
EXPECT_EQ(new_lod_tensor.NumElements(0), 2UL);
EXPECT_EQ(new_lod_tensor.NumElements(1), 4UL);
EXPECT_EQ(new_lod_tensor.NumElements(2), 8UL);
ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor().data<float>());
ASSERT_EQ(new_lod_tensor.data<float>(), lod_tensor_.data<float>());

level = 1;
new_lod_tensor = lod_tensor;
new_lod_tensor = lod_tensor_;
new_lod_tensor.SliceInLevel(level, 0, 2);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor().data<float>());
ASSERT_EQ(new_lod_tensor.data<float>(), lod_tensor_.data<float>());
}

} // namespace framework
Expand Down
6 changes: 2 additions & 4 deletions paddle/framework/lod_tensor_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,16 @@ __global__ void test(size_t* a, int size) {
}

TEST(LoDTensor, LoDInGPU) {
paddle::framework::Tensor tensor;
paddle::framework::LoDTensor lod_tensor;
paddle::platform::GPUPlace place(0);

paddle::framework::LoD src_lod;
src_lod.push_back(std::vector<size_t>{0, 2, 4, 6, 8, 10, 12, 14});

tensor.Resize({14, 16});
tensor.mutable_data<float>(place);
lod_tensor.Resize({14, 16});
lod_tensor.mutable_data<float>(place);

lod_tensor.set_lod(src_lod);
lod_tensor.set_tensor(&tensor);
CHECK_EQ(lod_tensor.lod_element(0, 2), 4);
CHECK_EQ(lod_tensor.lod_element(0, 4), 8);

Expand Down
42 changes: 42 additions & 0 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,48 @@ void OperatorBase::GenerateTemporaryNames() {
}
}

template <>
const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name);
return var == nullptr ? nullptr : GetTensorFromVar(var);
}

template <>
const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
const std::string& name) const {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specialized template function InferShapeContext::Input and InferShapeContext::MultiInput, thus, the InferShape of non-sequence operators can call Input<Tensor> to avoid to use LoDTensor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this specialization is great to hide the difference between LoDTensor and Tensor. I can also accept that we don't have this and expose LoDTensor to non-sequence operator contributors, because I feel that LoDTensor is a key difference from TensorFlow and could be a PaddlePaddle flag.

auto names = op().Inputs(name);
std::vector<const Tensor*> res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : GetTensorFromVar(var);
});
return res;
}

template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
auto* var = OutputVar(name);
return var == nullptr ? nullptr : const_cast<Tensor*>(GetTensorFromVar(var));
}

template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const {
auto names = op().Outputs(name);
std::vector<Tensor*> res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope().FindVar(sub_name);
return var == nullptr
? nullptr
: const_cast<Tensor*>(GetTensorFromVar(var));
});
return res;
}

void OpProtoAndCheckerMaker::Validate() {
validated_ = true;
CheckNoDuplicatedInOutAttrs();
Expand Down
45 changes: 45 additions & 0 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include "op_info.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
Expand Down Expand Up @@ -326,11 +327,27 @@ class InferShapeContext {
return res;
}

const Tensor* GetTensorFromVar(const Variable* var) const {
if (var->IsType<LoDTensor>()) {
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input(%s) must be LoDTensor or Tensor.");
return &var->Get<Tensor>();
}

private:
const OperatorBase& op_;
const Scope& scope_;
};

template <>
const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const;

template <>
const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
const std::string& name) const;

template <typename T>
struct EigenDeviceConverter;

Expand Down Expand Up @@ -363,9 +380,37 @@ class ExecutionContext : public InferShapeContext {
return device_context_;
}

// redefine Output function,
// use Variable::Get instead of Variable::GetMutable
template <typename T>
T* Output(const std::string& name) const {
auto var = OutputVar(name);
return var == nullptr ? nullptr : const_cast<T*>(&var->Get<T>());
}

// redefine MultiOutput function.
// use Variable::Get instead of Variable::GetMutable
template <typename T>
std::vector<T*> MultiOutput(const std::string& name) const {
auto names = op().Outputs(name);
std::vector<T*> res;
res.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { return Output<T>(sub_name); });
return res;
}

const platform::DeviceContext* device_context_;
};

template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;

template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const;

class OpKernel {
public:
/**
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/accuracy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0],
"inference size must be the same as label size");

ctx.Output<Tensor>("Accuracy")->Resize({1});
ctx.Output<framework::LoDTensor>("Accuracy")->Resize({1});
}
};

Expand Down
3 changes: 2 additions & 1 deletion paddle/operators/add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class AddOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
ctx.Input<Tensor>("Y")->dims(),
"Two input of Add Op's dimension must be same.");
ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("X")->dims());
ctx.Output<framework::LoDTensor>("Out")->Resize(
ctx.Input<Tensor>("X")->dims());
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ConcatOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto *out = ctx.Output<framework::Tensor>("Out");
auto *out = ctx.Output<framework::LoDTensor>("Out");
size_t axis = static_cast<size_t>(ctx.Attr<int>("axis"));
size_t n = ins.size();

Expand Down
12 changes: 7 additions & 5 deletions paddle/operators/cos_sim_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ class CosSimOp : public framework::OperatorWithKernel {
" just 1 (which will be broadcasted to match Input(X)).");

// resize tensor
ctx.Output<Tensor>("Out")->Resize({x_dims[0], 1});
ctx.Output<Tensor>("XNorm")->Resize({x_dims[0], 1});
ctx.Output<Tensor>("YNorm")->Resize({y_dims[0], 1});
ctx.Output<framework::LoDTensor>("Out")->Resize({x_dims[0], 1});
ctx.Output<framework::LoDTensor>("XNorm")->Resize({x_dims[0], 1});
ctx.Output<framework::LoDTensor>("YNorm")->Resize({y_dims[0], 1});
}
};

Expand Down Expand Up @@ -131,8 +131,10 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
"Shape of Input(Out@Grad) must be [X.Dim(0), 1].");

// resize tensor
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto *x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *y_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
if (x_grad) x_grad->Resize(x_dims);
if (y_grad) y_grad->Resize(y_dims);
}
Expand Down
8 changes: 5 additions & 3 deletions paddle/operators/elementwise_mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ElementWiseMulOp : public framework::OperatorWithKernel {
auto y_dim = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.")
ctx.Output<Tensor>("Out")->Resize(x_dim);
ctx.Output<framework::LoDTensor>("Out")->Resize(x_dim);
}
};

Expand Down Expand Up @@ -80,8 +80,10 @@ class ElementWiseMulOpGrad : public framework::OperatorWithKernel {
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto *x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *y_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));

PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.")
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/fill_zeros_like_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::Tensor>("Dst")->Resize(
ctx.Output<framework::LoDTensor>("Dst")->Resize(
ctx.Input<framework::Tensor>("Src")->dims());
}
};
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/gather_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class GatherOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
framework::DDim output_dims(ctx.Input<Tensor>("X")->dims());
output_dims[0] = batch_size;
ctx.Output<Tensor>("Out")->Resize(output_dims);
ctx.Output<framework::LoDTensor>("Out")->Resize(output_dims);
}
};

Expand All @@ -38,7 +38,7 @@ class GatherGradOp : public framework::OperatorWithKernel {

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X_grad = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");

X_grad->Resize(X->dims());
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/gaussian_random_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {

protected:
void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
auto* tensor = context.Output<framework::LoDTensor>("Out");
auto dims = Attr<std::vector<int>>("dims");
std::vector<int64_t> temp;
temp.reserve(dims.size());
Expand Down
5 changes: 3 additions & 2 deletions paddle/operators/lookup_table_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &context) const override {
auto table_t = context.Input<Tensor>("W");
auto ids_t = context.Input<Tensor>("Ids");
auto output_t = context.Output<Tensor>("Out");
auto output_t = context.Output<framework::LoDTensor>("Out");

output_t->Resize({ids_t->dims()[0], table_t->dims()[1]});
}
Expand Down Expand Up @@ -56,7 +56,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &context) const override {
auto table = context.Input<Tensor>("W");
auto d_table = context.Output<Tensor>(framework::GradVarName("W"));
auto d_table =
context.Output<framework::LoDTensor>(framework::GradVarName("W"));
d_table->Resize(table->dims());
}
};
Expand Down
Loading