@@ -28,40 +28,59 @@ class PixelShuffleOp : public framework::OperatorWithKernel {
2828 " Output(Out) of PixelShuffleOp should not be null." ));
2929
3030 auto input_dims = ctx->GetInputDim (" X" );
31- PADDLE_ENFORCE_EQ (
32- input_dims. size (), 4 ,
33- platform::errors::InvalidArgument (
34- " Input should be a 4-D tensor of format [N, C, H, W], but got %u." ,
35- input_dims.size ()));
31+ PADDLE_ENFORCE_EQ (input_dims. size (), 4 ,
32+ platform::errors::InvalidArgument (
33+ " Input should be a 4-D tensor of format [N, C, H, W] "
34+ " or [N, H, W, C ], but got %u." ,
35+ input_dims.size ()));
3636
3737 auto upscale_factor = ctx->Attrs ().Get <int >(" upscale_factor" );
3838
39- PADDLE_ENFORCE_EQ (input_dims[1 ] % (upscale_factor * upscale_factor), 0 ,
40- platform::errors::InvalidArgument (
41- " The square of upscale_factor[%u] should divide the "
42- " number of channel[%u]" ,
43- input_dims[1 ], upscale_factor * upscale_factor));
44-
39+ const std::string data_format =
40+ ctx->Attrs ().Get <std::string>(" data_format" );
41+ const bool channel_last = (data_format == " NHWC" );
42+
43+ if (!channel_last) {
44+ PADDLE_ENFORCE_EQ (
45+ input_dims[1 ] % (upscale_factor * upscale_factor), 0 ,
46+ platform::errors::InvalidArgument (
47+ " The square of upscale_factor[%u] should divide the "
48+ " number of channel[%u]" ,
49+ input_dims[1 ], upscale_factor * upscale_factor));
50+ } else {
51+ PADDLE_ENFORCE_EQ (
52+ input_dims[3 ] % (upscale_factor * upscale_factor), 0 ,
53+ platform::errors::InvalidArgument (
54+ " The square of upscale_factor[%u] should divide the "
55+ " number of channel[%u]" ,
56+ input_dims[3 ], upscale_factor * upscale_factor));
57+ }
4558 auto output_dims = input_dims;
4659 output_dims[0 ] = input_dims[0 ];
47- output_dims[1 ] = input_dims[1 ] / (upscale_factor * upscale_factor);
48- output_dims[2 ] = input_dims[2 ] * upscale_factor;
49- output_dims[3 ] = input_dims[3 ] * upscale_factor;
60+ if (!channel_last) {
61+ output_dims[1 ] = input_dims[1 ] / (upscale_factor * upscale_factor);
62+ output_dims[2 ] = input_dims[2 ] * upscale_factor;
63+ output_dims[3 ] = input_dims[3 ] * upscale_factor;
64+ } else {
65+ output_dims[1 ] = input_dims[1 ] * upscale_factor;
66+ output_dims[2 ] = input_dims[2 ] * upscale_factor;
67+ output_dims[3 ] = input_dims[3 ] / (upscale_factor * upscale_factor);
68+ }
5069 ctx->SetOutputDim (" Out" , output_dims);
5170 }
5271};
5372
5473class PixelShuffleOpMaker : public framework ::OpProtoAndCheckerMaker {
5574 public:
5675 void Make () override {
57- AddInput (
58- " X " ,
59- " (Tensor, default Tensor<float>) , "
60- " the input feature data of PixelShuffleOp, the layout is [N C H W ]." );
61- AddOutput (
62- " Out " ,
63- " (Tensor, default Tensor<float>), the output of "
64- " PixelShuffleOp. The layout is [N,C/factor^2, H*factor,W*factor]." );
76+ AddInput (" X " ,
77+ " (Tensor, default Tensor<float>), "
78+ " the input feature data of PixelShuffleOp, the layout is [N, C , "
79+ " H, W] or [N, H, W, C ]." );
80+ AddOutput (" Out " ,
81+ " (Tensor, default Tensor<float>), the output of "
82+ " PixelShuffleOp. The layout is [N, C/factor^2, H*factor, "
83+ " W*factor] or [N, H*factor, W*factor, C/factor^2 ]." );
6584 AddAttr<int >(" upscale_factor" ,
6685 " the factor to increase spatial resolution by." )
6786 .SetDefault (1 )
@@ -70,6 +89,11 @@ class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
7089 platform::errors::InvalidArgument (
7190 " upscale_factor should be larger than 0." ));
7291 });
92+ AddAttr<std::string>(
93+ " data_format" ,
94+ " An optional string from: \" NHWC\" , \" NCHW\" . "
95+ " Defaults to \" NHWC\" , Specify the data format of the input data." )
96+ .SetDefault (" NCHW" );
7397
7498 AddComment (R"DOC(
7599 Pixel Shuffle operator
@@ -114,19 +138,30 @@ class PixelShuffleGradOp : public framework::OperatorWithKernel {
114138 platform::errors::NotFound (" Output(X@Grad) should not be null" ));
115139
116140 auto do_dims = ctx->GetInputDim (framework::GradVarName (" Out" ));
117- PADDLE_ENFORCE_EQ (
118- do_dims. size (), 4 ,
119- platform::errors::InvalidArgument (
120- " Input should be a 4-D tensor of format [N, C, H, W], but got %u." ,
121- do_dims.size ()));
141+ PADDLE_ENFORCE_EQ (do_dims. size (), 4 ,
142+ platform::errors::InvalidArgument (
143+ " Input should be a 4-D tensor of format [N, C, H, W] "
144+ " or [N, H, W, C ], but got %u." ,
145+ do_dims.size ()));
122146
123147 auto upscale_factor = ctx->Attrs ().Get <int >(" upscale_factor" );
124148
149+ const std::string data_format =
150+ ctx->Attrs ().Get <std::string>(" data_format" );
151+ const bool channel_last = (data_format == " NHWC" );
152+
125153 auto dx_dims = do_dims;
126154 dx_dims[0 ] = do_dims[0 ];
127- dx_dims[1 ] = do_dims[1 ] * (upscale_factor * upscale_factor);
128- dx_dims[2 ] = do_dims[2 ] / upscale_factor;
129- dx_dims[3 ] = do_dims[3 ] / upscale_factor;
155+
156+ if (!channel_last) {
157+ dx_dims[1 ] = do_dims[1 ] * (upscale_factor * upscale_factor);
158+ dx_dims[2 ] = do_dims[2 ] / upscale_factor;
159+ dx_dims[3 ] = do_dims[3 ] / upscale_factor;
160+ } else {
161+ dx_dims[1 ] = do_dims[1 ] / upscale_factor;
162+ dx_dims[2 ] = do_dims[2 ] / upscale_factor;
163+ dx_dims[3 ] = do_dims[3 ] * (upscale_factor * upscale_factor);
164+ }
130165 ctx->SetOutputDim (framework::GradVarName (" X" ), dx_dims);
131166 }
132167};
0 commit comments