@@ -128,15 +128,20 @@ inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
128128 int n, int c, int h, int w) {
129129 CUDNN_CHECK (cudnnCreateFilterDescriptor (desc));
130130 CUDNN_CHECK (cudnnSetFilter4dDescriptor (*desc, dataType<Dtype>::type,
131- n, c, h, w));
131+ n, c, h, w));
132132}
133133
134134template <typename Dtype>
135135inline void createNdFilterDesc (cudnnFilterDescriptor_t* desc,
136136 std::vector<int > shape) {
137137 CUDNN_CHECK (cudnnCreateFilterDescriptor (desc));
138+ #if CUDNN_VERSION_MIN(5, 0, 0)
138139 CUDNN_CHECK (cudnnSetFilterNdDescriptor (*desc, dataType<Dtype>::type,
139- shape.size (), shape.data ()));
140+ CUDNN_TENSOR_NCHW, shape.size (), shape.data ()));
141+ #else
142+ CUDNN_CHECK (cudnnSetFilterNdDescriptor (*desc, dataType<Dtype>::type,
143+ shape.size (), shape.data ()));
144+ #endif
140145}
141146
142147template <typename Dtype>
@@ -149,7 +154,7 @@ inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
149154 cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter,
150155 int pad_h, int pad_w, int stride_h, int stride_w) {
151156 CUDNN_CHECK (cudnnSetConvolution2dDescriptor (*conv,
152- pad_h, pad_w, stride_h, stride_w, 1 , 1 , CUDNN_CROSS_CORRELATION));
157+ pad_h, pad_w, stride_h, stride_w, 1 , 1 , CUDNN_CROSS_CORRELATION));
153158}
154159
155160template <typename Dtype>
@@ -159,16 +164,22 @@ inline void setNdConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
159164 int nbDims;
160165 std::vector<int > shape (pad.size () + 2 );
161166 cudnnDataType_t cudnn_type;
167+ #if CUDNN_VERSION_MIN(5, 0, 0)
168+ cudnnTensorFormat_t tensor_format;
169+ cudnnGetFilterNdDescriptor (filter,
170+ shape.size (), &cudnn_type, &tensor_format, &nbDims, shape.data ());
171+ #else
162172 cudnnGetFilterNdDescriptor (filter,
163173 shape.size (), &cudnn_type, &nbDims, shape.data ());
174+ #endif
164175 CHECK_EQ (nbDims, pad.size () + 2 )
165176 << " Dimensions of filters and pad don't match !" ;
166177 CHECK_EQ (nbDims, stride.size () + 2 )
167178 << " Dimensions of filters and stride don't match !" ;
168179 std::vector<int > upscale (pad.size (), 1 );
169180 CUDNN_CHECK (cudnnSetConvolutionNdDescriptor (*conv,
170- pad.size (), pad.data (), stride.data (), upscale.data (),
171- CUDNN_CROSS_CORRELATION, cudnn_type));
181+ pad.size (), pad.data (), stride.data (), upscale.data (),
182+ CUDNN_CROSS_CORRELATION, cudnn_type));
172183}
173184
174185template <typename Dtype>
@@ -186,8 +197,13 @@ inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
186197 LOG (FATAL) << " Unknown pooling method." ;
187198 }
188199 CUDNN_CHECK (cudnnCreatePoolingDescriptor (pool_desc));
200+ #if CUDNN_VERSION_MIN(5, 0, 0)
201+ CUDNN_CHECK (cudnnSetPooling2dDescriptor (*pool_desc, *mode,
202+ CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
203+ #else
189204 CUDNN_CHECK (cudnnSetPooling2dDescriptor (*pool_desc, *mode, h, w,
190205 pad_h, pad_w, stride_h, stride_w));
206+ #endif
191207}
192208
193209template <typename Dtype>
@@ -210,8 +226,14 @@ inline void createNdPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
210226 LOG (FATAL) << " Unknown pooling method." ;
211227 }
212228 CUDNN_CHECK (cudnnCreatePoolingDescriptor (pool_desc));
229+ #if CUDNN_VERSION_MIN(5, 0, 0)
230+ CUDNN_CHECK (cudnnSetPoolingNdDescriptor (*pool_desc, *mode,
231+ CUDNN_PROPAGATE_NAN, shape.size (), shape.data (), pad.data (),
232+ stride.data ()));
233+ #else
213234 CUDNN_CHECK (cudnnSetPoolingNdDescriptor (*pool_desc, *mode, shape.size (),
214235 shape.data (), pad.data (), stride.data ()));
236+ #endif
215237}
216238
217239} // namespace cudnn
0 commit comments