diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc index 262922a7ef7b97..6bd7513da39d71 100644 --- a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc @@ -18,7 +18,6 @@ #include "paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h" -#include "paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/tile_tactic.h" diff --git a/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt b/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt index b6a2f067606468..f92d2caa966c2d 100644 --- a/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt +++ b/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt @@ -6,5 +6,4 @@ gather_srcs(cinnapi_src SRCS compute_inline_tactic.cc) gather_srcs(cinnapi_src SRCS optimize_reduction_tactic.cc) gather_srcs(cinnapi_src SRCS bind_cuda_tactic.cc) gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc) -gather_srcs(cinnapi_src SRCS loop_reorder_alignment_tactic.cc) gather_srcs(cinnapi_src SRCS tile_first_general_tactic.cc) diff --git a/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.cc deleted file mode 100644 index 8bf8a98cce2514..00000000000000 --- a/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.cc +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) 2024 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h" -#include -#include -#include "paddle/cinn/ir/ir.h" - -namespace cinn { -namespace ir { - -class LoopReorderAlignmentTactic final : public ScheduleTactic { - public: - void Init(ScheduleContext* context) override; - - void Apply(ir::IRSchedule* sch, const std::string& block_id) override; - - std::string TacticName() const override { - return "LoopReorderAlignmentTactic"; - } - - private: - bool NeedReorderLoops(); - - std::vector GetNewOrder(); - - void UpdateBaseRank(ir::IRSchedule* sch, const std::string& block_id); - - void DoReorder(ir::IRSchedule* sch, const std::string& block_id); - - private: - ScheduleContext* context_; - size_t base_rank_; - bool need_reorder_loops_; - std::vector new_order_; -}; - -void LoopReorderAlignmentTactic::Init(ScheduleContext* context) { - context_ = context; - base_rank_ = 0; - need_reorder_loops_ = NeedReorderLoops(); - new_order_ = GetNewOrder(); -} - -void LoopReorderAlignmentTactic::Apply(ir::IRSchedule* sch, - const std::string& block_id) { - if (!ir::IsReduceInitTensorName(block_id)) { - UpdateBaseRank(sch, block_id); - } - - if (need_reorder_loops_ && !ir::IsReduceInitTensorName(block_id)) { - DoReorder(sch, block_id); - } -} - -void LoopReorderAlignmentTactic::UpdateBaseRank(ir::IRSchedule* sch, - const std::string& block_id) { - auto loops = sch->GetLoops(block_id); - if (base_rank_ == 0) { - base_rank_ = loops.size(); - } else { - if (base_rank_ != loops.size()) { - throw std::runtime_error("loops rank not same "); - } - } -} - -bool LoopReorderAlignmentTactic::NeedReorderLoops() { - const auto HasReduceAxis = [&]() { - return context_->config.base_info->reduce_axis.size() > 0; - }; - if (!HasReduceAxis()) { - return false; - } - - const auto HasNonLastDimReduce = [&]() { - std::vector vec_reduce_axis = - context_->config.base_info->reduce_axis; - std::sort(vec_reduce_axis.begin(), vec_reduce_axis.end()); - return vec_reduce_axis.front() != - context_->config.base_info->data_rank - vec_reduce_axis.size(); - }; - - return HasNonLastDimReduce(); -} - -std::vector LoopReorderAlignmentTactic::GetNewOrder() { - std::set reduce_set(context_->config.base_info->reduce_axis.begin(), - context_->config.base_info->reduce_axis.end()); - - std::vector new_order; - for (int32_t i = 0; i < context_->config.base_info->data_rank; ++i) { - if (!reduce_set.count(i)) { - new_order.push_back(i); - } - } - for (auto axis : context_->config.base_info->reduce_axis) { - new_order.push_back(axis); - } - - return new_order; -} - -void LoopReorderAlignmentTactic::DoReorder(ir::IRSchedule* sch, - const std::string& block_id) { - const auto IsReduceBlock = [&](const std::string& block_id) { - return context_->config.base_info->reduce_tensor_names.count(block_id) > 0; - }; - if (IsReduceBlock(block_id)) { - return; - } - - sch->Reorder(block_id, new_order_); -} - -std::unique_ptr CreateLoopReorderAlignmentTactic() { - return std::make_unique(); -} - -} // namespace ir -} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h b/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h deleted file mode 100644 index ee4864a5ecf926..00000000000000 --- a/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2024 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h" - -namespace cinn { -namespace ir { - -std::unique_ptr CreateLoopReorderAlignmentTactic(); - -} // namespace ir -} // namespace cinn diff --git a/paddle/fluid/pir/transforms/sub_graph_detector.cc b/paddle/fluid/pir/transforms/sub_graph_detector.cc index ebc0f1e9f9d115..1491c235bf03e8 100644 --- a/paddle/fluid/pir/transforms/sub_graph_detector.cc +++ b/paddle/fluid/pir/transforms/sub_graph_detector.cc @@ -26,6 +26,7 @@ #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/core/builder.h" #include "paddle/pir/include/core/builtin_op.h" #include "paddle/pir/include/dialect/control_flow/ir/cf_dialect.h" @@ -484,34 +485,6 @@ std::vector AnalysisOutputs( return outputs; } -std::vector AnalysisExternalInputs(Operation* op) { // NOLINT - if (!op->isa()) { - return op->operands_source(); - } - // Get all ops in group - const auto all_ops = [&]() -> decltype(auto) { - const auto all_ops = op->dyn_cast().GetOperators(); - return std::unordered_set(all_ops.begin(), all_ops.end()); - }(); - std::unordered_set value_set; - const auto& IsOutsideInput = [&](const pir::Value& value) -> bool { - const bool is_outside = - value && value.defining_op() && !all_ops.count(value.defining_op()); - const bool has_visited = value_set.count(value); - if (!has_visited) value_set.insert(value); - return is_outside && !has_visited; - }; - - std::vector<::pir::Value> inputs; - // count all op's input Value - for (auto inner_op : all_ops) { - for (auto& value : inner_op->operands_source()) { - if (IsOutsideInput(value)) inputs.push_back(value); - } - } - return inputs; -} - namespace { pir::Operation* FindInsertPoint(const GroupOpsVec& group_ops, @@ -576,7 +549,7 @@ std::unordered_set GetUpstreamOpsAfterPosition( } return false; }; - std::vector op_inputs = AnalysisExternalInputs(op); + std::vector op_inputs = pir::GetUsedExternalValue(*op); for (auto value : op_inputs) { if (!value || !value.defining_op()) continue; pir::Operation* defining_op = value.defining_op();