|
1 | | -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. |
| 1 | +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
2 | 2 |
|
3 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | you may not use this file except in compliance with the License. |
@@ -87,13 +87,21 @@ class GatherGradOpXPUKernel : public framework::OpKernel<T> { |
87 | 87 | auto *index = ctx.Input<Tensor>("Index"); |
88 | 88 | auto *dx = ctx.Output<Tensor>(framework::GradVarName("X")); |
89 | 89 | auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out")); |
| 90 | + auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>(); |
90 | 91 |
|
91 | 92 | if (ctx.HasInput("Axis")) { |
92 | 93 | PADDLE_THROW(platform::errors::InvalidArgument( |
93 | 94 | "Now, it doesn't support XPU with Axis.")); |
94 | 95 | } |
95 | 96 |
|
96 | 97 | dx->mutable_data<T>(ctx.GetPlace()); |
| 98 | + const int zero = 0; |
| 99 | + int r_dx = xpu::memset(dev_ctx.x_context(), dx->data<T>(), zero, |
| 100 | + dx->numel() * sizeof(T)); |
| 101 | + PADDLE_ENFORCE_EQ( |
| 102 | + r_dx, xpu::Error_t::SUCCESS, |
| 103 | + platform::errors::External("XPU kernel error! error code=%d", r_dx)); |
| 104 | + |
97 | 105 | if (dout->numel() == 0) { |
98 | 106 | return; |
99 | 107 | } |
@@ -127,7 +135,6 @@ class GatherGradOpXPUKernel : public framework::OpKernel<T> { |
127 | 135 | int index_size = index_dims[0]; |
128 | 136 | int slice_size = dout->numel() / dout->dims()[0]; |
129 | 137 |
|
130 | | - auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>(); |
131 | 138 | int r = xpu::scatter<T>(dev_ctx.x_context(), dout->data<T>(), |
132 | 139 | index->data<int>(), index_size, slice_size, |
133 | 140 | dx->data<T>(), overwrite); |
|
0 commit comments