Skip to content

Commit 7cf3538

Browse files
committed
Merge pull request #4159 from flx42/cudnn_v5_support
Add cuDNN v5 support, drop cuDNN v3 support
2 parents bb0c1a5 + 1c3af70 commit 7cf3538

File tree

14 files changed

+97
-23
lines changed

14 files changed

+97
-23
lines changed

docker/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ docker_files: standalone_files
2222

2323
standalone_files: standalone/cpu/Dockerfile standalone/gpu/Dockerfile
2424

25-
FROM_GPU = "nvidia/cuda:7.5-cudnn4-devel-ubuntu14.04"
25+
FROM_GPU = "nvidia/cuda:7.5-cudnn5-devel-ubuntu14.04"
2626
FROM_CPU = "ubuntu:14.04"
2727
GPU_CMAKE_ARGS = -DUSE_CUDNN=1
2828
CPU_CMAKE_ARGS = -DCPU_ONLY=1

docker/standalone/gpu/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM nvidia/cuda:7.5-cudnn4-devel-ubuntu14.04
1+
FROM nvidia/cuda:7.5-cudnn5-devel-ubuntu14.04
22
33

44
RUN apt-get update && apt-get install -y --no-install-recommends \

docs/installation.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@ Optional dependencies:
4040

4141
* [OpenCV](http://opencv.org/) >= 2.4 including 3.0
4242
* IO libraries: `lmdb`, `leveldb` (note: leveldb requires `snappy`)
43-
* cuDNN for GPU acceleration (v4)
43+
* cuDNN for GPU acceleration (v5)
4444

4545
Pycaffe and Matcaffe interfaces have their own natural needs.
4646

4747
* For Python Caffe: `Python 2.7` or `Python 3.3+`, `numpy (>= 1.7)`, boost-provided `boost.python`
4848
* For MATLAB Caffe: MATLAB with the `mex` compiler.
4949

50-
**cuDNN Caffe**: for fastest operation Caffe is accelerated by drop-in integration of [NVIDIA cuDNN](https://developer.nvidia.com/cudnn). To speed up your Caffe models, install cuDNN then uncomment the `USE_CUDNN := 1` flag in `Makefile.config` when installing Caffe. Acceleration is automatic. The current version is cuDNN v4; older versions are supported in older Caffe.
50+
**cuDNN Caffe**: for fastest operation Caffe is accelerated by drop-in integration of [NVIDIA cuDNN](https://developer.nvidia.com/cudnn). To speed up your Caffe models, install cuDNN then uncomment the `USE_CUDNN := 1` flag in `Makefile.config` when installing Caffe. Acceleration is automatic. The current version is cuDNN v5; older versions are supported in older Caffe.
5151

5252
**CPU-only Caffe**: for cold-brewed CPU-only Caffe uncomment the `CPU_ONLY := 1` flag in `Makefile.config` to configure and build Caffe without CUDA. This is helpful for cloud or cluster deployment.
5353

include/caffe/layers/cudnn_relu_layer.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class CuDNNReLULayer : public ReLULayer<Dtype> {
3737
cudnnHandle_t handle_;
3838
cudnnTensorDescriptor_t bottom_desc_;
3939
cudnnTensorDescriptor_t top_desc_;
40+
cudnnActivationDescriptor_t activ_desc_;
4041
};
4142
#endif
4243

include/caffe/layers/cudnn_sigmoid_layer.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class CuDNNSigmoidLayer : public SigmoidLayer<Dtype> {
3737
cudnnHandle_t handle_;
3838
cudnnTensorDescriptor_t bottom_desc_;
3939
cudnnTensorDescriptor_t top_desc_;
40+
cudnnActivationDescriptor_t activ_desc_;
4041
};
4142
#endif
4243

include/caffe/layers/cudnn_tanh_layer.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class CuDNNTanHLayer : public TanHLayer<Dtype> {
3737
cudnnHandle_t handle_;
3838
cudnnTensorDescriptor_t bottom_desc_;
3939
cudnnTensorDescriptor_t top_desc_;
40+
cudnnActivationDescriptor_t activ_desc_;
4041
};
4142
#endif
4243

include/caffe/util/cudnn.hpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,13 @@ template <typename Dtype>
9191
inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
9292
int n, int c, int h, int w) {
9393
CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
94+
#if CUDNN_VERSION_MIN(5, 0, 0)
9495
CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType<Dtype>::type,
95-
n, c, h, w));
96+
CUDNN_TENSOR_NCHW, n, c, h, w));
97+
#else
98+
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(*desc, dataType<Dtype>::type,
99+
CUDNN_TENSOR_NCHW, n, c, h, w));
100+
#endif
96101
}
97102

98103
template <typename Dtype>
@@ -123,8 +128,21 @@ inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
123128
LOG(FATAL) << "Unknown pooling method.";
124129
}
125130
CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
126-
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode, h, w,
127-
pad_h, pad_w, stride_h, stride_w));
131+
#if CUDNN_VERSION_MIN(5, 0, 0)
132+
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode,
133+
CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
134+
#else
135+
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(*pool_desc, *mode,
136+
CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
137+
#endif
138+
}
139+
140+
template <typename Dtype>
141+
inline void createActivationDescriptor(cudnnActivationDescriptor_t* activ_desc,
142+
cudnnActivationMode_t mode) {
143+
CUDNN_CHECK(cudnnCreateActivationDescriptor(activ_desc));
144+
CUDNN_CHECK(cudnnSetActivationDescriptor(*activ_desc, mode,
145+
CUDNN_PROPAGATE_NAN, Dtype(0)));
128146
}
129147

130148
} // namespace cudnn

src/caffe/layers/cudnn_conv_layer.cu

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,11 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
3030
// Bias.
3131
if (this->bias_term_) {
3232
const Dtype* bias_data = this->blobs_[1]->gpu_data();
33-
#if CUDNN_VERSION_MIN(4, 0, 0)
3433
CUDNN_CHECK(cudnnAddTensor(handle_[g],
3534
cudnn::dataType<Dtype>::one,
3635
bias_desc_, bias_data + bias_offset_ * g,
3736
cudnn::dataType<Dtype>::one,
3837
top_descs_[i], top_data + top_offset_ * g));
39-
#else
40-
CUDNN_CHECK(cudnnAddTensor(handle_[g], CUDNN_ADD_SAME_C,
41-
cudnn::dataType<Dtype>::one,
42-
bias_desc_, bias_data + bias_offset_ * g,
43-
cudnn::dataType<Dtype>::one,
44-
top_descs_[i], top_data + top_offset_ * g));
45-
#endif
4638
}
4739
}
4840

@@ -82,7 +74,7 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
8274
// Gradient w.r.t. weights.
8375
if (this->param_propagate_down_[0]) {
8476
const Dtype* bottom_data = bottom[i]->gpu_data();
85-
CUDNN_CHECK(cudnnConvolutionBackwardFilter_v3(
77+
CUDNN_CHECK(cudnnConvolutionBackwardFilter(
8678
handle_[1*this->group_ + g],
8779
cudnn::dataType<Dtype>::one,
8880
bottom_descs_[i], bottom_data + bottom_offset_ * g,
@@ -100,7 +92,7 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
10092
weight = this->blobs_[0]->gpu_data();
10193
}
10294
Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
103-
CUDNN_CHECK(cudnnConvolutionBackwardData_v3(
95+
CUDNN_CHECK(cudnnConvolutionBackwardData(
10496
handle_[2*this->group_ + g],
10597
cudnn::dataType<Dtype>::one,
10698
filter_desc_, weight + this->weight_offset_ * g,

src/caffe/layers/cudnn_relu_layer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ void CuDNNReLULayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
1313
CUDNN_CHECK(cudnnCreate(&handle_));
1414
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
1515
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
16+
cudnn::createActivationDescriptor<Dtype>(&activ_desc_, CUDNN_ACTIVATION_RELU);
1617
handles_setup_ = true;
1718
}
1819

src/caffe/layers/cudnn_relu_layer.cu

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,21 @@ void CuDNNReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
1515

1616
const Dtype* bottom_data = bottom[0]->gpu_data();
1717
Dtype* top_data = top[0]->mutable_gpu_data();
18+
#if CUDNN_VERSION_MIN(5, 0, 0)
1819
CUDNN_CHECK(cudnnActivationForward(this->handle_,
19-
CUDNN_ACTIVATION_RELU,
20+
activ_desc_,
2021
cudnn::dataType<Dtype>::one,
2122
this->bottom_desc_, bottom_data,
2223
cudnn::dataType<Dtype>::zero,
2324
this->top_desc_, top_data));
25+
#else
26+
CUDNN_CHECK(cudnnActivationForward_v4(this->handle_,
27+
activ_desc_,
28+
cudnn::dataType<Dtype>::one,
29+
this->bottom_desc_, bottom_data,
30+
cudnn::dataType<Dtype>::zero,
31+
this->top_desc_, top_data));
32+
#endif
2433
}
2534

2635
template <typename Dtype>
@@ -40,13 +49,23 @@ void CuDNNReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
4049
const Dtype* top_diff = top[0]->gpu_diff();
4150
const Dtype* bottom_data = bottom[0]->gpu_data();
4251
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
52+
#if CUDNN_VERSION_MIN(5, 0, 0)
4353
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
44-
CUDNN_ACTIVATION_RELU,
54+
activ_desc_,
4555
cudnn::dataType<Dtype>::one,
4656
this->top_desc_, top_data, this->top_desc_, top_diff,
4757
this->bottom_desc_, bottom_data,
4858
cudnn::dataType<Dtype>::zero,
4959
this->bottom_desc_, bottom_diff));
60+
#else
61+
CUDNN_CHECK(cudnnActivationBackward_v4(this->handle_,
62+
activ_desc_,
63+
cudnn::dataType<Dtype>::one,
64+
this->top_desc_, top_data, this->top_desc_, top_diff,
65+
this->bottom_desc_, bottom_data,
66+
cudnn::dataType<Dtype>::zero,
67+
this->bottom_desc_, bottom_diff));
68+
#endif
5069
}
5170

5271
INSTANTIATE_LAYER_GPU_FUNCS(CuDNNReLULayer);

0 commit comments

Comments
 (0)