|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "paddle/operators/sequence_concat_op.h" |
| 16 | + |
| 17 | +namespace paddle { |
| 18 | +namespace operators { |
| 19 | + |
| 20 | +class SequenceConcatOp : public framework::OperatorWithKernel { |
| 21 | + public: |
| 22 | + using framework::OperatorWithKernel::OperatorWithKernel; |
| 23 | + |
| 24 | + protected: |
| 25 | + void InferShape(framework::InferShapeContext* ctx) const override { |
| 26 | + PADDLE_ENFORCE(ctx->HasInputs("X"), |
| 27 | + "Inputs(X) of SequenceConcatOp should not be null."); |
| 28 | + PADDLE_ENFORCE(ctx->HasOutput("Out"), |
| 29 | + "Output(Out) of SequenceConcatOp should not be null."); |
| 30 | + const size_t level = static_cast<size_t>(ctx->Attrs().Get<int>("level")); |
| 31 | + const size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis")); |
| 32 | + PADDLE_ENFORCE(level == 0UL || level == 1UL, |
| 33 | + "The sequence_concat operator only accepts sequence " |
| 34 | + "or a nested sequence as its input."); |
| 35 | + auto ins_dims = ctx->GetInputsDim("X"); |
| 36 | + framework::DDim out_dims = ins_dims[0]; |
| 37 | + const size_t n = ins_dims.size(); |
| 38 | + for (size_t i = 1; i < n; ++i) { |
| 39 | + out_dims[axis] += ins_dims[i][axis]; |
| 40 | + } |
| 41 | + ctx->SetOutputDim("Out", out_dims); |
| 42 | + } |
| 43 | +}; |
| 44 | + |
| 45 | +class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { |
| 46 | + public: |
| 47 | + SequenceConcatOpMaker(framework::OpProto* proto, |
| 48 | + framework::OpAttrChecker* op_checker) |
| 49 | + : OpProtoAndCheckerMaker(proto, op_checker) { |
| 50 | + AddInput("X", |
| 51 | + "(A vector of LoDTensor), the input is a vector of LoDTensor, " |
| 52 | + "each of which is a variable-length sequence or nested sequence.") |
| 53 | + .AsDuplicable(); |
| 54 | + AddOutput("Out", |
| 55 | + "(A LoDTensor), the variable-length output of " |
| 56 | + "sequence_concat Op."); |
| 57 | + AddAttr<int>("axis", |
| 58 | + "(int, default 0)" |
| 59 | + "The axis which the inputs will be joined with. " |
| 60 | + "If axis is 0, the inputs will be joined with LoD index.") |
| 61 | + .SetDefault(0); |
| 62 | + AddAttr<int>("level", |
| 63 | + "(int, default 0)" |
| 64 | + "The level at which the inputs will be joined. " |
| 65 | + "If the level is 0, the inputs will be joined at the nested " |
| 66 | + "sequence level. " |
| 67 | + "If the level is 1, the inputs will be joined at the " |
| 68 | + "sequence level. " |
| 69 | + "The level should be less than the level number of inputs.") |
| 70 | + .SetDefault(0); |
| 71 | + AddComment(R"DOC( |
| 72 | + The sequence_concat operator concatenates multiple LoDTensors. |
| 73 | + It only supports sequence (LoD Tensor with level number is 1) |
| 74 | + or a nested sequence (LoD tensor with level number is 2) as its input. |
| 75 | + - Case1: |
| 76 | + If the axis is other than 0(here, axis is 1 and level is 1), |
| 77 | + each input should have the same LoD information and the LoD |
| 78 | + information of the output keeps the same as the input. |
| 79 | +
|
| 80 | + LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) |
| 81 | + LoD(x1) = {{0,2,4}, {0,1,2,3,4}}; Dims(x1) = (4,4,4) |
| 82 | + LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4) |
| 83 | +
|
| 84 | + - Case2: |
| 85 | + If the axis is 0(here, leve is 0), the inputs are concatenated along |
| 86 | + time steps, the LoD information of the output need to re-compute. |
| 87 | +
|
| 88 | + LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) |
| 89 | + LoD(x1) = {{0,3,5}, {0,1,2,3,5}}; Dims(x1) = (5,3,4) |
| 90 | + LoD(Out) = {{0,5,9}, {0,1,2,3,4,5,6,7,9}}; Dims(Out) = (9,3,4) |
| 91 | +
|
| 92 | + - Case3: |
| 93 | + If the axis is 0(here, level is 1). |
| 94 | +
|
| 95 | + LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) |
| 96 | + LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (5,3,4) |
| 97 | + LoD(Out) = {{0,5,9}, {0,2,5,7,9}}; Dims(Out) = (9,3,4) |
| 98 | + |
| 99 | + NOTE: The levels of all the inputs should be the same. |
| 100 | + )DOC"); |
| 101 | + } |
| 102 | +}; |
| 103 | + |
| 104 | +class SequenceConcatGradOp : public framework::OperatorWithKernel { |
| 105 | + public: |
| 106 | + using framework::OperatorWithKernel::OperatorWithKernel; |
| 107 | + |
| 108 | + protected: |
| 109 | + void InferShape(framework::InferShapeContext* ctx) const override { |
| 110 | + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), |
| 111 | + "The gradient of Out should not be null."); |
| 112 | + PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")), |
| 113 | + "The gradient of X should not be null."); |
| 114 | + ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); |
| 115 | + } |
| 116 | +}; |
| 117 | + |
| 118 | +} // namespace operators |
| 119 | +} // namespace paddle |
| 120 | + |
| 121 | +namespace ops = paddle::operators; |
| 122 | +REGISTER_OP(sequence_concat, ops::SequenceConcatOp, ops::SequenceConcatOpMaker, |
| 123 | + sequence_concat_grad, ops::SequenceConcatGradOp); |
| 124 | +REGISTER_OP_CPU_KERNEL( |
| 125 | + sequence_concat, |
| 126 | + ops::SequenceConcatOpKernel<paddle::platform::CPUPlace, float>); |
| 127 | +REGISTER_OP_CPU_KERNEL( |
| 128 | + sequence_concat_grad, |
| 129 | + ops::SequenceConcatGradOpKernel<paddle::platform::CPUPlace, float>); |
0 commit comments