Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit ee5f699

Browse files
DickJC123eric-haibin-lin
authored andcommitted
CudnnFind() usage improvements (#12804)
* Add mx.context.gpu_memory_info() to python api for flexible tests. * Add test_gluon_gpu.py:test_large_models to show cudnnFind headroom issue. * Output model sizes tried by test_gluon_gpu.py:test_large_models. * Fix perl interface to MXGetGPUMemoryInformation. * Increase difficulty of test_gluon_gpu.py:test_large_models. * Forgot a file in fix for perl. * Modify test to pass on no-cudnn CI runner. * Mutex algo reg updates, serialize cudnnFind calls. * Fix for cudnnFind memory headroom issue. * Fix cpplint. * Respond to reviewers comments. * Guard against improper MXNET_GPU_MEM_LARGE_ALLOC_ROUND_SIZE values. * Fix potentially unassigned var.
1 parent fef9b5c commit ee5f699

File tree

12 files changed

+707
-491
lines changed

12 files changed

+707
-491
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,4 @@ List of Contributors
187187
* [LuckyPigeon](https://github.com/LuckyPigeon)
188188
* [Anton Chernov](https://github.com/lebeg)
189189
* [Denisa Roberts](https://github.com/D-Roberts)
190+
* [Dick Carter](https://github.com/DickJC123)

docs/faq/env_var.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
6767
* MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF
6868
- Values: Int ```(default=24)```
6969
- The cutoff threshold that decides the rounding strategy. Let's denote the threshold as T. If the memory size is smaller than `2 ** T` (by default, it's 2 ** 24 = 16MB), it rounds to the smallest `2 ** n` that is larger than the requested memory size; if the memory size is larger than `2 ** T`, it rounds to the next k * 2 ** T.
70+
* MXNET_GPU_MEM_LARGE_ALLOC_ROUND_SIZE
71+
- Values: Int ```(default=2097152)```
72+
- When using the naive pool type, memory allocations larger than this threshhold are rounded up to a multiple of this value.
73+
- The default was chosen to minimize global memory fragmentation within the GPU driver. Set this to 1 to disable.
7074

7175
## Engine Type
7276

include/mxnet/base.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,11 @@ struct Context {
225225
/*!
226226
* \brief get the free and total available memory on a GPU
227227
* \param dev the GPU number to query
228-
* \param free_mem pointer to the integer holding free GPU memory
229-
* \param total_mem pointer to the integer holding total GPU memory
228+
* \param free_mem pointer to the uint64_t holding free GPU memory
229+
* \param total_mem pointer to the uint64_t holding total GPU memory
230230
* \return No return value
231231
*/
232-
inline static void GetGPUMemoryInformation(int dev, int *free, int *total);
232+
inline static void GetGPUMemoryInformation(int dev, uint64_t *free, uint64_t *total);
233233
/*!
234234
* Create a pinned CPU context.
235235
* \param dev_id the device id for corresponding GPU.
@@ -334,8 +334,8 @@ inline int32_t Context::GetGPUCount() {
334334
#endif
335335
}
336336

337-
inline void Context::GetGPUMemoryInformation(int dev, int *free_mem,
338-
int *total_mem) {
337+
inline void Context::GetGPUMemoryInformation(int dev, uint64_t *free_mem,
338+
uint64_t *total_mem) {
339339
#if MXNET_USE_CUDA
340340

341341
size_t memF, memT;
@@ -354,8 +354,8 @@ inline void Context::GetGPUMemoryInformation(int dev, int *free_mem,
354354
e = cudaSetDevice(curDevice);
355355
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);
356356

357-
*free_mem = static_cast<int>(memF);
358-
*total_mem = static_cast<int>(memT);
357+
*free_mem = static_cast<uint64_t>(memF);
358+
*total_mem = static_cast<uint64_t>(memT);
359359

360360
#else
361361
LOG(FATAL)

include/mxnet/c_api.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,13 +441,23 @@ MXNET_DLL int MXGetGPUCount(int* out);
441441

442442
/*!
443443
* \brief get the free and total available memory on a GPU
444+
* Note: Deprecated, use MXGetGPUMemoryInformation64 instead.
444445
* \param dev the GPU number to query
445446
* \param free_mem pointer to the integer holding free GPU memory
446447
* \param total_mem pointer to the integer holding total GPU memory
447448
* \return 0 when success, -1 when failure happens
448449
*/
449450
MXNET_DLL int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem);
450451

452+
/*!
453+
* \brief get the free and total available memory on a GPU
454+
* \param dev the GPU number to query
455+
* \param free_mem pointer to the uint64_t holding free GPU memory
456+
* \param total_mem pointer to the uint64_t holding total GPU memory
457+
* \return 0 when success, -1 when failure happens
458+
*/
459+
MXNET_DLL int MXGetGPUMemoryInformation64(int dev, uint64_t *free_mem, uint64_t *total_mem);
460+
451461
/*!
452462
* \brief get the MXNet library version as an integer
453463
* \param pointer to the integer holding the version number

perl-package/AI-MXNetCAPI/mxnet.i

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,13 +344,23 @@ int MXGetGPUCount(int* out);
344344

345345
/*!
346346
* \brief get the free and total available memory on a GPU
347+
* Note: deprecated, use MXGetGPUMemoryInformation64().
347348
* \param dev the GPU number to query
348349
* \param free_mem pointer to the integer holding free GPU memory
349350
* \param total_mem pointer to the integer holding total GPU memory
350351
* \return 0 when success, -1 when failure happens
351352
*/
352353
int MXGetGPUMemoryInformation(int dev, int *out, int *out);
353354

355+
/*!
356+
* \brief get the free and total available memory on a GPU
357+
* \param dev the GPU number to query
358+
* \param free_mem pointer to the uint64_t holding free GPU memory
359+
* \param total_mem pointer to the uint64_t holding total GPU memory
360+
* \return 0 when success, -1 when failure happens
361+
*/
362+
int MXGetGPUMemoryInformation64(int dev, uint64_t *out, uint64_t *out);
363+
354364

355365
//-------------------------------------
356366
// Part 1: NDArray creation and deletion

python/mxnet/context.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,30 @@ def num_gpus():
258258
check_call(_LIB.MXGetGPUCount(ctypes.byref(count)))
259259
return count.value
260260

261+
def gpu_memory_info(device_id=0):
262+
"""Query CUDA for the free and total bytes of GPU global memory.
263+
264+
Parameters
265+
----------
266+
device_id : int, optional
267+
The device id of the GPU device.
268+
269+
Raises
270+
------
271+
Will raise an exception on any CUDA error.
272+
273+
Returns
274+
-------
275+
(free, total) : (int, int)
276+
The number of GPUs.
277+
278+
"""
279+
free = ctypes.c_uint64()
280+
total = ctypes.c_uint64()
281+
dev_id = ctypes.c_int(device_id)
282+
check_call(_LIB.MXGetGPUMemoryInformation64(dev_id, ctypes.byref(free), ctypes.byref(total)))
283+
return (free.value, total.value)
284+
261285
def current_context():
262286
"""Returns the current context.
263287

src/c_api/c_api.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,18 @@ int MXGetGPUCount(int* out) {
122122
API_END();
123123
}
124124

125+
// Deprecated: use MXGetGPUMemoryInformation64() instead.
125126
int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem) {
127+
API_BEGIN();
128+
uint64_t free_mem64 = 0UL;
129+
uint64_t total_mem64 = 0UL;
130+
Context::GetGPUMemoryInformation(dev, &free_mem64, &total_mem64);
131+
*free_mem = static_cast<int>(free_mem64);
132+
*total_mem = static_cast<int>(total_mem64);
133+
API_END();
134+
}
135+
136+
int MXGetGPUMemoryInformation64(int dev, uint64_t *free_mem, uint64_t *total_mem) {
126137
API_BEGIN();
127138
Context::GetGPUMemoryInformation(dev, free_mem, total_mem);
128139
API_END();

src/operator/nn/cudnn/cudnn_algoreg-inl.h

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
#include <mutex>
3131
#include <string>
3232
#include <vector>
33+
#include <functional>
34+
#include <utility>
3335
#include "../../../common/cuda_utils.h"
3436
#include "../convolution-inl.h"
3537
#include "../deconvolution-inl.h"
@@ -65,7 +67,11 @@ class CuDNNAlgo {
6567
template<typename ParamType>
6668
class CuDNNAlgoReg {
6769
public:
68-
bool Find(const ParamType &param,
70+
using AlgoSetter_t = std::function<void(CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *,
71+
CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *,
72+
CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *)>;
73+
74+
void FindOrElseRegister(const ParamType &param,
6975
const std::vector<TShape> &in_shape,
7076
const std::vector<TShape> &out_shape,
7177
cudnnDataType_t cudnn_data_type,
@@ -75,7 +81,8 @@ class CuDNNAlgoReg {
7581
bool add_to_weight,
7682
CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *fwd,
7783
CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *bwd,
78-
CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt) {
84+
CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt,
85+
const AlgoSetter_t &algo_setter) {
7986
CHECK(in_shape.size() == 2 || in_shape.size() == 3);
8087
ParamKey key{param, in_shape[0], in_shape[1], out_shape[0], cudnn_data_type,
8188
cudnn_forward_compute_type, cudnn_backward_compute_type, sm_arch, add_to_weight};
@@ -85,45 +92,28 @@ class CuDNNAlgoReg {
8592
*fwd = i->second.fwd;
8693
*bwd = i->second.bwd;
8794
*flt = i->second.flt;
88-
return true;
89-
}
90-
return false;
91-
}
92-
93-
void Register(const ParamType &param,
94-
const std::vector<TShape> &in_shape,
95-
const std::vector<TShape> &out_shape,
96-
cudnnDataType_t cudnn_data_type,
97-
cudnnDataType_t cudnn_forward_compute_type,
98-
cudnnDataType_t cudnn_backward_compute_type,
99-
int sm_arch,
100-
bool add_to_weight,
101-
const CuDNNAlgo<cudnnConvolutionFwdAlgo_t> &fwd,
102-
const CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> &bwd,
103-
const CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> &flt) {
104-
CHECK(in_shape.size() == 2 || in_shape.size() == 3);
105-
ParamKey key{param, in_shape[0], in_shape[1], out_shape[0], cudnn_data_type,
106-
cudnn_forward_compute_type, cudnn_backward_compute_type, sm_arch, add_to_weight};
107-
std::lock_guard<std::mutex> guard(lock_);
108-
if (param.cudnn_tune.value() && reg_.size() % 50 == 0) {
109-
LOG(INFO) << "Running performance tests to find the best convolution "
110-
"algorithm, "
111-
"this can take a while... (setting env variable "
112-
"MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)";
113-
if (reg_.size() >= 1000) {
114-
// Many people are very concerned about this warning, so change the warning once.
115-
if (!is_warning_autotune_) {
116-
LOG(INFO)
117-
<< "If you see this message in the middle of training, you are "
118-
"probably using bucketing. Consider setting env variable "
119-
"MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable cudnn tuning.";
120-
is_warning_autotune_ = true;
95+
} else {
96+
if (param.cudnn_tune.value() && reg_.size() % 50 == 0) {
97+
LOG(INFO) << "Running performance tests to find the best convolution "
98+
"algorithm, "
99+
"this can take a while... (setting env variable "
100+
"MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)";
101+
if (reg_.size() >= 1000) {
102+
// Many people are very concerned about this warning, so change the warning once.
103+
if (!is_warning_autotune_) {
104+
LOG(INFO)
105+
<< "If you see this message in the middle of training, you are "
106+
"probably using bucketing. Consider setting env variable "
107+
"MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable cudnn tuning.";
108+
is_warning_autotune_ = true;
109+
}
121110
}
122111
}
112+
// Call provided function to determine the algos- likely uses cudnnFind() or cudnnGet()
113+
algo_setter(fwd, bwd, flt);
114+
// Save result so future lookups hit in this registry
115+
reg_.insert(std::pair<ParamKey, CudnnAlgorithms>(key, CudnnAlgorithms{*fwd, *bwd, *flt}));
123116
}
124-
reg_[key].fwd = fwd;
125-
reg_[key].bwd = bwd;
126-
reg_[key].flt = flt;
127117
}
128118

129119
static CuDNNAlgoReg *Get();

0 commit comments

Comments
 (0)