Skip to content

Commit e6af7c0

Browse files
authored
[NPU] fix some bugs of npu op (#31739)
* fix softmax * fix mean * fix lookup_table_v2
1 parent 17862b7 commit e6af7c0

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,4 @@ endif()
187187

188188
if(WITH_ASCEND_CL)
189189
cc_test(gelu_op_npu_test SRCS gelu_op_npu_test.cc DEPS op_registry gelu_op scope device_context enforce executor)
190-
cc_test(mean_op_npu_test SRCS mean_op_npu_test.cc DEPS op_registry mean_op scope device_context enforce executor)
191190
endif()
192-

paddle/fluid/operators/lookup_table_v2_op_npu.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
5454
auto *table_t = ctx.Input<framework::LoDTensor>("W");
5555
auto *table_grad_t =
5656
ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
57+
table_grad_t->mutable_data<T>(ctx.GetPlace());
5758
framework::NPUAttributeMap attr_input = {{"use_locking", true}};
5859

5960
auto runner = NpuOpRunner("ScatterAdd", {*table_t, *ids_t, *output_grad_t},

paddle/fluid/operators/softmax_op.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,10 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
206206
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
207207
ctx, framework::GradVarName("Out"));
208208
if (input_data_type == framework::proto::VarType::FP16) {
209-
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
210-
platform::errors::InvalidArgument(
211-
"float16 can only be used on GPU place"));
209+
if (!(platform::is_gpu_place(ctx.GetPlace()) ||
210+
platform::is_npu_place(ctx.GetPlace())))
211+
PADDLE_THROW(platform::errors::InvalidArgument(
212+
"float16 can only be used on GPU/NPU place"));
212213
}
213214

214215
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,

0 commit comments

Comments
 (0)