-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[CINN] uniform all the 0 and reduce deleted axis #61608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
65dc09e
d22fbfa
4ca0b25
1c6c65f
a84a8c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| #include "paddle/cinn/optim/replace_var_with_expr.h" | ||
|
|
||
| PD_DECLARE_bool(cinn_new_group_scheduler); | ||
| PD_DECLARE_bool(group_schedule_tiling_first); | ||
| PD_DECLARE_bool(cinn_bucket_compile); | ||
|
|
||
| namespace cinn { | ||
|
|
@@ -93,9 +94,21 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { | |
| std::vector<ir::Expr> iter_values; | ||
| // reduce body and reduce init schedule block should have different objects | ||
| // for same axis so we re-create objects | ||
| VLOG(4) << "FLAGS_group_schedule_tiling_first = " | ||
| << FLAGS_group_schedule_tiling_first; | ||
| std::vector<Var> axis_vars = cinn::common::GenDefaultAxis(axis_len); | ||
| const std::vector<ir::Var>& reduce_axis = tensor->reduce_axis; | ||
| VLOG(4) << "ast gen: tensor init_body is " << init_body; | ||
| for (int i = 0; i < shape.size(); ++i) { | ||
| if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { | ||
| bool is_keep_dim = axis[i]->is_keepdim; | ||
| if (FLAGS_group_schedule_tiling_first && is_keep_dim) { | ||
| // if tiling first, we need to replace the reduce axis with 0, but don't | ||
| // deal with the non-reduce axis | ||
| optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0)); | ||
| continue; | ||
| } | ||
| if (!FLAGS_group_schedule_tiling_first && | ||
| FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { | ||
| optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0)); | ||
| continue; | ||
| } | ||
|
|
@@ -105,29 +118,41 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { | |
| /*is_reduce = */ false)); | ||
| optim::ReplaceVarWithExpr(&init_body, axis[i], block_vars.back()); | ||
| axis_vars[i]->is_reduce_axis = false; | ||
| if (shape[i] == Expr(1)) { | ||
| if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) { | ||
| iter_values.push_back(Expr(0)); | ||
| } else { | ||
| iter_values.push_back(axis_vars[i]); | ||
| } | ||
| } | ||
| VLOG(4) << "iter_value.size() and block_vars.size() is " | ||
| << iter_values.size() << " " << block_vars.size(); | ||
| init_body = ir::ScheduleBlockRealize::Make( | ||
| iter_values, | ||
| ir::ScheduleBlock::Make( | ||
| block_vars, {}, {}, reduce_init_name, init_body)); | ||
|
|
||
| // For the remaining reduce axis, make reduce body | ||
| const std::vector<ir::Var>& reduce_axis = tensor->reduce_axis; | ||
| ir::Expr reduce_body = | ||
| ConvertReduceBody(tensor->body(), tensor, axis_exprs); | ||
|
|
||
| VLOG(4) << "ast gen: reduce body is " << reduce_body; | ||
|
|
||
| // create schedule block itervars, i0,i1... | ||
| std::vector<ir::Var> reduce_block_vars; | ||
| std::vector<ir::Expr> reduce_iter_values; | ||
| // reduce body and reduce init schedule block should have different objects | ||
| // for same axis so we re-create objects | ||
| std::vector<Var> reduce_axis_vars = cinn::common::GenDefaultAxis(axis_len); | ||
| for (int i = 0; i < shape.size(); ++i) { | ||
| if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { | ||
| bool is_keep_dim = axis[i]->is_keepdim; | ||
| if (FLAGS_group_schedule_tiling_first && is_keep_dim) { | ||
| // if tiling first, we need to replace the reduce axis with 0, but don't | ||
| // deal with the non-reduce axis | ||
| optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0)); | ||
| continue; | ||
| } | ||
| if (!FLAGS_group_schedule_tiling_first && | ||
| FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { | ||
| optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0)); | ||
| continue; | ||
| } | ||
|
|
@@ -136,12 +161,13 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { | |
| cinn::UniqName("i" + std::to_string(i)), | ||
| /*is_reduce = */ false)); | ||
| reduce_axis_vars[i]->is_reduce_axis = false; | ||
| if (shape[i] == Expr(1)) { | ||
| if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) { | ||
| reduce_iter_values.push_back(Expr(0)); | ||
| } else { | ||
| reduce_iter_values.push_back(axis_vars[i]); | ||
| } | ||
| } | ||
| VLOG(4) << "ast gen: reduce body is after replace 0" << reduce_body; | ||
| for (int i = 0; i < reduce_axis.size(); ++i) { | ||
| int count = shape.size() + i; | ||
| reduce_block_vars.push_back( | ||
|
|
@@ -155,14 +181,43 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { | |
| } | ||
|
|
||
| int non_zero_axis_size = 0; | ||
| for (int i = 0; i < axis.size(); ++i) { | ||
| if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { | ||
| continue; | ||
| if (FLAGS_group_schedule_tiling_first) { | ||
| std::vector<ir::Var> non_reduce_axis_vars = [&]() { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| std::vector<ir::Var> res; | ||
| for (int i = 0; i < shape.size(); ++i) { | ||
| bool is_keep_dim = axis[i]->is_keepdim; | ||
| if (!is_keep_dim) { | ||
| res.push_back(axis[i]); | ||
| } | ||
| } | ||
| return res; | ||
| }(); | ||
| for (int i = 0; i < non_reduce_axis_vars.size(); ++i) { | ||
| optim::ReplaceVarWithExpr( | ||
| &reduce_body, non_reduce_axis_vars[i], reduce_block_vars[i]); | ||
| ++non_zero_axis_size; | ||
| } | ||
| optim::ReplaceVarWithExpr( | ||
| &reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]); | ||
| ++non_zero_axis_size; | ||
| } else { | ||
| for (int i = 0; i < axis.size(); ++i) { | ||
| if (!FLAGS_group_schedule_tiling_first && | ||
| FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { | ||
| continue; | ||
| } | ||
| optim::ReplaceVarWithExpr( | ||
| &reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]); | ||
| ++non_zero_axis_size; | ||
| } | ||
| } | ||
|
|
||
| VLOG(4) << "to replace : " << non_zero_axis_size << " " | ||
| << reduce_block_vars.size(); | ||
| for (auto i = 0; i < reduce_block_vars.size(); i++) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if (VLOG_IS_ON(4)) |
||
| VLOG(4) << "reduce_block_vars[" << i << "] = " << reduce_block_vars[i]; | ||
| } | ||
| for (auto i = 0; i < reduce_axis.size(); i++) { | ||
| VLOG(4) << "reduce_axis[" << i << "] = " << reduce_axis[i]; | ||
| } | ||
| VLOG(4) << "before replace body: " << reduce_body; | ||
| for (int i = non_zero_axis_size; i < reduce_block_vars.size(); ++i) { | ||
| optim::ReplaceVarWithExpr(&reduce_body, | ||
| reduce_axis[i - non_zero_axis_size], | ||
|
|
@@ -185,7 +240,12 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { | |
| // Put the two parts together | ||
| ir::Expr body = ir::Block::Make({init_body, reduce_body}); | ||
| for (int i = static_cast<int>(axis_len) - 1; i >= 0; --i) { | ||
| if (!FLAGS_cinn_bucket_compile && shape[i] == Expr(1)) { | ||
| bool is_keep_dim = axis[i]->is_keepdim; | ||
| if (FLAGS_group_schedule_tiling_first && is_keep_dim) { | ||
| continue; | ||
| } | ||
| if (!FLAGS_group_schedule_tiling_first && !FLAGS_cinn_bucket_compile && | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (!FLAGS_group_schedule_tiling_first || !FLAGS_cinn_bucket_compile) && shape[i] == Expr(1) |
||
| shape[i] == Expr(1)) { | ||
| continue; | ||
| } | ||
| ir::Var loop_var = axis[i]; | ||
|
|
@@ -210,7 +270,7 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { | |
| Expr(0), shape[i], cinn::UniqName("i" + std::to_string(i)), false)); | ||
| optim::ReplaceVarWithExpr(&body, axis[i], block_vars[i]); | ||
| axis_vars[i]->is_reduce_axis = false; | ||
| if (shape[i] == Expr(1)) { | ||
| if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) { | ||
| iter_values.push_back(Expr(0)); | ||
| } else { | ||
| iter_values.push_back(axis_vars[i]); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -187,6 +187,13 @@ ir::Tensor Compute(const std::vector<Expr> &domain, | |
| domain_without_reduce_axis, | ||
| op, | ||
| reduce_axis); | ||
| const auto set_keep_dim_for_tensor = [&]() { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SetKeepDimForTensor |
||
| for (int i = 0; i < _axis.size(); ++i) { | ||
| const auto &axis_var = _axis.at(i); | ||
| tensor->axis_[i]->is_keepdim = axis_var.as_var_ref()->is_keepdim; | ||
| } | ||
| }; | ||
| set_keep_dim_for_tensor(); | ||
| return tensor; | ||
| } | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.