Skip to content
1 change: 1 addition & 0 deletions paddle/fluid/operators/math/unpooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Unpool2dMaxFunctor<platform::CPUDeviceContext, T> {
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];

PADDLE_ENFORCE_LT(
index, output_feasize,
platform::errors::InvalidArgument(
Expand Down
47 changes: 13 additions & 34 deletions paddle/fluid/operators/math/unpooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,48 +25,27 @@ __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data,
const int channels, T* output_data,
const int output_height,
const int output_width) {
int in_n_stride = input_height * input_width * channels;
int in_c_stride = input_height * input_width;
int out_n_stride = output_height * output_width * channels;
int out_c_stride = output_height * output_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int bidx = i / in_n_stride;
int boffset = i % in_n_stride;
int cidx = boffset / in_c_stride;
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
int out_index = indices_data[i];
PADDLE_ENFORCE(out_index < out_c_stride,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥去掉这块的enforce呢?建议再check下数据数据检查

"out_index < out_c_stride. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
out_index, out_c_stride, out_index, out_c_stride);
output_data[out_offset + out_index] = input_data[i];
CUDA_KERNEL_LOOP(linearIndex, nthreads) {
int c = (linearIndex / input_width / input_height) % channels;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

异常数据注意enforce,否则安全扫描可能有问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx,下一个PR中添加异常数据检查

int n = linearIndex / input_width / input_height / channels;
output_data += (n * channels + c) * output_height * output_width;
int maxind = indices_data[linearIndex];
output_data[maxind] = input_data[linearIndex];
}
}

template <typename T>
__global__ void KernelUnpool2dMaxGrad(
const int nthreads, const T* input_data, const int* indices_data,
const int input_height, const int input_width, const int channels,
const T* output_data, const T* output_grad, const int output_height,
const int output_width, T* input_grad) {
int in_n_stride = input_height * input_width * channels;
int in_c_stride = input_height * input_width;
int out_n_stride = output_height * output_width * channels;
int out_c_stride = output_height * output_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int bidx = i / in_n_stride;
int boffset = i % in_n_stride;
int cidx = boffset / in_c_stride;
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
int out_index = indices_data[i];
PADDLE_ENFORCE(out_index < out_c_stride,
"out_index < out_c_stride. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
out_index, out_c_stride, out_index, out_c_stride);
input_grad[i] = output_grad[out_offset + out_index];
CUDA_KERNEL_LOOP(linearIndex, nthreads) {
int c = (linearIndex / input_width / input_height) % channels;
int n = linearIndex / input_width / input_height / channels;
output_grad += (n * channels + c) * output_height * output_width;
int maxind = indices_data[linearIndex];
input_grad[linearIndex] = output_grad[maxind];
}
}
/*
Expand Down
23 changes: 17 additions & 6 deletions paddle/fluid/operators/unpool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"unpooling_type",
"(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"});
AddAttr<std::vector<int>>("output_size",
"(vector, optional). The shape of output.")
.SetDefault({0, 0});
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("NCHW");
AddComment(R"DOC(
Input shape is: $(N, C_{in}, H_{in}, W_{in})$, Output shape is:
$(N, C_{out}, H_{out}, W_{out})$, where
Expand Down Expand Up @@ -93,6 +103,8 @@ class UnpoolOp : public framework::OperatorWithKernel {
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> output_size =
ctx->Attrs().Get<std::vector<int>>("output_size");
PADDLE_ENFORCE_EQ(in_x_dims.size() == 4, true,
platform::errors::InvalidArgument(
"Unpool Intput(X) must be of 4-dimensional, but "
Expand All @@ -111,8 +123,7 @@ class UnpoolOp : public framework::OperatorWithKernel {
if (!ctx->IsRuntime() && in_x_dims[i + 2] <= 0) {
output_shape.push_back(-1);
} else {
output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i],
paddings[i], strides[i]));
output_shape.push_back(output_size[i]);
}
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
Expand Down Expand Up @@ -156,15 +167,15 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker,
REGISTER_OPERATOR(unpool2d, ops::UnpoolOp, ops::Unpool2dOpMaker,
ops::UnpoolOpGradMaker<paddle::framework::OpDesc>,
ops::UnpoolOpGradMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(unpool_grad, ops::UnpoolOpGrad);
REGISTER_OPERATOR(unpool2d_grad, ops::UnpoolOpGrad);
REGISTER_OP_CPU_KERNEL(
unpool, ops::UnpoolKernel<paddle::platform::CPUDeviceContext, float>,
unpool2d, ops::UnpoolKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnpoolKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
unpool_grad,
unpool2d_grad,
ops::UnpoolGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnpoolGradKernel<paddle::platform::CPUDeviceContext, double>);
4 changes: 2 additions & 2 deletions paddle/fluid/operators/unpool_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ limitations under the License. */

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
unpool, ops::UnpoolKernel<paddle::platform::CUDADeviceContext, float>,
unpool2d, ops::UnpoolKernel<paddle::platform::CUDADeviceContext, float>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为改名,需要确认下是否之前有API使用了unpool

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在全景图里做了搜索,确认没有api使用过unpool

ops::UnpoolKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
unpool_grad,
unpool2d_grad,
ops::UnpoolGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnpoolGradKernel<paddle::platform::CUDADeviceContext, double>);
Loading