@@ -208,7 +208,6 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
208208 kNUM_CUDNN_FWD_ALGS , &find_count, &find_result,
209209 cudnn_workspace_ptr, workspace_size, false ));
210210 };
211- // if (!exhaustive_search && !deterministic) {
212211 workspace_handle.RunFuncSync (cudnn_find_func, workspace_size);
213212 algo = find_result.fwd_algo ;
214213 VLOG (3 ) << " cuDNN forward algo " << algo;
@@ -244,15 +243,16 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
244243 PADDLE_ENFORCE_CUDA_SUCCESS (
245244 platform::dynload::cudnnSetConvolutionGroupCount (cudnn_conv_desc,
246245 groups));
247- // Now only support NCHW
248- std::vector<int > bias_dim = {
249- 1 , static_cast <int >(transformed_output.dims ()[1 ]), 1 , 1 };
246+
250247 cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor <T>(
251248 layout, framework::vectorize<int >(transformed_input.dims ()));
252249 cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor <T>(
253250 layout, framework::vectorize<int >(transformed_output.dims ()));
254251 cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor <T>(
255252 layout, framework::vectorize<int >(filter->dims ()));
253+ // Now only support NCHW
254+ std::vector<int > bias_dim = {
255+ 1 , static_cast <int >(transformed_output.dims ()[1 ]), 1 , 1 };
256256 cudnnTensorDescriptor_t cudnn_bias_desc =
257257 bias_desc.descriptor <T>(layout, bias_dim);
258258 cudnnActivationDescriptor_t cudnn_act_desc =
@@ -430,6 +430,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
430430 }
431431};
432432#endif
433+
433434} // namespace operators
434435} // namespace paddle
435436
0 commit comments