Skip to content

Commit 60af285

Browse files
authored
[NPU] Support dataloader on npu place. (#31867)
1 parent 2a672f6 commit 60af285

File tree

7 files changed

+312
-9
lines changed

7 files changed

+312
-9
lines changed

paddle/fluid/operators/reader/buffered_reader.cc

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ BufferedReader::BufferedReader(
4343
buffer_size_(buffer_size),
4444
pin_memory_(pin_memory) {
4545
VLOG(1) << "BufferedReader";
46+
4647
#ifdef PADDLE_WITH_CUDA
4748
if (platform::is_gpu_place(place_) && !pin_memory) {
4849
int dev_idx = BOOST_GET_CONST(platform::CUDAPlace, place_).device;
@@ -57,9 +58,25 @@ BufferedReader::BufferedReader(
5758
stream_ = platform::CudaStreamResourcePool::Instance().New(dev_idx);
5859
}
5960
#endif
61+
62+
#ifdef PADDLE_WITH_ASCEND_CL
63+
if (platform::is_npu_place(place_)) {
64+
int dev_idx = BOOST_GET_CONST(platform::NPUPlace, place_).device;
65+
compute_stream_ =
66+
((platform::NPUDeviceContext *)(platform::DeviceContextPool::Instance()
67+
.Get(place_)))
68+
->stream();
69+
events_.resize(buffer_size);
70+
for (auto &event : events_) {
71+
event = platform::NpuEventResourcePool::Instance().New(dev_idx);
72+
}
73+
stream_ = platform::NpuStreamResourcePool::Instance().New(dev_idx);
74+
}
75+
#endif
6076
is_same_place_ = false;
6177
cpu_buffer_.resize(buffer_size);
6278
cuda_buffer_.resize(buffer_size);
79+
npu_buffer_.resize(buffer_size);
6380
ReadTillBufferFullAsync();
6481
}
6582

@@ -186,6 +203,58 @@ void BufferedReader::ReadAsync(size_t i) {
186203
}
187204
}
188205
#endif
206+
207+
#ifdef PADDLE_WITH_ASCEND_CL
208+
if (platform::is_npu_place(place_)) {
209+
TensorVec &npu = npu_buffer_[i];
210+
if (npu.empty()) {
211+
npu.resize(cpu.size());
212+
} else {
213+
PADDLE_ENFORCE_EQ(
214+
npu.size(), cpu.size(),
215+
platform::errors::InvalidArgument(
216+
"Input tensor number on NPU and CPU devices are not matched. "
217+
"The number on NPU is %d, on CPU is %d",
218+
npu.size(), cpu.size()));
219+
}
220+
221+
std::vector<void *> npu_ptrs;
222+
npu_ptrs.reserve(cpu.size());
223+
for (size_t i = 0; i < cpu.size(); ++i) {
224+
npu[i].Resize(cpu[i].dims());
225+
npu[i].set_layout(cpu[i].layout());
226+
npu_ptrs.emplace_back(npu[i].mutable_data(place_, cpu[i].type()));
227+
}
228+
229+
platform::SetNPUDeviceId(
230+
BOOST_GET_CONST(platform::NPUPlace, place_).device);
231+
PADDLE_ENFORCE_NPU_SUCCESS(
232+
aclrtRecordEvent(events_[i].get(), compute_stream_));
233+
PADDLE_ENFORCE_NPU_SUCCESS(
234+
aclrtStreamWaitEvent(stream_.get(), events_[i].get()));
235+
236+
platform::RecordEvent record_event("BufferedReader:MemoryCopy");
237+
for (size_t i = 0; i < cpu.size(); ++i) {
238+
auto cpu_place = cpu[i].place();
239+
auto cpu_ptr = cpu[i].data<void>();
240+
auto npu_ptr = npu_ptrs[i];
241+
auto size =
242+
cpu[i].numel() * paddle::framework::SizeOfType(cpu[i].type());
243+
if ((platform::is_npu_place(cpu_place))) {
244+
memory::Copy(BOOST_GET_CONST(platform::NPUPlace, place_), npu_ptr,
245+
BOOST_GET_CONST(platform::NPUPlace, cpu_place), cpu_ptr,
246+
size, stream_.get());
247+
} else {
248+
memory::Copy(BOOST_GET_CONST(platform::NPUPlace, place_), npu_ptr,
249+
BOOST_GET_CONST(platform::CPUPlace, cpu_place), cpu_ptr,
250+
size, stream_.get());
251+
PADDLE_ENFORCE_NPU_SUCCESS(aclrtSynchronizeStream(stream_.get()));
252+
}
253+
npu[i].set_lod(cpu[i].lod());
254+
}
255+
PADDLE_ENFORCE_NPU_SUCCESS(aclrtSynchronizeStream(stream_.get()));
256+
}
257+
#endif
189258
return i;
190259
}));
191260
}
@@ -217,9 +286,13 @@ void BufferedReader::ReadNextImpl(std::vector<framework::LoDTensor> *out) {
217286
return;
218287
}
219288

220-
*out = std::move((platform::is_gpu_place(place_) && !is_same_place_)
221-
? cuda_buffer_[i]
222-
: cpu_buffer_[i]);
289+
if (platform::is_gpu_place(place_) && !is_same_place_) {
290+
*out = std::move(cuda_buffer_[i]);
291+
} else if (platform::is_npu_place(place_) && !is_same_place_) {
292+
*out = std::move(npu_buffer_[i]);
293+
} else {
294+
*out = std::move(cpu_buffer_[i]);
295+
}
223296

224297
// Do not push current position into ReadAsync. Push the previous position
225298
// Since all computation in fluid are async, change the data of

paddle/fluid/operators/reader/buffered_reader.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
#include "paddle/fluid/platform/cuda_resource_pool.h"
2626
#include "paddle/fluid/platform/gpu_info.h"
2727
#endif
28-
28+
#ifdef PADDLE_WITH_ASCEND_CL
29+
#include "paddle/fluid/platform/npu_info.h"
30+
#include "paddle/fluid/platform/npu_resource_pool.h"
31+
#endif
2932
namespace paddle {
3033
namespace operators {
3134
namespace reader {
@@ -67,12 +70,20 @@ class BufferedReader : public framework::DecoratedReader {
6770
bool is_same_place_;
6871
std::vector<TensorVec> cpu_buffer_;
6972
std::vector<TensorVec> cuda_buffer_;
73+
std::vector<TensorVec> npu_buffer_;
7074
size_t prev_pos_{-1UL};
75+
7176
#ifdef PADDLE_WITH_CUDA
7277
cudaStream_t compute_stream_;
7378
std::shared_ptr<platform::CudaStreamObject> stream_;
7479
std::vector<std::shared_ptr<platform::CudaEventObject>> events_;
7580
#endif
81+
82+
#ifdef PADDLE_WITH_ASCEND_CL
83+
aclrtStream compute_stream_;
84+
std::shared_ptr<platform::NpuStreamObject> stream_;
85+
std::vector<std::shared_ptr<platform::NpuEventObject>> events_;
86+
#endif
7687
};
7788

7889
} // namespace reader

paddle/fluid/platform/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,11 @@ if(WITH_GPU)
135135
target_link_libraries(device_context cuda_resource_pool)
136136
endif()
137137

138+
if(WITH_ASCEND_CL)
139+
cc_library(npu_resource_pool SRCS npu_resource_pool.cc DEPS npu_info)
140+
target_link_libraries(device_context npu_resource_pool)
141+
endif()
142+
138143
nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
139144

140145
cc_test(init_test SRCS init_test.cc DEPS device_context)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifdef PADDLE_WITH_ASCEND_CL
16+
#include "paddle/fluid/platform/npu_resource_pool.h"
17+
#include "paddle/fluid/platform/npu_info.h"
18+
19+
namespace paddle {
20+
namespace platform {
21+
22+
NpuStreamResourcePool::NpuStreamResourcePool() {
23+
int dev_cnt = platform::GetNPUDeviceCount();
24+
pool_.reserve(dev_cnt);
25+
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
26+
auto creator = [dev_idx] {
27+
platform::SetNPUDeviceId(dev_idx);
28+
aclrtStream stream;
29+
PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateStream(&stream));
30+
return stream;
31+
};
32+
33+
auto deleter = [dev_idx](aclrtStream stream) {
34+
platform::SetNPUDeviceId(dev_idx);
35+
PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyStream(stream));
36+
};
37+
38+
pool_.emplace_back(ResourcePool<NpuStreamObject>::Create(creator, deleter));
39+
}
40+
}
41+
42+
NpuStreamResourcePool& NpuStreamResourcePool::Instance() {
43+
static NpuStreamResourcePool pool;
44+
return pool;
45+
}
46+
47+
std::shared_ptr<NpuStreamObject> NpuStreamResourcePool::New(int dev_idx) {
48+
PADDLE_ENFORCE_GE(
49+
dev_idx, 0,
50+
platform::errors::InvalidArgument(
51+
"The dev_idx should be not less than 0, but got %d.", dev_idx));
52+
PADDLE_ENFORCE_LT(
53+
dev_idx, pool_.size(),
54+
platform::errors::OutOfRange(
55+
"The dev_idx should be less than device count %d, but got %d.",
56+
pool_.size(), dev_idx));
57+
return pool_[dev_idx]->New();
58+
}
59+
60+
NpuEventResourcePool::NpuEventResourcePool() {
61+
int dev_cnt = platform::GetNPUDeviceCount();
62+
pool_.reserve(dev_cnt);
63+
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
64+
auto creator = [dev_idx] {
65+
platform::SetNPUDeviceId(dev_idx);
66+
aclrtEvent event;
67+
PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateEvent(&event));
68+
return event;
69+
};
70+
71+
auto deleter = [dev_idx](aclrtEvent event) {
72+
platform::SetNPUDeviceId(dev_idx);
73+
PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyEvent(event));
74+
};
75+
76+
pool_.emplace_back(ResourcePool<NpuEventObject>::Create(creator, deleter));
77+
}
78+
}
79+
80+
NpuEventResourcePool& NpuEventResourcePool::Instance() {
81+
static NpuEventResourcePool pool;
82+
return pool;
83+
}
84+
85+
std::shared_ptr<NpuEventObject> NpuEventResourcePool::New(int dev_idx) {
86+
PADDLE_ENFORCE_GE(
87+
dev_idx, 0,
88+
platform::errors::InvalidArgument(
89+
"The dev_idx should be not less than 0, but got %d.", dev_idx));
90+
PADDLE_ENFORCE_LT(
91+
dev_idx, pool_.size(),
92+
platform::errors::OutOfRange(
93+
"The dev_idx should be less than device count %d, but got %d.",
94+
pool_.size(), dev_idx));
95+
return pool_[dev_idx]->New();
96+
}
97+
98+
} // namespace platform
99+
} // namespace paddle
100+
101+
#endif
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#ifdef PADDLE_WITH_ASCEND_CL
18+
#include <memory>
19+
#include <type_traits>
20+
#include <vector>
21+
22+
#include "acl/acl.h"
23+
#include "paddle/fluid/platform/resource_pool.h"
24+
25+
namespace paddle {
26+
namespace platform {
27+
28+
using NpuStreamObject = std::remove_pointer<aclrtStream>::type;
29+
using NpuEventObject = std::remove_pointer<aclrtEvent>::type;
30+
31+
class NpuStreamResourcePool {
32+
public:
33+
std::shared_ptr<NpuStreamObject> New(int dev_idx);
34+
35+
static NpuStreamResourcePool &Instance();
36+
37+
private:
38+
NpuStreamResourcePool();
39+
40+
DISABLE_COPY_AND_ASSIGN(NpuStreamResourcePool);
41+
42+
private:
43+
std::vector<std::shared_ptr<ResourcePool<NpuStreamObject>>> pool_;
44+
};
45+
46+
class NpuEventResourcePool {
47+
public:
48+
std::shared_ptr<NpuEventObject> New(int dev_idx);
49+
50+
static NpuEventResourcePool &Instance();
51+
52+
private:
53+
NpuEventResourcePool();
54+
55+
DISABLE_COPY_AND_ASSIGN(NpuEventResourcePool);
56+
57+
private:
58+
std::vector<std::shared_ptr<ResourcePool<NpuEventObject>>> pool_;
59+
};
60+
61+
} // namespace platform
62+
} // namespace paddle
63+
64+
#endif
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import division
16+
17+
import sys
18+
import unittest
19+
import numpy as np
20+
21+
import paddle
22+
from ..unittests.test_multiprocess_dataloader_static import TestStaticDataLoader
23+
24+
paddle.enable_static()
25+
26+
27+
class TestStaticDataLoader(TestStaticDataLoader):
28+
def test_main(self):
29+
results = []
30+
places = [paddle.NPUPlace(0)]
31+
32+
for num_workers in [0, 2]:
33+
print(self.__class__.__name__, places, num_workers)
34+
sys.stdout.flush()
35+
ret = self._run_main(
36+
num_workers=num_workers, places=places, use_pe=False)
37+
results.append(ret)
38+
39+
diff = np.max(
40+
np.abs(results[0]['loss'] - results[1]['loss']) /
41+
np.abs(results[0]['loss']))
42+
self.assertLess(diff, 1e-2)
43+
44+
45+
if __name__ == '__main__':
46+
unittest.main()

python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True):
101101

102102

103103
class TestStaticDataLoader(unittest.TestCase):
104-
def run_main(self, num_workers, places):
104+
def run_main(self, num_workers, places, use_pe=True):
105105
scope = fluid.Scope()
106106
with fluid.scope_guard(scope):
107107
startup_prog, main_prog, image, label, loss = simple_fc_net_static()
@@ -120,10 +120,13 @@ def run_main(self, num_workers, places):
120120
exe = fluid.Executor(place=places[0])
121121
exe.run(startup_prog)
122122

123-
prog = fluid.CompiledProgram(main_prog)
124-
if len(places) > 1:
125-
prog = prog.with_data_parallel(
126-
loss_name=loss.name, places=places)
123+
if use_pe:
124+
prog = fluid.CompiledProgram(main_prog)
125+
if len(places) > 1:
126+
prog = prog.with_data_parallel(
127+
loss_name=loss.name, places=places)
128+
else:
129+
prog = main_prog
127130

128131
step_list = []
129132
loss_list = []

0 commit comments

Comments
 (0)