Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bc9a3b4
add pixel_shuffle op
shippingwang Mar 28, 2019
893aebd
add pixel_shuffle op, test=develop
shippingwang Mar 28, 2019
803818f
fix conflict, test=develop
shippingwang Mar 28, 2019
13b1171
rewrite code, test=develop
shippingwang Mar 28, 2019
401c0eb
delete useless comment, test=develop
shippingwang Mar 28, 2019
63fb779
Refine pixel_shuffle_op and unit testing
qingqing01 Mar 30, 2019
addcfbb
Merge pull request #1 from qingqing01/pixel_shuffle
qingqing01 Mar 30, 2019
21507bb
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
shippingwang Mar 31, 2019
e616149
Merge branch 'pixel_shuffle' of https://github.com/shippingwang/Paddl…
shippingwang Mar 31, 2019
a7f2567
refine code,test=develop
shippingwang Mar 31, 2019
cde9ac0
refine .cu,test=develop
shippingwang Mar 31, 2019
6568772
fix unittest,test=develop
shippingwang Apr 1, 2019
e5d7dea
Fix unit testing
qingqing01 Apr 1, 2019
8931630
Merge branch 'pixel_shuffle' into pixel_shuffle
qingqing01 Apr 1, 2019
a9f4871
Merge pull request #2 from qingqing01/pixel_shuffle
qingqing01 Apr 1, 2019
984113d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
shippingwang Apr 1, 2019
6beb60f
Merge branch 'pixel_shuffle' of https://github.com/shippingwang/Paddl…
shippingwang Apr 1, 2019
7addcdb
resolve conflict, test=develop
shippingwang Apr 1, 2019
909e604
fix test, test=develop
shippingwang Apr 1, 2019
186d133
fix API, test=develop
shippingwang Apr 1, 2019
43da98f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
shippingwang Apr 2, 2019
5572262
fix test datatype bug,test=develop
shippingwang Apr 2, 2019
bb21572
polish comments,test=develop
shippingwang Apr 2, 2019
e5b42d5
test=develop
shippingwang Apr 2, 2019
9775492
add API,test=develop
shippingwang Apr 2, 2019
e9fd635
test=develop
shippingwang Apr 2, 2019
f403189
Add Pixel_Shuffle OP,test=develop
shippingwang Apr 2, 2019
288b1ea
support python3,test=develop
shippingwang Apr 2, 2019
cddc827
add include memory to travis CI bug,test=develop
shippingwang Apr 3, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/API.spec
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], vararg
paddle.fluid.layers.kldiv_loss (ArgSpec(args=['x', 'target', 'reduction', 'name'], varargs=None, keywords=None, defaults=('mean', None)), ('document', '776d536cac47c89073abc7ee524d5aec'))
paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None)), ('document', '34ea12ac9f10a65dccbc50100d12e607'))
paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329'))
paddle.fluid.layers.pixel_shuffle (ArgSpec(args=['x', 'upscale_factor'], varargs=None, keywords=None, defaults=None), ('document', 'ad669cdf83e72a69ebc5ed79e36486de'))
paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '33bbd42027d872b3818b3d64ec52e139'))
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'b1ae2e1cc0750e58726374061ea90ecc'))
Expand Down
134 changes: 134 additions & 0 deletions paddle/fluid/operators/pixel_shuffle_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/pixel_shuffle_op.h"

namespace paddle {
namespace operators {

class PixelShuffleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of PixelShuffleOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of PixelShuffleOp should not be null.");

auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");

PADDLE_ENFORCE(input_dims[1] % (upscale_factor * upscale_factor) == 0,
"Upscale_factor should devide the number of channel");

auto output_dims = input_dims;
output_dims[0] = input_dims[0];
output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor);
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] * upscale_factor;
ctx->SetOutputDim("Out", output_dims);
}
};

class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
"(Tensor, default Tensor<float>), "
"the input feature data of PixelShuffleOp, the layout is [N C H W].");
AddOutput(
"Out",
"(Tensor, default Tensor<float>), the output of "
"PixelShuffleOp. The layout is [N,C/factor^2,H*factor,W*factor].");
AddAttr<int>("upscale_factor",
"the factor to increase spatial resolution by.")
.SetDefault(1)
.AddCustomChecker([](const int& upscale_factor) {
PADDLE_ENFORCE_GE(upscale_factor, 1,
"upscale_factor should be larger than 0.");
});

AddComment(R"DOC(
Pixel Shuffle operator
This operator rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
to a tensor of shape :math:`(C, H \times r, W \times r)`.

This is useful for implementing efficient sub-pixel convolution
with a stride of :math:`1/r`.

Please refer to the paper:
`Real-Time Single Image and Video Super-Resolution Using an Efficient
Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158v2>`_
by Shi et. al (2016) for more details.

)DOC");
}
};

class PixelShuffleGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("pixel_shuffle_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return std::unique_ptr<framework::OpDesc>(op);
}
};

class PixelShuffleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@Grad) should not be null");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@Grad) should not be null");

auto do_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE(do_dims.size() == 4, "The layout of input is NCHW.");

auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");

auto dx_dims = do_dims;
dx_dims[0] = do_dims[0];
dx_dims[1] = do_dims[1] * (upscale_factor * upscale_factor);
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] / upscale_factor;
ctx->SetOutputDim(framework::GradVarName("X"), dx_dims);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(pixel_shuffle, ops::PixelShuffleOp, ops::PixelShuffleOpMaker,
ops::PixelShuffleGradMaker);

REGISTER_OPERATOR(pixel_shuffle_grad, ops::PixelShuffleGradOp);

REGISTER_OP_CPU_KERNEL(
pixel_shuffle,
ops::PixelShuffleOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::PixelShuffleOpKernel<paddle::platform::CPUDeviceContext, double>);

REGISTER_OP_CPU_KERNEL(
pixel_shuffle_grad,
ops::PixelShuffleGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::PixelShuffleGradOpKernel<paddle::platform::CPUDeviceContext, double>);
26 changes: 26 additions & 0 deletions paddle/fluid/operators/pixel_shuffle_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/pixel_shuffle_op.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(
pixel_shuffle, ops::PixelShuffleOpKernel<plat::CUDADeviceContext, float>,
ops::PixelShuffleOpKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
pixel_shuffle_grad,
ops::PixelShuffleGradOpKernel<plat::CUDADeviceContext, float>,
ops::PixelShuffleGradOpKernel<plat::CUDADeviceContext, double>);
82 changes: 82 additions & 0 deletions paddle/fluid/operators/pixel_shuffle_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
class PixelShuffleOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());

int factor = ctx.Attr<int>("upscale_factor");

auto in_dims = in->dims();
auto o_dims = out->dims();

framework::Tensor t;
t.ShareDataWith(*in);
t.Resize({in_dims[0], o_dims[1], factor, factor, in_dims[2], in_dims[3]});

std::vector<int> axis = {0, 1, 4, 2, 5, 3};

framework::Tensor o;
o.ShareDataWith(*out);
o.Resize({in_dims[0], o_dims[1], in_dims[2], factor, in_dims[3], factor});

math::Transpose<DeviceContext, T, 6> trans;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
trans(dev_ctx, t, &o, axis);
out->Resize(o_dims);
}
};

template <typename DeviceContext, typename T>
class PixelShuffleGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());

int factor = ctx.Attr<int>("upscale_factor");

auto do_dims = dout->dims();
auto dx_dims = dx->dims();

framework::Tensor t;
t.ShareDataWith(*dout);
t.Resize({do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor});

std::vector<int> axis = {0, 1, 3, 5, 2, 4};

framework::Tensor o;
o.ShareDataWith(*dx);
o.Resize({do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]});

math::Transpose<DeviceContext, T, 6> trans;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
trans(dev_ctx, t, &o, axis);
dx->Resize(dx_dims);
}
};

} // namespace operators
} // namespace paddle
60 changes: 60 additions & 0 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@
'kldiv_loss',
'tree_conv',
'npair_loss',
'pixel_shuffle',
'fsp_matrix',
]

Expand Down Expand Up @@ -10923,6 +10924,65 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002):
return l2loss + celoss


def pixel_shuffle(x, upscale_factor):
"""

**Pixel Shuffle Layer**

This layer rearranges elements in a tensor of shape [N, C, H, W]
to a tensor of shape [N, C/r**2, H*r, W*r].
This is useful for implementing efficient sub-pixel convolution
with a stride of 1/r.
Please refer to the paper: `Real-Time Single Image and Video Super-Resolution
Using an Efficient Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158v2>`_ .
by Shi et. al (2016) for more details.

.. code-block:: text

Given a 4-D tensor with the shape:
x.shape = [1, 9, 4, 4]
Given upscale_factor:
upscale_factor= 3
output shape is:
[1, 1, 12, 12]

Args:

x(Variable): The input tensor variable.
upscale_factor(int): factor to increase spatial resolution

Returns:

Out(Variable): the pixel shuffle result is a tensor variable with the same shape and the same type as the input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the -> The

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK


Raises:

ValueError: If the square of upscale_factor cannot divide the channels of input.

Examples:

.. code-block:: python

input = fluid.layers.data(shape=[9,4,4])
output = fluid.layers.pixel_shuffle(x=input, upscale_factor=3)

"""

helper = LayerHelper("pixel_shuffle", **locals())

out = helper.create_variable_for_type_inference(dtype=x.dtype)

if not isinstance(upscale_factor, int):
raise TypeError("upscale factor must be int type")

helper.append_op(
type="pixel_shuffle",
inputs={"X": x},
outputs={"Out": out},
attrs={"upscale_factor": upscale_factor})
return out


def fsp_matrix(x, y):
"""

Expand Down
8 changes: 8 additions & 0 deletions python/paddle/fluid/tests/unittests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,14 @@ def test_shuffle_channel(self):
self.assertIsNotNone(out)
print(str(program))

def test_pixel_shuffle(self):
program = Program()
with program_guard(program):
x = layers.data(name="X", shape=[9, 4, 4], dtype="float32")
out = layers.pixel_shuffle(x, upscale_factor=3)
self.assertIsNotNone(out)
print(str(program))

def test_fsp(self):
program = Program()
with program_guard(program):
Expand Down
50 changes: 50 additions & 0 deletions python/paddle/fluid/tests/unittests/test_pixel_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
from op_test import OpTest


class TestPixelShuffle(OpTest):
def setUp(self):
self.op_type = "pixel_shuffle"
n, c, h, w = 2, 9, 4, 4
up_factor = 3
shape = [n, c, h, w]
x = np.random.random(shape).astype("float32")
new_shape = (n, c // (up_factor * up_factor), up_factor, up_factor, h,
w)
# reshape to (num,output_channel,upscale_factor,upscale_factor,h,w)
npresult = np.reshape(x, new_shape)
# transpose to (num,output_channel,h,upscale_factor,w,upscale_factor)
npresult = npresult.transpose(0, 1, 4, 2, 5, 3)
oshape = [n, c // (up_factor * up_factor), h * up_factor, w * up_factor]
npresult = np.reshape(npresult, oshape)

self.inputs = {'X': x}
self.outputs = {'Out': npresult}
self.attrs = {'upscale_factor': up_factor}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')


if __name__ == '__main__':
unittest.main()