Skip to content

Commit 09e965f

Browse files
committed
fix the ut,test=develop, test=kunlun
1 parent 5acf466 commit 09e965f

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

paddle/fluid/operators/gather_op_xpu.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
1+
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -87,13 +87,21 @@ class GatherGradOpXPUKernel : public framework::OpKernel<T> {
8787
auto *index = ctx.Input<Tensor>("Index");
8888
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
8989
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
90+
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
9091

9192
if (ctx.HasInput("Axis")) {
9293
PADDLE_THROW(platform::errors::InvalidArgument(
9394
"Now, it doesn't support XPU with Axis."));
9495
}
9596

9697
dx->mutable_data<T>(ctx.GetPlace());
98+
const int zero = 0;
99+
int r_dx = xpu::memset(dev_ctx.x_context(), dx->data<T>(), zero,
100+
dx->numel() * sizeof(T));
101+
PADDLE_ENFORCE_EQ(
102+
r_dx, xpu::Error_t::SUCCESS,
103+
platform::errors::External("XPU kernel error! error code=%d", r_dx));
104+
97105
if (dout->numel() == 0) {
98106
return;
99107
}
@@ -127,7 +135,6 @@ class GatherGradOpXPUKernel : public framework::OpKernel<T> {
127135
int index_size = index_dims[0];
128136
int slice_size = dout->numel() / dout->dims()[0];
129137

130-
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
131138
int r = xpu::scatter<T>(dev_ctx.x_context(), dout->data<T>(),
132139
index->data<int>(), index_size, slice_size,
133140
dx->data<T>(), overwrite);

0 commit comments

Comments
 (0)