Skip to content

Commit 1efe883

Browse files
author
cryoco
committed
map_matmul_to_mul_pass support 3dim
1 parent 309efda commit 1efe883

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
5757
std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape();
5858
size_t x_rank = x_shape.size();
5959
size_t y_rank = y_shape.size();
60-
flag = flag && x_rank == 2 && y_rank == 2;
60+
flag = flag && (x_rank == 2 || x_rank == 3) && y_rank == 2;
6161

6262
std::vector<Node*>& next_ops = matmul_out->outputs;
6363
flag = flag && next_ops.size() == 1 &&
@@ -69,7 +69,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
6969
desc.SetInput("X", {matmul_in_x->Name()});
7070
desc.SetInput("Y", {matmul_in_y->Name()});
7171
desc.SetOutput("Out", {matmul_out->Name()});
72-
desc.SetAttr("x_num_col_dims", 1);
72+
desc.SetAttr("x_num_col_dims", static_cast<int>(x_rank - 1));
7373
desc.SetAttr("y_num_col_dims", 1);
7474
if (matmul_op->Op()->HasAttr("enable_int8")) {
7575
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));

0 commit comments

Comments
 (0)