Skip to content

Commit 444a735

Browse files
authored
Optimize Matmul_v2 (#37037)
Optimize dot product of Matmul_v2
1 parent 6b0cc2b commit 444a735

File tree

1 file changed

+18
-20
lines changed

1 file changed

+18
-20
lines changed

paddle/pten/kernels/functions/math/matmul_func.h

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ limitations under the License. */
1717
#include "paddle/fluid/operators/math/blas.h"
1818
#include "paddle/fluid/operators/math/complex_functors.h"
1919

20-
#include "paddle/fluid/operators/eigen/eigen_function.h"
2120
#include "paddle/pten/core/dense_tensor.h"
22-
#include "paddle/pten/kernels/functions/eigen/common.h"
2321

2422
namespace pten {
2523
namespace math {
@@ -105,34 +103,34 @@ void MatMulFunction(const DeviceContext& dev_ctx,
105103
const T* x_data = X.data<T>();
106104
const T* y_data = Y.data<T>();
107105

106+
auto blas = paddle::operators::math::GetBlas<DeviceContext, T>(dev_ctx);
107+
108108
if (x_ndim == 1 && y_ndim == 1) {
109+
const int M = X.numel();
110+
const int N = Y.numel();
109111
PADDLE_ENFORCE_EQ(
110-
X.numel(),
111-
Y.numel(),
112+
M,
113+
N,
112114
paddle::platform::errors::InvalidArgument(
113115
"X's numbers must be equal to Y's numbers,"
114116
"when X/Y's dims =1. But received X has [%d] elements,"
115117
"received Y has [%d] elements",
116-
X.numel(),
117-
Y.numel()));
118+
M,
119+
N));
118120
VLOG(3) << "MatMul's case 1";
119-
Out->Resize({1});
120-
Out->mutable_data<T>();
121-
auto out_eigen = EigenScalar<T>::From(*Out);
122-
auto x_eigen = EigenVector<T>::Flatten(X);
123-
auto y_eigen = EigenVector<T>::Flatten(Y);
124-
125-
auto& dev = *dev_ctx.eigen_device();
126-
if (flag) {
127-
out_eigen.device(dev) = (x_eigen * y_eigen).sum() + out_eigen;
128-
} else {
129-
out_eigen.device(dev) = (x_eigen * y_eigen).sum();
130-
}
121+
blas.GEMM(CblasNoTrans,
122+
CblasTrans,
123+
1,
124+
1,
125+
M,
126+
static_cast<T>(1),
127+
y_data,
128+
x_data,
129+
static_cast<T>(flag),
130+
Out->mutable_data<T>());
131131
return;
132132
}
133133

134-
auto blas = paddle::operators::math::GetBlas<DeviceContext, T>(dev_ctx);
135-
136134
if (x_ndim == 1) {
137135
const int N = X.numel();
138136
if (trans_y) {

0 commit comments

Comments
 (0)