Skip to content

Commit cdab3a4

Browse files
Fix nullptr to TestFuseGemmEpilogueReluBWDFP* (#48997) (#49090)
Co-authored-by: Ming-Xu Huang <[email protected]>
1 parent ddcd1b6 commit cdab3a4

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,8 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
139139
}
140140

141141
ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
142-
// Note (Ming Huang): Reserve space of relu is a bit-mask,
143-
// which cannot pass nan_and_inf checking if shape is set.
144-
if (activation == "gelu" && ctx->HasOutput("ReserveSpace")) {
142+
143+
if (ctx->HasOutput("ReserveSpace")) {
145144
ctx->SetOutputDim("ReserveSpace", phi::make_ddim(out_dims));
146145
}
147146
}

paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,21 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
107107
sizeof(bias_data)));
108108

109109
if (enable_auxiliary && activation != "none") {
110-
size_t reserve_space_size = 0;
110+
// Note (Ming Huang): The initialization of ReseveSpace is happened in the
111+
// dev_ctx.Alloc. Therefore, we set real date type up here.
111112
if (activation == "relu") {
112-
// Count in bits.
113-
reserve_space_size = phi::product(out->dims()) / 8;
113+
paddle::experimental::DataType rs_type =
114+
paddle::experimental::DataType::BOOL;
115+
size_t reserve_space_size =
116+
phi::product(reserve_space->dims()) * SizeOf(rs_type);
117+
dev_ctx.Alloc(reserve_space, rs_type, reserve_space_size);
114118
} else {
115-
reserve_space_size = phi::product(out->dims()) * sizeof(T);
119+
size_t reserve_space_size =
120+
phi::product(reserve_space->dims()) * sizeof(T);
121+
dev_ctx.Alloc<T>(reserve_space, reserve_space_size);
116122
}
117-
dev_ctx.Alloc(reserve_space, out->type(), reserve_space_size);
118-
void* aux_data = reinterpret_cast<void*>(reserve_space->data<T>());
123+
124+
void* aux_data = reserve_space->data();
119125

120126
PADDLE_ENFORCE_GPU_SUCCESS(
121127
platform::dynload::cublasLtMatmulDescSetAttribute(
@@ -185,7 +191,6 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
185191
stream,
186192
workspace->ptr(),
187193
workspace_size);
188-
189194
PADDLE_ENFORCE_GPU_SUCCESS(
190195
platform::dynload::cublasLtMatmul(lt_handle,
191196
operation_desc,
@@ -478,7 +483,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
478483
sizeof(epiloque_func_for_dx)));
479484

480485
if (activation_grad != "none") {
481-
auto* aux_data = reserve_space->data<T>();
486+
auto* aux_data = reserve_space->data();
482487
PADDLE_ENFORCE_GPU_SUCCESS(
483488
platform::dynload::cublasLtMatmulDescSetAttribute(
484489
dx_operation_desc,

0 commit comments

Comments
 (0)