@@ -19,6 +19,7 @@ limitations under the License. */
1919#include < memory>
2020#include < string>
2121#include < vector>
22+
2223#include " paddle/fluid/framework/conv_search_cache.h"
2324#include " paddle/fluid/framework/operator_kernel_configs.h"
2425#include " paddle/fluid/operators/conv_cudnn_op_cache.h"
@@ -101,6 +102,24 @@ inline int MaxBwdFilterAlgos(cudnnHandle_t cudnn_handle) {
101102 return max_algos;
102103}
103104
105+ template <typename PerfType, typename AlgoType>
106+ void ChooseAlgoByWorkspace (PerfType* perf_results, size_t perf_num,
107+ size_t workspace_byte, AlgoType* algo) {
108+ for (size_t i = 0 ; i < perf_num; ++i) {
109+ auto result = perf_results[i];
110+ if (result.status == CUDNN_STATUS_SUCCESS &&
111+ result.memory < workspace_byte) {
112+ *algo = result.algo ;
113+ VLOG (3 ) << " algo: " << result.algo << " , time: " << result.time
114+ << " ms, wksp = " << result.memory
115+ << " , status = " << result.status ;
116+ return ;
117+ }
118+ }
119+ VLOG (3 ) << " Can not find alog that requires memory < "
120+ << static_cast <double >(workspace_byte) / (1 << 20 ) << " MB" ;
121+ }
122+
104123template <typename PerfType, typename AlgoType>
105124void ChooseAlgo (const std::vector<PerfType>& perf_results,
106125 size_t workspace_byte, AlgoType* algo) {
@@ -219,7 +238,10 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
219238
220239 if (workspace_size > workspace_size_limit) {
221240#if CUDNN_VERSION >= 8000
222- workspace_size_limit = workspace_size;
241+ // cudnnGetConvolutionForwardAlgorithm is removed in CUDNN-8
242+ ChooseAlgoByWorkspace<perf_t , algo_t >(perf_results.get (),
243+ kNUM_CUDNN_FWD_ALGS ,
244+ workspace_size_limit, &algo);
223245#else
224246 VLOG (1 ) << " Fallback to non-v7 method to find conv algorithm becasue "
225247 " the workspace size request("
@@ -316,7 +338,6 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
316338 size_t workspace_size = 0 ;
317339 bool has_got_workspace_size = true ;
318340 algo_t algo;
319-
320341#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
321342 auto & dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
322343 if (dev_ctx.GetComputeCapability () >= 70 && dtype == CUDNN_DATA_HALF) {
@@ -362,9 +383,10 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
362383 if (workspace_size > workspace_size_limit) {
363384 has_got_workspace_size = false ;
364385#if CUDNN_VERSION >= 8000
365- // There is no cudnnGetConvolutionBackwardDataAlgorithm in CUDNN 8
366- // version.
367- workspace_size_limit = workspace_size;
386+ // cudnnGetConvolutionBackwardDataAlgorithm is removed in CUDNN-8
387+ ChooseAlgoByWorkspace<perf_t , algo_t >(perf_results.get (),
388+ kNUM_CUDNN_BWD_DATA_ALGS ,
389+ workspace_size_limit, &algo);
368390#else
369391 VLOG (1 ) << " Fallback to non-v7 method to find conv algorithm becasue "
370392 " the workspace size request("
@@ -493,6 +515,23 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
493515 workspace_size = GetWorkspaceSize (args, algo);
494516 if (workspace_size > workspace_size_limit) {
495517 workspace_size = workspace_size_limit;
518+ #if CUDNN_VERSION >= 8000
519+ // cudnnGetConvolutionBackwardFilterAlgorithm is removed in CUDNN-8
520+ ChooseAlgoByWorkspace<perf_t , algo_t >(perf_results.get (),
521+ kNUM_CUDNN_BWD_FILTER_ALGS ,
522+ workspace_size_limit, &algo);
523+ #else
524+ VLOG (1 ) << " Fallback to non-v7 method to find conv algorithm becasue "
525+ " the workspace size request("
526+ << workspace_size << " ) exceeds the limit("
527+ << workspace_size_limit << " )" ;
528+ PADDLE_ENFORCE_CUDA_SUCCESS (
529+ platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm (
530+ args.handle , args.idesc .desc (), args.odesc .desc (),
531+ args.cdesc .desc (), args.wdesc .desc (),
532+ CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
533+ workspace_size_limit, &algo));
534+ #endif
496535 }
497536#else
498537 PADDLE_ENFORCE_CUDA_SUCCESS (
0 commit comments