Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 11 additions & 29 deletions paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->IsRuntime()) {
return;
}
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input W of FusedEmbeddingSeqPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"),
Expand All @@ -42,36 +45,15 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
// we only support sum now
PADDLE_ENFORCE_EQ(combiner, "sum");

int64_t last_dim = table_dims[1];
for (int i = 1; i != ids_dims.size(); ++i) {
last_dim *= ids_dims[i];
}

if (ctx->IsRuntime()) {
framework::Variable* ids_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Ids")[0]);
const auto& ids_lod = ids_var->Get<LoDTensor>().lod();
int64_t last_dim = FusedEmbeddingSeqPoolLastDim(table_dims, ids_dims);
// in compile time, the lod level of ids must be 1
framework::VarDesc* ids_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]);
PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1);

// in run time, the LoD of ids must be 1
PADDLE_ENFORCE(ids_lod.size(), 1u,
"The LoD level of Input(Ids) must be 1");
PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty");

int64_t batch_size = ids_lod[0].size() - 1;

// in run time, the shape from Ids -> output
// should be [seq_length, 1] -> [batch_size, embedding_size]
ctx->SetOutputDim("Out", framework::make_ddim({batch_size, last_dim}));
} else {
// in compile time, the lod level of ids must be 1
framework::VarDesc* ids_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]);
PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1);

// in compile time, the shape from Ids -> output
// should be [-1, 1] -> [-1, embedding_size]
ctx->SetOutputDim("Out", framework::make_ddim({-1, last_dim}));
}
// in compile time, the shape from Ids -> output
// should be [-1, 1] -> [-1, embedding_size]
ctx->SetOutputDim("Out", framework::make_ddim({-1, last_dim}));
}

protected:
Expand Down
20 changes: 20 additions & 0 deletions paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ struct EmbeddingVSumFunctor {
}
};

inline int FusedEmbeddingSeqPoolLastDim(const framework::DDim &table_dims,
const framework::DDim &ids_dims) {
int64_t last_dim = table_dims[1];
for (int i = 1; i != ids_dims.size(); ++i) {
last_dim *= ids_dims[i];
}
return last_dim;
}

template <typename T>
class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
public:
Expand All @@ -70,6 +79,17 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
const LoDTensor *table_var = context.Input<LoDTensor>("W");
const std::string &combiner_type = context.Attr<std::string>("combiner");

int64_t last_dim =
FusedEmbeddingSeqPoolLastDim(table_var->dims(), ids_t->dims());
const auto &ids_lod = ids_t->lod();
// in run time, the LoD of ids must be 1
PADDLE_ENFORCE(ids_lod.size(), 1u, "The LoD level of Input(Ids) must be 1");
PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty");
int64_t batch_size = ids_lod[0].size() - 1;
// in run time, the shape from Ids -> output
// should be [seq_length, 1] -> [batch_size, embedding_size]
output_t->Resize({batch_size, last_dim});

if (combiner_type == "sum") {
EmbeddingVSumFunctor<T> functor;
functor(context, table_var, ids_t, output_t);
Expand Down
15 changes: 5 additions & 10 deletions paddle/fluid/operators/hash_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ limitations under the License. */

#include "paddle/fluid/operators/hash_op.h"
#include <string>
#include <vector>

namespace paddle {
namespace operators {
Expand All @@ -27,6 +26,9 @@ class HashOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}

void InferShape(framework::InferShapeContext *ctx) const override {
if (ctx->IsRuntime()) {
return;
}
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of HashOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
Expand All @@ -36,15 +38,8 @@ class HashOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(dims.size(), 2UL,
"The input of hash_op's dimensions must be 2");
std::vector<int64_t> out_dims;
out_dims.reserve(dims.size() + 1);
// copy all dims except the last one
for (int i = 0u; i != dims.size() - 1; ++i) {
out_dims.emplace_back(dims[i]);
}
int num_hash = ctx->Attrs().Get<int>("num_hash");
out_dims.emplace_back(num_hash);
// keep the last dim to 1
out_dims.emplace_back(1);
HashOutputSize(dims, out_dims, num_hash);

ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
ctx->ShareLoD("X", /*->*/ "Out");
Expand All @@ -71,4 +66,4 @@ class HashOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators;

REGISTER_OP_WITHOUT_GRADIENT(hash, ops::HashOp, ops::HashOpMaker);
REGISTER_OP_CPU_KERNEL(hash, ops::HashKerel<int>, ops::HashKerel<int64_t>);
REGISTER_OP_CPU_KERNEL(hash, ops::HashKernel<int>, ops::HashKernel<int64_t>);
25 changes: 22 additions & 3 deletions paddle/fluid/operators/hash_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,46 @@ limitations under the License. */
extern "C" {
#include <xxhash.h>
}
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {
// template <typename DeviceContext, typename T>

inline void HashOutputSize(const framework::DDim& in_dims,
std::vector<int64_t>& out_dims, // NOLINT
int num_hash) {
out_dims.reserve(in_dims.size() + 1);
// copy all dims except the last one
for (int i = 0u; i != in_dims.size() - 1; ++i) {
out_dims.emplace_back(in_dims[i]);
}
out_dims.emplace_back(num_hash);
// keep the last dim to 1
out_dims.emplace_back(1);
}

template <typename T>
class HashKerel : public framework::OpKernel<T> {
class HashKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& context) const {
auto* out_t = context.Output<framework::LoDTensor>("Out");
auto* in_t = context.Input<framework::LoDTensor>("X");
int mod_by = context.Attr<int>("mod_by");
int num_hash = context.Attr<int>("num_hash");
auto* output = out_t->mutable_data<T>(context.GetPlace());

auto in_dims = in_t->dims();
auto in_lod = in_t->lod();
PADDLE_ENFORCE_EQ(
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(),
"The actual input data's size mismatched with LoD information.");

std::vector<int64_t> out_dims;
HashOutputSize(in_dims, out_dims, num_hash);
out_t->Resize(framework::make_ddim(out_dims));
auto* output = out_t->mutable_data<T>(context.GetPlace());

auto seq_length = in_dims[0];
auto last_dim = in_dims[in_dims.size() - 1];
auto* input = in_t->data<T>();
Expand All @@ -49,6 +67,7 @@ class HashKerel : public framework::OpKernel<T> {
}
input += last_dim;
}
out_t->set_lod(in_t->lod());
}
};

Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->IsRuntime()) {
return;
}
PADDLE_ENFORCE(
ctx->HasInput("X"),
"Input(X) of SequecceEnumerate operator should not be null.");
Expand All @@ -33,9 +36,9 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(
x_dims.size(), 2,
"Input(X) of SequenceEnumerate operator's rank should be 2.");
PADDLE_ENFORCE_EQ(
x_dims[1], 1,
"Input(X) of SequenceEnumerate operator's 2nd dimension should be 1.");
PADDLE_ENFORCE_EQ(x_dims[1], 1,
"Input(X) of SequenceEnumerate operator's 2nd "
"dimension should be 1.");

const auto win_size = ctx->Attrs().Get<int>("win_size");
ctx->SetOutputDim("Out", {x_dims[0], win_size});
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,15 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> {
auto lod0 = in_lod[0];
auto in_len = in->numel();
auto in_data = in->data<T>();
out->Resize({in_dims[0], win_size});
auto out_data = out->mutable_data<T>(context.GetPlace());
// Copy LoD to GPU
const size_t* dev_in_lod_ptr = lod0.CUDAData(context.GetPlace());
// Calc output tensor
CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_data, dev_in_lod_ptr, lod0.size(), win_size, pad_value, out_data);
out->set_lod(in->lod());
}
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
// Generate enumerate sequence set
auto lod0 = in_lod[0];
auto in_data = in->data<T>();
out->Resize({in_dims[0], win_size});
auto out_data = out->mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < lod0.size() - 1; ++i) {
for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) {
Expand All @@ -49,6 +50,7 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
}
}
}
out->set_lod(in->lod());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lod如果这个op没用,感觉可以尽量早点set比较好

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,会在下一个PR修改。

}
};

Expand Down