File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
paddle/fluid/framework/ir Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -57,7 +57,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
5757 std::vector<int64_t > y_shape = matmul_in_y->Var ()->GetShape ();
5858 size_t x_rank = x_shape.size ();
5959 size_t y_rank = y_shape.size ();
60- flag = flag && x_rank == 2 && y_rank == 2 ;
60+ flag = flag && ( x_rank == 2 || x_rank == 3 ) && y_rank == 2 ;
6161
6262 std::vector<Node*>& next_ops = matmul_out->outputs ;
6363 flag = flag && next_ops.size () == 1 &&
@@ -69,7 +69,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
6969 desc.SetInput (" X" , {matmul_in_x->Name ()});
7070 desc.SetInput (" Y" , {matmul_in_y->Name ()});
7171 desc.SetOutput (" Out" , {matmul_out->Name ()});
72- desc.SetAttr (" x_num_col_dims" , 1 );
72+ desc.SetAttr (" x_num_col_dims" , static_cast < int >(x_rank - 1 ) );
7373 desc.SetAttr (" y_num_col_dims" , 1 );
7474 if (matmul_op->Op ()->HasAttr (" enable_int8" )) {
7575 desc.SetAttr (" enable_int8" , matmul_op->Op ()->GetAttr (" enable_int8" ));
You can’t perform that action at this time.
0 commit comments