Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 11 additions & 202 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -699,208 +699,17 @@ ir::Expr OpLowererImpl::DoGroupSchedule(
const GroupPtr& group,
const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map,
const std::unordered_map<std::string, ir::Tensor>& tmp_tensor_info) {
if (FLAGS_cinn_new_group_scheduler) {
VLOG(3) << "using StaticShapeGroupScheduler to schedule group.";
std::unordered_set<std::string> output_tensor_names;
std::transform(
group->output_ops.begin(),
group->output_ops.end(),
std::inserter(output_tensor_names, output_tensor_names.begin()),
[&](::pir::Operation* op) { return ValueName(op->result(0)); });
std::unique_ptr<ir::GroupScheduler> group_scheduler =
ir::GroupScheduler::Make(
&ir_sch, output_tensor_names, target_, /* is_dy_shape = */ false);
group_scheduler->Schedule();
return ir_sch.GetModule().GetExprs().at(0);
}
// topological order.
auto ops_set = group->OpSet();
auto v_consumers = BuildVirtualConsumer(group);
auto ops_in_order = BFSTopologicalOrderWithPriority(group, v_consumers);
// find reducer.
std::unordered_set<::pir::Operation*> ops_inline;
auto greducer = FindGlobalReducer(ops_in_order);

// do schedule
for (auto op : ops_in_order) {
VLOG(4) << "Try FUSION " << op->name();
std::string op_name = CompatibleInfo::OpName(*op);
auto op_kind = CompatibleInfo::OpKind(*op);
// consumers.
auto consumers = GetConsumersInSet(op, ops_set);
auto* reducer = greducer ? FindNearestReducer(op, ops_set) : greducer;
if (!reducer && greducer) {
reducer = v_consumers.count(op) ? v_consumers.find(op)->second : reducer;
if (reducer &&
CompatibleInfo::OpKind(*reducer) != framework::kReduction) {
reducer = nullptr;
}
}

auto masters = GetMasters(op, name_gene_, ops_inline, ops_set);
// TODO(Aurelius84): support inline later.
if (CanbeInline(
op, reducer, name_gene_, consumers, masters, group, ops_set) &&
false) {
VLOG(3) << "Before compute inline, ir is:\n"
<< ir_sch.GetModule().GetExprs().at(0);
auto block = ir_sch.GetBlock(ValueName(op->result(0)));
ir::ComputeInlineChecker checker(ir_sch, block);
if (!checker.Check()) {
checker.BuildDataDependency();
continue;
}

// if exist global reduce node.
if (greducer) {
auto loops = ir_sch.GetLoops(ValueName(op->result(0)));
if (op_kind == framework::kElementWise) {
ir_sch.FlattenLoops(loops, true);
} else {
ir_sch.FlattenLoops(loops, false);
}
}

ir_sch.ComputeInline(block);
ops_inline.insert(op);
VLOG(3) << "After compute inline, ir is:\n"
<< ir_sch.GetModule().GetExprs().at(0);
continue;
}
// find master to computeat.
auto master = GetMasterToComputeAt(
op, name_gene_, ops_in_order, ops_inline, ops_set, v_consumers);
std::string op_out_name = ValueName(op->result(0));
// assign to reducer/master loop.
if (reducer) {
VLOG(3) << "Before assign node " << op_name
<< " into vertical link reducer "
<< CompatibleInfo::OpName(*reducer) << ", ir is:\n"
<< ir_sch.GetModule().GetExprs().at(0);
// if node is vertical with reduce, loop assign reducer.
LoopAssignReduce(ir_sch,
op,
reducer,
name_gene_,
this->target_,
tensor_map,
tmp_tensor_info);
} else if (greducer) {
auto greducer_out_shape = CompatibleInfo::ValueShape(greducer->result(0));
auto op_out_shape = CompatibleInfo::ValueShape(op->result(0));
if (CompatibleInfo::ShapeProduct(greducer_out_shape) !=
CompatibleInfo::ShapeProduct(op_out_shape)) {
LoopAssignReduce(ir_sch,
op,
greducer,
name_gene_,
this->target_,
tensor_map,
tmp_tensor_info);
}
} else if (master) {
VLOG(3) << "Before assign node " << op_name
<< " into horizontal link reducer, ir is:\n"
<< ir_sch.GetModule().GetExprs().at(0);
// if node is horizontal with reduce or node is reduce, loop assign
// master.
auto loops = ir_sch.GetLoops(op_out_name);
ir_sch.Fuse(loops);

if (master && op_kind != framework::kReduction) {
auto master_loops = ir_sch.GetLoops(ValueName(master->result(0)));
std::vector<int> splits;
for (auto loop : master_loops) {
splits.push_back(loop.As<ir::For>()->extent.as_int32());
}
loops = ir_sch.GetLoops(op_out_name);
ir_sch.Split(loops[0], splits);
}
}
VLOG(3) << "Before loop fusion, ir is:\n"
<< ir_sch.GetModule().GetExprs().at(0);
// do loop fuse.
LoopComputeAt(ir_sch,
op,
master ? master : ops_in_order.front(),
name_gene_,
group,
tensor_map,
tmp_tensor_info);
VLOG(3) << "After loop fusion, ir is:\n"
<< ir_sch.GetModule().GetExprs().at(0);
}

// do vectorize
auto all_blocks = ir_sch.GetAllBlocks();
VLOG(4) << "Size of blocks: " << all_blocks.size();
VLOG(4) << "Op Pattern : " << group->op_pattern_kind;

// only support first block?
auto block = all_blocks[0];

if (block->as<ir::ScheduleBlockRealize>() == nullptr ||
block->as<ir::ScheduleBlockRealize>()
->schedule_block->as<ir::ScheduleBlock>() == nullptr) {
std::string err_msg =
"Group scheduling, the Expr is not wrapped by ScheduleBlockRealize or "
"ScheduleBlock, cannot be scheduled.";
std::ostringstream detail_info;
detail_info << "Expr:\n";
detail_info << block;
throw CompileErrorHandler(CompilationStatus::LOWERING_FAIL,
err_msg,
detail_info.str(),
__FILE__,
__LINE__);
}
auto is_tensor_block = true;
auto tensor_name = block->as<ir::ScheduleBlockRealize>()
->schedule_block->as<ir::ScheduleBlock>()
->name;
if (!IsInTensorMap(tensor_name, tensor_map)) {
is_tensor_block = false;
}
if (FLAGS_cinn_use_cuda_vectorize && is_tensor_block &&
(group->op_pattern_kind == framework::kElementWise ||
group->op_pattern_kind == framework::kInjective ||
group->op_pattern_kind == framework::kBroadcast)) {
// auto loops = ir_sch.GetLoops(GetNodeData(node)->id());
auto loops = ir_sch.GetLoops(block);
VLOG(4) << "Op Pattern : " << loops.size();
if (loops.size() >= 1) {
VLOG(4) << "Before vectorize, ir is: \n"
<< ir_sch.GetModule().GetExprs().at(0);
auto loop_inner = loops.back();
int vector_width = 1;
auto psize = ir::GetLoopExtent(loop_inner);
auto dtype = GetTensorDtype(tensor_name, tensor_map);
VLOG(4) << tensor_name << " dtype " << dtype;
if (psize % 8 == 0 && (dtype.is_float16() || dtype.is_bfloat16())) {
vector_width = 8;
} else if (psize % 4 == 0) {
vector_width = 4;
} else if (psize % 2 == 0) {
vector_width = 2;
}
if (vector_width > 1) {
auto splited = ir_sch.Split(loop_inner, {-1, vector_width});
splited[0].As<ir::For>()->set_bind_info(
loop_inner.As<ir::For>()->bind_info());
splited[1].As<ir::For>()->set_serial();
ir_sch.Vectorize(splited[1], vector_width);
}
VLOG(4) << "After vectorize, ir is: \n"
<< ir_sch.GetModule().GetExprs().at(0);
}
}

VLOG(3) << "Before Sync IRLowerOp schedule, ir is: \n"
<< ir_sch.GetModule().GetExprs().at(0);
SyncThreadWithShared(
ir_sch, group, name_gene_, ops_inline, ops_set, tensor_map);
VLOG(4) << "After IRSchedule, ir is: \n"
<< ir_sch.GetModule().GetExprs().at(0);
VLOG(3) << "using StaticShapeGroupScheduler to schedule group.";
std::unordered_set<std::string> output_tensor_names;
std::transform(
group->output_ops.begin(),
group->output_ops.end(),
std::inserter(output_tensor_names, output_tensor_names.begin()),
[&](::pir::Operation* op) { return ValueName(op->result(0)); });
std::unique_ptr<ir::GroupScheduler> group_scheduler =
ir::GroupScheduler::Make(
&ir_sch, output_tensor_names, target_, /* is_dy_shape = */ false);
group_scheduler->Schedule();
return ir_sch.GetModule().GetExprs().at(0);
}

Expand Down