@@ -209,19 +209,19 @@ void Copy<platform::NPUPlace, platform::CPUPlace>(platform::NPUPlace dst_place,
209209
210210 platform::SetNPUDeviceId (dst_place.device );
211211
212- // NOTE(ascendrc): NPU memcpy async from host to device is a "real" async,
213- // which is different from CUDA. In Paddle, when async is called, "sync"
214- // is run actually, which means Paddle doesn't fully supported async.
215- // TODO(ascendrc): Support NPU memcpy async for better performance.
216- stream = nullptr ;
217-
218212 VLOG (4 ) << " memory::Copy " << num << " Bytes from " << src_place << " to "
219213 << dst_place << " by thream(" << stream << " )" ;
220214
221215 if (stream) {
222216 platform::RecordEvent record_event (" NpuMemcpyAsync:CPU->NPU" );
223217 platform::NPUMemcpyAsync (dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE, stream);
224218 } else {
219+ // On NPU, async operation after sync operation is ok, while sync operation
220+ // after async is not ok, since the async operation may not done.
221+ // So, its needed to do wait before sync operation.
222+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance ();
223+ static_cast <platform::NPUDeviceContext*>(pool.Get (dst_place))->Wait ();
224+
225225 platform::RecordEvent record_event (" NpuMemcpySync:CPU->NPU" );
226226 platform::NPUMemcpySync (dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE);
227227 }
@@ -237,19 +237,16 @@ void Copy<platform::CPUPlace, platform::NPUPlace>(platform::CPUPlace dst_place,
237237
238238 platform::SetNPUDeviceId (src_place.device );
239239
240- // NOTE(ascendrc): NPU memcpy async from device to host is a "real" async,
241- // which is different from CUDA. In Paddle, when async is called, "sync"
242- // is run actually, which means Paddle doesn't fully supported async.
243- // TODO(ascendrc): Support NPU memcpy async for better performance.
244- stream = nullptr ;
245-
246240 VLOG (4 ) << " memory::Copy " << num << " Bytes from " << src_place << " to "
247241 << dst_place << " by thream(" << stream << " )" ;
248242
249243 if (stream) {
250244 platform::RecordEvent record_event (" NpuMemcpyAsync:NPU->CPU" );
251245 platform::NPUMemcpyAsync (dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST, stream);
252246 } else {
247+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance ();
248+ static_cast <platform::NPUDeviceContext*>(pool.Get (dst_place))->Wait ();
249+
253250 platform::RecordEvent record_event (" GpuMemcpySync:NPU->CPU" );
254251 platform::NPUMemcpySync (dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST);
255252 }
@@ -272,6 +269,10 @@ void Copy<platform::NPUPlace, platform::NPUPlace>(platform::NPUPlace dst_place,
272269 platform::NPUMemcpyAsync (dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
273270 stream);
274271 } else {
272+ platform::DeviceContextPool& pool =
273+ platform::DeviceContextPool::Instance ();
274+ static_cast <platform::NPUDeviceContext*>(pool.Get (dst_place))->Wait ();
275+
275276 platform::RecordEvent record_event (" NpuMemcpySync(same_npu):NPU->NPU" );
276277 platform::NPUMemcpySync (dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
277278 }
@@ -286,6 +287,10 @@ void Copy<platform::NPUPlace, platform::NPUPlace>(platform::NPUPlace dst_place,
286287 platform::NPUMemcpyAsync (dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
287288 stream);
288289 } else {
290+ platform::DeviceContextPool& pool =
291+ platform::DeviceContextPool::Instance ();
292+ static_cast <platform::NPUDeviceContext*>(pool.Get (dst_place))->Wait ();
293+
289294 platform::RecordEvent record_event (" NpuMemcpyPeerSync:NPU->NPU" );
290295 platform::NPUMemcpySync (dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
291296 }
0 commit comments