-
Notifications
You must be signed in to change notification settings - Fork 6k
Expand file tree
/
Copy pathdevice_context.cc
More file actions
532 lines (456 loc) · 17 KB
/
device_context.cc
File metadata and controls
532 lines (456 loc) · 17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device_context.h"
#include <set>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/memory/memory.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#include "glog/logging.h"
namespace paddle {
namespace platform {
DeviceContextPool* DeviceContextPool::pool = nullptr;
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
auto it = device_contexts_.find(place);
if (it == device_contexts_.end()) {
PADDLE_THROW(
"Place %s is not supported, Please re-compile with WITH_GPU "
"option",
place);
}
return it->second.get().get();
}
template <typename DevCtx, typename PlaceType>
inline void EmplaceDeviceContext(
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
map_ptr,
platform::Place p) {
using PtrType = std::unique_ptr<DeviceContext>;
map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
// lazy evaluation. i.e., only create device context at
// first `Get`
return PtrType(new DevCtx(boost::get<PlaceType>(p)));
}));
}
DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
std::set<Place> set;
for (auto& p : places) {
set.insert(p);
}
for (auto& p : set) {
if (platform::is_cpu_place(p)) {
#ifdef PADDLE_WITH_MKLDNN
EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
#else
EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
#endif
} else if (platform::is_gpu_place(p)) {
#ifdef PADDLE_WITH_CUDA
EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
#else
PADDLE_THROW(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif
} else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
&device_contexts_, p);
#else
PADDLE_THROW(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif
}
}
}
DeviceTemporaryAllocator* DeviceTemporaryAllocator::allocators = nullptr;
#ifdef PADDLE_WITH_CUDA
platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::Place& place, const cudaStream_t& stream) {
PADDLE_ENFORCE(platform::is_gpu_place(place));
auto place_stream = std::make_pair(place, stream);
std::unique_lock<std::mutex> lock(mtx_);
auto it = device_allocator_.find(place_stream);
if (it == device_allocator_.end()) {
auto tmp_allocator = new TemporaryAllocator(place);
tmp_allocator->SetCallback([stream]() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
PADDLE_ENFORCE(cudaGetLastError());
});
device_allocator_[place_stream].reset(tmp_allocator);
return *tmp_allocator;
} else {
return *it->second;
}
}
template <>
platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::CUDADeviceContext& dev_ctx) {
return Get(dev_ctx.GetPlace(), dev_ctx.stream());
}
#endif
template <>
platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::CPUDeviceContext& dev_ctx) {
return cpu_allocator_;
}
platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::Place& place) {
PADDLE_ENFORCE(platform::is_cpu_place(place), "You should pass CPUPlace");
return cpu_allocator_;
}
CPUDeviceContext::CPUDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
return eigen_device_.get();
}
Place CPUDeviceContext::GetPlace() const { return place_; }
#ifdef PADDLE_WITH_CUDA
class EigenCudaStreamDevice : public Eigen::StreamInterface {
public:
EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
Eigen::initializeDeviceProp();
}
~EigenCudaStreamDevice() override {}
void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
stream_ = cuda_stream;
place_ = place;
device_prop_ = &Eigen::m_deviceProperties[place.device];
}
const cudaStream_t& stream() const override { return *stream_; }
const cudaDeviceProp& deviceProperties() const override {
return *device_prop_;
}
void* allocate(size_t num_bytes) const override {
if (UNLIKELY(num_bytes == 0)) {
return nullptr;
}
auto buf = paddle::memory::Alloc(place_, num_bytes);
void* retv = buf->ptr();
{
std::lock_guard<std::mutex> lock(mtx_);
allocations_.emplace(retv, std::move(buf));
}
return retv;
}
void deallocate(void* buffer) const override {
if (LIKELY(buffer)) {
std::lock_guard<std::mutex> lock(mtx_);
allocations_.erase(buffer);
}
}
void* scratchpad() const override {
if (scratch_ == NULL) {
scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
}
return scratch_;
}
unsigned int* semaphore() const override {
if (semaphore_ == NULL) {
char* scratch =
static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
semaphore_ = reinterpret_cast<unsigned int*>(scratch);
PADDLE_ENFORCE(
cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
}
return semaphore_;
}
private:
CUDAPlace place_;
const cudaStream_t* stream_; // not owned;
const cudaDeviceProp* device_prop_; // not owned;
mutable void* scratch_;
mutable unsigned int* semaphore_;
mutable std::mutex mtx_; // to protect allocations_
mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
};
CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
: workspace_(nullptr), stream_(stream), place_(place) {
PADDLE_ENFORCE(cudaSetDevice(place_.device));
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
}
CudnnHolder::~CudnnHolder() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= WorkspaceSize()) {
return;
}
if (workspace_ != nullptr) {
// Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
workspace_.reset();
}
workspace_ = paddle::memory::Alloc(place_, required_workspace_len);
}
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
: place_(place), cudnn_holder_(nullptr) {
CUDADeviceGuard guard(place_.device);
compute_capability_ = GetCUDAComputeCapability(place_.device);
multi_process_ = GetCUDAMultiProcessors(place_.device);
max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH));
if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_.reset(
new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH));
#endif
}
driver_version_ = GetCUDADriverVersion(place_.device);
runtime_version_ = GetCUDARuntimeVersion(place_.device);
LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
<< ", CUDA Capability: " << compute_capability_
<< ", Driver API Version: " << driver_version_ / 1000
<< "." << (driver_version_ % 100) / 10
<< ", Runtime API Version: "
<< runtime_version_ / 1000 << "."
<< (runtime_version_ % 100) / 10;
size_t cudnn_dso_ver = dynload::cudnnGetVersion();
LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
<< ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
<< (cudnn_dso_ver % 1000) / 100 << ".";
{
// Check CUDA/CUDNN version compatiblity
auto local_cuda_version =
(driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
auto compile_cuda_version =
(CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
if (local_cuda_version < compile_cuda_version) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << place_.device
<< ". The installed Paddle is compiled with CUDA "
<< compile_cuda_version / 10 << "." << compile_cuda_version % 10
<< ", but CUDA runtime version in your machine is "
<< local_cuda_version / 10 << "." << local_cuda_version % 10
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible CUDA "
"version.";
}
if (dynload::HasCUDNN()) {
auto local_cudnn_version = cudnn_dso_ver / 100;
auto compile_cudnn_version = CUDNN_VERSION / 100;
if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << place_.device
<< ". The installed Paddle is compiled with CUDNN "
<< compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
<< ", but CUDNN version in your machine is "
<< local_cudnn_version / 10 << "." << local_cudnn_version % 10
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible CUDNN "
"version.";
}
}
}
callback_manager_.reset(new StreamCallbackManager(stream_));
}
CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device);
Wait();
WaitStreamCallback();
cublas_handle_.reset();
cublas_tensor_core_handle_.reset();
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
#if !defined(_WIN32)
if (nccl_comm_) {
PADDLE_ENFORCE(dynload::ncclCommDestroy(nccl_comm_));
}
#endif
}
Place CUDADeviceContext::GetPlace() const { return place_; }
void CUDADeviceContext::Wait() const {
auto& allocator =
DeviceTemporaryAllocator::Instance().Get<CUDADeviceContext>(*this);
allocator.Release([this]() {
cudaError_t e_sync = cudaStreamSynchronize(stream_);
if (e_sync != 0) {
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync)
<< " errno:" << e_sync;
}
cudaError_t e_get = cudaGetLastError();
if (e_get != 0) {
LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get)
<< " errno:" << e_get;
}
});
}
int CUDADeviceContext::GetComputeCapability() const {
return compute_capability_;
}
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
return multi_process_ * max_threads_per_mp_;
}
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get();
}
bool CUDADeviceContext::tensor_core_available() const {
return cublas_tensor_core_handle_ != nullptr;
}
CudnnHolder* CUDADeviceContext::cudnn_holder() const {
std::call_once(init_cudnn_, [&]() {
if (dynload::HasCUDNN()) {
cudnn_holder_.reset(new CudnnHolder(&stream_, place_));
}
});
return cudnn_holder_.get();
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return cudnn_holder()->cudnn_handle();
}
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(cudnn_holder());
}
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
: place_(place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
return eigen_device_.get();
}
Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
#endif
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
: CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobmap_() {
p_blobmap_.reset(new BlobMap());
p_mutex_.reset(new std::mutex());
}
namespace {
// Current mkldnn session id.
thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default;
// Current data input shape string.
// - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific.
thread_local std::string cur_input_shape_str = "";
// the cache capacity of different input shapes for MKLDNN.
// Default 1 means fixed input shape, not dynamic shape.
thread_local int cur_input_shape_cache_capacity = 1;
} // namespace
void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; }
void set_cur_input_shape_str(std::string input_shape_str) {
cur_input_shape_str = input_shape_str;
}
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) {
cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); }
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
BlobMap* pMap = p_blobmap_.get();
auto map_it = pMap->find(cur_mkldnn_session_id);
if (map_it == pMap->end()) {
LOG(FATAL) << "MKLDNNDeviceContext don't find cur_mkldnn_session_id : "
<< cur_mkldnn_session_id;
}
return map_it->second->size();
}
void MKLDNNDeviceContext::SetBlob(const std::string& name,
std::shared_ptr<void> data) const {
BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<ShapeBlob> sBlob = nullptr;
std::shared_ptr<KeyBlob> pBlob = nullptr;
int sid = platform::get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_);
// Find ShapeBlob for current mkldnn session id.
auto map_it = pMap->find(sid);
if (map_it == pMap->end()) {
// 1st time to set blob in current thread
sBlob = std::shared_ptr<ShapeBlob>(new ShapeBlob());
(*pMap)[sid] = sBlob;
VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
} else {
sBlob = map_it->second;
}
// Find KeyBlob for current input shape
auto key_it = sBlob->find(cur_input_shape_str);
if (key_it == sBlob->end()) {
// In cache clearing mode, cur_input_shape_cache_capacity defines
// max pblob capacity
if ((sid == kMKLDNNSessionID_CacheClearing) &&
(sBlob->size() ==
static_cast<size_t>(cur_input_shape_cache_capacity))) {
VLOG(2) << "sid=" << sid
<< ", remove all blobs of shape: " << sBlob->begin()->first;
sBlob->erase(sBlob->begin()->first);
}
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
(*sBlob)[cur_input_shape_str] = pBlob;
} else {
pBlob = key_it->second;
}
// Find Blob via name
auto blob_it = pBlob->find(name);
if (blob_it == pBlob->end()) {
(*pBlob)[name] = data;
} else {
blob_it->second = data; // set data to existing blob
}
VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
// lock will be automatically released when out of scope
return;
}
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
const std::string& name) const {
BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<ShapeBlob> sBlob = nullptr;
std::shared_ptr<KeyBlob> pBlob = nullptr;
int sid = platform::get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_);
// Find ShapeBlob for current mkldnn session id firstly
auto map_it = pMap->find(sid);
if (map_it == pMap->end()) {
VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
return nullptr;
}
sBlob = map_it->second;
// Find KeyBlob for current input shape secondly
auto sBlob_it = sBlob->find(cur_input_shape_str);
if (sBlob_it == sBlob->end()) {
VLOG(2) << "GetBlob: sid=" << cur_input_shape_str
<< ", miss input_shape_str\n";
return nullptr;
}
pBlob = sBlob_it->second;
// Find Blob via name
auto key_it = pBlob->find(name);
if (key_it == pBlob->end()) {
VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
return nullptr;
}
VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
// lock will be automatically released when out of scope
return key_it->second;
}
#endif
} // namespace platform
} // namespace paddle