@@ -106,15 +106,21 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
106106 sizeof (bias_data)));
107107
108108 if (enable_auxiliary && activation != " none" ) {
109- size_t reserve_space_size = 0 ;
109+ // Note (Ming Huang): The initialization of ReseveSpace is happened in the
110+ // dev_ctx.Alloc. Therefore, we set real date type up here.
110111 if (activation == " relu" ) {
111- // Count in bits.
112- reserve_space_size = phi::product (out->dims ()) / 8 ;
112+ paddle::experimental::DataType rs_type =
113+ paddle::experimental::DataType::BOOL;
114+ size_t reserve_space_size =
115+ phi::product (reserve_space->dims ()) * SizeOf (rs_type);
116+ dev_ctx.Alloc (reserve_space, rs_type, reserve_space_size);
113117 } else {
114- reserve_space_size = phi::product (out->dims ()) * sizeof (T);
118+ size_t reserve_space_size =
119+ phi::product (reserve_space->dims ()) * sizeof (T);
120+ dev_ctx.Alloc <T>(reserve_space, reserve_space_size);
115121 }
116- dev_ctx. Alloc (reserve_space, out-> type (), reserve_space_size);
117- void * aux_data = reinterpret_cast < void *>( reserve_space->data <T>() );
122+
123+ void * aux_data = reserve_space->data ( );
118124
119125 PADDLE_ENFORCE_GPU_SUCCESS (
120126 platform::dynload::cublasLtMatmulDescSetAttribute (
@@ -184,7 +190,6 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
184190 stream,
185191 workspace->ptr (),
186192 workspace_size);
187-
188193 PADDLE_ENFORCE_GPU_SUCCESS (
189194 platform::dynload::cublasLtMatmul (lt_handle,
190195 operation_desc,
@@ -478,7 +483,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
478483 sizeof (epiloque_func_for_dx)));
479484
480485 if (activation_grad != " none" ) {
481- auto * aux_data = reserve_space->data <T> ();
486+ auto * aux_data = reserve_space->data ();
482487 PADDLE_ENFORCE_GPU_SUCCESS (
483488 platform::dynload::cublasLtMatmulDescSetAttribute (
484489 dx_operation_desc,
0 commit comments