Skip to content

Commit 5c4a024

Browse files
committed
fix workspace limit in cudnn-8
1 parent c52fe48 commit 5c4a024

File tree

1 file changed

+44
-5
lines changed

1 file changed

+44
-5
lines changed

paddle/fluid/operators/conv_cudnn_helper.h

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919
#include <memory>
2020
#include <string>
2121
#include <vector>
22+
2223
#include "paddle/fluid/framework/conv_search_cache.h"
2324
#include "paddle/fluid/framework/operator_kernel_configs.h"
2425
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
@@ -101,6 +102,24 @@ inline int MaxBwdFilterAlgos(cudnnHandle_t cudnn_handle) {
101102
return max_algos;
102103
}
103104

105+
template <typename PerfType, typename AlgoType>
106+
void ChooseAlgoByWorkspace(PerfType* perf_results, size_t perf_num,
107+
size_t workspace_byte, AlgoType* algo) {
108+
for (size_t i = 0; i < perf_num; ++i) {
109+
auto result = perf_results[i];
110+
if (result.status == CUDNN_STATUS_SUCCESS &&
111+
result.memory < workspace_byte) {
112+
*algo = result.algo;
113+
VLOG(3) << " algo: " << result.algo << ", time: " << result.time
114+
<< " ms, wksp = " << result.memory
115+
<< ", status = " << result.status;
116+
return;
117+
}
118+
}
119+
VLOG(3) << "Can not find alog that requires memory < "
120+
<< static_cast<double>(workspace_byte) / (1 << 20) << " MB";
121+
}
122+
104123
template <typename PerfType, typename AlgoType>
105124
void ChooseAlgo(const std::vector<PerfType>& perf_results,
106125
size_t workspace_byte, AlgoType* algo) {
@@ -219,7 +238,10 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
219238

220239
if (workspace_size > workspace_size_limit) {
221240
#if CUDNN_VERSION >= 8000
222-
workspace_size_limit = workspace_size;
241+
// cudnnGetConvolutionForwardAlgorithm is removed in CUDNN-8
242+
ChooseAlgoByWorkspace<perf_t, algo_t>(perf_results.get(),
243+
kNUM_CUDNN_FWD_ALGS,
244+
workspace_size_limit, &algo);
223245
#else
224246
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
225247
"the workspace size request("
@@ -316,7 +338,6 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
316338
size_t workspace_size = 0;
317339
bool has_got_workspace_size = true;
318340
algo_t algo;
319-
320341
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
321342
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
322343
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
@@ -362,9 +383,10 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
362383
if (workspace_size > workspace_size_limit) {
363384
has_got_workspace_size = false;
364385
#if CUDNN_VERSION >= 8000
365-
// There is no cudnnGetConvolutionBackwardDataAlgorithm in CUDNN 8
366-
// version.
367-
workspace_size_limit = workspace_size;
386+
// cudnnGetConvolutionBackwardDataAlgorithm is removed in CUDNN-8
387+
ChooseAlgoByWorkspace<perf_t, algo_t>(perf_results.get(),
388+
kNUM_CUDNN_BWD_DATA_ALGS,
389+
workspace_size_limit, &algo);
368390
#else
369391
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
370392
"the workspace size request("
@@ -493,6 +515,23 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
493515
workspace_size = GetWorkspaceSize(args, algo);
494516
if (workspace_size > workspace_size_limit) {
495517
workspace_size = workspace_size_limit;
518+
#if CUDNN_VERSION >= 8000
519+
// cudnnGetConvolutionBackwardFilterAlgorithm is removed in CUDNN-8
520+
ChooseAlgoByWorkspace<perf_t, algo_t>(perf_results.get(),
521+
kNUM_CUDNN_BWD_FILTER_ALGS,
522+
workspace_size_limit, &algo);
523+
#else
524+
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
525+
"the workspace size request("
526+
<< workspace_size << ") exceeds the limit("
527+
<< workspace_size_limit << ")";
528+
PADDLE_ENFORCE_CUDA_SUCCESS(
529+
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
530+
args.handle, args.idesc.desc(), args.odesc.desc(),
531+
args.cdesc.desc(), args.wdesc.desc(),
532+
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
533+
workspace_size_limit, &algo));
534+
#endif
496535
}
497536
#else
498537
PADDLE_ENFORCE_CUDA_SUCCESS(

0 commit comments

Comments
 (0)