Skip to content

Commit 6fccb8f

Browse files
authored
[CINN] uniform all the 0 and reduce deleted axis (#61608)
* uniform all the 0 and reduce deleted axis * remove one shape for keepdim cases. * fix by code review * fix some error in 0d format
1 parent 8b4219b commit 6fccb8f

File tree

7 files changed

+107
-19
lines changed

7 files changed

+107
-19
lines changed

paddle/cinn/ast_gen_ius/ast_gen.cc

Lines changed: 73 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "paddle/cinn/optim/replace_var_with_expr.h"
2323

2424
PD_DECLARE_bool(cinn_new_group_scheduler);
25+
PD_DECLARE_bool(group_schedule_tiling_first);
2526
PD_DECLARE_bool(cinn_bucket_compile);
2627

2728
namespace cinn {
@@ -93,9 +94,21 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
9394
std::vector<ir::Expr> iter_values;
9495
// reduce body and reduce init schedule block should have different objects
9596
// for same axis so we re-create objects
97+
VLOG(4) << "FLAGS_group_schedule_tiling_first = "
98+
<< FLAGS_group_schedule_tiling_first;
9699
std::vector<Var> axis_vars = cinn::common::GenDefaultAxis(axis_len);
100+
const std::vector<ir::Var>& reduce_axis = tensor->reduce_axis;
101+
VLOG(4) << "ast gen: tensor init_body is " << init_body;
97102
for (int i = 0; i < shape.size(); ++i) {
98-
if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
103+
bool is_keep_dim = axis[i]->is_keepdim;
104+
if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
105+
// if tiling first, we need to replace the reduce axis with 0, but don't
106+
// deal with the non-reduce axis
107+
optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0));
108+
continue;
109+
}
110+
if (!FLAGS_group_schedule_tiling_first &&
111+
FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
99112
optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0));
100113
continue;
101114
}
@@ -105,29 +118,41 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
105118
/*is_reduce = */ false));
106119
optim::ReplaceVarWithExpr(&init_body, axis[i], block_vars.back());
107120
axis_vars[i]->is_reduce_axis = false;
108-
if (shape[i] == Expr(1)) {
121+
if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) {
109122
iter_values.push_back(Expr(0));
110123
} else {
111124
iter_values.push_back(axis_vars[i]);
112125
}
113126
}
127+
VLOG(4) << "iter_value.size() and block_vars.size() is "
128+
<< iter_values.size() << " " << block_vars.size();
114129
init_body = ir::ScheduleBlockRealize::Make(
115130
iter_values,
116131
ir::ScheduleBlock::Make(
117132
block_vars, {}, {}, reduce_init_name, init_body));
118133

119134
// For the remaining reduce axis, make reduce body
120-
const std::vector<ir::Var>& reduce_axis = tensor->reduce_axis;
121135
ir::Expr reduce_body =
122136
ConvertReduceBody(tensor->body(), tensor, axis_exprs);
137+
138+
VLOG(4) << "ast gen: reduce body is " << reduce_body;
139+
123140
// create schedule block itervars, i0,i1...
124141
std::vector<ir::Var> reduce_block_vars;
125142
std::vector<ir::Expr> reduce_iter_values;
126143
// reduce body and reduce init schedule block should have different objects
127144
// for same axis so we re-create objects
128145
std::vector<Var> reduce_axis_vars = cinn::common::GenDefaultAxis(axis_len);
129146
for (int i = 0; i < shape.size(); ++i) {
130-
if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
147+
bool is_keep_dim = axis[i]->is_keepdim;
148+
if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
149+
// if tiling first, we need to replace the reduce axis with 0, but don't
150+
// deal with the non-reduce axis
151+
optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0));
152+
continue;
153+
}
154+
if (!FLAGS_group_schedule_tiling_first &&
155+
FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
131156
optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0));
132157
continue;
133158
}
@@ -136,12 +161,13 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
136161
cinn::UniqName("i" + std::to_string(i)),
137162
/*is_reduce = */ false));
138163
reduce_axis_vars[i]->is_reduce_axis = false;
139-
if (shape[i] == Expr(1)) {
164+
if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) {
140165
reduce_iter_values.push_back(Expr(0));
141166
} else {
142167
reduce_iter_values.push_back(axis_vars[i]);
143168
}
144169
}
170+
VLOG(4) << "ast gen: reduce body is after replace 0" << reduce_body;
145171
for (int i = 0; i < reduce_axis.size(); ++i) {
146172
int count = shape.size() + i;
147173
reduce_block_vars.push_back(
@@ -155,14 +181,43 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
155181
}
156182

157183
int non_zero_axis_size = 0;
158-
for (int i = 0; i < axis.size(); ++i) {
159-
if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
160-
continue;
184+
if (FLAGS_group_schedule_tiling_first) {
185+
std::vector<ir::Var> non_reduce_axis_vars = [&]() {
186+
std::vector<ir::Var> res;
187+
for (int i = 0; i < shape.size(); ++i) {
188+
bool is_keep_dim = axis[i]->is_keepdim;
189+
if (!is_keep_dim) {
190+
res.push_back(axis[i]);
191+
}
192+
}
193+
return res;
194+
}();
195+
for (int i = 0; i < non_reduce_axis_vars.size(); ++i) {
196+
optim::ReplaceVarWithExpr(
197+
&reduce_body, non_reduce_axis_vars[i], reduce_block_vars[i]);
198+
++non_zero_axis_size;
161199
}
162-
optim::ReplaceVarWithExpr(
163-
&reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]);
164-
++non_zero_axis_size;
200+
} else {
201+
for (int i = 0; i < axis.size(); ++i) {
202+
if (!FLAGS_group_schedule_tiling_first &&
203+
FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
204+
continue;
205+
}
206+
optim::ReplaceVarWithExpr(
207+
&reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]);
208+
++non_zero_axis_size;
209+
}
210+
}
211+
212+
VLOG(4) << "to replace : " << non_zero_axis_size << " "
213+
<< reduce_block_vars.size();
214+
for (auto i = 0; i < reduce_block_vars.size(); i++) {
215+
VLOG(4) << "reduce_block_vars[" << i << "] = " << reduce_block_vars[i];
216+
}
217+
for (auto i = 0; i < reduce_axis.size(); i++) {
218+
VLOG(4) << "reduce_axis[" << i << "] = " << reduce_axis[i];
165219
}
220+
VLOG(4) << "before replace body: " << reduce_body;
166221
for (int i = non_zero_axis_size; i < reduce_block_vars.size(); ++i) {
167222
optim::ReplaceVarWithExpr(&reduce_body,
168223
reduce_axis[i - non_zero_axis_size],
@@ -185,7 +240,12 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
185240
// Put the two parts together
186241
ir::Expr body = ir::Block::Make({init_body, reduce_body});
187242
for (int i = static_cast<int>(axis_len) - 1; i >= 0; --i) {
188-
if (!FLAGS_cinn_bucket_compile && shape[i] == Expr(1)) {
243+
bool is_keep_dim = axis[i]->is_keepdim;
244+
if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
245+
continue;
246+
}
247+
if (!FLAGS_group_schedule_tiling_first && !FLAGS_cinn_bucket_compile &&
248+
shape[i] == Expr(1)) {
189249
continue;
190250
}
191251
ir::Var loop_var = axis[i];
@@ -210,7 +270,7 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
210270
Expr(0), shape[i], cinn::UniqName("i" + std::to_string(i)), false));
211271
optim::ReplaceVarWithExpr(&body, axis[i], block_vars[i]);
212272
axis_vars[i]->is_reduce_axis = false;
213-
if (shape[i] == Expr(1)) {
273+
if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) {
214274
iter_values.push_back(Expr(0));
215275
} else {
216276
iter_values.push_back(axis_vars[i]);

paddle/cinn/hlir/pe/reduction.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,14 @@ Tensor DoReduce(const Tensor& tensor,
166166
int indice_cnt = 0;
167167
int reduce_cnt = 0;
168168

169+
// Set keepdim flags of indices.
170+
if (tensor->shape.size() == indices.size()) {
171+
for (const auto& i : real_axes) {
172+
VLOG(4) << "Set is_keepdim = true for var(" << i << ")";
173+
indices[i].as_var_ref()->is_keepdim = true;
174+
}
175+
}
176+
169177
for (size_t i = 0; i < tensor->shape.size(); ++i) {
170178
bool squeeze_i = std::find(squeeze_axes.begin(), squeeze_axes.end(), i) !=
171179
squeeze_axes.end();

paddle/cinn/ir/ir.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,11 +218,13 @@ Expr _Var_::Make(Expr lower_bound,
218218
Expr upper_bound,
219219
const std::string &name,
220220
bool is_reduce_axis,
221-
bool is_symbolic_constant) {
221+
bool is_symbolic_constant,
222+
bool is_keepdim) {
222223
auto *n = make_shared<_Var_>();
223224
n->lower_bound = lower_bound;
224225
n->upper_bound = upper_bound;
225226
n->is_reduce_axis = is_reduce_axis;
227+
n->is_keepdim = is_keepdim;
226228
n->is_symbolic_constant = is_symbolic_constant;
227229
n->name = name;
228230
n->set_type(lower_bound.type());
@@ -233,6 +235,7 @@ Expr _Var_::Copy() const {
233235
auto *n = make_shared<_Var_>();
234236
n->name = name;
235237
n->is_reduce_axis = is_reduce_axis;
238+
n->is_keepdim = is_keepdim;
236239
n->lower_bound = lower_bound;
237240
n->upper_bound = upper_bound;
238241
n->set_type(type());

paddle/cinn/ir/ir.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ struct _Var_ : public ExprNode<_Var_> {
381381
std::string name;
382382

383383
bool is_reduce_axis{false};
384+
bool is_keepdim{false};
384385
bool is_symbolic_constant{false};
385386
//! Lower bound and upper bound of a axis.
386387
// @{
@@ -401,7 +402,8 @@ struct _Var_ : public ExprNode<_Var_> {
401402
Expr upper_bound,
402403
const std::string& name,
403404
bool is_reduce,
404-
bool is_symbolic_constant = false);
405+
bool is_symbolic_constant = false,
406+
bool is_keepdim = false);
405407

406408
void Verify() const override;
407409

@@ -419,12 +421,14 @@ struct Var : public IrNodeRef {
419421
Var(Expr lower_bound,
420422
Expr upper_bound,
421423
const std::string& name,
422-
bool is_reduce = false)
423-
: Var(_Var_::Make(lower_bound, upper_bound, name, is_reduce)) {}
424+
bool is_reduce = false,
425+
bool is_keepdim = false)
426+
: Var(_Var_::Make(
427+
lower_bound, upper_bound, name, is_reduce, false, is_keepdim)) {}
424428
Var(int upper_bound, const std::string& name)
425-
: Var(_Var_::Make(Expr(0), Expr(upper_bound), name, false)) {}
429+
: Var(_Var_::Make(Expr(0), Expr(upper_bound), name, false, false)) {}
426430
Var(Expr upper_bound, const std::string& name)
427-
: Var(_Var_::Make(Expr(0), upper_bound, name, false)) {}
431+
: Var(_Var_::Make(Expr(0), upper_bound, name, false, false)) {}
428432

429433
operator Expr() { return Expr(get()); }
430434
operator Expr() const {
@@ -977,6 +981,7 @@ struct ScheduleBlock : public ExprNode<ScheduleBlock> {
977981
std::map<std::string, attr_t> attrs;
978982
std::string name;
979983
Expr body;
984+
int32_t reduce_type{-1}; // 0 for warp reduce, 1 for block reduce
980985

981986
static Expr Make(const std::vector<Var>& iter_vars,
982987
const std::vector<Expr>& read_buffers,

paddle/cinn/lang/compute.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,13 @@ ir::Tensor Compute(const std::vector<Expr> &domain,
187187
domain_without_reduce_axis,
188188
op,
189189
reduce_axis);
190+
const auto set_keep_dim_for_tensor = [&]() {
191+
for (int i = 0; i < _axis.size(); ++i) {
192+
const auto &axis_var = _axis.at(i);
193+
tensor->axis_[i]->is_keepdim = axis_var.as_var_ref()->is_keepdim;
194+
}
195+
};
196+
set_keep_dim_for_tensor();
190197
return tensor;
191198
}
192199

paddle/cinn/pybind/ir/ir_api.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ void BindIrIr(py::module *m) {
383383
ir::Expr,
384384
const std::string &,
385385
bool,
386+
bool,
386387
bool>(&ir::_Var_::Make))
387388
.def("copy", &ir::_Var_::Copy);
388389

paddle/cinn/runtime/flags.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ PD_DEFINE_bool(cinn_bucket_compile,
6969
BoolFromEnv("FLAGS_cinn_bucket_compile", false),
7070
"Whether to enable bucket compile for dynamic shape.");
7171

72+
PD_DEFINE_bool(group_schedule_tiling_first,
73+
BoolFromEnv("FLAGS_group_schedule_tiling_first", false),
74+
"Whether to enable new group scheduler tiling first strategy.");
75+
7276
PD_DEFINE_bool(cinn_use_common_subexpression_elimination,
7377
BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination",
7478
false),

0 commit comments

Comments
 (0)