Skip to content

Commit 189b08d

Browse files
Make infer shape of pad2d support for input with negative dims in compile time. (#18695)
test=develop
1 parent c457a69 commit 189b08d

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

paddle/fluid/operators/pad2d_op.cc

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -495,13 +495,21 @@ class Pad2dOp : public framework::OperatorWithKernel {
495495
PADDLE_ENFORCE_EQ(paddings.size(), 4,
496496
"Size of paddings should be equal to 4.");
497497
if (data_format == "NCHW") {
498-
out_dims[1] = x_dim[1];
499-
out_dims[2] = x_dim[2] + paddings[0] + paddings[1]; // height
500-
out_dims[3] = x_dim[3] + paddings[2] + paddings[3]; // width
501-
} else { // NHWC
502-
out_dims[3] = x_dim[3];
503-
out_dims[1] = x_dim[1] + paddings[0] + paddings[1];
504-
out_dims[2] = x_dim[2] + paddings[2] + paddings[3];
498+
out_dims[1] = x_dim[1]; // channel
499+
out_dims[2] = ((!ctx->IsRuntime()) && (x_dim[2] < 0))
500+
? x_dim[2]
501+
: (x_dim[2] + paddings[0] + paddings[1]); // height
502+
out_dims[3] = ((!ctx->IsRuntime()) && (x_dim[3] < 0))
503+
? x_dim[3]
504+
: (x_dim[3] + paddings[2] + paddings[3]); // width
505+
} else { // NHWC
506+
out_dims[3] = x_dim[3]; // channel
507+
out_dims[1] = ((!ctx->IsRuntime()) && (x_dim[1] < 0))
508+
? x_dim[1]
509+
: (x_dim[1] + paddings[0] + paddings[1]); // height
510+
out_dims[2] = ((!ctx->IsRuntime()) && (x_dim[2] < 0))
511+
? x_dim[2]
512+
: (x_dim[2] + paddings[2] + paddings[3]); // width
505513
}
506514
}
507515

0 commit comments

Comments
 (0)