-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add axis for mul_op and rowwise_add_op
#3888
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
e76fa85
86655cb
af0264a
69fbc54
d71396b
e168fc4
256d6a3
f2a66ff
823bdd6
3d62c6d
0c13660
5aacd64
d7c8bdc
b744430
1d9a4d2
f6e72c9
b6a4666
856611c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,6 +51,18 @@ class LargerThanChecker { | |
| T lower_bound_; | ||
| }; | ||
|
|
||
| template <typename T> | ||
| class EqualLargerThanChecker { | ||
| public: | ||
| explicit EqualLargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} | ||
| void operator()(T& value) const { | ||
| PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fail"); | ||
|
||
| } | ||
|
|
||
| private: | ||
| T lower_bound_; | ||
| }; | ||
|
|
||
| // we can provide users more common Checker, like 'LessThanChecker', | ||
| // 'BetweenChecker'... | ||
|
|
||
|
|
@@ -114,6 +126,11 @@ class TypedAttrChecker { | |
| return *this; | ||
| } | ||
|
|
||
| TypedAttrChecker& EqualLargerThan(const T& lower_bound) { | ||
| value_checkers_.push_back(EqualLargerThanChecker<T>(lower_bound)); | ||
| return *this; | ||
| } | ||
|
|
||
| // we can add more common limits, like LessThan(), Between()... | ||
|
|
||
| TypedAttrChecker& SetDefault(const T& default_value) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -195,18 +195,6 @@ std::vector<int> vectorize(const DDim& ddim) { | |
| return result; | ||
| } | ||
|
|
||
| struct ProductVisitor : public boost::static_visitor<ssize_t> { | ||
| template <int D> | ||
| ssize_t operator()(const Dim<D>& dim) { | ||
| return product(dim); | ||
| } | ||
| }; | ||
|
|
||
| ssize_t product(const DDim& ddim) { | ||
| ProductVisitor visitor; | ||
| return boost::apply_visitor(visitor, ddim); | ||
| } | ||
|
|
||
| struct SliceVectorizeVisitor : public boost::static_visitor<> { | ||
| std::vector<int>& vector; | ||
| int begin; | ||
|
|
@@ -247,6 +235,18 @@ DDim slice_ddim(const DDim& dim, int begin, int end) { | |
| return make_ddim(vec); | ||
| } | ||
|
|
||
| struct ProductVisitor : public boost::static_visitor<ssize_t> { | ||
| template <int D> | ||
| ssize_t operator()(const Dim<D>& dim) { | ||
| return product(dim); | ||
| } | ||
| }; | ||
|
|
||
| ssize_t product(const DDim& ddim) { | ||
| ProductVisitor visitor; | ||
| return boost::apply_visitor(visitor, ddim); | ||
| } | ||
|
|
||
| /// \cond HIDDEN | ||
|
|
||
| struct ArityVisitor : boost::static_visitor<int> { | ||
|
|
@@ -283,5 +283,17 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { | |
| DDim::DDim(std::initializer_list<int> init_list) { | ||
| *this = make_ddim(init_list); | ||
| } | ||
|
|
||
| DDim flatten_to_2d(const DDim& src, int num_row_dims) { | ||
| int rank = src.size(); | ||
| return make_ddim( | ||
| {static_cast<int>(product(slice_ddim(src, 0, rank - num_row_dims))), | ||
| static_cast<int>(product(slice_ddim(src, rank - num_row_dims, rank)))}); | ||
| } | ||
|
|
||
| DDim flatten_to_1d(const DDim& src) { | ||
| return make_ddim({static_cast<int>(product(src))}); | ||
|
||
| } | ||
|
|
||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,20 +63,35 @@ struct EigenTensor { | |
|
|
||
| template <typename T, int MajorType = Eigen::RowMajor, | ||
| typename IndexType = Eigen::DenseIndex> | ||
| struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {}; | ||
| struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> { | ||
| static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_row_dims) { | ||
| int rank = tensor.dims_.size(); | ||
| PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, | ||
| "`num_row_dims` must be between (0, rank_of_tensor)."); | ||
| return EigenMatrix::From(tensor, | ||
| flatten_to_2d(tensor.dims(), num_row_dims)); | ||
| } | ||
|
|
||
| static typename EigenMatrix::ConstType Reshape(const Tensor& tensor, | ||
| int num_row_dims) { | ||
| int rank = tensor.dims_.size(); | ||
| PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, | ||
| "`num_row_dims` must be between (0, rank_of_tensor)."); | ||
| return EigenMatrix::From(tensor, | ||
| flatten_to_2d(tensor.dims(), num_row_dims)); | ||
| } | ||
| }; | ||
|
|
||
| template <typename T, int MajorType = Eigen::RowMajor, | ||
| typename IndexType = Eigen::DenseIndex> | ||
| struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> { | ||
| // Flatten reshapes a Tensor into an EigenVector. | ||
| static typename EigenVector::Type Flatten(Tensor& tensor) { | ||
| return EigenVector::From( | ||
| tensor, make_ddim({static_cast<int>(product(tensor.dims_))})); | ||
| return EigenVector::From(tensor, {static_cast<int>(product(tensor.dims_))}); | ||
|
||
| } | ||
|
|
||
| static typename EigenVector::ConstType Flatten(const Tensor& tensor) { | ||
| return EigenVector::From( | ||
| tensor, make_ddim({static_cast<int>(product(tensor.dims_))})); | ||
| return EigenVector::From(tensor, {static_cast<int>(product(tensor.dims_))}); | ||
|
||
| } | ||
| }; | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,5 +108,25 @@ TEST(Eigen, Matrix) { | |
| } | ||
| } | ||
|
|
||
| TEST(Eigen, MatrixReshape) { | ||
| Tensor t; | ||
| float* p = | ||
| t.mutable_data<float>(make_ddim({2, 3, 6, 4}), platform::CPUPlace()); | ||
|
||
| for (int i = 0; i < 2 * 3 * 6 * 4; ++i) { | ||
| p[i] = static_cast<float>(i); | ||
| } | ||
|
|
||
| EigenMatrix<float>::Type em = EigenMatrix<float>::Reshape(t, 2); | ||
|
|
||
| ASSERT_EQ(2 * 3, em.dimension(0)); | ||
| ASSERT_EQ(6 * 4, em.dimension(1)); | ||
|
|
||
| for (int i = 0; i < 2 * 3; i++) { | ||
| for (int j = 0; j < 6 * 4; j++) { | ||
| ASSERT_NEAR(i * 6 * 4 + j, em(i, j), 1e-6f); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -148,5 +148,13 @@ inline Tensor& Tensor::Resize(const DDim& dims) { | |
|
|
||
| inline const DDim& Tensor::dims() const { return dims_; } | ||
|
|
||
| template <typename T> | ||
| inline Tensor FlattenToMatrix(const Tensor& src, int num_row_dims) { | ||
|
||
| Tensor res; | ||
| res.ShareDataWith<T>(src); | ||
| res.Resize(flatten_to_2d(src.dims(), num_row_dims)); | ||
| return res; | ||
| } | ||
|
|
||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -262,3 +262,16 @@ TEST(Tensor, CopyFrom) { | |
| } | ||
| #endif | ||
| } | ||
|
|
||
| TEST(Tensor, FlattenToMatrix) { | ||
| using namespace paddle::framework; | ||
| using namespace paddle::platform; | ||
| Tensor src; | ||
| int* src_ptr = src.mutable_data<int>(make_ddim({2, 3, 4, 9}), CPUPlace()); | ||
|
||
| for (int i = 0; i < 2 * 3 * 4 * 9; ++i) { | ||
| src_ptr[i] = i; | ||
| } | ||
| Tensor res = FlattenToMatrix<int>(src, 2); | ||
| ASSERT_EQ(res.dims()[0], 2 * 3); | ||
| ASSERT_EQ(res.dims()[1], 4 * 9); | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,18 +25,27 @@ class MulOp : public framework::OperatorWithKernel { | |
|
|
||
| protected: | ||
| void InferShape(const framework::InferShapeContext &ctx) const override { | ||
| auto dim0 = ctx.Input<Tensor>("X")->dims(); | ||
| auto dim1 = ctx.Input<Tensor>("Y")->dims(); | ||
| PADDLE_ENFORCE_EQ(dim0.size(), 2, | ||
| "input X(%s) should be a tensor with 2 dims, a matrix", | ||
| ctx.op().Input("X")); | ||
| PADDLE_ENFORCE_EQ(dim1.size(), 2, | ||
| "input Y(%s) should be a tensor with 2 dims, a matrix", | ||
| ctx.op().Input("Y")); | ||
| auto x_dims = ctx.Input<Tensor>("X")->dims(); | ||
| auto y_dims = ctx.Input<Tensor>("Y")->dims(); | ||
| int x_num_row_dims = GetAttr<int>("x_num_row_dims"); | ||
| int y_num_row_dims = GetAttr<int>("y_num_row_dims"); | ||
|
|
||
| PADDLE_ENFORCE(x_dims.size() > x_num_row_dims, | ||
| "The rank of input tensor X(%s) should be larger than " | ||
| "`mul_op`'s `x_num_row_dims`.", | ||
| ctx.op().Input("X")); | ||
| PADDLE_ENFORCE(y_dims.size() > y_num_row_dims, | ||
| "The rank of input tensor Y(%s) should be larger than " | ||
| "`mul_op`'s `y_num_row_dims`.", | ||
| ctx.op().Input("Y")); | ||
|
|
||
| auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_row_dims); | ||
| auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_row_dims); | ||
|
|
||
| PADDLE_ENFORCE_EQ( | ||
| dim0[1], dim1[0], | ||
| x_mat_dims[1], y_mat_dims[0], | ||
| "First matrix's width must be equal with second matrix's height."); | ||
| ctx.Output<Tensor>("Out")->Resize({dim0[0], dim1[1]}); | ||
| ctx.Output<Tensor>("Out")->Resize({x_mat_dims[0], y_mat_dims[1]}); | ||
| } | ||
| }; | ||
|
|
||
|
|
@@ -47,6 +56,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { | |
| AddInput("X", "The first input of mul op"); | ||
| AddInput("Y", "The second input of mul op"); | ||
| AddOutput("Out", "The output of mul op"); | ||
| AddAttr<int>( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a very useful syntax in AddAttr<int>("x_num_col_dims", R"DOC(mul_op can take ...
....
)DOC");
See http://en.cppreference.com/w/cpp/language/string_literal
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, Thank you! |
||
| "x_num_row_dims", | ||
| "mul_op can take tensors with more than two dimensions as input `X`, " | ||
| "in that case, tensors will be flattened to a matrix. The matrix's " | ||
|
||
| "second dimension(row length) will be the product of tensor's last " | ||
| "`num_row_dims` dimensions, and the matrix's first dimension(column " | ||
|
||
| "length) will be the product of tensor's first `rank - num_row_dims` " | ||
| "dimensions.") | ||
| .SetDefault(1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 依据上面的描述,和最常用的情况不符合,最常用的是reshape成:height = dims[0], width = product(dims[1:])
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经修改,把参数从 |
||
| .EqualLargerThan(1); | ||
| AddAttr<int>( | ||
| "y_num_row_dims", | ||
| "mul_op can take tensors with more than two dimensions as input `Y`, " | ||
| "in that case, tensors will be flattened to a matrix. Just like input " | ||
| "`X`.") | ||
| .SetDefault(1) | ||
| .EqualLargerThan(1); | ||
| AddComment(R"DOC( | ||
| Two Element Mul Operator. | ||
|
|
||
|
|
@@ -70,10 +96,20 @@ class MulOpGrad : public framework::OperatorWithKernel { | |
| auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); | ||
| auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
| auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y")); | ||
| PADDLE_ENFORCE(x_dims[0] == out_dims[0], | ||
| "Out@GRAD M X N must equal to X dims 0, M "); | ||
| PADDLE_ENFORCE(y_dims[1] == out_dims[1], | ||
| "Out@GRAD M X N must equal to Y dims 1, N "); | ||
|
|
||
| auto x_mat_dims = | ||
| framework::flatten_to_2d(x_dims, GetAttr<int>("x_num_row_dims")); | ||
| auto y_mat_dims = | ||
| framework::flatten_to_2d(y_dims, GetAttr<int>("y_num_row_dims")); | ||
|
|
||
| PADDLE_ENFORCE_EQ( | ||
| x_mat_dims[0], out_dims[0], | ||
| "The first dimension of Out@GRAD must equal to the first dimension of " | ||
| "the first operand."); | ||
| PADDLE_ENFORCE_EQ( | ||
| y_mat_dims[1], out_dims[1], | ||
| "The second dimension of Out@GRAD must equal to the second " | ||
| "dimension of the second operand."); | ||
|
|
||
| if (x_grad) x_grad->Resize(x_dims); | ||
| if (y_grad) y_grad->Resize(y_dims); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,13 +2,13 @@ | |
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| you may obtain a copy of the License at | ||
|
||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANy KIND, either express or implied. | ||
|
||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
|
|
@@ -31,37 +31,65 @@ template <typename Place, typename T> | |
| class MulKernel : public framework::OpKernel { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& context) const override { | ||
| auto* x = context.Input<Tensor>("X"); | ||
| auto* y = context.Input<Tensor>("Y"); | ||
| auto* z = context.Output<Tensor>("Out"); | ||
| z->mutable_data<T>(context.GetPlace()); | ||
| const Tensor* x = context.Input<Tensor>("X"); | ||
| const Tensor* y = context.Input<Tensor>("Y"); | ||
| Tensor* Z = context.Output<Tensor>("Out"); | ||
|
||
| const Tensor x_matrix = | ||
| x->dims().size() > 2 | ||
| ? framework::FlattenToMatrix<T>( | ||
| *x, context.template GetAttr<int>("x_num_row_dims")) | ||
| : *x; | ||
| const Tensor y_matrix = | ||
| y->dims().size() > 2 | ||
| ? framework::FlattenToMatrix<T>( | ||
| *y, context.template GetAttr<int>("y_num_row_dims")) | ||
| : *y; | ||
|
|
||
| Z->mutable_data<T>(context.GetPlace()); | ||
| auto* device_context = | ||
| const_cast<platform::DeviceContext*>(context.device_context_); | ||
| math::matmul<Place, T>(*x, false, *y, false, 1, z, 0, device_context); | ||
| math::matmul<Place, T>(x_matrix, false, y_matrix, false, 1, Z, 0, | ||
| device_context); | ||
| } | ||
| }; | ||
|
|
||
| template <typename Place, typename T> | ||
| class MulGradKernel : public framework::OpKernel { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& ctx) const override { | ||
| auto* x = ctx.Input<Tensor>("X"); | ||
| auto* y = ctx.Input<Tensor>("Y"); | ||
| auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); | ||
| int x_num_row_dims = ctx.template GetAttr<int>("x_num_row_dims"); | ||
| int y_num_row_dims = ctx.template GetAttr<int>("y_num_row_dims"); | ||
| const Tensor* x = ctx.Input<Tensor>("X"); | ||
| const Tensor* y = ctx.Input<Tensor>("Y"); | ||
| const Tensor x_matrix = | ||
| x->dims().size() > 2 ? framework::FlattenToMatrix<T>(*x, x_num_row_dims) | ||
| : *x; | ||
| const Tensor y_matrix = | ||
| y->dims().size() > 2 ? framework::FlattenToMatrix<T>(*y, y_num_row_dims) | ||
| : *y; | ||
| const Tensor* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); | ||
|
|
||
| auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
| auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); | ||
| Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
| Tensor* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); | ||
| auto* device_context = | ||
| const_cast<platform::DeviceContext*>(ctx.device_context_); | ||
| if (dx) { | ||
| dx->mutable_data<T>(ctx.GetPlace()); | ||
| Tensor dx_matrix = dx->dims().size() > 2 ? framework::FlattenToMatrix<T>( | ||
| *dx, x_num_row_dims) | ||
| : *dx; | ||
| // dx = dout * y'. dx: M x K, dout : M x N, y : K x N | ||
| math::matmul<Place, T>(*dout, false, *y, true, 1, dx, 0, device_context); | ||
| math::matmul<Place, T>(*dout, false, y_matrix, true, 1, &dx_matrix, 0, | ||
| device_context); | ||
| } | ||
| if (dy) { | ||
| dy->mutable_data<T>(ctx.GetPlace()); | ||
| Tensor dy_matrix = dy->dims().size() > 2 ? framework::FlattenToMatrix<T>( | ||
| *dy, y_num_row_dims) | ||
| : *dy; | ||
| // dy = x' * dout. dy K x N, dout : M x N, x : M x K | ||
| math::matmul<Place, T>(*x, true, *dout, false, 1, dy, 0, device_context); | ||
| math::matmul<Place, T>(x_matrix, true, *dout, false, 1, &dy_matrix, 0, | ||
| device_context); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name is better to compatible with gtest. Such as
CHECK_GEor something?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EqualLargerThanis a function, not a macro, so the name shall not be too short.