diff --git a/include/caffe/util/cudnn.hpp b/include/caffe/util/cudnn.hpp index b531dd5fa7a..88fb4e50acd 100644 --- a/include/caffe/util/cudnn.hpp +++ b/include/caffe/util/cudnn.hpp @@ -65,6 +65,11 @@ inline void createTensor4dDesc(cudnnTensorDescriptor_t* desc) { CUDNN_CHECK(cudnnCreateTensorDescriptor(desc)); } +template +inline void createTensorDesc(cudnnTensorDescriptor_t* desc) { + CUDNN_CHECK(cudnnCreateTensorDescriptor(desc)); +} + template inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc, int n, int c, int h, int w, @@ -73,6 +78,15 @@ inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc, n, c, h, w, stride_n, stride_c, stride_h, stride_w)); } +template +inline void setTensorNdDesc(cudnnTensorDescriptor_t* desc, + std::vector shape, + std::vector stride) { + CHECK_EQ(shape.size(), stride.size()) << "Dimensions of shape and stride don't match !"; + CUDNN_CHECK(cudnnSetTensorNdDescriptor(*desc, dataType::type, + shape.size(), shape.data(), stride.data())); +} + template inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc, int n, int c, int h, int w) { @@ -84,6 +98,16 @@ inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc, stride_n, stride_c, stride_h, stride_w); } +template +inline void setTensorNdDesc(cudnnTensorDescriptor_t* desc, + std::vector shape) { + std::vector stride(shape.size(), 1); + for(int i = stride.size()-2; i >= 0; --i) { + stride[i] = shape[i+1] * stride[i+1]; + } + setTensorNdDesc(desc, shape, stride); +} + template inline void createFilterDesc(cudnnFilterDescriptor_t* desc, int n, int c, int h, int w) { @@ -92,6 +116,14 @@ inline void createFilterDesc(cudnnFilterDescriptor_t* desc, n, c, h, w)); } +template +inline void createNdFilterDesc(cudnnFilterDescriptor_t* desc, + std::vector shape) { + CUDNN_CHECK(cudnnCreateFilterDescriptor(desc)); + CUDNN_CHECK(cudnnSetFilterNdDescriptor(*desc, dataType::type, + shape.size(), shape.data())); +} + template inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv) { CUDNN_CHECK(cudnnCreateConvolutionDescriptor(conv)); @@ -105,6 +137,21 @@ inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv, pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION)); } +template +inline void setNdConvolutionDesc(cudnnConvolutionDescriptor_t* conv, + cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter, + std::vector pad, std::vector stride) { + int nbDims; + std::vector shape(pad.size()+2); + cudnnDataType_t cudnn_type; + cudnnGetFilterNdDescriptor(filter, shape.size(), &cudnn_type, &nbDims, shape.data()); + CHECK_EQ(nbDims, pad.size()+2) << "Dimensions of filters and pad don't match !"; + CHECK_EQ(nbDims, stride.size()+2) << "Dimensions of filters and stride don't match !"; + std::vector upscale(pad.size(), 1); + CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(*conv, + pad.size(), pad.data(), stride.data(), upscale.data(), CUDNN_CROSS_CORRELATION)); +} + template inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc, PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode, @@ -124,6 +171,27 @@ inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc, pad_h, pad_w, stride_h, stride_w)); } +template +inline void createNdPoolingDesc(cudnnPoolingDescriptor_t* pool_desc, + PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode, + std::vector shape, std::vector pad, std::vector stride) { + CHECK_EQ(shape.size(), pad.size()) << "Dimensions of shape and pad don't match !"; + CHECK_EQ(shape.size(), stride.size()) << "Dimensions of shape and stride don't match !"; + switch (poolmethod) { + case PoolingParameter_PoolMethod_MAX: + *mode = CUDNN_POOLING_MAX; + break; + case PoolingParameter_PoolMethod_AVE: + *mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + break; + default: + LOG(FATAL) << "Unknown pooling method."; + } + CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc)); + CUDNN_CHECK(cudnnSetPoolingNdDescriptor(*pool_desc, *mode, shape.size(), + shape.data(), pad.data(), stride.data())); +} + } // namespace cudnn } // namespace caffe diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 211e3d9042d..b1c325b892f 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -254,6 +254,61 @@ class CuDNNConvolutionLayer : public ConvolutionLayer { size_t workspaceSizeInBytes; void *workspace; }; + +template +class CudnnNdConvolutionLayer : public Layer { + public: + explicit CudnnNdConvolutionLayer(const LayerParameter& param) + : Layer(param), handles_setup_(false) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + virtual ~CudnnNdConvolutionLayer(); + + virtual inline const char* type() const { return "NdConvolution"; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + // Compute height_out_ and width_out_ from other parameters. + virtual void compute_output_shape(); + + vector kernel_shape_; + vector stride_shape_; + int num_; + int channels_; + vector pad_shape_; + vector input_shape_; + int group_; + int num_output_; + vector output_shape_; + bool bias_term_; + + int conv_out_spatial_dim_; + int kernel_dim_; + int output_offset_; + + Blob bias_multiplier_; + + bool handles_setup_; + cudnnHandle_t* handle_; + cudaStream_t* stream_; + vector bottom_descs_, top_descs_; + cudnnTensorDescriptor_t bias_desc_; + cudnnFilterDescriptor_t filter_desc_; + vector conv_descs_; + int bottom_offset_, top_offset_, weight_offset_, bias_offset_; + size_t workspaceSizeInBytes; + void *workspace; +}; #endif /** @@ -451,6 +506,50 @@ class CuDNNPoolingLayer : public PoolingLayer { cudnnPoolingDescriptor_t pooling_desc_; cudnnPoolingMode_t mode_; }; + +template +class CudnnNdPoolingLayer : public Layer { + public: + explicit CudnnNdPoolingLayer(const LayerParameter& param) + : Layer(param), handles_setup_(false) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + virtual ~CudnnNdPoolingLayer(); + + virtual inline const char* type() const { return "NdPooling"; } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + virtual void compute_output_shape(); + + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + vector kernel_shape_; + vector stride_shape_; + vector pad_shape_; + int channels_; + vector input_shape_; + vector pooled_shape_; + bool global_pooling_; + Blob rand_idx_; + Blob max_idx_; + + bool handles_setup_; + cudnnHandle_t handle_; + cudnnTensorDescriptor_t bottom_desc_, top_desc_; + cudnnPoolingDescriptor_t pooling_desc_; + cudnnPoolingMode_t mode_; +}; #endif /** diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index 926c7d8ff78..730a31c8c62 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -40,6 +40,29 @@ shared_ptr > GetConvolutionLayer( REGISTER_LAYER_CREATOR(Convolution, GetConvolutionLayer); +template +shared_ptr > GetNdConvolutionLayer( + const LayerParameter& param) { + ConvolutionParameter_Engine engine = param.convolution_param().engine(); + if (engine == ConvolutionParameter_Engine_DEFAULT) { + engine = ConvolutionParameter_Engine_CAFFE; +#ifdef USE_CUDNN + engine = ConvolutionParameter_Engine_CUDNN; +#endif + } + if (engine == ConvolutionParameter_Engine_CAFFE) { + NOT_IMPLEMENTED; +#ifdef USE_CUDNN + } else if (engine == ConvolutionParameter_Engine_CUDNN) { + return shared_ptr >(new CudnnNdConvolutionLayer(param)); +#endif + } else { + LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + } +} + +REGISTER_LAYER_CREATOR(NdConvolution, GetNdConvolutionLayer); + // Get pooling layer according to engine. template shared_ptr > GetPoolingLayer(const LayerParameter& param) { @@ -70,6 +93,29 @@ shared_ptr > GetPoolingLayer(const LayerParameter& param) { REGISTER_LAYER_CREATOR(Pooling, GetPoolingLayer); +// Get pooling layer according to engine. +template +shared_ptr > GetNdPoolingLayer(const LayerParameter& param) { + PoolingParameter_Engine engine = param.pooling_param().engine(); + if (engine == PoolingParameter_Engine_DEFAULT) { + engine = PoolingParameter_Engine_CAFFE; +#ifdef USE_CUDNN + engine = PoolingParameter_Engine_CUDNN; +#endif + } + if (engine == PoolingParameter_Engine_CAFFE) { + NOT_IMPLEMENTED; +#ifdef USE_CUDNN + } else if (engine == PoolingParameter_Engine_CUDNN) { + return shared_ptr >(new CudnnNdPoolingLayer(param)); +#endif + } else { + LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + } +} + +REGISTER_LAYER_CREATOR(NdPooling, GetNdPoolingLayer); + // Get relu layer according to engine. template shared_ptr > GetReLULayer(const LayerParameter& param) { diff --git a/src/caffe/layers/cudnn_ndconv_layer.cpp b/src/caffe/layers/cudnn_ndconv_layer.cpp new file mode 100644 index 00000000000..78dea63cb6d --- /dev/null +++ b/src/caffe/layers/cudnn_ndconv_layer.cpp @@ -0,0 +1,275 @@ +#ifdef USE_CUDNN +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +// Set to three for the benefit of the backward pass, which +// can use separate streams for calculating the gradient w.r.t. +// bias, filter weights, and bottom data for each group independently +#define CUDNN_STREAMS_PER_GROUP 3 + +/** + * TODO(dox) explain cuDNN interface + */ +template +void CudnnNdConvolutionLayer::LayerSetUp( + const vector*>& bottom, const vector*>& top) { + ConvolutionParameter conv_param = this->layer_param_.convolution_param(); + // Configure the kernel size, padding, stride, and inputs. + CHECK(conv_param.has_kernel_shape()) + << "Kernel shape is required."; + if(conv_param.has_pad_shape()) { + CHECK_EQ(conv_param.kernel_shape().dim_size(), conv_param.pad_shape().dim_size()) + << "Kernel and Pad shape don't match !"; + } + if(conv_param.has_stride_shape()) { + CHECK_EQ(conv_param.kernel_shape().dim_size(), conv_param.stride_shape().dim_size()) + << "Kernel and Stride shape don't match !"; + } + for(int i = 0; i < conv_param.kernel_shape().dim_size(); ++i) { + kernel_shape_.push_back(conv_param.kernel_shape().dim(i)); + CHECK_GT(kernel_shape_[i], 0) << "Filter dimensions cannot be zero."; + } + if(conv_param.has_pad_shape()) { + for(int i = 0; i < conv_param.kernel_shape().dim_size(); ++i) { + pad_shape_.push_back(conv_param.pad_shape().dim(i)); + } + } else { + pad_shape_ = std::vector(kernel_shape_.size(), 0); + } + if(conv_param.has_stride_shape()) { + for(int i = 0; i < conv_param.kernel_shape().dim_size(); ++i) { + stride_shape_.push_back(conv_param.stride_shape().dim(i)); + } + } else { + stride_shape_ = std::vector(kernel_shape_.size(), 1); + } + // Configure output channels and groups. + channels_ = bottom[0]->shape(1); + num_output_ = this->layer_param_.convolution_param().num_output(); + CHECK_GT(num_output_, 0); + group_ = this->layer_param_.convolution_param().group(); + CHECK_EQ(channels_ % group_, 0); + CHECK_EQ(num_output_ % group_, 0) + << "Number of output should be multiples of group."; + + // Handle the parameters: weights and biases. + // - blobs_[0] holds the filter weights + // - blobs_[1] holds the biases (optional) + bias_term_ = this->layer_param_.convolution_param().bias_term(); + + vector weight_shape(kernel_shape_); + weight_shape.insert(weight_shape.begin(), channels_ / group_); + weight_shape.insert(weight_shape.begin(), num_output_); + + if (this->blobs_.size() > 0) { + LOG(INFO) << "Skipping parameter initialization"; + } else { + if (bias_term_) { + this->blobs_.resize(2); + } else { + this->blobs_.resize(1); + } + // Initialize and fill the weights: + // output channels x input channels per-group x kernel height x kernel width + this->blobs_[0].reset(new Blob(weight_shape)); + shared_ptr > weight_filler(GetFiller( + this->layer_param_.convolution_param().weight_filler())); + weight_filler->Fill(this->blobs_[0].get()); + // If necessary, initialize and fill the biases. + if (bias_term_) { + vector bias_shape(1, num_output_); + this->blobs_[1].reset(new Blob(bias_shape)); + shared_ptr > bias_filler(GetFiller( + this->layer_param_.convolution_param().bias_filler())); + bias_filler->Fill(this->blobs_[1].get()); + } + } + + // Propagate gradients to the parameters (as directed by backward pass). + this->param_propagate_down_.resize(this->blobs_.size(), true); + + // Initialize CUDA streams and cuDNN. + stream_ = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP]; + handle_ = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP]; + workspaceSizeInBytes = 0; + workspace = NULL; + + for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) { + CUDA_CHECK(cudaStreamCreate(&stream_[g])); + CUDNN_CHECK(cudnnCreate(&handle_[g])); + CUDNN_CHECK(cudnnSetStream(handle_[g], stream_[g])); + } + + // Set the indexing parameters. + weight_shape[0] /= group_; + weight_offset_ = 1; + for(int i = 0; i < weight_shape.size(); ++i) { + weight_offset_ *= weight_shape[i]; + } + bias_offset_ = weight_shape[0]; + + // Create filter descriptor. + cudnn::createNdFilterDesc(&filter_desc_, weight_shape); + + // Create tensor descriptor(s) for data and corresponding convolution(s). + for (int i = 0; i < bottom.size(); i++) { + cudnnTensorDescriptor_t bottom_desc; + cudnn::createTensorDesc(&bottom_desc); + bottom_descs_.push_back(bottom_desc); + cudnnTensorDescriptor_t top_desc; + cudnn::createTensorDesc(&top_desc); + top_descs_.push_back(top_desc); + cudnnConvolutionDescriptor_t conv_desc; + cudnn::createConvolutionDesc(&conv_desc); + conv_descs_.push_back(conv_desc); + } + + // Tensor descriptor for bias. + if (this->bias_term_) { + cudnn::createTensorDesc(&bias_desc_); + } + + handles_setup_ = true; +} + +template +void CudnnNdConvolutionLayer::Reshape( + const vector*>& bottom, const vector*>& top) { + num_ = bottom[0]->shape(0); + CHECK_EQ(bottom[0]->shape(1), channels_) << "Input size incompatible with convolution kernel."; + input_shape_ = bottom[0]->shape(); + // TODO: generalize to handle inputs of different shapes. + for (int bottom_id = 1; bottom_id < bottom.size(); ++bottom_id) { + CHECK_EQ(num_, bottom[bottom_id]->shape(0)) + << "Inputs must have same num."; + CHECK_EQ(channels_, bottom[bottom_id]->shape(1)) + << "Inputs must have same channels."; + for(int i = 0; i < bottom[0]->num_axes(); ++i) { + CHECK_EQ(input_shape_[i], bottom[bottom_id]->shape(i)) << "Inputs must have same shape."; + } + } + // Shape the tops. + compute_output_shape(); + for (int top_id = 0; top_id < top.size(); ++top_id) { + top[top_id]->Reshape(output_shape_); + } + + conv_out_spatial_dim_ = 1; + for(int i = 2; i < output_shape_.size(); ++i) { + conv_out_spatial_dim_ *= output_shape_[i]; + } + + kernel_dim_ = channels_; + for(int i = 0; i < kernel_shape_.size(); ++i) { + kernel_dim_ *= kernel_shape_[i]; + } + weight_offset_ = num_output_ * kernel_dim_ / group_ / group_; + output_offset_ = num_output_ * conv_out_spatial_dim_ / group_; + // Set up the all ones "bias multiplier" for adding biases by BLAS + if (bias_term_) { + vector bias_multiplier_shape(1, conv_out_spatial_dim_); + bias_multiplier_.Reshape(bias_multiplier_shape); + caffe_set(bias_multiplier_.count(), Dtype(1), + bias_multiplier_.mutable_cpu_data()); + } + + bottom_offset_ = 1; + for(int i = 1; i < input_shape_.size(); ++i) { + bottom_offset_ *= input_shape_[i]; + } + bottom_offset_ /= group_; + top_offset_ = 1; + for(int i = 1; i < output_shape_.size(); ++i) { + top_offset_ *= output_shape_[i]; + } + top_offset_ /= group_; + + vector bottom_tensor_shape(input_shape_); + bottom_tensor_shape[1] /= group_; + vector bottom_tensor_stride(input_shape_.size(), 1); + for(int i = input_shape_.size()-2; i >= 0; --i) { + bottom_tensor_stride[i] = input_shape_[i+1] * bottom_tensor_stride[i+1]; + } + vector top_tensor_shape(output_shape_); + top_tensor_shape[1] /= group_; + vector top_tensor_stride(output_shape_.size(), 1); + for(int i = output_shape_.size()-2; i >= 0; --i) { + top_tensor_stride[i] = output_shape_[i+1] * top_tensor_stride[i+1]; + } + + for (int i = 0; i < bottom.size(); i++) { + cudnn::setTensorNdDesc(&bottom_descs_[i], + bottom_tensor_shape, bottom_tensor_stride); + cudnn::setTensorNdDesc(&top_descs_[i], + top_tensor_shape, top_tensor_stride); + cudnn::setNdConvolutionDesc(&conv_descs_[i], bottom_descs_[i], + filter_desc_, pad_shape_, stride_shape_); + } + + // Tensor descriptor for bias. + if (this->bias_term_) { + vector bias_shape(input_shape_.size(), 1); + bias_shape[1] = this->num_output_ / this->group_; + cudnn::setTensorNdDesc(&bias_desc_, bias_shape); + } +} + +template +void CudnnNdConvolutionLayer::compute_output_shape() { + output_shape_.clear(); + output_shape_.push_back(num_); + output_shape_.push_back(num_output_); + + for(int i = 2; i < input_shape_.size(); ++i) { + int dim = (input_shape_[i] + 2*pad_shape_[i-2] - kernel_shape_[i-2]) / stride_shape_[i-2] + 1; + if(dim > 1){ + output_shape_.push_back(dim); + } + } +} + +template +void CudnnNdConvolutionLayer::Forward_cpu(const vector* >& bottom, const vector* >& top) { + NOT_IMPLEMENTED; +} + +template +void CudnnNdConvolutionLayer::Backward_cpu(const vector* >& bottom, const vector& propagate_down, const vector* >& top) { + NOT_IMPLEMENTED; +} + +template +CudnnNdConvolutionLayer::~CudnnNdConvolutionLayer() { + // Check that handles have been setup before destroying. + if (!handles_setup_) { return; } + + for (int i = 0; i < bottom_descs_.size(); i++) { + cudnnDestroyTensorDescriptor(bottom_descs_[i]); + cudnnDestroyTensorDescriptor(top_descs_[i]); + cudnnDestroyConvolutionDescriptor(conv_descs_[i]); + } + if (this->bias_term_) { + cudnnDestroyTensorDescriptor(bias_desc_); + } + cudnnDestroyFilterDescriptor(filter_desc_); + + for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) { + cudaStreamDestroy(stream_[g]); + cudnnDestroy(handle_[g]); + } + + delete [] stream_; + delete [] handle_; +} + +INSTANTIATE_CLASS(CudnnNdConvolutionLayer); + +} // namespace caffe +#endif diff --git a/src/caffe/layers/cudnn_ndconv_layer.cu b/src/caffe/layers/cudnn_ndconv_layer.cu new file mode 100644 index 00000000000..cfd2177834f --- /dev/null +++ b/src/caffe/layers/cudnn_ndconv_layer.cu @@ -0,0 +1,161 @@ +#ifdef USE_CUDNN +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +__global__ void sync_ndconv_groups() { } + +template +void CudnnNdConvolutionLayer::Forward_gpu( + const vector*>& bottom, const vector*>& top) { + for (int i = 0; i < bottom.size(); ++i) { + const Dtype* bottom_data = bottom[i]->gpu_data(); + Dtype* top_data = top[i]->mutable_gpu_data(); + const Dtype* weight = this->blobs_[0]->gpu_data(); + + size_t workspace_limit_bytes = this->channels_*sizeof(int); + for(int j = 0; j < this->kernel_shape_.size(); ++j) { + workspace_limit_bytes *= kernel_shape_[j]; + } + ++workspace_limit_bytes; + + // Forward through cuDNN in parallel over groups. + for (int g = 0; g < this->group_; g++) { + cudnnConvolutionFwdAlgo_t algo; + + // pick the convolution algorithm + // TODO(shelhamer) this should be done during reshape + // TODO(shelhamer) the choice of automatic or manual algorithm picking + // should be exposed in proto + CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[g], + bottom_descs_[i], + filter_desc_, + conv_descs_[i], + top_descs_[i], + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_limit_bytes, // memoryLimitInBytes, + &algo)); + + // get minimum size of the workspace needed for the desired algorithm + size_t workspaceSizeInBytes_temp = 0; + + CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle_[g], + bottom_descs_[i], + filter_desc_, + conv_descs_[i], + top_descs_[i], + algo, + &workspaceSizeInBytes_temp)); + + if (workspaceSizeInBytes_temp > workspaceSizeInBytes) { + workspaceSizeInBytes = workspaceSizeInBytes_temp; + // free the existing workspace and allocate a new (larger) one + cudaFree(this->workspace); + cudaError_t err = cudaMalloc(&(this->workspace), workspaceSizeInBytes); + if (err != cudaSuccess) { + // force zero memory path + algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; + workspace = NULL; + workspaceSizeInBytes = 0; + } + } + + // Filters. + CUDNN_CHECK(cudnnConvolutionForward(handle_[g], + cudnn::dataType::one, + bottom_descs_[i], bottom_data + bottom_offset_ * g, + filter_desc_, weight + weight_offset_ * g, + conv_descs_[i], + algo, workspace, workspaceSizeInBytes, + cudnn::dataType::zero, + top_descs_[i], top_data + top_offset_ * g)); + + // Bias. + if (this->bias_term_) { + const Dtype* bias_data = this->blobs_[1]->gpu_data(); + CUDNN_CHECK(cudnnAddTensor(handle_[g], CUDNN_ADD_SAME_C, + cudnn::dataType::one, + bias_desc_, bias_data + bias_offset_ * g, + cudnn::dataType::one, + top_descs_[i], top_data + top_offset_ * g)); + } + } + + // Synchronize the work across groups, each of which went into its own + // stream, by launching an empty kernel into the default (null) stream. + // NOLINT_NEXT_LINE(whitespace/operators) + sync_ndconv_groups<<<1, 1>>>(); + } +} + +template +void CudnnNdConvolutionLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* weight = NULL; + Dtype* weight_diff = NULL; + if (this->param_propagate_down_[0]) { + weight = this->blobs_[0]->gpu_data(); + weight_diff = this->blobs_[0]->mutable_gpu_diff(); + } + Dtype* bias_diff = NULL; + if (this->bias_term_ && this->param_propagate_down_[1]) { + bias_diff = this->blobs_[1]->mutable_gpu_diff(); + } + for (int i = 0; i < top.size(); ++i) { + const Dtype* top_diff = top[i]->gpu_diff(); + // Backward through cuDNN in parallel over groups and gradients. + for (int g = 0; g < this->group_; g++) { + // Gradient w.r.t. bias. + if (this->bias_term_ && this->param_propagate_down_[1]) { + CUDNN_CHECK(cudnnConvolutionBackwardBias(handle_[0*this->group_ + g], + cudnn::dataType::one, + top_descs_[i], top_diff + top_offset_ * g, + cudnn::dataType::one, + bias_desc_, bias_diff + bias_offset_ * g)); + } + + // Gradient w.r.t. weights. + if (this->param_propagate_down_[0]) { + const Dtype* bottom_data = bottom[i]->gpu_data(); + CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle_[1*this->group_ + g], + cudnn::dataType::one, + bottom_descs_[i], bottom_data + bottom_offset_ * g, + top_descs_[i], top_diff + top_offset_ * g, + conv_descs_[i], + cudnn::dataType::one, + filter_desc_, weight_diff + weight_offset_ * g)); + } + + // Gradient w.r.t. bottom data. + if (propagate_down[i]) { + if (weight == NULL) { + weight = this->blobs_[0]->gpu_data(); + } + Dtype* bottom_diff = bottom[i]->mutable_gpu_diff(); + CUDNN_CHECK(cudnnConvolutionBackwardData(handle_[2*this->group_ + g], + cudnn::dataType::one, + filter_desc_, weight + weight_offset_ * g, + top_descs_[i], top_diff + top_offset_ * g, + conv_descs_[i], + cudnn::dataType::zero, + bottom_descs_[i], bottom_diff + bottom_offset_ * g)); + } + } + + // Synchronize the work across groups, each of which went into its own + // stream, by launching an empty kernel into the default (null) stream. + // NOLINT_NEXT_LINE(whitespace/operators) + sync_ndconv_groups<<<1, 1>>>(); + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(CudnnNdConvolutionLayer); + +} // namespace caffe +#endif diff --git a/src/caffe/layers/cudnn_ndpooling_layer.cpp b/src/caffe/layers/cudnn_ndpooling_layer.cpp new file mode 100644 index 00000000000..c41e1039b6c --- /dev/null +++ b/src/caffe/layers/cudnn_ndpooling_layer.cpp @@ -0,0 +1,132 @@ +#ifdef USE_CUDNN +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CudnnNdPoolingLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + PoolingParameter pool_param = this->layer_param_.pooling_param(); + CHECK(pool_param.has_kernel_shape()) + << "Kernel shape is required."; + if(pool_param.has_pad_shape()) { + CHECK_EQ(pool_param.kernel_shape().dim_size(), pool_param.pad_shape().dim_size()) + << "Kernel and Pad shape don't match !"; + } + if(pool_param.has_stride_shape()) { + CHECK_EQ(pool_param.kernel_shape().dim_size(), pool_param.stride_shape().dim_size()) + << "Kernel and Stride shape don't match !"; + } + global_pooling_ = pool_param.global_pooling(); + + if(global_pooling_) { + kernel_shape_ = vector(bottom[0]->shape().begin()+2, bottom[0]->shape().end()); + } else { + for(int i = 0; i < pool_param.kernel_shape().dim_size(); ++i) { + kernel_shape_.push_back(pool_param.kernel_shape().dim(i)); + CHECK_GT(kernel_shape_[i], 0) << "Filter dimensions cannot be zero."; + } + } + if(pool_param.has_pad_shape()) { + for(int i = 0; i < pool_param.kernel_shape().dim_size(); ++i) { + pad_shape_.push_back(pool_param.pad_shape().dim(i)); + } + } else { + pad_shape_ = std::vector(kernel_shape_.size(), 0); + } + if(pool_param.has_stride_shape()) { + for(int i = 0; i < pool_param.kernel_shape().dim_size(); ++i) { + stride_shape_.push_back(pool_param.stride_shape().dim(i)); + } + } else { + stride_shape_ = std::vector(kernel_shape_.size(), 1); + } + + CUDNN_CHECK(cudnnCreate(&handle_)); + cudnn::createTensorDesc(&bottom_desc_); + cudnn::createTensorDesc(&top_desc_); + cudnn::createNdPoolingDesc(&pooling_desc_, + this->layer_param_.pooling_param().pool(), &mode_, + kernel_shape_, pad_shape_, stride_shape_); + handles_setup_ = true; +} + +template +void CudnnNdPoolingLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + + channels_ = bottom[0]->shape(1); + input_shape_ = bottom[0]->shape(); + if(global_pooling_) { + kernel_shape_ = vector(bottom[0]->shape().begin()+2, bottom[0]->shape().end()); + } + + compute_output_shape(); + top[0]->Reshape(pooled_shape_); + + // If max pooling, we will initialize the vector index part. + if (this->layer_param_.pooling_param().pool() == + PoolingParameter_PoolMethod_MAX && top.size() == 1) { + max_idx_.Reshape(pooled_shape_); + } + // If stochastic pooling, we will initialize the random index part. + if (this->layer_param_.pooling_param().pool() == + PoolingParameter_PoolMethod_STOCHASTIC) { + rand_idx_.Reshape(pooled_shape_); + } + + cudnn::setTensorNdDesc(&bottom_desc_, input_shape_); + cudnn::setTensorNdDesc(&top_desc_, pooled_shape_); +} + +template +void CudnnNdPoolingLayer::compute_output_shape() { + pooled_shape_ = std::vector(input_shape_.begin(), input_shape_.begin()+2); + for(int i = 2; i < input_shape_.size(); ++i) { + int dim = (input_shape_[i] + 2 * pad_shape_[i-2] - kernel_shape_[i-2]) / stride_shape_[i-2] + 1; + + if(pad_shape_[i-2] > 0) { + if ((dim - 1) * stride_shape_[i-2] >= input_shape_[i] + pad_shape_[i-2]) { + --dim; + } + CHECK_LT((dim - 1) * stride_shape_[i-2], input_shape_[i] + pad_shape_[i-2]); + } + + if(dim > 1) { + pooled_shape_.push_back(dim); + } + + } +} + +template +void CudnnNdPoolingLayer::Forward_cpu(const vector* >& bottom, const vector* >& top) { + NOT_IMPLEMENTED; +} + +template +void CudnnNdPoolingLayer::Backward_cpu(const vector* >& bottom, const vector& propagate_down, const vector* >& top) { + NOT_IMPLEMENTED; +} + +template +CudnnNdPoolingLayer::~CudnnNdPoolingLayer() { + // Check that handles have been setup before destroying. + if (!handles_setup_) { return; } + + cudnnDestroyTensorDescriptor(bottom_desc_); + cudnnDestroyTensorDescriptor(top_desc_); + cudnnDestroyPoolingDescriptor(pooling_desc_); + cudnnDestroy(handle_); +} + +INSTANTIATE_CLASS(CudnnNdPoolingLayer); + +} // namespace caffe +#endif diff --git a/src/caffe/layers/cudnn_ndpooling_layer.cu b/src/caffe/layers/cudnn_ndpooling_layer.cu new file mode 100644 index 00000000000..cd4c6ce8843 --- /dev/null +++ b/src/caffe/layers/cudnn_ndpooling_layer.cu @@ -0,0 +1,45 @@ +#ifdef USE_CUDNN +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CudnnNdPoolingLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + CUDNN_CHECK(cudnnPoolingForward(handle_, pooling_desc_, + cudnn::dataType::one, + bottom_desc_, bottom_data, + cudnn::dataType::zero, + top_desc_, top_data)); +} + +template +void CudnnNdPoolingLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { + return; + } + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + CUDNN_CHECK(cudnnPoolingBackward(handle_, pooling_desc_, + cudnn::dataType::one, + top_desc_, top_data, top_desc_, top_diff, + bottom_desc_, bottom_data, + cudnn::dataType::zero, + bottom_desc_, bottom_diff)); +} + +INSTANTIATE_LAYER_GPU_FUNCS(CudnnNdPoolingLayer); + +} // namespace caffe +#endif diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index aa299f8660b..d0a8561e571 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -476,13 +476,16 @@ message ConvolutionParameter { optional uint32 pad = 3 [default = 0]; // The padding size (equal in Y, X) optional uint32 pad_h = 9 [default = 0]; // The padding height optional uint32 pad_w = 10 [default = 0]; // The padding width + optional BlobShape pad_shape = 15; optional uint32 kernel_size = 4; // The kernel size (square) optional uint32 kernel_h = 11; // The kernel height optional uint32 kernel_w = 12; // The kernel width + optional BlobShape kernel_shape = 16; optional uint32 group = 5 [default = 1]; // The group size for group conv optional uint32 stride = 6 [default = 1]; // The stride (equal in Y, X) optional uint32 stride_h = 13; // The stride height optional uint32 stride_w = 14; // The stride width + optional BlobShape stride_shape = 17; optional FillerParameter weight_filler = 7; // The filler for the weight optional FillerParameter bias_filler = 8; // The filler for the bias enum Engine { @@ -490,7 +493,7 @@ message ConvolutionParameter { CAFFE = 1; CUDNN = 2; } - optional Engine engine = 15 [default = DEFAULT]; + optional Engine engine = 18 [default = DEFAULT]; } message DataParameter { @@ -729,21 +732,24 @@ message PoolingParameter { optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X) optional uint32 pad_h = 9 [default = 0]; // The padding height optional uint32 pad_w = 10 [default = 0]; // The padding width + optional BlobShape pad_shape = 13; optional uint32 kernel_size = 2; // The kernel size (square) optional uint32 kernel_h = 5; // The kernel height optional uint32 kernel_w = 6; // The kernel width + optional BlobShape kernel_shape = 11; optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X) optional uint32 stride_h = 7; // The stride height optional uint32 stride_w = 8; // The stride width + optional BlobShape stride_shape = 12; enum Engine { DEFAULT = 0; CAFFE = 1; CUDNN = 2; } - optional Engine engine = 11 [default = DEFAULT]; + optional Engine engine = 14 [default = DEFAULT]; // If global_pooling then it will pool over the size of the bottom by doing // kernel_h = bottom->height and kernel_w = bottom->width - optional bool global_pooling = 12 [default = false]; + optional bool global_pooling = 15 [default = false]; } message PowerParameter {