Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 30 additions & 24 deletions paddle/operators/multiplex_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,37 @@ namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

class MultiplexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"),
"Input(Ids) shouldn't be null.");
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
"Input(X) should not be null");
"MultiInput(X) shouldn't be empty.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) shouldn't be null.");
auto ids_dim = ctx.Input<Tensor>("Ids")->dims();
PADDLE_ENFORCE(
ids_dim.size() == 2 && ids_dim[1] == 1,
"The index tensor must be a vector with size batchSize x 1.");

auto ins = ctx.MultiInput<Tensor>("X");
auto *out = ctx.Output<LoDTensor>("Out");
auto *out = ctx.Output<Tensor>("Out");
auto num_ins = ins.size();
PADDLE_ENFORCE(num_ins > 2,
"multiplex operator should have more than 2 inputs.");
PADDLE_ENFORCE_EQ(ins[0]->dims().size(), 1,
"The first input must be a index vector.");
auto in_dim = ins[1]->dims();
PADDLE_ENFORCE(num_ins > 1,
"multiplex operator should have more than "
"one candidate input tensors.");

for (size_t i = 2; i < num_ins; i++) {
auto in_dim = ins[0]->dims();
PADDLE_ENFORCE(in_dim.size() == 2, "Candidate tensors must be matrix.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Candidate tensors不是必需2维吧~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the limitation

for (size_t i = 1; i < num_ins; i++) {
auto dim = ins[i]->dims();
PADDLE_ENFORCE(
in_dim == dim,
"All the input tensors except the first one must have the same size");
PADDLE_ENFORCE(in_dim == dim,
"All the candidate tensors must have the same size.");
}
out->Resize(in_dim);
}
Expand All @@ -54,25 +59,26 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
MultiplexOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensors of multiplex operator.").AsDuplicable();
AddInput("Ids", "The index tensor of multiplex operator.");
AddInput("X", "The candidate tensors of multiplex operator.")
.AsDuplicable();
AddOutput("Out", "The output tensor of multiplex operator.");
AddComment(R"DOC(Multiplex operator

Multiplex multiple tensors according to the index provided by the first
input tensor.

ins[0]: the index tensor.
ins[1:N]: the candidate output tensors.
Ids: the index tensor.
X[0 : N - 1]: the candidate tensors for output (N >= 2).
For each index i from 0 to batchSize - 1, the output is the i-th row of the
the (index[i] + 1)-th tensor.
the (Ids[i])-th tensor.

For i-th row of the output tensor:

y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1)
y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{0}.width - 1)

where y is the output tensor. `x_{k}` is the k-th input tensor
and `k = x{0}[i] + 1`.

and `k = Ids[i]`.
)DOC");
}
};
Expand All @@ -84,15 +90,15 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
"Input(X) should not be null");
"Input(X) should not be null.");
PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(),
"Output(X@Grad) should not be null");
"Output(X@Grad) should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto d_ins = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X"));
auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
auto ins = ctx.MultiInput<Tensor>("X");
// don't compute gradient for index (ins[0])
for (size_t i = 1; i < ins.size(); i++) {
// No need to compute gradient for Input(Ids)
for (size_t i = 0; i < ins.size(); i++) {
if (d_ins[i]) {
d_ins[i]->Resize(ins[i]->dims());
}
Expand Down
43 changes: 23 additions & 20 deletions paddle/operators/multiplex_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,30 @@
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename Place, typename T>
class MultiplexGPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::LoDTensor>("Out");

auto ins = ctx.MultiInput<Tensor>("X");
auto* ids = ctx.Input<Tensor>("Ids");
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());

auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1];
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1];
// copy index to cpu
framework::Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
auto* index = index_t_cpu.data<T>();
Tensor index_t_cpu;
index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace());
auto* index = index_t_cpu.data<int32_t>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
int32_t k = index[i];
PADDLE_ENFORCE_GE(k, 0, "index must be nonnegative.");
PADDLE_ENFORCE_LT(k, ins.size(),
"index exceeds the number of candidate tensors.");
memory::Copy(place, out->data<T>() + i * cols, place,
Expand All @@ -51,31 +54,31 @@ template <typename Place, typename T>
class MultiplexGradGPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto d_ins =
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
for (size_t i = 1; i < d_ins.size(); i++) {
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<Tensor>("X");
auto* ids = ctx.Input<Tensor>("Ids");
auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
for (size_t i = 0; i < d_ins.size(); i++) {
if (d_ins[i]) {
d_ins[i]->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_ins[i]);
t.device(ctx.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
}
}

auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1];
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1];
// copy index to cpu
framework::Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
auto* index = index_t_cpu.data<T>();
Tensor index_t_cpu;
index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace());
auto* index = index_t_cpu.data<int32_t>();

auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
size_t k = static_cast<size_t>(index[i]);
if (d_ins[k]) {
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
d_out->data<T>() + i * cols, cols * sizeof(T), stream);
Expand Down
23 changes: 13 additions & 10 deletions paddle/operators/multiplex_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,18 @@ class MultiplexCPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::LoDTensor>("Out");
auto ids = ctx.Input<framework::Tensor>("Ids");
auto* out = ctx.Output<framework::Tensor>("Out");

out->mutable_data<T>(ctx.GetPlace());

auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1];
auto* index = ins[0]->data<T>();
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1];
auto index = ids->data<int32_t>();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
int32_t k = index[i];
PADDLE_ENFORCE_GE(k, 0, "index must be nonnegative.");
PADDLE_ENFORCE_LT(static_cast<size_t>(k), ins.size(),
"index exceeds the number of candidate tensors.");
memory::Copy(place, out->data<T>() + i * cols, place,
Expand All @@ -50,23 +52,24 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* ids = ctx.Input<framework::Tensor>("Ids");
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto d_ins =
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
for (size_t i = 1; i < d_ins.size(); i++) {
for (size_t i = 0; i < d_ins.size(); i++) {
if (d_ins[i]) {
d_ins[i]->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_ins[i]);
t.device(ctx.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
}
}

auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1];
auto* index = ins[0]->data<T>();
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1];
auto* index = ids->data<int32_t>();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
size_t k = static_cast<size_t>(index[i]);
if (d_ins[k]) {
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
d_out->data<T>() + i * cols, cols * sizeof(T));
Expand Down
12 changes: 7 additions & 5 deletions python/paddle/v2/framework/tests/test_multiplex_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,22 @@
class TestMultiplexOp(OpTest):
def setUp(self):
self.op_type = "multiplex"
rows = 3
index = np.array([3, 1, 0])
rows = 4
index = np.arange(0, rows).astype('int32')
np.random.shuffle(index)
index = np.reshape(index, (rows, 1))
ins1 = np.random.random((rows, 10)).astype("float32")
ins2 = np.random.random((rows, 10)).astype("float32")
ins3 = np.random.random((rows, 10)).astype("float32")
ins4 = np.random.random((rows, 10)).astype("float32")
self.inputs = {
'X': [('index', index), ('x1', ins1), ('x2', ins2), ('x3', ins3),
('x4', ins4)]
'Ids': index,
'X': [('x1', ins1), ('x2', ins2), ('x3', ins3), ('x4', ins4)]
}
# multiplex output
output = np.zeros_like(ins1)
for i in range(0, rows):
k = index[i] + 1
k = index[i][0]
output[i] = self.inputs['X'][k][1][i]
self.outputs = {'Out': output}

Expand Down