File tree Expand file tree Collapse file tree 3 files changed +14
-5
lines changed
Expand file tree Collapse file tree 3 files changed +14
-5
lines changed Original file line number Diff line number Diff line change @@ -76,6 +76,7 @@ BufferedReader::BufferedReader(
7676 is_same_place_ = false ;
7777 cpu_buffer_.resize (buffer_size);
7878 cuda_buffer_.resize (buffer_size);
79+ npu_buffer_.resize (buffer_size);
7980 ReadTillBufferFullAsync ();
8081}
8182
@@ -254,7 +255,6 @@ void BufferedReader::ReadAsync(size_t i) {
254255 PADDLE_ENFORCE_NPU_SUCCESS (aclrtSynchronizeStream (stream_.get ()));
255256 }
256257#endif
257-
258258 return i;
259259 }));
260260}
@@ -286,9 +286,13 @@ void BufferedReader::ReadNextImpl(std::vector<framework::LoDTensor> *out) {
286286 return ;
287287 }
288288
289- *out = std::move ((platform::is_gpu_place (place_) && !is_same_place_)
290- ? cuda_buffer_[i]
291- : cpu_buffer_[i]);
289+ if (platform::is_gpu_place (place_) && !is_same_place_) {
290+ *out = cuda_buffer_[i];
291+ } else if (platform::is_npu_place (place_) && !is_same_place_) {
292+ *out = npu_buffer_[i];
293+ } else {
294+ *out = cpu_buffer_[i];
295+ }
292296
293297 // Do not push current position into ReadAsync. Push the previous position
294298 // Since all computation in fluid are async, change the data of
Original file line number Diff line number Diff line change @@ -135,6 +135,11 @@ if(WITH_GPU)
135135 target_link_libraries (device_context cuda_resource_pool)
136136endif ()
137137
138+ if (WITH_ASCEND_CL)
139+ cc_library(npu_resource_pool SRCS npu_resource_pool.cc DEPS npu_info)
140+ target_link_libraries (device_context npu_resource_pool)
141+ endif ()
142+
138143nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
139144
140145cc_test(init_test SRCS init_test.cc DEPS device_context)
Original file line number Diff line number Diff line change @@ -69,7 +69,7 @@ NpuEventResourcePool::NpuEventResourcePool() {
6969 };
7070
7171 auto deleter = [dev_idx](aclrtEvent event) {
72- platform::SetDeviceId (dev_idx);
72+ platform::SetNPUDeviceId (dev_idx);
7373 PADDLE_ENFORCE_NPU_SUCCESS (aclrtDestroyEvent (event));
7474 };
7575
You can’t perform that action at this time.
0 commit comments