Skip to content

Commit cb73fee

Browse files
authored
add Wait after TensorCopy (#34005)
1 parent cbf22d6 commit cb73fee

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

paddle/fluid/operators/uniform_random_op_npu.cc

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,13 @@ class NPUUniformRandomKernel : public framework::OpKernel<T> {
5656
"unsupport type: %s.",
5757
framework::ToTypeName(out_var->Type())));
5858
}
59-
T *data = tensor->mutable_data<T>(ctx.GetPlace());
60-
59+
tensor->mutable_data<T>(ctx.GetPlace());
6160
int64_t size = tensor->numel();
62-
std::unique_ptr<T[]> data_cpu(new T[size]);
61+
62+
Tensor cpu_tensor(tensor->type());
63+
cpu_tensor.Resize(tensor->dims());
64+
T *data_cpu = cpu_tensor.mutable_data<T>(platform::CPUPlace());
65+
6366
std::uniform_real_distribution<T> dist(
6467
static_cast<T>(ctx.Attr<float>("min")),
6568
static_cast<T>(ctx.Attr<float>("max")));
@@ -90,12 +93,10 @@ class NPUUniformRandomKernel : public framework::OpKernel<T> {
9093
}
9194

9295
// copy to NPU
93-
auto stream =
94-
ctx.template device_context<paddle::platform::NPUDeviceContext>()
95-
.stream();
96-
memory::Copy(BOOST_GET_CONST(platform::NPUPlace, ctx.GetPlace()), data,
97-
platform::CPUPlace(), reinterpret_cast<void *>(data_cpu.get()),
98-
size * sizeof(T), stream);
96+
framework::TensorCopy(
97+
cpu_tensor, ctx.GetPlace(),
98+
ctx.template device_context<platform::DeviceContext>(), tensor);
99+
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
99100
}
100101
};
101102

python/paddle/fluid/tests/unittests/npu/test_uniform_random_op_npu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def init_dtype(self):
6767
self.dtype = np.float32
6868

6969
def test_check_output(self):
70-
self.check_output_customized(self.verify_output)
70+
self.check_output_customized(self.verify_output, self.place)
7171

7272
def verify_output(self, outs):
7373
hist, prob = self.output_hist(np.array(outs[0]))

python/paddle/fluid/tests/unittests/op_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1357,8 +1357,10 @@ def check_output(self,
13571357
if self.op_type not in compile_vs_runtime_white_list.COMPILE_RUN_OP_WHITE_LIST:
13581358
self.check_compile_vs_runtime(fetch_list, outs)
13591359

1360-
def check_output_customized(self, checker):
1360+
def check_output_customized(self, checker, custom_place=None):
13611361
places = self._get_places()
1362+
if custom_place:
1363+
places.append(custom_place)
13621364
for place in places:
13631365
outs = self.calc_output(place)
13641366
outs = [np.array(out) for out in outs]

0 commit comments

Comments
 (0)