Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0ad1284
CMake: link with ${HDF5_HL_LIBRARIES}
intelfx Jul 25, 2016
c62e06b
Fix search for Atlas on arch.
Jul 26, 2016
bc1a433
add cudnn interfaces for n-dimensional computation
Feb 18, 2016
e5c13a5
add support for nd convolution in cudnn
Feb 18, 2016
cc357bd
change interface of pool to support n-dimensions
Feb 19, 2016
12cb24f
fix 2D pooling on CPU and GPU
Feb 19, 2016
c1b0f38
remove some calls of Blob::LegacyShape() to support 3D
May 23, 2016
721553e
fix xavier filler to use new blob shape accessors
Feb 19, 2016
b2f3848
fix tests for new pooling parameter interface
Apr 12, 2016
7173035
add 3D cudnn convolution tests
Apr 13, 2016
c9de153
add 3D cudnn pooling tests
Apr 14, 2016
eb93d32
fix CUDNN_BAD_PARAM when using InnerProduct layer
Apr 28, 2016
919b6d7
change interface for cudnn v5
May 23, 2016
9e9e9ba
Merge pull request #4523 from delftrobotics/cmake-atlas
longjon Aug 4, 2016
6431477
Merge pull request #4516 from intelfx/BVLC-work
longjon Aug 4, 2016
61e0165
num in blob is deprecated
fyu Aug 7, 2016
375003a
Merge pull request #4559 from fyu/loss_reshape
jeffdonahue Aug 7, 2016
f86a099
add cudnn interfaces for n-dimensional computation
Feb 18, 2016
4f63ea5
add support for nd convolution in cudnn
Feb 18, 2016
5e1f04e
change interface of pool to support n-dimensions
Feb 19, 2016
2346c5e
fix 2D pooling on CPU and GPU
Feb 19, 2016
0dcb68a
remove some calls of Blob::LegacyShape() to support 3D
May 23, 2016
fb0f9f5
fix xavier filler to use new blob shape accessors
Feb 19, 2016
b8ca687
fix tests for new pooling parameter interface
Apr 12, 2016
c88f8fa
add 3D cudnn convolution tests
Apr 13, 2016
d0efc10
add 3D cudnn pooling tests
Apr 14, 2016
45562a0
fix CUDNN_BAD_PARAM when using InnerProduct layer
Apr 28, 2016
b506327
change interface for cudnn v5
May 23, 2016
fc39d7e
remove some calls of Blob::LegacyShape() to support 3D
Sep 12, 2016
857f47d
fix msra filler to use new blob shape accessors
Sep 12, 2016
334e76f
fix positive_unitball filler to use new blob shape accessors
Sep 12, 2016
efda84c
Merge branch 'nd-cudnn' of github.com:christianpayer/caffe into nd-cudnn
Nov 2, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ include(cmake/ProtoBuf.cmake)
# ---[ HDF5
find_package(HDF5 COMPONENTS HL REQUIRED)
include_directories(SYSTEM ${HDF5_INCLUDE_DIRS} ${HDF5_HL_INCLUDE_DIR})
list(APPEND Caffe_LINKER_LIBS ${HDF5_LIBRARIES})
list(APPEND Caffe_LINKER_LIBS ${HDF5_LIBRARIES} ${HDF5_HL_LIBRARIES})

# ---[ LMDB
if(USE_LMDB)
Expand Down
6 changes: 3 additions & 3 deletions cmake/Modules/FindAtlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ set(Atlas_LIB_SEARCH_PATHS
find_path(Atlas_CBLAS_INCLUDE_DIR NAMES cblas.h PATHS ${Atlas_INCLUDE_SEARCH_PATHS})
find_path(Atlas_CLAPACK_INCLUDE_DIR NAMES clapack.h PATHS ${Atlas_INCLUDE_SEARCH_PATHS})

find_library(Atlas_CBLAS_LIBRARY NAMES ptcblas_r ptcblas cblas_r cblas PATHS ${Atlas_LIB_SEARCH_PATHS})
find_library(Atlas_BLAS_LIBRARY NAMES atlas_r atlas PATHS ${Atlas_LIB_SEARCH_PATHS})
find_library(Atlas_LAPACK_LIBRARY NAMES alapack_r alapack lapack_atlas PATHS ${Atlas_LIB_SEARCH_PATHS})
find_library(Atlas_CBLAS_LIBRARY NAMES ptcblas_r ptcblas cblas_r cblas PATHS ${Atlas_LIB_SEARCH_PATHS})
find_library(Atlas_BLAS_LIBRARY NAMES atlas_r atlas PATHS ${Atlas_LIB_SEARCH_PATHS})
find_library(Atlas_LAPACK_LIBRARY NAMES lapack alapack_r alapack lapack_atlas PATHS ${Atlas_LIB_SEARCH_PATHS})

set(LOOKED_FOR
Atlas_CBLAS_INCLUDE_DIR
Expand Down
12 changes: 6 additions & 6 deletions include/caffe/filler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ class PositiveUnitballFiller : public Filler<Dtype> {
caffe_rng_uniform<Dtype>(blob->count(), 0, 1, blob->mutable_cpu_data());
// We expect the filler to not be called very frequently, so we will
// just use a simple implementation
int dim = blob->count() / blob->num();
int dim = blob->count() / blob->shape(0);
CHECK(dim);
for (int i = 0; i < blob->num(); ++i) {
for (int i = 0; i < blob->shape(0); ++i) {
Dtype sum = 0;
for (int j = 0; j < dim; ++j) {
sum += data[i * dim + j];
Expand Down Expand Up @@ -147,8 +147,8 @@ class XavierFiller : public Filler<Dtype> {
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
int fan_in = blob->count() / blob->num();
int fan_out = blob->count() / blob->channels();
int fan_in = blob->count() / blob->shape(0);
int fan_out = blob->count() / blob->shape(1);
Dtype n = fan_in; // default to fan_in
if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_AVERAGE) {
Expand Down Expand Up @@ -189,8 +189,8 @@ class MSRAFiller : public Filler<Dtype> {
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
int fan_in = blob->count() / blob->num();
int fan_out = blob->count() / blob->channels();
int fan_in = blob->count() / blob->shape(0);
int fan_out = blob->count() / blob->shape(1);
Dtype n = fan_in; // default to fan_in
if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_AVERAGE) {
Expand Down
15 changes: 10 additions & 5 deletions include/caffe/layers/pooling_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,17 @@ class PoolingLayer : public Layer<Dtype> {
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

int kernel_h_, kernel_w_;
int stride_h_, stride_w_;
int pad_h_, pad_w_;
/// @brief The spatial dimensions of a filter kernel.
std::vector<int> kernel_shape_;
/// @brief The spatial dimensions of the stride.
std::vector<int> stride_;
/// @brief The spatial dimensions of the padding.
std::vector<int> pad_;

int num_spatial_axes_;
int channels_;
int height_, width_;
int pooled_height_, pooled_width_;
std::vector<int> input_shape_;
std::vector<int> pooled_shape_;
bool global_pooling_;
Blob<Dtype> rand_idx_;
Blob<int> max_idx_;
Expand Down
110 changes: 107 additions & 3 deletions include/caffe/util/cudnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include <cudnn.h>

#include <vector>

#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"

Expand Down Expand Up @@ -68,6 +70,11 @@ inline void createTensor4dDesc(cudnnTensorDescriptor_t* desc) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(desc));
}

template <typename Dtype>
inline void createTensorDesc(cudnnTensorDescriptor_t* desc) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(desc));
}

template <typename Dtype>
inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
int n, int c, int h, int w,
Expand All @@ -76,6 +83,24 @@ inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
n, c, h, w, stride_n, stride_c, stride_h, stride_w));
}

template <typename Dtype>
inline void setTensorNdDesc(cudnnTensorDescriptor_t* desc,
std::vector<int> shape,
std::vector<int> stride) {
CHECK_EQ(shape.size(), stride.size())
<< "Dimensions of shape and stride don't match !";
// fill shape with 1 to create tensors with at least 4 dimensions
// to prevent CUDNN_STATUS_BAD_PARAM error in CUDNN v4
// TODO([email protected]): check CUDNN doc, probably fixed
// in newer versions
for (int i = shape.size(); i < 4; ++i) {
shape.push_back(1);
stride.push_back(1);
}
CUDNN_CHECK(cudnnSetTensorNdDescriptor(*desc, dataType<Dtype>::type,
shape.size(), shape.data(), stride.data()));
}

template <typename Dtype>
inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
int n, int c, int h, int w) {
Expand All @@ -87,6 +112,17 @@ inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
stride_n, stride_c, stride_h, stride_w);
}

template <typename Dtype>
inline void setTensorNdDesc(cudnnTensorDescriptor_t* desc,
std::vector<int> shape) {
// set up stride
std::vector<int> stride(shape.size(), 1);
for (int i = stride.size() - 2; i >= 0; --i) {
stride[i] = shape[i + 1] * stride[i + 1];
}
setTensorNdDesc<Dtype>(desc, shape, stride);
}

template <typename Dtype>
inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
int n, int c, int h, int w) {
Expand All @@ -100,6 +136,19 @@ inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
#endif
}

template <typename Dtype>
inline void createNdFilterDesc(cudnnFilterDescriptor_t* desc,
std::vector<int> shape) {
CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetFilterNdDescriptor(*desc, dataType<Dtype>::type,
CUDNN_TENSOR_NCHW, shape.size(), shape.data()));
#else
CUDNN_CHECK(cudnnSetFilterNdDescriptor(*desc, dataType<Dtype>::type,
shape.size(), shape.data()));
#endif
}

template <typename Dtype>
inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv) {
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(conv));
Expand All @@ -113,6 +162,31 @@ inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
}

template <typename Dtype>
inline void setNdConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter,
std::vector<int> pad, std::vector<int> stride) {
int nbDims;
std::vector<int> shape(pad.size() + 2);
cudnnDataType_t cudnn_type;
#if CUDNN_VERSION_MIN(5, 0, 0)
cudnnTensorFormat_t tensor_format;
cudnnGetFilterNdDescriptor(filter,
shape.size(), &cudnn_type, &tensor_format, &nbDims, shape.data());
#else
cudnnGetFilterNdDescriptor(filter,
shape.size(), &cudnn_type, &nbDims, shape.data());
#endif
CHECK_EQ(nbDims, pad.size() + 2)
<< "Dimensions of filters and pad don't match !";
CHECK_EQ(nbDims, stride.size() + 2)
<< "Dimensions of filters and stride don't match !";
std::vector<int> upscale(pad.size(), 1);
CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(*conv,
pad.size(), pad.data(), stride.data(), upscale.data(),
CUDNN_CROSS_CORRELATION, cudnn_type));
}

template <typename Dtype>
inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode,
Expand All @@ -130,10 +204,10 @@ inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode,
CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
#else
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(*pool_desc, *mode,
CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode, h, w,
pad_h, pad_w, stride_h, stride_w));
#endif
}

Expand All @@ -145,6 +219,36 @@ inline void createActivationDescriptor(cudnnActivationDescriptor_t* activ_desc,
CUDNN_PROPAGATE_NAN, Dtype(0)));
}

template <typename Dtype>
inline void createNdPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode,
std::vector<int> shape, std::vector<int> pad,
std::vector<int> stride) {
CHECK_EQ(shape.size(), pad.size())
<< "Dimensions of shape and pad don't match !";
CHECK_EQ(shape.size(), stride.size())
<< "Dimensions of shape and stride don't match !";
switch (poolmethod) {
case PoolingParameter_PoolMethod_MAX:
*mode = CUDNN_POOLING_MAX;
break;
case PoolingParameter_PoolMethod_AVE:
*mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
break;
default:
LOG(FATAL) << "Unknown pooling method.";
}
CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPoolingNdDescriptor(*pool_desc, *mode,
CUDNN_PROPAGATE_NAN, shape.size(), shape.data(), pad.data(),
stride.data()));
#else
CUDNN_CHECK(cudnnSetPoolingNdDescriptor(*pool_desc, *mode, shape.size(),
shape.data(), pad.data(), stride.data()));
#endif
}

} // namespace cudnn

} // namespace caffe
Expand Down
79 changes: 41 additions & 38 deletions src/caffe/layers/cudnn_conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,21 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
bias_offset_ = (this->num_output_ / this->group_);

// Create filter descriptor.
const int* kernel_shape_data = this->kernel_shape_.cpu_data();
const int kernel_h = kernel_shape_data[0];
const int kernel_w = kernel_shape_data[1];
cudnn::createFilterDesc<Dtype>(&filter_desc_,
this->num_output_ / this->group_, this->channels_ / this->group_,
kernel_h, kernel_w);
std::vector<int> kernel_shape;
kernel_shape.push_back(this->num_output_ / this->group_);
kernel_shape.push_back(this->channels_ / this->group_);
for (unsigned int i = 0; i < this->num_spatial_axes_; ++i)
kernel_shape.push_back(this->kernel_shape_.cpu_data()[i]);

cudnn::createNdFilterDesc<Dtype>(&filter_desc_, kernel_shape);

// Create tensor descriptor(s) for data and corresponding convolution(s).
for (int i = 0; i < bottom.size(); i++) {
cudnnTensorDescriptor_t bottom_desc;
cudnn::createTensor4dDesc<Dtype>(&bottom_desc);
cudnn::createTensorDesc<Dtype>(&bottom_desc);
bottom_descs_.push_back(bottom_desc);
cudnnTensorDescriptor_t top_desc;
cudnn::createTensor4dDesc<Dtype>(&top_desc);
cudnn::createTensorDesc<Dtype>(&top_desc);
top_descs_.push_back(top_desc);
cudnnConvolutionDescriptor_t conv_desc;
cudnn::createConvolutionDesc<Dtype>(&conv_desc);
Expand All @@ -81,7 +82,7 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(

// Tensor descriptor for bias.
if (this->bias_term_) {
cudnn::createTensor4dDesc<Dtype>(&bias_desc_);
cudnn::createTensorDesc<Dtype>(&bias_desc_);
}

handles_setup_ = true;
Expand All @@ -91,41 +92,42 @@ template <typename Dtype>
void CuDNNConvolutionLayer<Dtype>::Reshape(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
ConvolutionLayer<Dtype>::Reshape(bottom, top);
CHECK_EQ(2, this->num_spatial_axes_)
<< "CuDNNConvolution input must have 2 spatial axes "
<< "(e.g., height and width). "
<< "Use 'engine: CAFFE' for general ND convolution.";

bottom_offset_ = this->bottom_dim_ / this->group_;
top_offset_ = this->top_dim_ / this->group_;
const int height = bottom[0]->shape(this->channel_axis_ + 1);
const int width = bottom[0]->shape(this->channel_axis_ + 2);
const int height_out = top[0]->shape(this->channel_axis_ + 1);
const int width_out = top[0]->shape(this->channel_axis_ + 2);
const int* pad_data = this->pad_.cpu_data();
const int pad_h = pad_data[0];
const int pad_w = pad_data[1];
const int* stride_data = this->stride_.cpu_data();
const int stride_h = stride_data[0];
const int stride_w = stride_data[1];

std::vector<int> bottom_tensor_shape(bottom[0]->shape());
bottom_tensor_shape[1] /= this->group_;
std::vector<int> bottom_tensor_stride(bottom[0]->shape().size(), 1);
for (int i = bottom[0]->shape().size() - 2; i >= 0; --i) {
bottom_tensor_stride[i] =
bottom[0]->shape(i + 1) * bottom_tensor_stride[i + 1];
}

std::vector<int> top_tensor_shape(top[0]->shape());
top_tensor_shape[1] /= this->group_;
std::vector<int> top_tensor_stride(top[0]->shape().size(), 1);
for (int i = top[0]->shape().size() - 2; i >= 0; --i) {
top_tensor_stride[i] = top[0]->shape(i + 1) * top_tensor_stride[i + 1];
}

std::vector<int> pad, stride;
for (unsigned int i = 0; i < this->num_spatial_axes_; ++i) {
pad.push_back(this->pad_.cpu_data()[i]);
stride.push_back(this->stride_.cpu_data()[i]);
}

// Specify workspace limit for kernels directly until we have a
// planning strategy and a rewrite of Caffe's GPU memory mangagement
size_t workspace_limit_bytes = 8*1024*1024;

for (int i = 0; i < bottom.size(); i++) {
cudnn::setTensor4dDesc<Dtype>(&bottom_descs_[i],
this->num_,
this->channels_ / this->group_, height, width,
this->channels_ * height * width,
height * width, width, 1);
cudnn::setTensor4dDesc<Dtype>(&top_descs_[i],
this->num_,
this->num_output_ / this->group_, height_out, width_out,
this->num_output_ * this->out_spatial_dim_,
this->out_spatial_dim_, width_out, 1);
cudnn::setConvolutionDesc<Dtype>(&conv_descs_[i], bottom_descs_[i],
filter_desc_, pad_h, pad_w,
stride_h, stride_w);
cudnn::setTensorNdDesc<Dtype>(&bottom_descs_[i],
bottom_tensor_shape, bottom_tensor_stride);
cudnn::setTensorNdDesc<Dtype>(&top_descs_[i],
top_tensor_shape, top_tensor_stride);
cudnn::setNdConvolutionDesc<Dtype>(&conv_descs_[i], bottom_descs_[i],
filter_desc_, pad, stride);

// choose forward and backward algorithms + workspace(s)
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[0],
Expand Down Expand Up @@ -226,8 +228,9 @@ void CuDNNConvolutionLayer<Dtype>::Reshape(

// Tensor descriptor for bias.
if (this->bias_term_) {
cudnn::setTensor4dDesc<Dtype>(&bias_desc_,
1, this->num_output_ / this->group_, 1, 1);
vector<int> bias_shape(bottom[0]->shape().size(), 1);
bias_shape[1] = this->num_output_ / this->group_;
cudnn::setTensorNdDesc<Dtype>(&bias_desc_, bias_shape);
}
}

Expand Down
15 changes: 6 additions & 9 deletions src/caffe/layers/cudnn_pooling_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,20 @@ void CuDNNPoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
PoolingLayer<Dtype>::LayerSetUp(bottom, top);
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
cudnn::createPoolingDesc<Dtype>(&pooling_desc_,
cudnn::createTensorDesc<Dtype>(&bottom_desc_);
cudnn::createTensorDesc<Dtype>(&top_desc_);
cudnn::createNdPoolingDesc<Dtype>(&pooling_desc_,
this->layer_param_.pooling_param().pool(), &mode_,
this->kernel_h_, this->kernel_w_, this->pad_h_, this->pad_w_,
this->stride_h_, this->stride_w_);
this->kernel_shape_, this->pad_, this->stride_);
handles_setup_ = true;
}

template <typename Dtype>
void CuDNNPoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
PoolingLayer<Dtype>::Reshape(bottom, top);
cudnn::setTensor4dDesc<Dtype>(&bottom_desc_, bottom[0]->num(),
this->channels_, this->height_, this->width_);
cudnn::setTensor4dDesc<Dtype>(&top_desc_, bottom[0]->num(),
this->channels_, this->pooled_height_, this->pooled_width_);
cudnn::setTensorNdDesc<Dtype>(&bottom_desc_, this->input_shape_);
cudnn::setTensorNdDesc<Dtype>(&top_desc_, this->pooled_shape_);
}

template <typename Dtype>
Expand Down
Loading