@@ -11,7 +11,6 @@ limitations under the License. */
1111
1212#pragma once
1313#include < glog/logging.h>
14- #include < algorithm>
1514#include < string>
1615#include < unordered_set>
1716#include < utility>
@@ -25,7 +24,6 @@ limitations under the License. */
2524#include " paddle/fluid/framework/eigen.h"
2625#include " paddle/fluid/framework/op_registry.h"
2726#include " paddle/fluid/operators/detail/safe_ref.h"
28- #include " paddle/fluid/operators/math/blas.h"
2927#include " paddle/fluid/platform/float16.h"
3028
3129#ifdef PADDLE_WITH_MKLDNN
@@ -303,28 +301,8 @@ template <typename T>
303301struct GeluFunctor : public BaseActivationFunctor <T> {
304302 template <typename Device, typename X, typename Out>
305303 void operator ()(Device d, X x, Out out) const {
306- // Because the execute or device context can not be deliver here, it keep the
307- // marco for NVCC.
308- #if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
309- !defined (__OSX__) && !defined (PADDLE_WITH_CUDA)
310- auto x_data = x.data ();
311- auto out_data = out.data ();
312- int n = std::min (x.size (), out.size ());
313-
314- std::memset (out_data, 0 , n * sizeof (T));
315- math::CBlas<T>::AXPY (n, static_cast <T>(M_SQRT1_2), x_data, 1 , out_data, 1 );
316- math::CBlas<T>::VMERF (n, out_data, out_data, VML_LA);
317- for (int i = 0 ; i < n; i++) {
318- out_data[i] += static_cast <T>(1 );
319- }
320- math::CBlas<T>::VMUL (n, x_data, out_data, out_data);
321- for (int i = 0 ; i < n; i++) {
322- out_data[i] *= static_cast <T>(0.5 );
323- }
324- #else
325304 auto temp = (x * static_cast <T>(M_SQRT1_2)).erf ();
326305 out.device (d) = x * static_cast <T>(0.5 ) * (static_cast <T>(1 ) + temp);
327- #endif
328306 }
329307};
330308
0 commit comments