Skip to content

Commit bd29981

Browse files
authored
add split with variable in factors and rewrite vectorize,unroll,bind error handling mechanism (#60449)
1 parent 290bf41 commit bd29981

File tree

5 files changed

+102
-10
lines changed

5 files changed

+102
-10
lines changed

paddle/cinn/ir/schedule/impl/for_type.cc

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,37 @@ void DyScheduleImpl::Parallel(const Expr& loop) {
6363
}
6464

6565
void DyScheduleImpl::Vectorize(const Expr& loop, int factor) {
66+
CINN_IR_SCHEDULE_BEGIN();
67+
std::string primitive = "Vectorize";
68+
std::ostringstream os;
6669
CHECK_GT(factor, 0) << "vectorize factor should be more than 0";
67-
CHECK(loop.As<For>()->extent.is_constant())
68-
<< "The loop to be vectorized should be constant!\n";
70+
if (factor <= 0) {
71+
os << "vectorize factor should be more than 0\n";
72+
throw IRScheduleErrorHandler(primitive, os.str(), module_expr_);
73+
}
74+
if (!loop.As<For>()->extent.is_constant()) {
75+
os << "The loop to be vectorized should be constant!\n";
76+
throw IRScheduleErrorHandler(primitive, os.str(), module_expr_);
77+
}
6978
MutateForType(loop, ForType::Vectorized, factor);
79+
CINN_IR_SCHEDULE_END(this->err_msg_level_);
7080
}
7181

7282
void DyScheduleImpl::Unroll(const Expr& loop) {
73-
CHECK(loop.As<For>()->extent.is_constant())
74-
<< "The loop to be unrolled should be constant!\n";
83+
CINN_IR_SCHEDULE_BEGIN();
84+
std::string primitive = "Unroll";
85+
std::ostringstream os;
86+
if (!loop.As<For>()->extent.is_constant()) {
87+
os << "The loop to be unrolled should be constant!\n";
88+
throw IRScheduleErrorHandler(primitive, os.str(), module_expr_);
89+
}
7590
MutateForType(loop, ForType::Unrolled);
91+
CINN_IR_SCHEDULE_END(this->err_msg_level_);
7692
}
7793

7894
void DyScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) {
7995
#ifdef CINN_WITH_CUDA
96+
CINN_IR_SCHEDULE_BEGIN();
8097
std::string primitive = "Bind";
8198
std::ostringstream os;
8299

@@ -117,6 +134,7 @@ void DyScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) {
117134
}
118135
MutateForType(loop, ForType::GPUThread, offset);
119136
}
137+
CINN_IR_SCHEDULE_END(this->err_msg_level_);
120138
#endif
121139
}
122140
} // namespace ir

paddle/cinn/ir/schedule/impl/ir_schedule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class DyScheduleImpl : public ScheduleBase {
4949
std::vector<Expr> GetChildBlocks(const Expr& expr) const;
5050
Expr GetBlock(const std::string& block_name) const;
5151
std::vector<Expr> Split(const Expr& loop, const std::vector<int>& factors);
52+
std::vector<Expr> Split(const Expr& loop, const std::vector<Expr>& factors);
5253
std::vector<Expr> SamplePerfectTile(
5354
utils::LinearRandomEngine::StateType* rand_seed,
5455
const Expr& loop,
@@ -122,6 +123,7 @@ class StScheduleImpl : public ScheduleBase {
122123
std::vector<Expr> GetChildBlocks(const Expr& expr) const;
123124
Expr GetBlock(const std::string& block_name) const;
124125
std::vector<Expr> Split(const Expr& loop, const std::vector<int>& factors);
126+
std::vector<Expr> Split(const Expr& loop, const std::vector<Expr>& factors);
125127
std::vector<Expr> SamplePerfectTile(
126128
utils::LinearRandomEngine::StateType* rand_seed,
127129
const Expr& loop,

paddle/cinn/ir/schedule/impl/loop_transformation.cc

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/cinn/common/macros.h"
1615
#include "paddle/cinn/ir/schedule/impl/ir_schedule.h"
1716

17+
#include "paddle/cinn/common/integer_set.h"
18+
#include "paddle/cinn/common/macros.h"
19+
1820
/** \brief A macro that guards the beginning of each implementation of schedule
1921
*/
2022
#define CINN_IR_SCHEDULE_BEGIN() try {
@@ -157,6 +159,63 @@ std::vector<Expr> DyScheduleImpl::Split(const Expr& loop,
157159
return splited_loops;
158160
}
159161

162+
// TODO(@LiuYang): now -1 can't exsit in factors,
163+
std::vector<Expr> DyScheduleImpl::Split(const Expr& loop,
164+
const std::vector<Expr>& factors) {
165+
CHECK(loop.As<ir::For>())
166+
<< "Expr param of Split must be For node! Please check.";
167+
auto* for_node = loop.As<ir::For>();
168+
CHECK(common::is_zero(for_node->min))
169+
<< "The For node must start with 0! Please check.";
170+
CHECK(!factors.empty())
171+
<< "The factors param of Split should not be empty! Please check.";
172+
CHECK(!loop.As<ir::For>()->extent.is_constant())
173+
<< "Can't Split a loop with constant extent but with variable in "
174+
"factors!";
175+
Expr tot_extent = for_node->extent;
176+
177+
VLOG(3) << "Try Split loop from (" << for_node->loop_var->name << ", 0, "
178+
<< tot_extent << ") to (" << cinn::utils::Join(factors, ", ")
179+
<< ") at loop:\n"
180+
<< loop;
181+
182+
std::vector<Expr> process_factors(factors);
183+
Expr prod_size(1);
184+
for (auto factor : factors) prod_size = prod_size * Expr(factor);
185+
cinn::common::SymbolicExprAnalyzer analyzer({});
186+
CHECK(analyzer.ProveEQ(tot_extent, prod_size).value_or(false))
187+
<< "Product of factors can't be proved to be equal to the extent of "
188+
"current for loop!";
189+
190+
std::vector<Var> new_loop_vars;
191+
Expr substitute_value(0);
192+
for (int i = 0; i < process_factors.size(); ++i) {
193+
Var temp_var(common::UniqName(for_node->loop_var->name));
194+
substitute_value = Expr(temp_var) + substitute_value * process_factors[i];
195+
new_loop_vars.push_back(temp_var);
196+
}
197+
substitute_value = cinn::common::AutoSimplify(substitute_value);
198+
Expr new_node = ir::ir_utils::IRCopy(for_node->body);
199+
ReplaceExpr(&new_node, {for_node->loop_var}, {substitute_value});
200+
std::vector<Expr> splited_loops;
201+
splited_loops.resize(process_factors.size());
202+
203+
for (int i = process_factors.size() - 1; i >= 0; i--) {
204+
if (!new_node.As<ir::Block>()) new_node = Block::Make({new_node});
205+
new_node = For::Make(new_loop_vars[i],
206+
Expr(0),
207+
process_factors[i],
208+
for_node->for_type(),
209+
for_node->device_api,
210+
new_node);
211+
splited_loops[i] = new_node;
212+
}
213+
214+
this->Replace(loop, new_node);
215+
VLOG(3) << "After Split, ir is:\n" << splited_loops.at(0);
216+
return splited_loops;
217+
}
218+
160219
Expr DyScheduleImpl::Fuse(const std::vector<Expr>& loops) {
161220
VLOG(3) << "Tring to fuse:\n" << cinn::utils::Join(loops, "\n");
162221
std::vector<const ir::For*> for_nodes;
@@ -370,6 +429,12 @@ std::vector<Expr> StScheduleImpl::Split(const Expr& loop,
370429
return splited_loops;
371430
}
372431

432+
std::vector<Expr> StScheduleImpl::Split(const Expr& loop,
433+
const std::vector<Expr>& factors) {
434+
CHECK(false) << "Static shape schedule don't support Split with some "
435+
"variables in factors";
436+
}
437+
373438
Expr StScheduleImpl::Fuse(const std::vector<Expr>& loops) {
374439
VLOG(3) << "Tring to fuse:\n" << cinn::utils::Join(loops, "\n");
375440
std::vector<const ir::For*> for_nodes;

paddle/cinn/ir/schedule/ir_schedule.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,11 +405,16 @@ std::vector<Expr> IRSchedule::Split(const std::string& block_name,
405405
std::vector<Expr> IRSchedule::Split(const Expr& loop,
406406
const std::vector<Expr>& factors) {
407407
std::vector<int> int_factors;
408-
std::transform(factors.begin(),
409-
factors.end(),
410-
std::back_inserter(int_factors),
411-
[](Expr x) { return x.as_int32(); });
412-
auto results = impl_->Split(loop, int_factors);
408+
std::vector<Expr> results;
409+
std::for_each(factors.begin(), factors.end(), [&int_factors](const Expr& e) {
410+
if (e.is_constant()) int_factors.push_back(e.as_int32());
411+
});
412+
if (int_factors.size() == factors.size()) {
413+
results = impl_->Split(loop, int_factors);
414+
} else {
415+
results = impl_->Split(loop, factors);
416+
}
417+
413418
trace_.Append(ScheduleDesc::Step(
414419
"Split",
415420
{{"loop", std::vector<Expr>({loop})}, {"factors", factors}},

paddle/cinn/ir/schedule/schedule_base.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ class ScheduleBase {
9797
virtual Expr GetBlock(const std::string& block_name) const = 0;
9898
virtual std::vector<Expr> Split(const Expr& loop,
9999
const std::vector<int>& factors) = 0;
100+
virtual std::vector<Expr> Split(const Expr& loop,
101+
const std::vector<Expr>& factors) = 0;
100102
virtual std::vector<Expr> SamplePerfectTile(
101103
utils::LinearRandomEngine::StateType* rand_seed,
102104
const Expr& loop,

0 commit comments

Comments
 (0)