Skip to content

Commit e11ecfc

Browse files
zyfncgchenwhqlMingMingShangTianYuanRishengShixiaowei02
authored
Add matmul_v2 kernel in pten (PaddlePaddle#36844)
* initial tensor design & sign kernel demo * add move constructor for meta & add lodtensor * add dirs & sign xpu kernel * add mean cpu&cuda kernel impl * move sign & mean xpu & npu kernel * add selected_rows basic impl * refactor design, BaseTensor to DenseTensor, etc. * add scale mkldnn kernel * polish xpu & npu impl details * fix mkldnn reuse compile failed * change tensor operation lib name * rename util filename * add more comments * change TensorImplInterface to TensorInterface * add kernel key and factory * remove MKLDNNTensorMeta, add MKLDNNDenseTensor * change XXDeviceContext to XXContext * add base kernel registrar utils & test on sign * replace boost::any by paddle::any * fix several ci failed * fix npu compile error * add ordered map util * fix multiple ordered_map compile errors * move dev into include dir * support sign op in static op run * fix static op run error * fix new executor compile failed * add dygraph branch & remove sign_op.h * fix test_infer_no_need_buffer_slots * fix rocm compile link error * fix unitybuild error & clear glog * fix npu compile failed * skip quant trans test * fix part windows compile problem * fix xpu enforce error * fix inference test failed * remove ordered_map to solve quant failed * fix part of rcom compile faild * add more register kernels * revert scale kernel temporarily * fix code format error * add new kernel registrar marco * rename top to tcmpt * revert xpu, npu, mkldnn impl & remove op def * add kernel args parse functor to auto parse args * revert some change & add scale kernels * add op proto in dygraph kernelcontext building * polish kernel dispatch logic & nameing rule * fix scale kernel match error * fix scale test failed * add mean API and unittest * test mean api success * add branch to solve compiled error * skip clang format error * add mean skip rule in op_library * add dot kernel, api and unittest (#6) * remove old kernel and add symbol link * fix dot compiled failed * add merco for module declare * fix npu and xpu compile error * revert sign, mean, scale, dot kernel removing * add comment for keeping old kernel impl * fix mutable_data error * fix bfloat16 conflit * fix inference undef error * adapt to msvc compile rules * polish comment for template inst * add cmake template instantiation for win * fix backend to place device id bug * fix ifdef error * Op2functor (#7) * add kernel args maker class * make args maker non-const * remove debug log * modify codes by review options * split constructPrKernelContext function * fix output name bug * fix test_mean_op test_sign_op failed * fill_any_like kernel refactor (#10) * fill_any_like kernel refactor * remove useless code of full_like c++ api * skip dtype for fill_any_like * add attrs for kernel key constrcut * add use_pt_kernel Flags to control whether to use pt kernel (#13) * add use_pt_kernel Flags to control whether to use pt kernel * change the default value to true for cheking pt kernels * fix mutable_data cuda place error * move high level apis into hapi * remove selectedrows adapting temporarily * Support Scalar in Tensor Compute Library (#14) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * remove mkldnn tensor & polish details * use flat_hash_map and small_vector in kernel factory * Refactor flatten kernel (#12) * refactor flatten kernel * update infershape function * fix compile bugs * fix bugs when merge * fix compiler bugs * fix bugs when run test_flatten_api * fix bugs when run test * Revert "use flat_hash_map and small_vector in kernel factory" This reverts commit 2309149. * Move cpu, cuda and other device code into kernels (#15) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * start refactor matmul * move cpu, cuda and other device modules into kernels * merge code * polish code in operator.cc * Perfect unitests (#16) * perfect unittest * update license * replace with flat_hash_map, small_vector (#19) * fix small_vector build error on windows platform * replace with flat_hash_map, small_vector * remove todo * Perfect unitests (#20) * perfect unittest * update license * fix bug when run tcmpt_utils_test * refactor execution adapting impl * fix insert conflit * Fix CI bug of test_yolov3 (#21) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * start refactor matmul * move cpu, cuda and other device modules into kernels * merge code * polish code in operator.cc * Fix CI bug of test_yolov3 * add the tensor base class, test=develop (#17) * update the tensor base class, test=develop * remove two funcs, test=develop * update the error msg, test=develop Co-authored-by: Chen Weihang <[email protected]> * [no-verify] commit backend and tensor signature changes * Rename tcmpt to pten (#23) * rename tcmpt to pten * update omitted files for rename to pten * update omitted file for rename to pten * remove k of all enum var * remove kernel_instantiate (#26) * remove symbols and spatial_tensor * change common to functions * readd share tensor impl methods * add a candidate dense tensor class, test=develop (#28) * change all Pt to Pten * resolve conflit with xiaowei * Op2functor opt1 (#27) * replace to small vector and change to const & * add std::move Co-authored-by: Chen Weihang <[email protected]> * polish kernel factory and kernel registry * fix operator test error msg mismatch * remove tensor signature and backend set member * move scalar and polish enforce * revert dtype layout change to fix error * fix enum operator override error * add several base unittests * add pten utils tests * polish some details * Dev/op2func refactor 3 (#30) * add a candidate dense tensor class, test=develop * remove TensorBase::backend(), test=develop * remove some ops, test=develop * cherry-pick the pr of tensor meta, test=develop * moves the dense tensor and some ops, test=develop * update the linalg operator, test=develop * update other operators, test=develop * fix errors, test=develop * fix bugs, test=develop * try to resolve the problem of windows ci, test=develop * updates codes, test=develop * fix the tensor_utils.cc, test=develop * modify the dense tensor, test=develop * fix the data type, test=develop Co-authored-by: shixiaowei02 <[email protected]> * polish some details * polish kernel signature details * fix a bug about offsets of the tensor, test=develop (#31) Co-authored-by: shixiaowei02 <[email protected]> * add matmul kernel in pten * add unittest for new matmul_v2 kernel * fix bug of CI compile * fix bug of CI compile * merge conflict * remove useless file Co-authored-by: Chen Weihang <[email protected]> Co-authored-by: chentianyu03 <[email protected]> Co-authored-by: YuanRisheng <[email protected]> Co-authored-by: 石晓伟 <[email protected]>
1 parent e5aa145 commit e11ecfc

File tree

16 files changed

+866
-12
lines changed

16 files changed

+866
-12
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License. */
2323
#include "paddle/fluid/framework/data_type_transform.h"
2424
#include "paddle/fluid/framework/details/nan_inf_utils.h"
2525
#include "paddle/fluid/framework/op_call_stack.h"
26+
#include "paddle/fluid/framework/pten_utils.h"
2627
#include "paddle/fluid/framework/shape_inference.h"
2728
#include "paddle/fluid/framework/transfer_scope_cache.h"
2829
#include "paddle/fluid/framework/unused_var_check.h"

paddle/fluid/imperative/prepared_operator.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/fluid/framework/data_type_transform.h"
1818
#include "paddle/fluid/framework/details/nan_inf_utils.h"
19+
#include "paddle/fluid/framework/pten_utils.h"
1920
#include "paddle/fluid/imperative/infer_shape_context.h"
2021
#include "paddle/pten/common/scalar.h"
2122
#include "paddle/utils/small_vector.h"

paddle/fluid/operators/matmul_v2_op.h

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ limitations under the License. */
2525
#include "paddle/fluid/operators/math/complex_functors.h"
2626
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
2727

28+
// only can include the headers in paddle/pten/api dirs
29+
#include "paddle/pten/api/include/core.h"
30+
#include "paddle/pten/api/include/linalg.h"
31+
#include "paddle/pten/hapi/lib/utils/tensor_utils.h"
32+
2833
#if defined(__NVCC__) || defined(__HIPCC__)
2934
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
3035
#endif
@@ -380,15 +385,17 @@ class MatMulV2Kernel : public framework::OpKernel<T> {
380385
auto* Out = ctx.Output<Tensor>("Out");
381386
bool trans_x = ctx.Attr<bool>("trans_x");
382387
bool trans_y = ctx.Attr<bool>("trans_y");
383-
PADDLE_ENFORCE_NE(framework::product(X->dims()), 0,
384-
platform::errors::InvalidArgument(
385-
"The Input(X) dims size must not be equal 0,"
386-
" but reviced dims size is 0. "));
387-
PADDLE_ENFORCE_NE(framework::product(Y->dims()), 0,
388-
platform::errors::InvalidArgument(
389-
"The Input(Y) dims size must not be equal 0,"
390-
" but reviced dims size is 0. "));
391-
MatMulFunction<DeviceContext, T>(X, Y, Out, trans_x, trans_y, ctx);
388+
389+
auto& dev_ctx = ctx.device_context<DeviceContext>();
390+
Out->mutable_data<T>(X->place());
391+
392+
auto pt_x = paddle::experimental::MakePtenDenseTensor(*X);
393+
auto pt_y = paddle::experimental::MakePtenDenseTensor(*Y);
394+
auto pt_out = paddle::experimental::MakePtenDenseTensor(*Out);
395+
396+
// call new kernel
397+
pten::Matmul<T>(dev_ctx, *pt_x.get(), *pt_y.get(), trans_x, trans_y,
398+
pt_out.get());
392399
}
393400
};
394401

paddle/pten/hapi/include/linalg.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,10 @@ namespace experimental {
2121

2222
Tensor dot(const Tensor& x, const Tensor& y);
2323

24+
Tensor matmul(const Tensor& x,
25+
const Tensor& y,
26+
bool transpose_x,
27+
bool transpose_y);
28+
2429
} // namespace experimental
2530
} // namespace paddle

paddle/pten/hapi/lib/linalg.cc

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ limitations under the License. */
2525
#include "paddle/pten/core/kernel_context.h"
2626
#include "paddle/pten/hapi/lib/kernel_dispatch.h"
2727
#include "paddle/pten/hapi/lib/utils/allocator.h"
28-
#include "paddle/pten/infershape/binary.h"
2928

3029
namespace paddle {
3130
namespace experimental {
@@ -65,5 +64,47 @@ Tensor dot(const Tensor& x, const Tensor& y) {
6564
return out;
6665
}
6766

67+
Tensor matmul(const Tensor& x,
68+
const Tensor& y,
69+
bool transpose_x,
70+
bool transpose_y) {
71+
// 1. Get kernel signature and kernel
72+
auto kernel_key_set = ParseKernelKeyByInputArgs(x, y);
73+
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
74+
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
75+
"matmul_v2", kernel_key);
76+
77+
// 2. Get Device Context
78+
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
79+
auto kernel_context = pten::KernelContext(*dev_ctx);
80+
81+
// 3. Auto data transform
82+
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
83+
auto dense_y = std::dynamic_pointer_cast<pten::DenseTensor>(y.impl());
84+
kernel_context.EmplaceBackInput(dense_x);
85+
kernel_context.EmplaceBackInput(dense_y);
86+
kernel_context.EmplaceBackAttr(transpose_x);
87+
kernel_context.EmplaceBackAttr(transpose_y);
88+
// TODO(chenweihang): add transform impl
89+
90+
// 4. InferShape
91+
auto out_meta = MatmulInferShape(
92+
dense_x->meta(), dense_y->meta(), transpose_x, transpose_y);
93+
94+
// 5. Prepare outputs
95+
const auto allocator = std::make_shared<DefaultAllocator>(
96+
pten::TransToFluidPlace(kernel_key.backend()));
97+
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
98+
kernel_context.EmplaceBackOutput(dense_out);
99+
100+
Tensor out;
101+
out.set_impl(dense_out);
102+
103+
// 6. Call kernel
104+
kernel(&kernel_context);
105+
106+
return out;
107+
}
108+
68109
} // namespace experimental
69110
} // namespace paddle

paddle/pten/infershape/binary.cc

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,74 @@ DenseTensorMeta DotInferShape(const DenseTensorMeta& x_meta,
5959
return return_meta;
6060
}
6161

62+
DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta,
63+
const DenseTensorMeta& y_meta,
64+
bool trans_x,
65+
bool trans_y) {
66+
std::vector<int64_t> dims_x = paddle::framework::vectorize(x_meta.dims);
67+
std::vector<int64_t> dims_y = paddle::framework::vectorize(y_meta.dims);
68+
auto ndims_x = dims_x.size();
69+
auto ndims_y = dims_y.size();
70+
PADDLE_ENFORCE_GT(ndims_x,
71+
0,
72+
paddle::platform::errors::InvalidArgument(
73+
"The Input(x) dims size must be greater than 0,"
74+
" but reviced dims size is 0. "));
75+
PADDLE_ENFORCE_GT(ndims_y,
76+
0,
77+
paddle::platform::errors::InvalidArgument(
78+
"The Input(y) dims size must be greater than 0,"
79+
" but reviced dims size is 0. "));
80+
81+
bool x_broadcasted = false, y_broadcasted = false;
82+
if (ndims_x == 1) {
83+
dims_x.insert(dims_x.begin(), 1);
84+
ndims_x = 2;
85+
x_broadcasted = true;
86+
}
87+
88+
if (ndims_y == 1) {
89+
dims_y.push_back(1);
90+
ndims_y = 2;
91+
y_broadcasted = true;
92+
}
93+
94+
size_t M, N;
95+
if (trans_x) {
96+
M = dims_x[ndims_x - 1];
97+
} else {
98+
M = dims_x[ndims_x - 2];
99+
}
100+
if (trans_y) {
101+
N = dims_y[ndims_y - 2];
102+
} else {
103+
N = dims_y[ndims_y - 1];
104+
}
105+
106+
std::vector<int64_t> new_dims;
107+
if (ndims_x > ndims_y) {
108+
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
109+
} else if (ndims_x < ndims_y) {
110+
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
111+
} else {
112+
new_dims.reserve(ndims_x);
113+
for (size_t i = 0; i < ndims_x - 2; ++i) {
114+
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
115+
}
116+
}
117+
if (!x_broadcasted) {
118+
new_dims.push_back(M);
119+
}
120+
if (!y_broadcasted) {
121+
new_dims.push_back(N);
122+
}
123+
if (x_broadcasted && y_broadcasted) {
124+
new_dims.push_back(1);
125+
}
126+
127+
auto ddim_out = paddle::framework::make_ddim(new_dims);
128+
129+
return {x_meta.type, ddim_out, x_meta.layout};
130+
}
131+
62132
} // namespace pten

paddle/pten/infershape/binary.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,9 @@ namespace pten {
3636
DenseTensorMeta DotInferShape(const DenseTensorMeta& x_meta,
3737
const DenseTensorMeta& y_meta);
3838

39+
DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta,
40+
const DenseTensorMeta& y_meta,
41+
bool trans_x,
42+
bool trans_y);
43+
3944
} // namespace pten

paddle/pten/kernels/cpu/linalg.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include "paddle/fluid/operators/math/blas.h"
2222
#include "paddle/fluid/platform/complex.h"
2323

24+
#include "paddle/pten/kernels/functions/math/matmul_func.h"
25+
2426
namespace pten {
2527

2628
template <typename T>
@@ -45,6 +47,27 @@ void Dot(const CPUContext& dev_ctx,
4547
}
4648
}
4749

50+
template <typename T>
51+
void Matmul(const CPUContext& dev_ctx,
52+
const DenseTensor& x,
53+
const DenseTensor& y,
54+
bool transpose_x,
55+
bool transpose_y,
56+
DenseTensor* out) {
57+
PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()),
58+
0,
59+
paddle::platform::errors::InvalidArgument(
60+
"The Input(X) dims size must not be equal 0,"
61+
" but reviced dims size is 0. "));
62+
PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()),
63+
0,
64+
paddle::platform::errors::InvalidArgument(
65+
"The Input(Y) dims size must not be equal 0,"
66+
" but reviced dims size is 0. "));
67+
math::MatMulFunction<CPUContext, T>(
68+
dev_ctx, x, y, out, transpose_x, transpose_y);
69+
}
70+
4871
} // namespace pten
4972

5073
PT_REGISTER_MODULE(LinalgCPU);
@@ -62,3 +85,7 @@ PT_REGISTER_KERNEL("dot",
6285
int64_t,
6386
complex64,
6487
complex128) {}
88+
89+
PT_REGISTER_KERNEL(
90+
"matmul_v2", CPU, ANY, pten::Matmul, float, double, complex64, complex128) {
91+
}

paddle/pten/kernels/cpu/linalg.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ void Dot(const CPUContext& dev_ctx,
3030
DenseTensor* out);
3131

3232
template <typename T>
33-
void matmul(const CPUContext& dev_ctx,
33+
void Matmul(const CPUContext& dev_ctx,
3434
const DenseTensor& x,
3535
const DenseTensor& y,
3636
bool transpose_x,

paddle/pten/kernels/cuda/linalg.cu

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/pten/core/kernel_registry.h"
1818
#include "paddle/pten/kernels/functions/eigen/dot.h"
19+
#include "paddle/pten/kernels/functions/math/matmul_func.h"
1920

2021
// See Note [ Why still include the fluid headers? ]
2122
#include "paddle/fluid/platform/complex.h"
@@ -30,10 +31,32 @@ void Dot(const CUDAContext& dev_ctx,
3031
eigen::Dot<CUDAContext, T>(dev_ctx, x, y, out);
3132
}
3233

34+
template <typename T>
35+
void Matmul(const CUDAContext& dev_ctx,
36+
const DenseTensor& x,
37+
const DenseTensor& y,
38+
bool transpose_x,
39+
bool transpose_y,
40+
DenseTensor* out) {
41+
PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()),
42+
0,
43+
paddle::platform::errors::InvalidArgument(
44+
"The Input(X) dims size must not be equal 0,"
45+
" but reviced dims size is 0. "));
46+
PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()),
47+
0,
48+
paddle::platform::errors::InvalidArgument(
49+
"The Input(Y) dims size must not be equal 0,"
50+
" but reviced dims size is 0. "));
51+
math::MatMulFunction<CUDAContext, T>(
52+
dev_ctx, x, y, out, transpose_x, transpose_y);
53+
}
54+
3355
} // namespace pten
3456

3557
PT_REGISTER_MODULE(LinalgCUDA);
3658

59+
using float16 = paddle::platform::float16;
3760
using complex64 = ::paddle::platform::complex<float>;
3861
using complex128 = ::paddle::platform::complex<double>;
3962

@@ -47,3 +70,13 @@ PT_REGISTER_KERNEL("dot",
4770
int64_t,
4871
complex64,
4972
complex128) {}
73+
74+
PT_REGISTER_KERNEL("matmul_v2",
75+
CUDA,
76+
ANY,
77+
pten::Matmul,
78+
float,
79+
double,
80+
float16,
81+
complex64,
82+
complex128) {}

0 commit comments

Comments
 (0)