@@ -17,9 +17,7 @@ limitations under the License. */
1717#include " paddle/fluid/operators/math/blas.h"
1818#include " paddle/fluid/operators/math/complex_functors.h"
1919
20- #include " paddle/fluid/operators/eigen/eigen_function.h"
2120#include " paddle/pten/core/dense_tensor.h"
22- #include " paddle/pten/kernels/functions/eigen/common.h"
2321
2422namespace pten {
2523namespace math {
@@ -105,34 +103,34 @@ void MatMulFunction(const DeviceContext& dev_ctx,
105103 const T* x_data = X.data <T>();
106104 const T* y_data = Y.data <T>();
107105
106+ auto blas = paddle::operators::math::GetBlas<DeviceContext, T>(dev_ctx);
107+
108108 if (x_ndim == 1 && y_ndim == 1 ) {
109+ const int M = X.numel ();
110+ const int N = Y.numel ();
109111 PADDLE_ENFORCE_EQ (
110- X. numel () ,
111- Y. numel () ,
112+ M ,
113+ N ,
112114 paddle::platform::errors::InvalidArgument (
113115 " X's numbers must be equal to Y's numbers,"
114116 " when X/Y's dims =1. But received X has [%d] elements,"
115117 " received Y has [%d] elements" ,
116- X. numel () ,
117- Y. numel () ));
118+ M ,
119+ N ));
118120 VLOG (3 ) << " MatMul's case 1" ;
119- Out->Resize ({1 });
120- Out->mutable_data <T>();
121- auto out_eigen = EigenScalar<T>::From (*Out);
122- auto x_eigen = EigenVector<T>::Flatten (X);
123- auto y_eigen = EigenVector<T>::Flatten (Y);
124-
125- auto & dev = *dev_ctx.eigen_device ();
126- if (flag) {
127- out_eigen.device (dev) = (x_eigen * y_eigen).sum () + out_eigen;
128- } else {
129- out_eigen.device (dev) = (x_eigen * y_eigen).sum ();
130- }
121+ blas.GEMM (CblasNoTrans,
122+ CblasTrans,
123+ 1 ,
124+ 1 ,
125+ M,
126+ static_cast <T>(1 ),
127+ y_data,
128+ x_data,
129+ static_cast <T>(flag),
130+ Out->mutable_data <T>());
131131 return ;
132132 }
133133
134- auto blas = paddle::operators::math::GetBlas<DeviceContext, T>(dev_ctx);
135-
136134 if (x_ndim == 1 ) {
137135 const int N = X.numel ();
138136 if (trans_y) {
0 commit comments