Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions paddle/fluid/operators/squeeze_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
"tensor's rank.");
}

auto out_dims = GetOutputShape(axes, x_dims);
auto out_dims = GetOutputShape(axes, x_dims, false);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
Expand All @@ -50,7 +50,8 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
}

static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
const framework::DDim &in_dims) {
const framework::DDim &in_dims,
bool is_runtime) {
size_t num_squeeze_dims = squeeze_dims.size();
int cnt_squeezed_dims = 0;
bool should_squeeze[9] = {false};
Expand All @@ -71,9 +72,12 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
// Check current index, the upper limit has beed checked in line 36.
PADDLE_ENFORCE(current >= 0,
"Invalid axis, the negative axis is out of range.");
PADDLE_ENFORCE(in_dims[current] == 1,
"Invalid axis index, the axis that will be squeezed "
"should be equal to 1.");

if (is_runtime) {
PADDLE_ENFORCE(in_dims[current] == 1,
"Invalid axis index, the axis that will be squeezed "
"should be equal to 1.");
}

if (!(should_squeeze[current])) {
++cnt_squeezed_dims;
Expand Down Expand Up @@ -104,7 +108,7 @@ class SqueezeOp : public framework::OperatorBase {
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims);
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims, true);

framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
Expand Down Expand Up @@ -224,7 +228,7 @@ class Squeeze2Op : public framework::OperatorBase {
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims);
auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims, true);

framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
Expand Down