diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc index d56fc994fdcea3..9f7a52d97fb178 100644 --- a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc @@ -15,31 +15,36 @@ #include "paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h" #include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h" +#include "paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h" +#include "paddle/cinn/ir/group_schedule/tactic/tile_tactic.h" #include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h" +#include "paddle/cinn/ir/op/ir_operators.h" namespace cinn { namespace ir { void DynamicShapeGroupScheduler::Init() { + // Only 1 bucket for test now. + schedule_context_.target = target_; schedule_context_.output_names = OutputTensorNames(); schedule_context_.global_master = FindGlobalMasterNode(); schedule_context_.iter_space_info = ConstructIterSpaceInfo(schedule_context_.global_master); - schedule_context_.target = target_; + schedule_context_.bucket_info = {/* sp_lower_bound = */ 1024, + /* sp_upper_bound = */ INT_MAX, + /* rb_lower_bound = */ 64, + /* rb_upper_bound = */ INT_MAX}; tactics_.emplace_back(new AlignIterSpaceTactic()); + tactics_.emplace_back(new TileTactic()); tactics_.emplace_back(new ComputeInlineTactic()); + tactics_.emplace_back(new BindCudaTactic()); tactics_.emplace_back(new ArrangeStorageTactic()); } void DynamicShapeGroupScheduler::Schedule() { - // Fake schedule for test ApplyTactics(); - std::vector all_blocks = ir_sch_->GetAllBlocks(); - auto block0_loops = ir_sch_->GetLoops(all_blocks[0]); - auto splited_loops1 = ir_sch_->Split(block0_loops[0], {1024, -1}); - ir_sch_->Bind(splited_loops1[0], "threadIdx.x"); - + // Fake bucket for test ir::Expr predicate1 = ir::LE::Make(Expr(1023), Expr(1024)); std::unique_ptr new_ir_sch1 = std::make_unique(*ir_sch_); @@ -55,12 +60,12 @@ void DynamicShapeGroupScheduler::ApplyTactics() { VLOG(6) << "before applying [" << tactic->TacticName() << "] on ScheduleBlockNode [" << node->id() << "] func body:\n" << ir_sch_->GetModule().GetExprs().front(); - tactic->Init(&schedule_context_); tactic->Apply(ir_sch_, node->id()); VLOG(6) << "after applying [" << tactic->TacticName() << "] on ScheduleBlockNode [" << node->id() << "] func body:\n" << ir_sch_->GetModule().GetExprs().front(); }; + tactic->Init(&schedule_context_); schedule_block_graph_->DFSTopoWalk(ApplyTacticFunc); schedule_block_graph_->Update(*ir_sch_); VLOG(5) << "[End " << tactic->TacticName() @@ -96,6 +101,7 @@ IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo( std::unordered_map iter_var2value = analyzer::GetIterVarToValueOfSBlock(block); + // init iter info if (!reduce_iter_vars.empty()) { std::set reduce_loads = ir::ir_utils::CollectIRNodesWithoutTensor( block, @@ -161,6 +167,20 @@ IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo( info.rb_last_order.push_back(i); } } + // init total extents + ir::Expr sp_extent = ir::Expr(1); + ir::Expr rb_extent = ir::Expr(1); + for (const auto& axis : info.sp_space) { + const ir::Expr& extent = std::get<0>(axis); + sp_extent = sp_extent * extent; + } + for (const auto& axis : info.rb_space) { + const ir::Expr& extent = std::get<0>(axis); + rb_extent = rb_extent * extent; + } + info.total_sp_extent = sp_extent; + info.total_rb_extent = rb_extent; + return info; } diff --git a/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt b/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt index da964e770ae9ba..b12e669b8c2d07 100644 --- a/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt +++ b/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt @@ -1,5 +1,7 @@ core_gather_headers() gather_srcs(cinnapi_src SRCS align_iter_space_tactic.cc) +gather_srcs(cinnapi_src SRCS tile_tactic.cc) gather_srcs(cinnapi_src SRCS compute_inline_tactic.cc) +gather_srcs(cinnapi_src SRCS bind_cuda_tactic.cc) gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc) diff --git a/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h index 69729ce2bfb8c6..ef30f80ce470b2 100644 --- a/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h @@ -15,7 +15,6 @@ #pragma once #include -#include #include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h" namespace cinn { diff --git a/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.cc new file mode 100644 index 00000000000000..0da0ce3bcb396c --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h" +#include +#include "paddle/cinn/ir/ir.h" + +namespace cinn { +namespace ir { + +void BindCudaTactic::Init(ScheduleContext* context) { context_ = context; } + +const std::unordered_map + axis_type2bind_info = { + {IterativeSpaceInfo::AxisType::kCudaBlockX, "blockIdx.x"}, + {IterativeSpaceInfo::AxisType::kCudaBlockY, "blockIdx.y"}, + {IterativeSpaceInfo::AxisType::kCudaBlockZ, "blockIdx.z"}, + {IterativeSpaceInfo::AxisType::kCudaThreadX, "threadIdx.x"}, + {IterativeSpaceInfo::AxisType::kCudaThreadY, "threadIdx.y"}, + {IterativeSpaceInfo::AxisType::kCudaThreadZ, "threadIdx.z"}, +}; + +void BindCudaTactic::Apply(ir::IRSchedule* sch, const std::string& block_id) { + std::vector loops = sch->GetLoops(block_id); + int loop_idx = 0; + for (int i = 0; + i < context_->iter_space_info.sp_space.size() && loop_idx < loops.size(); + ++i, ++loop_idx) { + const auto& axis = context_->iter_space_info.sp_space[i]; + const IterativeSpaceInfo::AxisType& axis_type = std::get<1>(axis); + if (axis_type2bind_info.count(axis_type) != 0) { + sch->Bind(loops[loop_idx], axis_type2bind_info.at(axis_type)); + } + } + for (int i = 0; + i < context_->iter_space_info.rb_space.size() && loop_idx < loops.size(); + ++i, ++loop_idx) { + const auto& axis = context_->iter_space_info.rb_space[i]; + const IterativeSpaceInfo::AxisType& axis_type = std::get<1>(axis); + if (axis_type2bind_info.count(axis_type) != 0) { + sch->Bind(loops[loop_idx], axis_type2bind_info.at(axis_type)); + } + } +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h b/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h new file mode 100644 index 00000000000000..b66c7d1fb802c0 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h @@ -0,0 +1,36 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h" + +namespace cinn { +namespace ir { + +class BindCudaTactic final : public ScheduleTactic { + public: + void Init(ScheduleContext* context) override; + + void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + + std::string TacticName() const override { return "BindCudaTactic"; } + + private: + ScheduleContext* context_; +}; + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h b/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h index 4084c69bf493ae..05c258b82c47ce 100644 --- a/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h @@ -36,6 +36,10 @@ struct IterativeSpaceInfo { std::vector> sp_space; // reduce or broadcast iterative space std::vector> rb_space; + // total sp extent + ir::Expr total_sp_extent; + // total rb extent + ir::Expr total_rb_extent; // original loop order with same iteration order as the memory order std::vector memory_consistent_order_space; // index that transform from memory consistent order to rb last order @@ -45,11 +49,19 @@ struct IterativeSpaceInfo { std::vector rb_last_order; }; +struct BucketInfo { + int sp_lower_bound = 0; + int sp_upper_bound = UINT_MAX; + int rb_lower_bound = 0; + int rb_upper_bound = UINT_MAX; +}; + struct ScheduleContext { std::unordered_set output_names; ScheduleBlockNode* global_master; IterativeSpaceInfo iter_space_info; Target target; + BucketInfo bucket_info; }; class ScheduleTactic { diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc new file mode 100644 index 00000000000000..3cace2636f2d39 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ir/group_schedule/tactic/tile_tactic.h" +#include "paddle/cinn/ir/ir.h" + +namespace cinn { +namespace ir { + +void TileTactic::Init(ScheduleContext* context) { + context_ = context; + // fake strategy + auto GetFirstFactor = [](int num) { + int factor = 1; + for (int i = num - 1; i >= 1; --i) { + if (num % i == 0) { + return i; + } + } + }; + + bool has_rb_iter = !context_->iter_space_info.rb_space.empty(); + bool has_sp_iter = !context_->iter_space_info.sp_space.empty(); + context_->iter_space_info.rb_space.clear(); + context_->iter_space_info.sp_space.clear(); + + if (has_sp_iter) { + int sp_factor = GetFirstFactor(context_->bucket_info.sp_lower_bound); + context_->iter_space_info.sp_space.emplace_back( + ir::Expr(context_->bucket_info.sp_lower_bound / sp_factor), + IterativeSpaceInfo::AxisType::kCudaBlockX); + context_->iter_space_info.sp_space.emplace_back( + ir::Expr(sp_factor), + has_rb_iter ? IterativeSpaceInfo::AxisType::kCudaThreadY + : IterativeSpaceInfo::AxisType::kCudaThreadX); + context_->iter_space_info.sp_space.emplace_back( + ir::Expr(-1), IterativeSpaceInfo::AxisType::kSerial); + } + + if (has_rb_iter) { + context_->iter_space_info.rb_space.emplace_back( + ir::Expr(context_->bucket_info.rb_lower_bound), + IterativeSpaceInfo::AxisType::kCudaThreadX); + context_->iter_space_info.rb_space.emplace_back( + ir::Expr(-1), IterativeSpaceInfo::AxisType::kSerial); + } +} + +void TileTactic::Apply(ir::IRSchedule* sch, const std::string& block_id) { + std::vector loops = sch->GetLoops(block_id); + CHECK(loops.size() == 1 || loops.size() == 2) + << "All loops must be unified as sp_loop or rb_loop."; + if (loops.size() == 2) { + std::vector rb_factors; + for (const auto& axis : context_->iter_space_info.rb_space) { + rb_factors.push_back(std::get<0>(axis)); + } + sch->Split(loops[1], rb_factors); + loops = sch->GetLoops(block_id); + VLOG(6) << "after split rb loop of " << block_id << ": " + << sch->GetModule().GetExprs()[0]; + } + std::vector sp_factors; + for (const auto& axis : context_->iter_space_info.sp_space) { + sp_factors.push_back(std::get<0>(axis)); + } + sch->Split(loops[0], sp_factors); + VLOG(6) << "after split sp loop of " << block_id << ": " + << sch->GetModule().GetExprs()[0]; +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_tactic.h b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.h new file mode 100644 index 00000000000000..8a6d2bb8dd7668 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.h @@ -0,0 +1,36 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h" + +namespace cinn { +namespace ir { + +class TileTactic final : public ScheduleTactic { + public: + void Init(ScheduleContext* context) override; + + void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + + std::string TacticName() const override { return "TileTactic"; } + + private: + ScheduleContext* context_; +}; + +} // namespace ir +} // namespace cinn