From 65dc09edb3075171de96777e6a22bb93a00776d0 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Sun, 4 Feb 2024 12:13:17 +0000 Subject: [PATCH 1/4] uniform all the 0 and reduce deleted axis --- paddle/cinn/ast_gen_ius/ast_gen.cc | 21 ++++++++++++++------- paddle/cinn/runtime/flags.cc | 4 ++++ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/paddle/cinn/ast_gen_ius/ast_gen.cc b/paddle/cinn/ast_gen_ius/ast_gen.cc index 009158d3f9cce2..b62c594290e483 100644 --- a/paddle/cinn/ast_gen_ius/ast_gen.cc +++ b/paddle/cinn/ast_gen_ius/ast_gen.cc @@ -22,6 +22,7 @@ #include "paddle/cinn/optim/replace_var_with_expr.h" PD_DECLARE_bool(cinn_new_group_scheduler); +PD_DECLARE_bool(group_schedule_tiling_first); PD_DECLARE_bool(cinn_bucket_compile); namespace cinn { @@ -93,9 +94,12 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { std::vector iter_values; // reduce body and reduce init schedule block should have different objects // for same axis so we re-create objects + VLOG(4) << "XKXK: FLAGS_group_schedule_tiling_first = " + << FLAGS_group_schedule_tiling_first; std::vector axis_vars = cinn::common::GenDefaultAxis(axis_len); for (int i = 0; i < shape.size(); ++i) { - if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { + if (!FLAGS_group_schedule_tiling_first && + FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0)); continue; } @@ -105,7 +109,7 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { /*is_reduce = */ false)); optim::ReplaceVarWithExpr(&init_body, axis[i], block_vars.back()); axis_vars[i]->is_reduce_axis = false; - if (shape[i] == Expr(1)) { + if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) { iter_values.push_back(Expr(0)); } else { iter_values.push_back(axis_vars[i]); @@ -127,7 +131,8 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { // for same axis so we re-create objects std::vector reduce_axis_vars = cinn::common::GenDefaultAxis(axis_len); for (int i = 0; i < shape.size(); ++i) { - if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { + if (!FLAGS_group_schedule_tiling_first && + FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0)); continue; } @@ -136,7 +141,7 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { cinn::UniqName("i" + std::to_string(i)), /*is_reduce = */ false)); reduce_axis_vars[i]->is_reduce_axis = false; - if (shape[i] == Expr(1)) { + if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) { reduce_iter_values.push_back(Expr(0)); } else { reduce_iter_values.push_back(axis_vars[i]); @@ -156,7 +161,8 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { int non_zero_axis_size = 0; for (int i = 0; i < axis.size(); ++i) { - if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { + if (!FLAGS_group_schedule_tiling_first && + FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { continue; } optim::ReplaceVarWithExpr( @@ -185,7 +191,8 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { // Put the two parts together ir::Expr body = ir::Block::Make({init_body, reduce_body}); for (int i = static_cast(axis_len) - 1; i >= 0; --i) { - if (!FLAGS_cinn_bucket_compile && shape[i] == Expr(1)) { + if (!FLAGS_group_schedule_tiling_first && !FLAGS_cinn_bucket_compile && + shape[i] == Expr(1)) { continue; } ir::Var loop_var = axis[i]; @@ -210,7 +217,7 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { Expr(0), shape[i], cinn::UniqName("i" + std::to_string(i)), false)); optim::ReplaceVarWithExpr(&body, axis[i], block_vars[i]); axis_vars[i]->is_reduce_axis = false; - if (shape[i] == Expr(1)) { + if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) { iter_values.push_back(Expr(0)); } else { iter_values.push_back(axis_vars[i]); diff --git a/paddle/cinn/runtime/flags.cc b/paddle/cinn/runtime/flags.cc index 4db79158568f88..d679961d0fd719 100644 --- a/paddle/cinn/runtime/flags.cc +++ b/paddle/cinn/runtime/flags.cc @@ -69,6 +69,10 @@ PD_DEFINE_bool(cinn_bucket_compile, BoolFromEnv("FLAGS_cinn_bucket_compile", false), "Whether to enable bucket compile for dynamic shape."); +PD_DEFINE_bool(group_schedule_tiling_first, + BoolFromEnv("FLAGS_group_schedule_tiling_first", false), + "Whether to enable new group scheduler tiling first strategy."); + PD_DEFINE_bool(cinn_use_common_subexpression_elimination, BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination", false), From d22fbfa9cbcdad503e47d166f65e1ceee58b949f Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 6 Feb 2024 14:07:49 +0000 Subject: [PATCH 2/4] remove one shape for keepdim cases. --- paddle/cinn/ast_gen_ius/ast_gen.cc | 56 ++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/paddle/cinn/ast_gen_ius/ast_gen.cc b/paddle/cinn/ast_gen_ius/ast_gen.cc index b62c594290e483..db30434f5f65e6 100644 --- a/paddle/cinn/ast_gen_ius/ast_gen.cc +++ b/paddle/cinn/ast_gen_ius/ast_gen.cc @@ -94,10 +94,44 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { std::vector iter_values; // reduce body and reduce init schedule block should have different objects // for same axis so we re-create objects - VLOG(4) << "XKXK: FLAGS_group_schedule_tiling_first = " + VLOG(4) << "FLAGS_group_schedule_tiling_first = " << FLAGS_group_schedule_tiling_first; std::vector axis_vars = cinn::common::GenDefaultAxis(axis_len); + const std::vector& reduce_axis = tensor->reduce_axis; + const auto reduce_axis_position = [&reduce_axis, + &tensor]() -> std::vector { + VLOG(4) << "start calculus reduce_axis_position: "; + std::vector res; + auto fn_body = tensor->operation.ptr()->as()->body[0]; + if (fn_body.defined() && fn_body.As()) { + auto& reduce_body = + fn_body.As()->body; // reduce body is a tensor store. + auto& load_indices = reduce_body.As()->indices; + int position = -1; + for (auto& obj : load_indices) { + position += 1; + for (auto& reduce_var : reduce_axis) { + if (obj.as_var_ref() == reduce_var) { + res.push_back(position); + } + } + } + for (auto i : res) { + VLOG(4) << "reduce axis position is " << i; + } + return res; + } + }(); for (int i = 0; i < shape.size(); ++i) { + if (FLAGS_group_schedule_tiling_first && + std::find(reduce_axis_position.begin(), + reduce_axis_position.end(), + i) != reduce_axis_position.end()) { + // if tiling first, we need to replace the reduce axis with 0, but don't + // deal with the non-reduce axis + optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0)); + continue; + } if (!FLAGS_group_schedule_tiling_first && FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0)); @@ -121,7 +155,6 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { block_vars, {}, {}, reduce_init_name, init_body)); // For the remaining reduce axis, make reduce body - const std::vector& reduce_axis = tensor->reduce_axis; ir::Expr reduce_body = ConvertReduceBody(tensor->body(), tensor, axis_exprs); // create schedule block itervars, i0,i1... @@ -131,6 +164,15 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { // for same axis so we re-create objects std::vector reduce_axis_vars = cinn::common::GenDefaultAxis(axis_len); for (int i = 0; i < shape.size(); ++i) { + if (FLAGS_group_schedule_tiling_first && + std::find(reduce_axis_position.begin(), + reduce_axis_position.end(), + i) != reduce_axis_position.end()) { + // if tiling first, we need to replace the reduce axis with 0, but don't + // deal with the non-reduce axis + optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0)); + continue; + } if (!FLAGS_group_schedule_tiling_first && FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0)); @@ -169,6 +211,10 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { &reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]); ++non_zero_axis_size; } + + if (FLAGS_group_schedule_tiling_first) { + non_zero_axis_size = axis.size() - reduce_axis.size(); + } for (int i = non_zero_axis_size; i < reduce_block_vars.size(); ++i) { optim::ReplaceVarWithExpr(&reduce_body, reduce_axis[i - non_zero_axis_size], @@ -191,6 +237,12 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { // Put the two parts together ir::Expr body = ir::Block::Make({init_body, reduce_body}); for (int i = static_cast(axis_len) - 1; i >= 0; --i) { + if (FLAGS_group_schedule_tiling_first && + std::find(reduce_axis_position.begin(), + reduce_axis_position.end(), + i) != reduce_axis_position.end()) { + continue; + } if (!FLAGS_group_schedule_tiling_first && !FLAGS_cinn_bucket_compile && shape[i] == Expr(1)) { continue; From 4ca0b259999816a31f5db22693212e0a056f5c26 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 20 Feb 2024 00:08:52 +0800 Subject: [PATCH 3/4] fix by code review --- paddle/cinn/ast_gen_ius/ast_gen.cc | 90 ++++++++++++++++-------------- 1 file changed, 49 insertions(+), 41 deletions(-) diff --git a/paddle/cinn/ast_gen_ius/ast_gen.cc b/paddle/cinn/ast_gen_ius/ast_gen.cc index db30434f5f65e6..4746681e93f453 100644 --- a/paddle/cinn/ast_gen_ius/ast_gen.cc +++ b/paddle/cinn/ast_gen_ius/ast_gen.cc @@ -97,36 +97,44 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { VLOG(4) << "FLAGS_group_schedule_tiling_first = " << FLAGS_group_schedule_tiling_first; std::vector axis_vars = cinn::common::GenDefaultAxis(axis_len); - const std::vector& reduce_axis = tensor->reduce_axis; - const auto reduce_axis_position = [&reduce_axis, - &tensor]() -> std::vector { + const std::vector& reduce_axes_vars = tensor->reduce_axis; + const auto reduce_axis_position = [&reduce_axes_vars, &tensor]() { VLOG(4) << "start calculus reduce_axis_position: "; std::vector res; - auto fn_body = tensor->operation.ptr()->as()->body[0]; - if (fn_body.defined() && fn_body.As()) { - auto& reduce_body = - fn_body.As()->body; // reduce body is a tensor store. - auto& load_indices = reduce_body.As()->indices; - int position = -1; - for (auto& obj : load_indices) { - position += 1; - for (auto& reduce_var : reduce_axis) { - if (obj.as_var_ref() == reduce_var) { - res.push_back(position); - } + const auto& fn_body = + tensor->operation.ptr()->as()->body[0]; + bool is_a_valid_reduce_op = fn_body.defined() && fn_body.As(); + if (!is_a_valid_reduce_op) { + PD_THROW( + "The reduce body is not a valid reduce op, please check the " + "input."); + } + const auto& reduce_body = + fn_body.As()->body; // reduce body is a tensor store. + const auto& load_indices = reduce_body.As()->indices; + int position = -1; + for (const auto& obj : load_indices) { + position += 1; + for (auto& reduce_var : reduce_axes_vars) { + if (obj.as_var_ref() == reduce_var) { + res.push_back(position); } } - for (auto i : res) { - VLOG(4) << "reduce axis position is " << i; - } - return res; } + VLOG(4) << "reduce axis position is " << [&] { + std::stringstream ss; + for (int i : res) { + ss << i << " "; + } + return ss.str(); + }(); + return res; }(); for (int i = 0; i < shape.size(); ++i) { - if (FLAGS_group_schedule_tiling_first && - std::find(reduce_axis_position.begin(), - reduce_axis_position.end(), - i) != reduce_axis_position.end()) { + bool reduce_axis_found = std::find(reduce_axis_position.begin(), + reduce_axis_position.end(), + i) != reduce_axis_position.end(); + if (FLAGS_group_schedule_tiling_first && reduce_axis_found) { // if tiling first, we need to replace the reduce axis with 0, but don't // deal with the non-reduce axis optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0)); @@ -164,10 +172,10 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { // for same axis so we re-create objects std::vector reduce_axis_vars = cinn::common::GenDefaultAxis(axis_len); for (int i = 0; i < shape.size(); ++i) { - if (FLAGS_group_schedule_tiling_first && - std::find(reduce_axis_position.begin(), - reduce_axis_position.end(), - i) != reduce_axis_position.end()) { + bool reduce_axis_found = std::find(reduce_axis_position.begin(), + reduce_axis_position.end(), + i) != reduce_axis_position.end(); + if (FLAGS_group_schedule_tiling_first && reduce_axis_found) { // if tiling first, we need to replace the reduce axis with 0, but don't // deal with the non-reduce axis optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0)); @@ -189,14 +197,14 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { reduce_iter_values.push_back(axis_vars[i]); } } - for (int i = 0; i < reduce_axis.size(); ++i) { + for (int i = 0; i < reduce_axes_vars.size(); ++i) { int count = shape.size() + i; reduce_block_vars.push_back( - Var(reduce_axis[i]->lower_bound, - reduce_axis[i]->upper_bound, + Var(reduce_axes_vars[i]->lower_bound, + reduce_axes_vars[i]->upper_bound, cinn::UniqName("i" + std::to_string(count)), /*is_reduce = */ true)); - ir::Var reduce_axis_var = reduce_axis[i]; + ir::Var reduce_axis_var = reduce_axes_vars[i]; reduce_axis_var->is_reduce_axis = true; reduce_iter_values.push_back(reduce_axis_var); } @@ -213,11 +221,11 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { } if (FLAGS_group_schedule_tiling_first) { - non_zero_axis_size = axis.size() - reduce_axis.size(); + non_zero_axis_size = axis.size() - reduce_axes_vars.size(); } for (int i = non_zero_axis_size; i < reduce_block_vars.size(); ++i) { optim::ReplaceVarWithExpr(&reduce_body, - reduce_axis[i - non_zero_axis_size], + reduce_axes_vars[i - non_zero_axis_size], reduce_block_vars[i]); } @@ -225,10 +233,10 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { reduce_iter_values, ir::ScheduleBlock::Make( reduce_block_vars, {}, {}, tensor->name, reduce_body)); - for (int i = static_cast(reduce_axis.size()) - 1; i >= 0; --i) { - reduce_body = ir::For::Make(reduce_axis[i], - reduce_axis[i]->lower_bound, - reduce_axis[i]->upper_bound, + for (int i = static_cast(reduce_axes_vars.size()) - 1; i >= 0; --i) { + reduce_body = ir::For::Make(reduce_axes_vars[i], + reduce_axes_vars[i]->lower_bound, + reduce_axes_vars[i]->upper_bound, ir::ForType::Serial, ir::DeviceAPI::Host, ir::Block::Make({reduce_body})); @@ -237,10 +245,10 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { // Put the two parts together ir::Expr body = ir::Block::Make({init_body, reduce_body}); for (int i = static_cast(axis_len) - 1; i >= 0; --i) { - if (FLAGS_group_schedule_tiling_first && - std::find(reduce_axis_position.begin(), - reduce_axis_position.end(), - i) != reduce_axis_position.end()) { + bool reduce_axis_found = std::find(reduce_axis_position.begin(), + reduce_axis_position.end(), + i) != reduce_axis_position.end(); + if (FLAGS_group_schedule_tiling_first && reduce_axis_found) { continue; } if (!FLAGS_group_schedule_tiling_first && !FLAGS_cinn_bucket_compile && From a84a8c5166ccdc0e24f19046453020fc3a6ffecc Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 1 Mar 2024 14:57:01 +0800 Subject: [PATCH 4/4] fix some error in 0d format --- paddle/cinn/ast_gen_ius/ast_gen.cc | 119 ++++++++++++++--------------- paddle/cinn/hlir/pe/reduction.cc | 8 ++ paddle/cinn/ir/ir.cc | 5 +- paddle/cinn/ir/ir.h | 15 ++-- paddle/cinn/lang/compute.cc | 7 ++ paddle/cinn/pybind/ir/ir_api.cc | 1 + 6 files changed, 86 insertions(+), 69 deletions(-) diff --git a/paddle/cinn/ast_gen_ius/ast_gen.cc b/paddle/cinn/ast_gen_ius/ast_gen.cc index 4746681e93f453..57b10fb7ca8849 100644 --- a/paddle/cinn/ast_gen_ius/ast_gen.cc +++ b/paddle/cinn/ast_gen_ius/ast_gen.cc @@ -97,44 +97,11 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { VLOG(4) << "FLAGS_group_schedule_tiling_first = " << FLAGS_group_schedule_tiling_first; std::vector axis_vars = cinn::common::GenDefaultAxis(axis_len); - const std::vector& reduce_axes_vars = tensor->reduce_axis; - const auto reduce_axis_position = [&reduce_axes_vars, &tensor]() { - VLOG(4) << "start calculus reduce_axis_position: "; - std::vector res; - const auto& fn_body = - tensor->operation.ptr()->as()->body[0]; - bool is_a_valid_reduce_op = fn_body.defined() && fn_body.As(); - if (!is_a_valid_reduce_op) { - PD_THROW( - "The reduce body is not a valid reduce op, please check the " - "input."); - } - const auto& reduce_body = - fn_body.As()->body; // reduce body is a tensor store. - const auto& load_indices = reduce_body.As()->indices; - int position = -1; - for (const auto& obj : load_indices) { - position += 1; - for (auto& reduce_var : reduce_axes_vars) { - if (obj.as_var_ref() == reduce_var) { - res.push_back(position); - } - } - } - VLOG(4) << "reduce axis position is " << [&] { - std::stringstream ss; - for (int i : res) { - ss << i << " "; - } - return ss.str(); - }(); - return res; - }(); + const std::vector& reduce_axis = tensor->reduce_axis; + VLOG(4) << "ast gen: tensor init_body is " << init_body; for (int i = 0; i < shape.size(); ++i) { - bool reduce_axis_found = std::find(reduce_axis_position.begin(), - reduce_axis_position.end(), - i) != reduce_axis_position.end(); - if (FLAGS_group_schedule_tiling_first && reduce_axis_found) { + bool is_keep_dim = axis[i]->is_keepdim; + if (FLAGS_group_schedule_tiling_first && is_keep_dim) { // if tiling first, we need to replace the reduce axis with 0, but don't // deal with the non-reduce axis optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0)); @@ -157,6 +124,8 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { iter_values.push_back(axis_vars[i]); } } + VLOG(4) << "iter_value.size() and block_vars.size() is " + << iter_values.size() << " " << block_vars.size(); init_body = ir::ScheduleBlockRealize::Make( iter_values, ir::ScheduleBlock::Make( @@ -165,6 +134,9 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { // For the remaining reduce axis, make reduce body ir::Expr reduce_body = ConvertReduceBody(tensor->body(), tensor, axis_exprs); + + VLOG(4) << "ast gen: reduce body is " << reduce_body; + // create schedule block itervars, i0,i1... std::vector reduce_block_vars; std::vector reduce_iter_values; @@ -172,10 +144,8 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { // for same axis so we re-create objects std::vector reduce_axis_vars = cinn::common::GenDefaultAxis(axis_len); for (int i = 0; i < shape.size(); ++i) { - bool reduce_axis_found = std::find(reduce_axis_position.begin(), - reduce_axis_position.end(), - i) != reduce_axis_position.end(); - if (FLAGS_group_schedule_tiling_first && reduce_axis_found) { + bool is_keep_dim = axis[i]->is_keepdim; + if (FLAGS_group_schedule_tiling_first && is_keep_dim) { // if tiling first, we need to replace the reduce axis with 0, but don't // deal with the non-reduce axis optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0)); @@ -197,35 +167,60 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { reduce_iter_values.push_back(axis_vars[i]); } } - for (int i = 0; i < reduce_axes_vars.size(); ++i) { + VLOG(4) << "ast gen: reduce body is after replace 0" << reduce_body; + for (int i = 0; i < reduce_axis.size(); ++i) { int count = shape.size() + i; reduce_block_vars.push_back( - Var(reduce_axes_vars[i]->lower_bound, - reduce_axes_vars[i]->upper_bound, + Var(reduce_axis[i]->lower_bound, + reduce_axis[i]->upper_bound, cinn::UniqName("i" + std::to_string(count)), /*is_reduce = */ true)); - ir::Var reduce_axis_var = reduce_axes_vars[i]; + ir::Var reduce_axis_var = reduce_axis[i]; reduce_axis_var->is_reduce_axis = true; reduce_iter_values.push_back(reduce_axis_var); } int non_zero_axis_size = 0; - for (int i = 0; i < axis.size(); ++i) { - if (!FLAGS_group_schedule_tiling_first && - FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { - continue; + if (FLAGS_group_schedule_tiling_first) { + std::vector non_reduce_axis_vars = [&]() { + std::vector res; + for (int i = 0; i < shape.size(); ++i) { + bool is_keep_dim = axis[i]->is_keepdim; + if (!is_keep_dim) { + res.push_back(axis[i]); + } + } + return res; + }(); + for (int i = 0; i < non_reduce_axis_vars.size(); ++i) { + optim::ReplaceVarWithExpr( + &reduce_body, non_reduce_axis_vars[i], reduce_block_vars[i]); + ++non_zero_axis_size; + } + } else { + for (int i = 0; i < axis.size(); ++i) { + if (!FLAGS_group_schedule_tiling_first && + FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { + continue; + } + optim::ReplaceVarWithExpr( + &reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]); + ++non_zero_axis_size; } - optim::ReplaceVarWithExpr( - &reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]); - ++non_zero_axis_size; } - if (FLAGS_group_schedule_tiling_first) { - non_zero_axis_size = axis.size() - reduce_axes_vars.size(); + VLOG(4) << "to replace : " << non_zero_axis_size << " " + << reduce_block_vars.size(); + for (auto i = 0; i < reduce_block_vars.size(); i++) { + VLOG(4) << "reduce_block_vars[" << i << "] = " << reduce_block_vars[i]; + } + for (auto i = 0; i < reduce_axis.size(); i++) { + VLOG(4) << "reduce_axis[" << i << "] = " << reduce_axis[i]; } + VLOG(4) << "before replace body: " << reduce_body; for (int i = non_zero_axis_size; i < reduce_block_vars.size(); ++i) { optim::ReplaceVarWithExpr(&reduce_body, - reduce_axes_vars[i - non_zero_axis_size], + reduce_axis[i - non_zero_axis_size], reduce_block_vars[i]); } @@ -233,10 +228,10 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { reduce_iter_values, ir::ScheduleBlock::Make( reduce_block_vars, {}, {}, tensor->name, reduce_body)); - for (int i = static_cast(reduce_axes_vars.size()) - 1; i >= 0; --i) { - reduce_body = ir::For::Make(reduce_axes_vars[i], - reduce_axes_vars[i]->lower_bound, - reduce_axes_vars[i]->upper_bound, + for (int i = static_cast(reduce_axis.size()) - 1; i >= 0; --i) { + reduce_body = ir::For::Make(reduce_axis[i], + reduce_axis[i]->lower_bound, + reduce_axis[i]->upper_bound, ir::ForType::Serial, ir::DeviceAPI::Host, ir::Block::Make({reduce_body})); @@ -245,10 +240,8 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { // Put the two parts together ir::Expr body = ir::Block::Make({init_body, reduce_body}); for (int i = static_cast(axis_len) - 1; i >= 0; --i) { - bool reduce_axis_found = std::find(reduce_axis_position.begin(), - reduce_axis_position.end(), - i) != reduce_axis_position.end(); - if (FLAGS_group_schedule_tiling_first && reduce_axis_found) { + bool is_keep_dim = axis[i]->is_keepdim; + if (FLAGS_group_schedule_tiling_first && is_keep_dim) { continue; } if (!FLAGS_group_schedule_tiling_first && !FLAGS_cinn_bucket_compile && diff --git a/paddle/cinn/hlir/pe/reduction.cc b/paddle/cinn/hlir/pe/reduction.cc index 7e33a1475e48b3..605a1b3d6443fe 100644 --- a/paddle/cinn/hlir/pe/reduction.cc +++ b/paddle/cinn/hlir/pe/reduction.cc @@ -166,6 +166,14 @@ Tensor DoReduce(const Tensor& tensor, int indice_cnt = 0; int reduce_cnt = 0; + // Set keepdim flags of indices. + if (tensor->shape.size() == indices.size()) { + for (const auto& i : real_axes) { + VLOG(4) << "Set is_keepdim = true for var(" << i << ")"; + indices[i].as_var_ref()->is_keepdim = true; + } + } + for (size_t i = 0; i < tensor->shape.size(); ++i) { bool squeeze_i = std::find(squeeze_axes.begin(), squeeze_axes.end(), i) != squeeze_axes.end(); diff --git a/paddle/cinn/ir/ir.cc b/paddle/cinn/ir/ir.cc index 2e194200d19937..f3c64790551cac 100644 --- a/paddle/cinn/ir/ir.cc +++ b/paddle/cinn/ir/ir.cc @@ -218,11 +218,13 @@ Expr _Var_::Make(Expr lower_bound, Expr upper_bound, const std::string &name, bool is_reduce_axis, - bool is_symbolic_constant) { + bool is_symbolic_constant, + bool is_keepdim) { auto *n = make_shared<_Var_>(); n->lower_bound = lower_bound; n->upper_bound = upper_bound; n->is_reduce_axis = is_reduce_axis; + n->is_keepdim = is_keepdim; n->is_symbolic_constant = is_symbolic_constant; n->name = name; n->set_type(lower_bound.type()); @@ -233,6 +235,7 @@ Expr _Var_::Copy() const { auto *n = make_shared<_Var_>(); n->name = name; n->is_reduce_axis = is_reduce_axis; + n->is_keepdim = is_keepdim; n->lower_bound = lower_bound; n->upper_bound = upper_bound; n->set_type(type()); diff --git a/paddle/cinn/ir/ir.h b/paddle/cinn/ir/ir.h index c02517f9836fc3..5a1f9f6a1f739f 100644 --- a/paddle/cinn/ir/ir.h +++ b/paddle/cinn/ir/ir.h @@ -381,6 +381,7 @@ struct _Var_ : public ExprNode<_Var_> { std::string name; bool is_reduce_axis{false}; + bool is_keepdim{false}; bool is_symbolic_constant{false}; //! Lower bound and upper bound of a axis. // @{ @@ -401,7 +402,8 @@ struct _Var_ : public ExprNode<_Var_> { Expr upper_bound, const std::string& name, bool is_reduce, - bool is_symbolic_constant = false); + bool is_symbolic_constant = false, + bool is_keepdim = false); void Verify() const override; @@ -419,12 +421,14 @@ struct Var : public IrNodeRef { Var(Expr lower_bound, Expr upper_bound, const std::string& name, - bool is_reduce = false) - : Var(_Var_::Make(lower_bound, upper_bound, name, is_reduce)) {} + bool is_reduce = false, + bool is_keepdim = false) + : Var(_Var_::Make( + lower_bound, upper_bound, name, is_reduce, false, is_keepdim)) {} Var(int upper_bound, const std::string& name) - : Var(_Var_::Make(Expr(0), Expr(upper_bound), name, false)) {} + : Var(_Var_::Make(Expr(0), Expr(upper_bound), name, false, false)) {} Var(Expr upper_bound, const std::string& name) - : Var(_Var_::Make(Expr(0), upper_bound, name, false)) {} + : Var(_Var_::Make(Expr(0), upper_bound, name, false, false)) {} operator Expr() { return Expr(get()); } operator Expr() const { @@ -977,6 +981,7 @@ struct ScheduleBlock : public ExprNode { std::map attrs; std::string name; Expr body; + int32_t reduce_type{-1}; // 0 for warp reduce, 1 for block reduce static Expr Make(const std::vector& iter_vars, const std::vector& read_buffers, diff --git a/paddle/cinn/lang/compute.cc b/paddle/cinn/lang/compute.cc index 4828eaac64e13c..bd195fd26a6390 100644 --- a/paddle/cinn/lang/compute.cc +++ b/paddle/cinn/lang/compute.cc @@ -187,6 +187,13 @@ ir::Tensor Compute(const std::vector &domain, domain_without_reduce_axis, op, reduce_axis); + const auto set_keep_dim_for_tensor = [&]() { + for (int i = 0; i < _axis.size(); ++i) { + const auto &axis_var = _axis.at(i); + tensor->axis_[i]->is_keepdim = axis_var.as_var_ref()->is_keepdim; + } + }; + set_keep_dim_for_tensor(); return tensor; } diff --git a/paddle/cinn/pybind/ir/ir_api.cc b/paddle/cinn/pybind/ir/ir_api.cc index 56dff498dd7101..efebf1206a8674 100644 --- a/paddle/cinn/pybind/ir/ir_api.cc +++ b/paddle/cinn/pybind/ir/ir_api.cc @@ -383,6 +383,7 @@ void BindIrIr(py::module *m) { ir::Expr, const std::string &, bool, + bool, bool>(&ir::_Var_::Make)) .def("copy", &ir::_Var_::Copy);