Skip to content

Commit 0044880

Browse files
committed
fix masked_select infer shape
1 parent 481ee79 commit 0044880

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

paddle/fluid/operators/masked_select_op.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ class MaskedSelectOp : public framework::OperatorWithKernel {
2626
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "Input", "MaskedSelect");
2727
OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "MaskedSelect");
2828
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Out", "MaskedSelect");
29-
framework::DDim output_dims(ctx->GetInputDim("X"));
30-
ctx->SetOutputDim("Y", output_dims);
29+
30+
// output will only be a 1-D Tensor
31+
ctx->SetOutputDim("Y", framework::make_ddim({-1}));
3132
ctx->ShareLoD("X", /*->*/ "Y");
3233
}
3334

0 commit comments

Comments
 (0)