Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 73 additions & 13 deletions paddle/cinn/ast_gen_ius/ast_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/cinn/optim/replace_var_with_expr.h"

PD_DECLARE_bool(cinn_new_group_scheduler);
PD_DECLARE_bool(group_schedule_tiling_first);
PD_DECLARE_bool(cinn_bucket_compile);

namespace cinn {
Expand Down Expand Up @@ -93,9 +94,21 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
std::vector<ir::Expr> iter_values;
// reduce body and reduce init schedule block should have different objects
// for same axis so we re-create objects
VLOG(4) << "FLAGS_group_schedule_tiling_first = "
<< FLAGS_group_schedule_tiling_first;
std::vector<Var> axis_vars = cinn::common::GenDefaultAxis(axis_len);
const std::vector<ir::Var>& reduce_axis = tensor->reduce_axis;
VLOG(4) << "ast gen: tensor init_body is " << init_body;
for (int i = 0; i < shape.size(); ++i) {
if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
bool is_keep_dim = axis[i]->is_keepdim;
if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
// if tiling first, we need to replace the reduce axis with 0, but don't
// deal with the non-reduce axis
optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0));
continue;
}
if (!FLAGS_group_schedule_tiling_first &&
FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0));
continue;
}
Expand All @@ -105,29 +118,41 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
/*is_reduce = */ false));
optim::ReplaceVarWithExpr(&init_body, axis[i], block_vars.back());
axis_vars[i]->is_reduce_axis = false;
if (shape[i] == Expr(1)) {
if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) {
iter_values.push_back(Expr(0));
} else {
iter_values.push_back(axis_vars[i]);
}
}
VLOG(4) << "iter_value.size() and block_vars.size() is "
<< iter_values.size() << " " << block_vars.size();
init_body = ir::ScheduleBlockRealize::Make(
iter_values,
ir::ScheduleBlock::Make(
block_vars, {}, {}, reduce_init_name, init_body));

// For the remaining reduce axis, make reduce body
const std::vector<ir::Var>& reduce_axis = tensor->reduce_axis;
ir::Expr reduce_body =
ConvertReduceBody(tensor->body(), tensor, axis_exprs);

VLOG(4) << "ast gen: reduce body is " << reduce_body;

// create schedule block itervars, i0,i1...
std::vector<ir::Var> reduce_block_vars;
std::vector<ir::Expr> reduce_iter_values;
// reduce body and reduce init schedule block should have different objects
// for same axis so we re-create objects
std::vector<Var> reduce_axis_vars = cinn::common::GenDefaultAxis(axis_len);
for (int i = 0; i < shape.size(); ++i) {
if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
bool is_keep_dim = axis[i]->is_keepdim;
if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
// if tiling first, we need to replace the reduce axis with 0, but don't
// deal with the non-reduce axis
optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0));
continue;
}
if (!FLAGS_group_schedule_tiling_first &&
FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0));
continue;
}
Expand All @@ -136,12 +161,13 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
cinn::UniqName("i" + std::to_string(i)),
/*is_reduce = */ false));
reduce_axis_vars[i]->is_reduce_axis = false;
if (shape[i] == Expr(1)) {
if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) {
reduce_iter_values.push_back(Expr(0));
} else {
reduce_iter_values.push_back(axis_vars[i]);
}
}
VLOG(4) << "ast gen: reduce body is after replace 0" << reduce_body;
for (int i = 0; i < reduce_axis.size(); ++i) {
int count = shape.size() + i;
reduce_block_vars.push_back(
Expand All @@ -155,14 +181,43 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
}

int non_zero_axis_size = 0;
for (int i = 0; i < axis.size(); ++i) {
if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
continue;
if (FLAGS_group_schedule_tiling_first) {
std::vector<ir::Var> non_reduce_axis_vars = [&]() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const std::vector<ir::Var>

std::vector<ir::Var> res;
for (int i = 0; i < shape.size(); ++i) {
bool is_keep_dim = axis[i]->is_keepdim;
if (!is_keep_dim) {
res.push_back(axis[i]);
}
}
return res;
}();
for (int i = 0; i < non_reduce_axis_vars.size(); ++i) {
optim::ReplaceVarWithExpr(
&reduce_body, non_reduce_axis_vars[i], reduce_block_vars[i]);
++non_zero_axis_size;
}
optim::ReplaceVarWithExpr(
&reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]);
++non_zero_axis_size;
} else {
for (int i = 0; i < axis.size(); ++i) {
if (!FLAGS_group_schedule_tiling_first &&
FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
continue;
}
optim::ReplaceVarWithExpr(
&reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]);
++non_zero_axis_size;
}
}

VLOG(4) << "to replace : " << non_zero_axis_size << " "
<< reduce_block_vars.size();
for (auto i = 0; i < reduce_block_vars.size(); i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (VLOG_IS_ON(4))
否则LOG级别小于4时也会执行循环。

VLOG(4) << "reduce_block_vars[" << i << "] = " << reduce_block_vars[i];
}
for (auto i = 0; i < reduce_axis.size(); i++) {
VLOG(4) << "reduce_axis[" << i << "] = " << reduce_axis[i];
}
VLOG(4) << "before replace body: " << reduce_body;
for (int i = non_zero_axis_size; i < reduce_block_vars.size(); ++i) {
optim::ReplaceVarWithExpr(&reduce_body,
reduce_axis[i - non_zero_axis_size],
Expand All @@ -185,7 +240,12 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
// Put the two parts together
ir::Expr body = ir::Block::Make({init_body, reduce_body});
for (int i = static_cast<int>(axis_len) - 1; i >= 0; --i) {
if (!FLAGS_cinn_bucket_compile && shape[i] == Expr(1)) {
bool is_keep_dim = axis[i]->is_keepdim;
if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
continue;
}
if (!FLAGS_group_schedule_tiling_first && !FLAGS_cinn_bucket_compile &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(!FLAGS_group_schedule_tiling_first || !FLAGS_cinn_bucket_compile) && shape[i] == Expr(1)

shape[i] == Expr(1)) {
continue;
}
ir::Var loop_var = axis[i];
Expand All @@ -210,7 +270,7 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
Expr(0), shape[i], cinn::UniqName("i" + std::to_string(i)), false));
optim::ReplaceVarWithExpr(&body, axis[i], block_vars[i]);
axis_vars[i]->is_reduce_axis = false;
if (shape[i] == Expr(1)) {
if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) {
iter_values.push_back(Expr(0));
} else {
iter_values.push_back(axis_vars[i]);
Expand Down
8 changes: 8 additions & 0 deletions paddle/cinn/hlir/pe/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ Tensor DoReduce(const Tensor& tensor,
int indice_cnt = 0;
int reduce_cnt = 0;

// Set keepdim flags of indices.
if (tensor->shape.size() == indices.size()) {
for (const auto& i : real_axes) {
VLOG(4) << "Set is_keepdim = true for var(" << i << ")";
indices[i].as_var_ref()->is_keepdim = true;
}
}

for (size_t i = 0; i < tensor->shape.size(); ++i) {
bool squeeze_i = std::find(squeeze_axes.begin(), squeeze_axes.end(), i) !=
squeeze_axes.end();
Expand Down
5 changes: 4 additions & 1 deletion paddle/cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,13 @@ Expr _Var_::Make(Expr lower_bound,
Expr upper_bound,
const std::string &name,
bool is_reduce_axis,
bool is_symbolic_constant) {
bool is_symbolic_constant,
bool is_keepdim) {
auto *n = make_shared<_Var_>();
n->lower_bound = lower_bound;
n->upper_bound = upper_bound;
n->is_reduce_axis = is_reduce_axis;
n->is_keepdim = is_keepdim;
n->is_symbolic_constant = is_symbolic_constant;
n->name = name;
n->set_type(lower_bound.type());
Expand All @@ -233,6 +235,7 @@ Expr _Var_::Copy() const {
auto *n = make_shared<_Var_>();
n->name = name;
n->is_reduce_axis = is_reduce_axis;
n->is_keepdim = is_keepdim;
n->lower_bound = lower_bound;
n->upper_bound = upper_bound;
n->set_type(type());
Expand Down
15 changes: 10 additions & 5 deletions paddle/cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ struct _Var_ : public ExprNode<_Var_> {
std::string name;

bool is_reduce_axis{false};
bool is_keepdim{false};
bool is_symbolic_constant{false};
//! Lower bound and upper bound of a axis.
// @{
Expand All @@ -401,7 +402,8 @@ struct _Var_ : public ExprNode<_Var_> {
Expr upper_bound,
const std::string& name,
bool is_reduce,
bool is_symbolic_constant = false);
bool is_symbolic_constant = false,
bool is_keepdim = false);

void Verify() const override;

Expand All @@ -419,12 +421,14 @@ struct Var : public IrNodeRef {
Var(Expr lower_bound,
Expr upper_bound,
const std::string& name,
bool is_reduce = false)
: Var(_Var_::Make(lower_bound, upper_bound, name, is_reduce)) {}
bool is_reduce = false,
bool is_keepdim = false)
: Var(_Var_::Make(
lower_bound, upper_bound, name, is_reduce, false, is_keepdim)) {}
Var(int upper_bound, const std::string& name)
: Var(_Var_::Make(Expr(0), Expr(upper_bound), name, false)) {}
: Var(_Var_::Make(Expr(0), Expr(upper_bound), name, false, false)) {}
Var(Expr upper_bound, const std::string& name)
: Var(_Var_::Make(Expr(0), upper_bound, name, false)) {}
: Var(_Var_::Make(Expr(0), upper_bound, name, false, false)) {}

operator Expr() { return Expr(get()); }
operator Expr() const {
Expand Down Expand Up @@ -977,6 +981,7 @@ struct ScheduleBlock : public ExprNode<ScheduleBlock> {
std::map<std::string, attr_t> attrs;
std::string name;
Expr body;
int32_t reduce_type{-1}; // 0 for warp reduce, 1 for block reduce

static Expr Make(const std::vector<Var>& iter_vars,
const std::vector<Expr>& read_buffers,
Expand Down
7 changes: 7 additions & 0 deletions paddle/cinn/lang/compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ ir::Tensor Compute(const std::vector<Expr> &domain,
domain_without_reduce_axis,
op,
reduce_axis);
const auto set_keep_dim_for_tensor = [&]() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SetKeepDimForTensor

for (int i = 0; i < _axis.size(); ++i) {
const auto &axis_var = _axis.at(i);
tensor->axis_[i]->is_keepdim = axis_var.as_var_ref()->is_keepdim;
}
};
set_keep_dim_for_tensor();
return tensor;
}

Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/pybind/ir/ir_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ void BindIrIr(py::module *m) {
ir::Expr,
const std::string &,
bool,
bool,
bool>(&ir::_Var_::Make))
.def("copy", &ir::_Var_::Copy);

Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/runtime/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ PD_DEFINE_bool(cinn_bucket_compile,
BoolFromEnv("FLAGS_cinn_bucket_compile", false),
"Whether to enable bucket compile for dynamic shape.");

PD_DEFINE_bool(group_schedule_tiling_first,
BoolFromEnv("FLAGS_group_schedule_tiling_first", false),
"Whether to enable new group scheduler tiling first strategy.");

PD_DEFINE_bool(cinn_use_common_subexpression_elimination,
BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination",
false),
Expand Down