Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions include/caffe/layers/cudnn_ndconv_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include <vector>

#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

#include "caffe/layers/conv_layer.hpp"

namespace caffe {


template <typename Dtype>
class CudnnNdConvolutionLayer : public Layer<Dtype> {
public:
explicit CudnnNdConvolutionLayer(const LayerParameter& param)
: Layer<Dtype>(param), handles_setup_(false) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual ~CudnnNdConvolutionLayer();

virtual inline const char* type() const { return "NdConvolution"; }

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

// Compute height_out_ and width_out_ from other parameters.
virtual void compute_output_shape();

vector<int> kernel_shape_;
vector<int> stride_shape_;
int num_;
int channels_;
vector<int> pad_shape_;
vector<int> input_shape_;
int group_;
int num_output_;
vector<int> output_shape_;
bool bias_term_;

int conv_out_spatial_dim_;
int kernel_dim_;
int output_offset_;

Blob<Dtype> bias_multiplier_;

bool handles_setup_;
cudnnHandle_t* handle_;
cudaStream_t* stream_;
vector<cudnnTensorDescriptor_t> bottom_descs_, top_descs_;
cudnnTensorDescriptor_t bias_desc_;
cudnnFilterDescriptor_t filter_desc_;
vector<cudnnConvolutionDescriptor_t> conv_descs_;
int bottom_offset_, top_offset_, weight_offset_, bias_offset_;
size_t workspaceSizeInBytes;
void *workspace;
};

} // namespace caffe

48 changes: 48 additions & 0 deletions include/caffe/layers/cudnn_pooling_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,54 @@ class CuDNNPoolingLayer : public PoolingLayer<Dtype> {
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);



bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_, top_desc_;
cudnnPoolingDescriptor_t pooling_desc_;
cudnnPoolingMode_t mode_;
};

template <typename Dtype>
class CudnnNdPoolingLayer : public Layer<Dtype> {
public:
explicit CudnnNdPoolingLayer(const LayerParameter& param)
: Layer<Dtype>(param), handles_setup_(false) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual ~CudnnNdPoolingLayer();

virtual inline const char* type() const { return "NdPooling"; }
virtual inline int ExactNumBottomBlobs() const { return 1; }
virtual inline int ExactNumTopBlobs() const { return 1; }

protected:
virtual void compute_output_shape();

virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

vector<int> kernel_shape_;
vector<int> stride_shape_;
vector<int> pad_shape_;
int channels_;
vector<int> input_shape_;
vector<int> pooled_shape_;
bool global_pooling_;
Blob<Dtype> rand_idx_;
Blob<int> max_idx_;



bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_, top_desc_;
Expand Down
130 changes: 103 additions & 27 deletions include/caffe/util/cudnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#ifdef USE_CUDNN

#include <cudnn.h>
#include <vector>

#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
Expand All @@ -19,28 +20,28 @@

inline const char* cudnnGetErrorString(cudnnStatus_t status) {
switch (status) {
case CUDNN_STATUS_SUCCESS:
return "CUDNN_STATUS_SUCCESS";
case CUDNN_STATUS_NOT_INITIALIZED:
return "CUDNN_STATUS_NOT_INITIALIZED";
case CUDNN_STATUS_ALLOC_FAILED:
return "CUDNN_STATUS_ALLOC_FAILED";
case CUDNN_STATUS_BAD_PARAM:
return "CUDNN_STATUS_BAD_PARAM";
case CUDNN_STATUS_INTERNAL_ERROR:
return "CUDNN_STATUS_INTERNAL_ERROR";
case CUDNN_STATUS_INVALID_VALUE:
return "CUDNN_STATUS_INVALID_VALUE";
case CUDNN_STATUS_ARCH_MISMATCH:
return "CUDNN_STATUS_ARCH_MISMATCH";
case CUDNN_STATUS_MAPPING_ERROR:
return "CUDNN_STATUS_MAPPING_ERROR";
case CUDNN_STATUS_EXECUTION_FAILED:
return "CUDNN_STATUS_EXECUTION_FAILED";
case CUDNN_STATUS_NOT_SUPPORTED:
return "CUDNN_STATUS_NOT_SUPPORTED";
case CUDNN_STATUS_LICENSE_ERROR:
return "CUDNN_STATUS_LICENSE_ERROR";
case CUDNN_STATUS_SUCCESS:
return "CUDNN_STATUS_SUCCESS";
case CUDNN_STATUS_NOT_INITIALIZED:
return "CUDNN_STATUS_NOT_INITIALIZED";
case CUDNN_STATUS_ALLOC_FAILED:
return "CUDNN_STATUS_ALLOC_FAILED";
case CUDNN_STATUS_BAD_PARAM:
return "CUDNN_STATUS_BAD_PARAM";
case CUDNN_STATUS_INTERNAL_ERROR:
return "CUDNN_STATUS_INTERNAL_ERROR";
case CUDNN_STATUS_INVALID_VALUE:
return "CUDNN_STATUS_INVALID_VALUE";
case CUDNN_STATUS_ARCH_MISMATCH:
return "CUDNN_STATUS_ARCH_MISMATCH";
case CUDNN_STATUS_MAPPING_ERROR:
return "CUDNN_STATUS_MAPPING_ERROR";
case CUDNN_STATUS_EXECUTION_FAILED:
return "CUDNN_STATUS_EXECUTION_FAILED";
case CUDNN_STATUS_NOT_SUPPORTED:
return "CUDNN_STATUS_NOT_SUPPORTED";
case CUDNN_STATUS_LICENSE_ERROR:
return "CUDNN_STATUS_LICENSE_ERROR";
}
return "Unknown cudnn status";
}
Expand Down Expand Up @@ -68,12 +69,27 @@ inline void createTensor4dDesc(cudnnTensorDescriptor_t* desc) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(desc));
}

template <typename Dtype>
inline void createTensorDesc(cudnnTensorDescriptor_t* desc) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(desc));
}

template <typename Dtype>
inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
int n, int c, int h, int w,
int stride_n, int stride_c, int stride_h, int stride_w) {
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType<Dtype>::type,
n, c, h, w, stride_n, stride_c, stride_h, stride_w));
n, c, h, w, stride_n, stride_c, stride_h, stride_w));
}

template <typename Dtype>
inline void setTensorNdDesc(cudnnTensorDescriptor_t* desc,
std::vector<int> shape,
std::vector<int> stride) {
CHECK_EQ(shape.size(), stride.size()) <<
"Dimensions of shape and stride don't match !";
CUDNN_CHECK(cudnnSetTensorNdDescriptor(*desc, dataType<Dtype>::type,
shape.size(), shape.data(), stride.data()));
}

template <typename Dtype>
Expand All @@ -84,15 +100,33 @@ inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
const int stride_c = h * stride_h;
const int stride_n = c * stride_c;
setTensor4dDesc<Dtype>(desc, n, c, h, w,
stride_n, stride_c, stride_h, stride_w);
stride_n, stride_c, stride_h, stride_w);
}

template <typename Dtype>
inline void setTensorNdDesc(cudnnTensorDescriptor_t* desc,
std::vector<int> shape) {
std::vector<int> stride(shape.size(), 1);
for (int i = stride.size()-2; i >= 0; --i) {
stride[i] = shape[i+1] * stride[i+1];
}
setTensorNdDesc<Dtype>(desc, shape, stride);
}

template <typename Dtype>
inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
int n, int c, int h, int w) {
CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType<Dtype>::type,
n, c, h, w));
n, c, h, w));
}

template <typename Dtype>
inline void createNdFilterDesc(cudnnFilterDescriptor_t* desc,
std::vector<int> shape) {
CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
CUDNN_CHECK(cudnnSetFilterNdDescriptor(*desc, dataType<Dtype>::type,
shape.size(), shape.data()));
}

template <typename Dtype>
Expand All @@ -105,7 +139,26 @@ inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter,
int pad_h, int pad_w, int stride_h, int stride_w) {
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
}

template <typename Dtype>
inline void setNdConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter,
std::vector<int> pad, std::vector<int> stride) {
int nbDims;
std::vector<int> shape(pad.size()+2);
cudnnDataType_t cudnn_type;
cudnnGetFilterNdDescriptor(filter, shape.size(), &cudnn_type, &nbDims,
shape.data());
CHECK_EQ(nbDims, pad.size()+2) <<
"Dimensions of filters and pad don't match !";
CHECK_EQ(nbDims, stride.size()+2) <<
"Dimensions of filters and stride don't match !";
std::vector<int> upscale(pad.size(), 1);
CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(*conv,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function takes an additional input arg in CUDNN v4 (or possibly earlier version too)
Something like following will take care of this:

#if CUDNN_VERSION >= 4000
  CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(*conv,
              pad.size(), pad.data(), stride.data(), upscale.data(),
              CUDNN_CROSS_CORRELATION, cudnn_type));
#else
   CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(*conv,
               pad.size(), pad.data(), stride.data(), upscale.data(),
               CUDNN_CROSS_CORRELATION));
#endif

pad.size(), pad.data(), stride.data(), upscale.data(),
CUDNN_CROSS_CORRELATION));
}

template <typename Dtype>
Expand All @@ -124,7 +177,30 @@ inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
}
CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode, h, w,
pad_h, pad_w, stride_h, stride_w));
pad_h, pad_w, stride_h, stride_w));
}

template <typename Dtype>
inline void createNdPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode,
std::vector<int> shape, std::vector<int> pad, std::vector<int> stride) {
CHECK_EQ(shape.size(), pad.size()) <<
"Dimensions of shape and pad don't match !";
CHECK_EQ(shape.size(), stride.size()) <<
"Dimensions of shape and stride don't match !";
switch (poolmethod) {
case PoolingParameter_PoolMethod_MAX:
*mode = CUDNN_POOLING_MAX;
break;
case PoolingParameter_PoolMethod_AVE:
*mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
break;
default:
LOG(FATAL) << "Unknown pooling method.";
}
CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
CUDNN_CHECK(cudnnSetPoolingNdDescriptor(*pool_desc, *mode, shape.size(),
shape.data(), pad.data(), stride.data()));
}

} // namespace cudnn
Expand Down
57 changes: 56 additions & 1 deletion src/caffe/layer_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "caffe/layers/cudnn_conv_layer.hpp"
#include "caffe/layers/cudnn_lcn_layer.hpp"
#include "caffe/layers/cudnn_lrn_layer.hpp"
#include "caffe/layers/cudnn_ndconv_layer.hpp"
#include "caffe/layers/cudnn_pooling_layer.hpp"
#include "caffe/layers/cudnn_relu_layer.hpp"
#include "caffe/layers/cudnn_sigmoid_layer.hpp"
Expand All @@ -36,7 +37,7 @@ namespace caffe {
// Get convolution layer according to engine.
template <typename Dtype>
shared_ptr<Layer<Dtype> > GetConvolutionLayer(
const LayerParameter& param) {
const LayerParameter& param) {
ConvolutionParameter conv_param = param.convolution_param();
ConvolutionParameter_Engine engine = conv_param.engine();
#ifdef USE_CUDNN
Expand Down Expand Up @@ -72,6 +73,33 @@ shared_ptr<Layer<Dtype> > GetConvolutionLayer(

REGISTER_LAYER_CREATOR(Convolution, GetConvolutionLayer);

// Get NdConvolutionLayer when cudnn

#ifdef USE_CUDNN
template <typename Dtype>
shared_ptr<Layer<Dtype> > GetNdConvolutionLayer(
const LayerParameter& param) {
ConvolutionParameter_Engine engine = param.convolution_param().engine();
if (engine == ConvolutionParameter_Engine_DEFAULT) {
engine = ConvolutionParameter_Engine_CAFFE;
// #ifdef USE_CUDNN
engine = ConvolutionParameter_Engine_CUDNN;
// #endif
}
if (engine == ConvolutionParameter_Engine_CAFFE) {
NOT_IMPLEMENTED;
// #ifdef USE_CUDNN
} else if (engine == ConvolutionParameter_Engine_CUDNN) {
return shared_ptr<Layer<Dtype> >(new CudnnNdConvolutionLayer<Dtype>(param));
// #endif
} else {
LOG(FATAL) << "Layer " << param.name() << " has unknown engine.";
}
}

REGISTER_LAYER_CREATOR(NdConvolution, GetNdConvolutionLayer);
#endif

// Get pooling layer according to engine.
template <typename Dtype>
shared_ptr<Layer<Dtype> > GetPoolingLayer(const LayerParameter& param) {
Expand Down Expand Up @@ -137,6 +165,33 @@ shared_ptr<Layer<Dtype> > GetLRNLayer(const LayerParameter& param) {

REGISTER_LAYER_CREATOR(LRN, GetLRNLayer);


// Get NdPooling layer according to engine.
#ifdef USE_CUDNN
template <typename Dtype>
shared_ptr<Layer<Dtype> > GetNdPoolingLayer(const LayerParameter& param) {
PoolingParameter_Engine engine = param.pooling_param().engine();
if (engine == PoolingParameter_Engine_DEFAULT) {
engine = PoolingParameter_Engine_CAFFE;
// #ifdef USE_CUDNN
engine = PoolingParameter_Engine_CUDNN;
// #endif
}
if (engine == PoolingParameter_Engine_CAFFE) {
NOT_IMPLEMENTED;
// #ifdef USE_CUDNN
} else if (engine == PoolingParameter_Engine_CUDNN) {
return shared_ptr<Layer<Dtype> >(new CudnnNdPoolingLayer<Dtype>(param));
// #endif
} else {
LOG(FATAL) << "Layer " << param.name() << " has unknown engine.";
}
}

REGISTER_LAYER_CREATOR(NdPooling, GetNdPoolingLayer);
#endif


// Get relu layer according to engine.
template <typename Dtype>
shared_ptr<Layer<Dtype> > GetReLULayer(const LayerParameter& param) {
Expand Down
1 change: 0 additions & 1 deletion src/caffe/layers/cudnn_conv_layer.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#ifdef USE_CUDNN
#include <vector>

#include "caffe/layers/cudnn_conv_layer.hpp"

namespace caffe {
Expand Down
Loading