@@ -699,208 +699,17 @@ ir::Expr OpLowererImpl::DoGroupSchedule(
699699 const GroupPtr& group,
700700 const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map,
701701 const std::unordered_map<std::string, ir::Tensor>& tmp_tensor_info) {
702- if (FLAGS_cinn_new_group_scheduler) {
703- VLOG (3 ) << " using StaticShapeGroupScheduler to schedule group." ;
704- std::unordered_set<std::string> output_tensor_names;
705- std::transform (
706- group->output_ops .begin (),
707- group->output_ops .end (),
708- std::inserter (output_tensor_names, output_tensor_names.begin ()),
709- [&](::pir::Operation* op) { return ValueName (op->result (0 )); });
710- std::unique_ptr<ir::GroupScheduler> group_scheduler =
711- ir::GroupScheduler::Make (
712- &ir_sch, output_tensor_names, target_, /* is_dy_shape = */ false );
713- group_scheduler->Schedule ();
714- return ir_sch.GetModule ().GetExprs ().at (0 );
715- }
716- // topological order.
717- auto ops_set = group->OpSet ();
718- auto v_consumers = BuildVirtualConsumer (group);
719- auto ops_in_order = BFSTopologicalOrderWithPriority (group, v_consumers);
720- // find reducer.
721- std::unordered_set<::pir::Operation*> ops_inline;
722- auto greducer = FindGlobalReducer (ops_in_order);
723-
724- // do schedule
725- for (auto op : ops_in_order) {
726- VLOG (4 ) << " Try FUSION " << op->name ();
727- std::string op_name = CompatibleInfo::OpName (*op);
728- auto op_kind = CompatibleInfo::OpKind (*op);
729- // consumers.
730- auto consumers = GetConsumersInSet (op, ops_set);
731- auto * reducer = greducer ? FindNearestReducer (op, ops_set) : greducer;
732- if (!reducer && greducer) {
733- reducer = v_consumers.count (op) ? v_consumers.find (op)->second : reducer;
734- if (reducer &&
735- CompatibleInfo::OpKind (*reducer) != framework::kReduction ) {
736- reducer = nullptr ;
737- }
738- }
739-
740- auto masters = GetMasters (op, name_gene_, ops_inline, ops_set);
741- // TODO(Aurelius84): support inline later.
742- if (CanbeInline (
743- op, reducer, name_gene_, consumers, masters, group, ops_set) &&
744- false ) {
745- VLOG (3 ) << " Before compute inline, ir is:\n "
746- << ir_sch.GetModule ().GetExprs ().at (0 );
747- auto block = ir_sch.GetBlock (ValueName (op->result (0 )));
748- ir::ComputeInlineChecker checker (ir_sch, block);
749- if (!checker.Check ()) {
750- checker.BuildDataDependency ();
751- continue ;
752- }
753-
754- // if exist global reduce node.
755- if (greducer) {
756- auto loops = ir_sch.GetLoops (ValueName (op->result (0 )));
757- if (op_kind == framework::kElementWise ) {
758- ir_sch.FlattenLoops (loops, true );
759- } else {
760- ir_sch.FlattenLoops (loops, false );
761- }
762- }
763-
764- ir_sch.ComputeInline (block);
765- ops_inline.insert (op);
766- VLOG (3 ) << " After compute inline, ir is:\n "
767- << ir_sch.GetModule ().GetExprs ().at (0 );
768- continue ;
769- }
770- // find master to computeat.
771- auto master = GetMasterToComputeAt (
772- op, name_gene_, ops_in_order, ops_inline, ops_set, v_consumers);
773- std::string op_out_name = ValueName (op->result (0 ));
774- // assign to reducer/master loop.
775- if (reducer) {
776- VLOG (3 ) << " Before assign node " << op_name
777- << " into vertical link reducer "
778- << CompatibleInfo::OpName (*reducer) << " , ir is:\n "
779- << ir_sch.GetModule ().GetExprs ().at (0 );
780- // if node is vertical with reduce, loop assign reducer.
781- LoopAssignReduce (ir_sch,
782- op,
783- reducer,
784- name_gene_,
785- this ->target_ ,
786- tensor_map,
787- tmp_tensor_info);
788- } else if (greducer) {
789- auto greducer_out_shape = CompatibleInfo::ValueShape (greducer->result (0 ));
790- auto op_out_shape = CompatibleInfo::ValueShape (op->result (0 ));
791- if (CompatibleInfo::ShapeProduct (greducer_out_shape) !=
792- CompatibleInfo::ShapeProduct (op_out_shape)) {
793- LoopAssignReduce (ir_sch,
794- op,
795- greducer,
796- name_gene_,
797- this ->target_ ,
798- tensor_map,
799- tmp_tensor_info);
800- }
801- } else if (master) {
802- VLOG (3 ) << " Before assign node " << op_name
803- << " into horizontal link reducer, ir is:\n "
804- << ir_sch.GetModule ().GetExprs ().at (0 );
805- // if node is horizontal with reduce or node is reduce, loop assign
806- // master.
807- auto loops = ir_sch.GetLoops (op_out_name);
808- ir_sch.Fuse (loops);
809-
810- if (master && op_kind != framework::kReduction ) {
811- auto master_loops = ir_sch.GetLoops (ValueName (master->result (0 )));
812- std::vector<int > splits;
813- for (auto loop : master_loops) {
814- splits.push_back (loop.As <ir::For>()->extent .as_int32 ());
815- }
816- loops = ir_sch.GetLoops (op_out_name);
817- ir_sch.Split (loops[0 ], splits);
818- }
819- }
820- VLOG (3 ) << " Before loop fusion, ir is:\n "
821- << ir_sch.GetModule ().GetExprs ().at (0 );
822- // do loop fuse.
823- LoopComputeAt (ir_sch,
824- op,
825- master ? master : ops_in_order.front (),
826- name_gene_,
827- group,
828- tensor_map,
829- tmp_tensor_info);
830- VLOG (3 ) << " After loop fusion, ir is:\n "
831- << ir_sch.GetModule ().GetExprs ().at (0 );
832- }
833-
834- // do vectorize
835- auto all_blocks = ir_sch.GetAllBlocks ();
836- VLOG (4 ) << " Size of blocks: " << all_blocks.size ();
837- VLOG (4 ) << " Op Pattern : " << group->op_pattern_kind ;
838-
839- // only support first block?
840- auto block = all_blocks[0 ];
841-
842- if (block->as <ir::ScheduleBlockRealize>() == nullptr ||
843- block->as <ir::ScheduleBlockRealize>()
844- ->schedule_block ->as <ir::ScheduleBlock>() == nullptr ) {
845- std::string err_msg =
846- " Group scheduling, the Expr is not wrapped by ScheduleBlockRealize or "
847- " ScheduleBlock, cannot be scheduled." ;
848- std::ostringstream detail_info;
849- detail_info << " Expr:\n " ;
850- detail_info << block;
851- throw CompileErrorHandler (CompilationStatus::LOWERING_FAIL,
852- err_msg,
853- detail_info.str (),
854- __FILE__,
855- __LINE__);
856- }
857- auto is_tensor_block = true ;
858- auto tensor_name = block->as <ir::ScheduleBlockRealize>()
859- ->schedule_block ->as <ir::ScheduleBlock>()
860- ->name ;
861- if (!IsInTensorMap (tensor_name, tensor_map)) {
862- is_tensor_block = false ;
863- }
864- if (FLAGS_cinn_use_cuda_vectorize && is_tensor_block &&
865- (group->op_pattern_kind == framework::kElementWise ||
866- group->op_pattern_kind == framework::kInjective ||
867- group->op_pattern_kind == framework::kBroadcast )) {
868- // auto loops = ir_sch.GetLoops(GetNodeData(node)->id());
869- auto loops = ir_sch.GetLoops (block);
870- VLOG (4 ) << " Op Pattern : " << loops.size ();
871- if (loops.size () >= 1 ) {
872- VLOG (4 ) << " Before vectorize, ir is: \n "
873- << ir_sch.GetModule ().GetExprs ().at (0 );
874- auto loop_inner = loops.back ();
875- int vector_width = 1 ;
876- auto psize = ir::GetLoopExtent (loop_inner);
877- auto dtype = GetTensorDtype (tensor_name, tensor_map);
878- VLOG (4 ) << tensor_name << " dtype " << dtype;
879- if (psize % 8 == 0 && (dtype.is_float16 () || dtype.is_bfloat16 ())) {
880- vector_width = 8 ;
881- } else if (psize % 4 == 0 ) {
882- vector_width = 4 ;
883- } else if (psize % 2 == 0 ) {
884- vector_width = 2 ;
885- }
886- if (vector_width > 1 ) {
887- auto splited = ir_sch.Split (loop_inner, {-1 , vector_width});
888- splited[0 ].As <ir::For>()->set_bind_info (
889- loop_inner.As <ir::For>()->bind_info ());
890- splited[1 ].As <ir::For>()->set_serial ();
891- ir_sch.Vectorize (splited[1 ], vector_width);
892- }
893- VLOG (4 ) << " After vectorize, ir is: \n "
894- << ir_sch.GetModule ().GetExprs ().at (0 );
895- }
896- }
897-
898- VLOG (3 ) << " Before Sync IRLowerOp schedule, ir is: \n "
899- << ir_sch.GetModule ().GetExprs ().at (0 );
900- SyncThreadWithShared (
901- ir_sch, group, name_gene_, ops_inline, ops_set, tensor_map);
902- VLOG (4 ) << " After IRSchedule, ir is: \n "
903- << ir_sch.GetModule ().GetExprs ().at (0 );
702+ VLOG (3 ) << " using StaticShapeGroupScheduler to schedule group." ;
703+ std::unordered_set<std::string> output_tensor_names;
704+ std::transform (
705+ group->output_ops .begin (),
706+ group->output_ops .end (),
707+ std::inserter (output_tensor_names, output_tensor_names.begin ()),
708+ [&](::pir::Operation* op) { return ValueName (op->result (0 )); });
709+ std::unique_ptr<ir::GroupScheduler> group_scheduler =
710+ ir::GroupScheduler::Make (
711+ &ir_sch, output_tensor_names, target_, /* is_dy_shape = */ false );
712+ group_scheduler->Schedule ();
904713 return ir_sch.GetModule ().GetExprs ().at (0 );
905714}
906715
0 commit comments