|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "paddle/operators/math/pooling.h" |
| 16 | + |
| 17 | +namespace paddle { |
| 18 | +namespace operators { |
| 19 | +namespace math { |
| 20 | + |
| 21 | +template <typename T> |
| 22 | +class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> { |
| 23 | + public: |
| 24 | + void operator()(const platform::DeviceContext& context, |
| 25 | + const framework::Tensor& input, framework::Tensor& output, |
| 26 | + framework::Tensor& mask, std::vector<int>& ksize, |
| 27 | + std::vector<int>& strides, std::vector<int>& paddings) { |
| 28 | + const int batch_size = input.dims()[0]; |
| 29 | + |
| 30 | + const int input_height = input.dims()[2]; |
| 31 | + const int input_width = input.dims()[3]; |
| 32 | + const int output_channels = output.dims()[1]; |
| 33 | + const int output_height = output.dims()[2]; |
| 34 | + const int output_width = output.dims()[3]; |
| 35 | + const int ksize_height = ksize[0]; |
| 36 | + const int ksize_width = ksize[1]; |
| 37 | + const int stride_height = strides[0]; |
| 38 | + const int stride_width = strides[1]; |
| 39 | + const int padding_height = paddings[0]; |
| 40 | + const int padding_width = paddings[1]; |
| 41 | + |
| 42 | + const int input_stride = input_height * input_width; |
| 43 | + const int output_stride = output_height * output_width; |
| 44 | + |
| 45 | + const T* input_data = input.data<T>(); |
| 46 | + T* output_data = output.mutable_data<T>(context.GetPlace()); |
| 47 | + |
| 48 | + T* mask_data = mask.mutable_data<T>(context.GetPlace()); |
| 49 | + |
| 50 | + for (int i = 0; i < batch_size; i++) { |
| 51 | + for (int c = 0; c < output_channels; ++c) { |
| 52 | + for (int ph = 0; ph < output_height; ++ph) { |
| 53 | + int hstart = ph * stride_height - padding_height; |
| 54 | + int hend = std::min(hstart + ksize_height, input_height); |
| 55 | + hstart = std::max(hstart, 0); |
| 56 | + for (int pw = 0; pw < output_width; ++pw) { |
| 57 | + int wstart = pw * stride_width - padding_width; |
| 58 | + int wend = std::min(wstart + ksize_width, input_width); |
| 59 | + wstart = std::max(wstart, 0); |
| 60 | + |
| 61 | + T ele = static_cast<T>(-FLT_MAX); |
| 62 | + int index = -1; |
| 63 | + for (int h = hstart; h < hend; ++h) { |
| 64 | + for (int w = wstart; w < wend; ++w) { |
| 65 | + if (ele < input_data[h * input_width + w]) { |
| 66 | + ele = input_data[h * input_width + w]; |
| 67 | + index = h * input_width + w; |
| 68 | + } |
| 69 | + } |
| 70 | + } |
| 71 | + output_data[ph * output_width + pw] = ele; |
| 72 | + mask_data[ph * output_width + pw] = index; |
| 73 | + } |
| 74 | + } |
| 75 | + // offset |
| 76 | + input_data += input_stride; |
| 77 | + output_data += output_stride; |
| 78 | + mask_data += output_stride; |
| 79 | + } |
| 80 | + } |
| 81 | + } |
| 82 | +}; |
| 83 | + |
| 84 | +template <typename T> |
| 85 | +class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> { |
| 86 | + public: |
| 87 | + void operator()(const platform::DeviceContext& context, |
| 88 | + framework::Tensor& input_grad, |
| 89 | + const framework::Tensor& output_grad, |
| 90 | + const framework::Tensor& mask, std::vector<int>& ksize, |
| 91 | + std::vector<int>& strides, std::vector<int>& paddings) { |
| 92 | + const int batch_size = input_grad.dims()[0]; |
| 93 | + const int input_height = input_grad.dims()[2]; |
| 94 | + const int input_width = input_grad.dims()[3]; |
| 95 | + const int output_channels = output_grad.dims()[1]; |
| 96 | + const int output_height = output_grad.dims()[2]; |
| 97 | + const int output_width = output_grad.dims()[3]; |
| 98 | + const int input_stride = input_height * input_width; |
| 99 | + const int output_stride = output_height * output_width; |
| 100 | + |
| 101 | + const T* mask_data = mask.data<T>(); |
| 102 | + const T* output_grad_data = output_grad.data<T>(); |
| 103 | + T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); |
| 104 | + |
| 105 | + for (size_t n = 0; n < batch_size; ++n) { |
| 106 | + for (size_t c = 0; c < output_channels; ++c) { |
| 107 | + for (size_t ph = 0; ph < output_height; ++ph) { |
| 108 | + for (size_t pw = 0; pw < output_width; ++pw) { |
| 109 | + const size_t output_idx = ph * output_width + pw; |
| 110 | + const size_t input_idx = static_cast<size_t>(mask_data[output_idx]); |
| 111 | + |
| 112 | + input_grad_data[input_idx] += output_grad_data[output_idx]; |
| 113 | + } |
| 114 | + } |
| 115 | + } |
| 116 | + // offset |
| 117 | + input_grad_data += input_stride; |
| 118 | + output_grad_data += output_stride; |
| 119 | + mask_data += output_stride; |
| 120 | + } |
| 121 | + } |
| 122 | +}; |
| 123 | + |
| 124 | +template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float>; |
| 125 | +template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float>; |
| 126 | +template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double>; |
| 127 | +template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double>; |
| 128 | + |
| 129 | +template <typename T> |
| 130 | +class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> { |
| 131 | + public: |
| 132 | + void operator()(const platform::DeviceContext& context, |
| 133 | + const framework::Tensor& input, framework::Tensor& output, |
| 134 | + framework::Tensor& mask, std::vector<int>& ksize, |
| 135 | + std::vector<int>& strides, std::vector<int>& paddings) { |
| 136 | + const int batch_size = input.dims()[0]; |
| 137 | + const int input_depth = input.dims()[2]; |
| 138 | + const int input_height = input.dims()[3]; |
| 139 | + const int input_width = input.dims()[4]; |
| 140 | + const int output_channels = output.dims()[1]; |
| 141 | + const int output_depth = output.dims()[2]; |
| 142 | + const int output_height = output.dims()[3]; |
| 143 | + const int output_width = output.dims()[4]; |
| 144 | + const int ksize_depth = ksize[0]; |
| 145 | + const int ksize_height = ksize[1]; |
| 146 | + const int ksize_width = ksize[2]; |
| 147 | + const int stride_depth = strides[0]; |
| 148 | + const int stride_height = strides[1]; |
| 149 | + const int stride_width = strides[2]; |
| 150 | + const int padding_depth = paddings[0]; |
| 151 | + const int padding_height = paddings[1]; |
| 152 | + const int padding_width = paddings[2]; |
| 153 | + const int input_stride = input_depth * input_height * input_width; |
| 154 | + const int output_stride = output_depth * output_height * output_width; |
| 155 | + const T* input_data = input.data<T>(); |
| 156 | + T* output_data = output.mutable_data<T>(context.GetPlace()); |
| 157 | + T* mask_data = mask.mutable_data<T>(context.GetPlace()); |
| 158 | + |
| 159 | + for (int i = 0; i < batch_size; i++) { |
| 160 | + for (int c = 0; c < output_channels; ++c) { |
| 161 | + for (int pd = 0; pd < output_depth; ++pd) { |
| 162 | + int dstart = pd * stride_depth - padding_depth; |
| 163 | + int dend = std::min(dstart + ksize_depth, input_depth); |
| 164 | + dstart = std::max(dstart, 0); |
| 165 | + for (int ph = 0; ph < output_height; ++ph) { |
| 166 | + int hstart = ph * stride_height - padding_height; |
| 167 | + int hend = std::min(hstart + ksize_height, input_height); |
| 168 | + hstart = std::max(hstart, 0); |
| 169 | + for (int pw = 0; pw < output_width; ++pw) { |
| 170 | + int wstart = pw * stride_width - padding_width; |
| 171 | + int wend = std::min(wstart + ksize_width, input_width); |
| 172 | + wstart = std::max(wstart, 0); |
| 173 | + int output_idx = (pd * output_height + ph) * output_width + pw; |
| 174 | + T ele = static_cast<T>(-FLT_MAX); |
| 175 | + int index = -1; |
| 176 | + for (int d = dstart; d < dend; ++d) { |
| 177 | + for (int h = hstart; h < hend; ++h) { |
| 178 | + for (int w = wstart; w < wend; ++w) { |
| 179 | + if (ele < |
| 180 | + input_data[(d * input_height + h) * input_width + w]) { |
| 181 | + index = (d * input_height + h) * input_width + w; |
| 182 | + ele = |
| 183 | + input_data[(d * input_height + h) * input_width + w]; |
| 184 | + } |
| 185 | + } |
| 186 | + } |
| 187 | + } |
| 188 | + output_data[output_idx] = ele; |
| 189 | + mask_data[output_idx] = index; |
| 190 | + } |
| 191 | + } |
| 192 | + } |
| 193 | + // offset |
| 194 | + input_data += input_stride; |
| 195 | + output_data += output_stride; |
| 196 | + mask_data += output_stride; |
| 197 | + } |
| 198 | + } |
| 199 | + } |
| 200 | +}; |
| 201 | + |
| 202 | +template <typename T> |
| 203 | +class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> { |
| 204 | + public: |
| 205 | + void operator()(const platform::DeviceContext& context, |
| 206 | + framework::Tensor& input_grad, |
| 207 | + const framework::Tensor& output_grad, |
| 208 | + const framework::Tensor& mask, std::vector<int>& ksize, |
| 209 | + std::vector<int>& strides, std::vector<int>& paddings) { |
| 210 | + const int batch_size = input_grad.dims()[0]; |
| 211 | + const int input_depth = input_grad.dims()[2]; |
| 212 | + const int input_height = input_grad.dims()[3]; |
| 213 | + const int input_width = input_grad.dims()[4]; |
| 214 | + const int output_channels = output_grad.dims()[1]; |
| 215 | + const int output_depth = output_grad.dims()[2]; |
| 216 | + const int output_height = output_grad.dims()[3]; |
| 217 | + const int output_width = output_grad.dims()[4]; |
| 218 | + const int input_stride = input_depth * input_height * input_width; |
| 219 | + const int output_stride = output_depth * output_height * output_width; |
| 220 | + |
| 221 | + const T* mask_data = mask.data<T>(); |
| 222 | + const T* output_grad_data = output_grad.data<T>(); |
| 223 | + T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); |
| 224 | + |
| 225 | + for (size_t n = 0; n < batch_size; ++n) { |
| 226 | + for (size_t c = 0; c < output_channels; ++c) { |
| 227 | + for (size_t pd = 0; pd < output_depth; ++pd) { |
| 228 | + for (size_t ph = 0; ph < output_height; ++ph) { |
| 229 | + for (size_t pw = 0; pw < output_width; ++pw) { |
| 230 | + const size_t output_idx = |
| 231 | + (pd * output_height + ph) * output_width + pw; |
| 232 | + const size_t input_idx = |
| 233 | + static_cast<size_t>(mask_data[output_idx]); |
| 234 | + |
| 235 | + input_grad_data[input_idx] += output_grad_data[output_idx]; |
| 236 | + } |
| 237 | + } |
| 238 | + } |
| 239 | + // offset |
| 240 | + input_grad_data += input_stride; |
| 241 | + output_grad_data += output_stride; |
| 242 | + mask_data += output_stride; |
| 243 | + } |
| 244 | + } |
| 245 | + } |
| 246 | +}; |
| 247 | + |
| 248 | +template class MaxPool3dWithIndexFunctor<platform::CPUPlace, float>; |
| 249 | +template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, float>; |
| 250 | +template class MaxPool3dWithIndexFunctor<platform::CPUPlace, double>; |
| 251 | +template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, double>; |
| 252 | + |
| 253 | +} // namespace math |
| 254 | +} // namespace operators |
| 255 | +} // namespace paddle |
0 commit comments