2222#include " paddle/cinn/optim/replace_var_with_expr.h"
2323
2424PD_DECLARE_bool (cinn_new_group_scheduler);
25+ PD_DECLARE_bool (group_schedule_tiling_first);
2526PD_DECLARE_bool (cinn_bucket_compile);
2627
2728namespace cinn {
@@ -93,9 +94,21 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
9394 std::vector<ir::Expr> iter_values;
9495 // reduce body and reduce init schedule block should have different objects
9596 // for same axis so we re-create objects
97+ VLOG (4 ) << " FLAGS_group_schedule_tiling_first = "
98+ << FLAGS_group_schedule_tiling_first;
9699 std::vector<Var> axis_vars = cinn::common::GenDefaultAxis (axis_len);
100+ const std::vector<ir::Var>& reduce_axis = tensor->reduce_axis ;
101+ VLOG (4 ) << " ast gen: tensor init_body is " << init_body;
97102 for (int i = 0 ; i < shape.size (); ++i) {
98- if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr (1 )) {
103+ bool is_keep_dim = axis[i]->is_keepdim ;
104+ if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
105+ // if tiling first, we need to replace the reduce axis with 0, but don't
106+ // deal with the non-reduce axis
107+ optim::ReplaceVarWithExpr (&init_body, axis[i], Expr (0 ));
108+ continue ;
109+ }
110+ if (!FLAGS_group_schedule_tiling_first &&
111+ FLAGS_cinn_new_group_scheduler && shape[i] == Expr (1 )) {
99112 optim::ReplaceVarWithExpr (&init_body, axis[i], Expr (0 ));
100113 continue ;
101114 }
@@ -105,29 +118,41 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
105118 /* is_reduce = */ false ));
106119 optim::ReplaceVarWithExpr (&init_body, axis[i], block_vars.back ());
107120 axis_vars[i]->is_reduce_axis = false ;
108- if (shape[i] == Expr (1 )) {
121+ if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr (1 )) {
109122 iter_values.push_back (Expr (0 ));
110123 } else {
111124 iter_values.push_back (axis_vars[i]);
112125 }
113126 }
127+ VLOG (4 ) << " iter_value.size() and block_vars.size() is "
128+ << iter_values.size () << " " << block_vars.size ();
114129 init_body = ir::ScheduleBlockRealize::Make (
115130 iter_values,
116131 ir::ScheduleBlock::Make (
117132 block_vars, {}, {}, reduce_init_name, init_body));
118133
119134 // For the remaining reduce axis, make reduce body
120- const std::vector<ir::Var>& reduce_axis = tensor->reduce_axis ;
121135 ir::Expr reduce_body =
122136 ConvertReduceBody (tensor->body (), tensor, axis_exprs);
137+
138+ VLOG (4 ) << " ast gen: reduce body is " << reduce_body;
139+
123140 // create schedule block itervars, i0,i1...
124141 std::vector<ir::Var> reduce_block_vars;
125142 std::vector<ir::Expr> reduce_iter_values;
126143 // reduce body and reduce init schedule block should have different objects
127144 // for same axis so we re-create objects
128145 std::vector<Var> reduce_axis_vars = cinn::common::GenDefaultAxis (axis_len);
129146 for (int i = 0 ; i < shape.size (); ++i) {
130- if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr (1 )) {
147+ bool is_keep_dim = axis[i]->is_keepdim ;
148+ if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
149+ // if tiling first, we need to replace the reduce axis with 0, but don't
150+ // deal with the non-reduce axis
151+ optim::ReplaceVarWithExpr (&reduce_body, axis[i], Expr (0 ));
152+ continue ;
153+ }
154+ if (!FLAGS_group_schedule_tiling_first &&
155+ FLAGS_cinn_new_group_scheduler && shape[i] == Expr (1 )) {
131156 optim::ReplaceVarWithExpr (&reduce_body, axis[i], Expr (0 ));
132157 continue ;
133158 }
@@ -136,12 +161,13 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
136161 cinn::UniqName (" i" + std::to_string (i)),
137162 /* is_reduce = */ false ));
138163 reduce_axis_vars[i]->is_reduce_axis = false ;
139- if (shape[i] == Expr (1 )) {
164+ if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr (1 )) {
140165 reduce_iter_values.push_back (Expr (0 ));
141166 } else {
142167 reduce_iter_values.push_back (axis_vars[i]);
143168 }
144169 }
170+ VLOG (4 ) << " ast gen: reduce body is after replace 0" << reduce_body;
145171 for (int i = 0 ; i < reduce_axis.size (); ++i) {
146172 int count = shape.size () + i;
147173 reduce_block_vars.push_back (
@@ -155,14 +181,43 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
155181 }
156182
157183 int non_zero_axis_size = 0 ;
158- for (int i = 0 ; i < axis.size (); ++i) {
159- if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr (1 )) {
160- continue ;
184+ if (FLAGS_group_schedule_tiling_first) {
185+ std::vector<ir::Var> non_reduce_axis_vars = [&]() {
186+ std::vector<ir::Var> res;
187+ for (int i = 0 ; i < shape.size (); ++i) {
188+ bool is_keep_dim = axis[i]->is_keepdim ;
189+ if (!is_keep_dim) {
190+ res.push_back (axis[i]);
191+ }
192+ }
193+ return res;
194+ }();
195+ for (int i = 0 ; i < non_reduce_axis_vars.size (); ++i) {
196+ optim::ReplaceVarWithExpr (
197+ &reduce_body, non_reduce_axis_vars[i], reduce_block_vars[i]);
198+ ++non_zero_axis_size;
161199 }
162- optim::ReplaceVarWithExpr (
163- &reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]);
164- ++non_zero_axis_size;
200+ } else {
201+ for (int i = 0 ; i < axis.size (); ++i) {
202+ if (!FLAGS_group_schedule_tiling_first &&
203+ FLAGS_cinn_new_group_scheduler && shape[i] == Expr (1 )) {
204+ continue ;
205+ }
206+ optim::ReplaceVarWithExpr (
207+ &reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]);
208+ ++non_zero_axis_size;
209+ }
210+ }
211+
212+ VLOG (4 ) << " to replace : " << non_zero_axis_size << " "
213+ << reduce_block_vars.size ();
214+ for (auto i = 0 ; i < reduce_block_vars.size (); i++) {
215+ VLOG (4 ) << " reduce_block_vars[" << i << " ] = " << reduce_block_vars[i];
216+ }
217+ for (auto i = 0 ; i < reduce_axis.size (); i++) {
218+ VLOG (4 ) << " reduce_axis[" << i << " ] = " << reduce_axis[i];
165219 }
220+ VLOG (4 ) << " before replace body: " << reduce_body;
166221 for (int i = non_zero_axis_size; i < reduce_block_vars.size (); ++i) {
167222 optim::ReplaceVarWithExpr (&reduce_body,
168223 reduce_axis[i - non_zero_axis_size],
@@ -185,7 +240,12 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
185240 // Put the two parts together
186241 ir::Expr body = ir::Block::Make ({init_body, reduce_body});
187242 for (int i = static_cast <int >(axis_len) - 1 ; i >= 0 ; --i) {
188- if (!FLAGS_cinn_bucket_compile && shape[i] == Expr (1 )) {
243+ bool is_keep_dim = axis[i]->is_keepdim ;
244+ if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
245+ continue ;
246+ }
247+ if (!FLAGS_group_schedule_tiling_first && !FLAGS_cinn_bucket_compile &&
248+ shape[i] == Expr (1 )) {
189249 continue ;
190250 }
191251 ir::Var loop_var = axis[i];
@@ -210,7 +270,7 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
210270 Expr (0 ), shape[i], cinn::UniqName (" i" + std::to_string (i)), false ));
211271 optim::ReplaceVarWithExpr (&body, axis[i], block_vars[i]);
212272 axis_vars[i]->is_reduce_axis = false ;
213- if (shape[i] == Expr (1 )) {
273+ if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr (1 )) {
214274 iter_values.push_back (Expr (0 ));
215275 } else {
216276 iter_values.push_back (axis_vars[i]);
0 commit comments