Skip to content

Commit ce50c3d

Browse files
authored
[CINN+PIR]Clean Old GroupScheduler logic and switch into new_group_scheduler (#60642)
1 parent 3324c9d commit ce50c3d

File tree

1 file changed

+11
-202
lines changed

1 file changed

+11
-202
lines changed

paddle/cinn/hlir/framework/pir/op_lowering_impl.cc

Lines changed: 11 additions & 202 deletions
Original file line numberDiff line numberDiff line change
@@ -699,208 +699,17 @@ ir::Expr OpLowererImpl::DoGroupSchedule(
699699
const GroupPtr& group,
700700
const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map,
701701
const std::unordered_map<std::string, ir::Tensor>& tmp_tensor_info) {
702-
if (FLAGS_cinn_new_group_scheduler) {
703-
VLOG(3) << "using StaticShapeGroupScheduler to schedule group.";
704-
std::unordered_set<std::string> output_tensor_names;
705-
std::transform(
706-
group->output_ops.begin(),
707-
group->output_ops.end(),
708-
std::inserter(output_tensor_names, output_tensor_names.begin()),
709-
[&](::pir::Operation* op) { return ValueName(op->result(0)); });
710-
std::unique_ptr<ir::GroupScheduler> group_scheduler =
711-
ir::GroupScheduler::Make(
712-
&ir_sch, output_tensor_names, target_, /* is_dy_shape = */ false);
713-
group_scheduler->Schedule();
714-
return ir_sch.GetModule().GetExprs().at(0);
715-
}
716-
// topological order.
717-
auto ops_set = group->OpSet();
718-
auto v_consumers = BuildVirtualConsumer(group);
719-
auto ops_in_order = BFSTopologicalOrderWithPriority(group, v_consumers);
720-
// find reducer.
721-
std::unordered_set<::pir::Operation*> ops_inline;
722-
auto greducer = FindGlobalReducer(ops_in_order);
723-
724-
// do schedule
725-
for (auto op : ops_in_order) {
726-
VLOG(4) << "Try FUSION " << op->name();
727-
std::string op_name = CompatibleInfo::OpName(*op);
728-
auto op_kind = CompatibleInfo::OpKind(*op);
729-
// consumers.
730-
auto consumers = GetConsumersInSet(op, ops_set);
731-
auto* reducer = greducer ? FindNearestReducer(op, ops_set) : greducer;
732-
if (!reducer && greducer) {
733-
reducer = v_consumers.count(op) ? v_consumers.find(op)->second : reducer;
734-
if (reducer &&
735-
CompatibleInfo::OpKind(*reducer) != framework::kReduction) {
736-
reducer = nullptr;
737-
}
738-
}
739-
740-
auto masters = GetMasters(op, name_gene_, ops_inline, ops_set);
741-
// TODO(Aurelius84): support inline later.
742-
if (CanbeInline(
743-
op, reducer, name_gene_, consumers, masters, group, ops_set) &&
744-
false) {
745-
VLOG(3) << "Before compute inline, ir is:\n"
746-
<< ir_sch.GetModule().GetExprs().at(0);
747-
auto block = ir_sch.GetBlock(ValueName(op->result(0)));
748-
ir::ComputeInlineChecker checker(ir_sch, block);
749-
if (!checker.Check()) {
750-
checker.BuildDataDependency();
751-
continue;
752-
}
753-
754-
// if exist global reduce node.
755-
if (greducer) {
756-
auto loops = ir_sch.GetLoops(ValueName(op->result(0)));
757-
if (op_kind == framework::kElementWise) {
758-
ir_sch.FlattenLoops(loops, true);
759-
} else {
760-
ir_sch.FlattenLoops(loops, false);
761-
}
762-
}
763-
764-
ir_sch.ComputeInline(block);
765-
ops_inline.insert(op);
766-
VLOG(3) << "After compute inline, ir is:\n"
767-
<< ir_sch.GetModule().GetExprs().at(0);
768-
continue;
769-
}
770-
// find master to computeat.
771-
auto master = GetMasterToComputeAt(
772-
op, name_gene_, ops_in_order, ops_inline, ops_set, v_consumers);
773-
std::string op_out_name = ValueName(op->result(0));
774-
// assign to reducer/master loop.
775-
if (reducer) {
776-
VLOG(3) << "Before assign node " << op_name
777-
<< " into vertical link reducer "
778-
<< CompatibleInfo::OpName(*reducer) << ", ir is:\n"
779-
<< ir_sch.GetModule().GetExprs().at(0);
780-
// if node is vertical with reduce, loop assign reducer.
781-
LoopAssignReduce(ir_sch,
782-
op,
783-
reducer,
784-
name_gene_,
785-
this->target_,
786-
tensor_map,
787-
tmp_tensor_info);
788-
} else if (greducer) {
789-
auto greducer_out_shape = CompatibleInfo::ValueShape(greducer->result(0));
790-
auto op_out_shape = CompatibleInfo::ValueShape(op->result(0));
791-
if (CompatibleInfo::ShapeProduct(greducer_out_shape) !=
792-
CompatibleInfo::ShapeProduct(op_out_shape)) {
793-
LoopAssignReduce(ir_sch,
794-
op,
795-
greducer,
796-
name_gene_,
797-
this->target_,
798-
tensor_map,
799-
tmp_tensor_info);
800-
}
801-
} else if (master) {
802-
VLOG(3) << "Before assign node " << op_name
803-
<< " into horizontal link reducer, ir is:\n"
804-
<< ir_sch.GetModule().GetExprs().at(0);
805-
// if node is horizontal with reduce or node is reduce, loop assign
806-
// master.
807-
auto loops = ir_sch.GetLoops(op_out_name);
808-
ir_sch.Fuse(loops);
809-
810-
if (master && op_kind != framework::kReduction) {
811-
auto master_loops = ir_sch.GetLoops(ValueName(master->result(0)));
812-
std::vector<int> splits;
813-
for (auto loop : master_loops) {
814-
splits.push_back(loop.As<ir::For>()->extent.as_int32());
815-
}
816-
loops = ir_sch.GetLoops(op_out_name);
817-
ir_sch.Split(loops[0], splits);
818-
}
819-
}
820-
VLOG(3) << "Before loop fusion, ir is:\n"
821-
<< ir_sch.GetModule().GetExprs().at(0);
822-
// do loop fuse.
823-
LoopComputeAt(ir_sch,
824-
op,
825-
master ? master : ops_in_order.front(),
826-
name_gene_,
827-
group,
828-
tensor_map,
829-
tmp_tensor_info);
830-
VLOG(3) << "After loop fusion, ir is:\n"
831-
<< ir_sch.GetModule().GetExprs().at(0);
832-
}
833-
834-
// do vectorize
835-
auto all_blocks = ir_sch.GetAllBlocks();
836-
VLOG(4) << "Size of blocks: " << all_blocks.size();
837-
VLOG(4) << "Op Pattern : " << group->op_pattern_kind;
838-
839-
// only support first block?
840-
auto block = all_blocks[0];
841-
842-
if (block->as<ir::ScheduleBlockRealize>() == nullptr ||
843-
block->as<ir::ScheduleBlockRealize>()
844-
->schedule_block->as<ir::ScheduleBlock>() == nullptr) {
845-
std::string err_msg =
846-
"Group scheduling, the Expr is not wrapped by ScheduleBlockRealize or "
847-
"ScheduleBlock, cannot be scheduled.";
848-
std::ostringstream detail_info;
849-
detail_info << "Expr:\n";
850-
detail_info << block;
851-
throw CompileErrorHandler(CompilationStatus::LOWERING_FAIL,
852-
err_msg,
853-
detail_info.str(),
854-
__FILE__,
855-
__LINE__);
856-
}
857-
auto is_tensor_block = true;
858-
auto tensor_name = block->as<ir::ScheduleBlockRealize>()
859-
->schedule_block->as<ir::ScheduleBlock>()
860-
->name;
861-
if (!IsInTensorMap(tensor_name, tensor_map)) {
862-
is_tensor_block = false;
863-
}
864-
if (FLAGS_cinn_use_cuda_vectorize && is_tensor_block &&
865-
(group->op_pattern_kind == framework::kElementWise ||
866-
group->op_pattern_kind == framework::kInjective ||
867-
group->op_pattern_kind == framework::kBroadcast)) {
868-
// auto loops = ir_sch.GetLoops(GetNodeData(node)->id());
869-
auto loops = ir_sch.GetLoops(block);
870-
VLOG(4) << "Op Pattern : " << loops.size();
871-
if (loops.size() >= 1) {
872-
VLOG(4) << "Before vectorize, ir is: \n"
873-
<< ir_sch.GetModule().GetExprs().at(0);
874-
auto loop_inner = loops.back();
875-
int vector_width = 1;
876-
auto psize = ir::GetLoopExtent(loop_inner);
877-
auto dtype = GetTensorDtype(tensor_name, tensor_map);
878-
VLOG(4) << tensor_name << " dtype " << dtype;
879-
if (psize % 8 == 0 && (dtype.is_float16() || dtype.is_bfloat16())) {
880-
vector_width = 8;
881-
} else if (psize % 4 == 0) {
882-
vector_width = 4;
883-
} else if (psize % 2 == 0) {
884-
vector_width = 2;
885-
}
886-
if (vector_width > 1) {
887-
auto splited = ir_sch.Split(loop_inner, {-1, vector_width});
888-
splited[0].As<ir::For>()->set_bind_info(
889-
loop_inner.As<ir::For>()->bind_info());
890-
splited[1].As<ir::For>()->set_serial();
891-
ir_sch.Vectorize(splited[1], vector_width);
892-
}
893-
VLOG(4) << "After vectorize, ir is: \n"
894-
<< ir_sch.GetModule().GetExprs().at(0);
895-
}
896-
}
897-
898-
VLOG(3) << "Before Sync IRLowerOp schedule, ir is: \n"
899-
<< ir_sch.GetModule().GetExprs().at(0);
900-
SyncThreadWithShared(
901-
ir_sch, group, name_gene_, ops_inline, ops_set, tensor_map);
902-
VLOG(4) << "After IRSchedule, ir is: \n"
903-
<< ir_sch.GetModule().GetExprs().at(0);
702+
VLOG(3) << "using StaticShapeGroupScheduler to schedule group.";
703+
std::unordered_set<std::string> output_tensor_names;
704+
std::transform(
705+
group->output_ops.begin(),
706+
group->output_ops.end(),
707+
std::inserter(output_tensor_names, output_tensor_names.begin()),
708+
[&](::pir::Operation* op) { return ValueName(op->result(0)); });
709+
std::unique_ptr<ir::GroupScheduler> group_scheduler =
710+
ir::GroupScheduler::Make(
711+
&ir_sch, output_tensor_names, target_, /* is_dy_shape = */ false);
712+
group_scheduler->Schedule();
904713
return ir_sch.GetModule().GetExprs().at(0);
905714
}
906715

0 commit comments

Comments
 (0)