@@ -107,15 +107,21 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
107107 sizeof (bias_data)));
108108
109109 if (enable_auxiliary && activation != " none" ) {
110- size_t reserve_space_size = 0 ;
110+ // Note (Ming Huang): The initialization of ReseveSpace is happened in the
111+ // dev_ctx.Alloc. Therefore, we set real date type up here.
111112 if (activation == " relu" ) {
112- // Count in bits.
113- reserve_space_size = phi::product (out->dims ()) / 8 ;
113+ paddle::experimental::DataType rs_type =
114+ paddle::experimental::DataType::BOOL;
115+ size_t reserve_space_size =
116+ phi::product (reserve_space->dims ()) * SizeOf (rs_type);
117+ dev_ctx.Alloc (reserve_space, rs_type, reserve_space_size);
114118 } else {
115- reserve_space_size = phi::product (out->dims ()) * sizeof (T);
119+ size_t reserve_space_size =
120+ phi::product (reserve_space->dims ()) * sizeof (T);
121+ dev_ctx.Alloc <T>(reserve_space, reserve_space_size);
116122 }
117- dev_ctx. Alloc (reserve_space, out-> type (), reserve_space_size);
118- void * aux_data = reinterpret_cast < void *>( reserve_space->data <T>() );
123+
124+ void * aux_data = reserve_space->data ( );
119125
120126 PADDLE_ENFORCE_GPU_SUCCESS (
121127 platform::dynload::cublasLtMatmulDescSetAttribute (
@@ -185,7 +191,6 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
185191 stream,
186192 workspace->ptr (),
187193 workspace_size);
188-
189194 PADDLE_ENFORCE_GPU_SUCCESS (
190195 platform::dynload::cublasLtMatmul (lt_handle,
191196 operation_desc,
@@ -478,7 +483,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
478483 sizeof (epiloque_func_for_dx)));
479484
480485 if (activation_grad != " none" ) {
481- auto * aux_data = reserve_space->data <T> ();
486+ auto * aux_data = reserve_space->data ();
482487 PADDLE_ENFORCE_GPU_SUCCESS (
483488 platform::dynload::cublasLtMatmulDescSetAttribute (
484489 dx_operation_desc,
0 commit comments