Skip to content

Commit a9f4871

Browse files
authored
Merge pull request #2 from qingqing01/pixel_shuffle
Fix unit testing
2 parents a7f2567 + 8931630 commit a9f4871

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

paddle/fluid/operators/pixel_shuffle_op.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@ REGISTER_OP_CUDA_KERNEL(
2323
ops::PixelShuffleOpCUDAKernel<plat::CUDADeviceContext, double>);
2424
REGISTER_OP_CUDA_KERNEL(
2525
pixel_shuffle_grad,
26-
ops::PixelShuffleGradOpCUDAKernel<plat::CUDADeviceContext, float>,
27-
ops::PixelShuffleGradOpCUDAKernel<plat::CUDADeviceContext, double>);
26+
ops::PixelShuffleGradOpKernel<plat::CUDADeviceContext, float>,
27+
ops::PixelShuffleGradOpKernel<plat::CUDADeviceContext, double>);

paddle/fluid/operators/pixel_shuffle_op.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class PixelShuffleOpKernel : public framework::OpKernel<T> {
2929
int factor = ctx.Attr<int>("upscale_factor");
3030

3131
auto in_dims = in->dims();
32-
auto o_dims = in->dims();
32+
auto o_dims = out->dims();
3333

3434
framework::Tensor t;
3535
t.ShareDataWith(*in);
@@ -44,6 +44,7 @@ class PixelShuffleOpKernel : public framework::OpKernel<T> {
4444
math::Transpose<DeviceContext, T, 6> trans;
4545
auto& dev_ctx = ctx.template device_context<DeviceContext>();
4646
trans(dev_ctx, t, &o, axis);
47+
out->Resize(o_dims);
4748
}
4849
};
4950

@@ -58,7 +59,7 @@ class PixelShuffleGradOpKernel : public framework::OpKernel<T> {
5859
int factor = ctx.Attr<int>("upscale_factor");
5960

6061
auto do_dims = dout->dims();
61-
auto dx_dims = dout->dims();
62+
auto dx_dims = dx->dims();
6263

6364
framework::Tensor t;
6465
t.ShareDataWith(*dout);
@@ -73,6 +74,7 @@ class PixelShuffleGradOpKernel : public framework::OpKernel<T> {
7374
math::Transpose<DeviceContext, T, 6> trans;
7475
auto& dev_ctx = ctx.template device_context<DeviceContext>();
7576
trans(dev_ctx, t, &o, axis);
77+
dx->Resize(dx_dims);
7678
}
7779
};
7880

python/paddle/fluid/tests/unittests/test_pixel_shuffle.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,19 @@ def setUp(self):
2525
n, c, h, w = 2, 9, 4, 4
2626
up_factor = 3
2727
shape = [n, c, h, w]
28-
x = np.random.random(self.shape).astype("float32")
28+
x = np.random.random(shape).astype("float32")
2929

3030
new_shape = (n, c / (up_factor * up_factor), up_factor, up_factor, h, w)
3131
# reshape to (num,output_channel,upscale_factor,upscale_factor,h,w)
3232
npresult = np.reshape(x, new_shape)
3333
# transpose to (num,output_channel,h,upscale_factor,w,upscale_factor)
3434
npresult = npresult.transpose(0, 1, 4, 2, 5, 3)
35-
npresult = np.reshape(npresult, (1, 1, 12, 12))
35+
oshape = [n, c / (up_factor * up_factor), h * up_factor, w * up_factor]
36+
npresult = np.reshape(npresult, oshape)
3637

3738
self.inputs = {'X': x}
3839
self.outputs = {'Out': npresult}
40+
self.attrs = {'upscale_factor': up_factor}
3941

4042
def test_check_output(self):
4143
self.check_output()

0 commit comments

Comments
 (0)