diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index da8cb090b7a2d3..bfd9348e39aed4 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -268,7 +268,7 @@ void SplitOp::Build(pir::Builder& builder, // NOLINT pir::Value input, const std::vector& sections, int axis) { - VLOG(4) << "Start build ConcatOp"; + VLOG(4) << "Start build SplitOp"; argument.inputs.push_back(input); diff --git a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml index 593fcca6b3348a..8028775504f2d3 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml +++ b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml @@ -58,21 +58,25 @@ interfaces : paddle::dialect::InferSymbolicShapeInterface - op : reduce_prod - args : (Tensor x, int64_t[] dim, bool keep_dim) + args : (Tensor x, int64_t[] dim, bool keep_dim, bool reduce_all) output : Tensor(out) infer_meta : func : ReduceInferMeta + param : [x, dim, keep_dim] kernel : func : frobenius_norm + param : [x, dim, keep_dim] interfaces : paddle::dialect::InferSymbolicShapeInterface - op : reduce_sum - args : (Tensor x, int64_t[] dim, bool keep_dim) + args : (Tensor x, int64_t[] dim, bool keep_dim, DataType dtype=DataType::UNDEFINED) output : Tensor(out) infer_meta : func : ReduceInferMeta + param : [x, dim, keep_dim] kernel : func : frobenius_norm + param : [x, dim, keep_dim] interfaces : paddle::dialect::InferSymbolicShapeInterface - op : reshape diff --git a/paddle/cinn/hlir/dialect/operator/transforms/accuracy_check_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/accuracy_check_pass.cc index 94701ec6aad4bf..34dc952e1f71cf 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/accuracy_check_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/accuracy_check_pass.cc @@ -75,10 +75,7 @@ class AddAccuracyCheckPattern } } pir::Operation* pd_op = - cinn::dialect::details::RewriteCinnOpToPdOp(op, builder); - for (uint32_t i = 0; i < op->num_results(); ++i) { - ir_mapping.Add(op->result(i), pd_op->result(i)); - } + cinn::dialect::details::RewriteCinnOpToPdOp(op, ir_mapping, builder); }; const auto& ClonePdOp = [&](::pir::Operation* op) -> void { @@ -106,7 +103,7 @@ class AddAccuracyCheckPattern class AccuarcyCheckPass : public pir::Pass { public: - AccuarcyCheckPass() : pir::Pass("accuarcy_check_pass", /*opt_level=*/4) {} + AccuarcyCheckPass() : pir::Pass("accuracy_check_pass", /*opt_level=*/4) {} bool Initialize(pir::IrContext* context) override { pir::RewritePatternSet ps(context); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index fe512cd76bfa50..f02604544957e5 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -45,6 +45,7 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h" #include "paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/shape_ops_fallback_to_phi_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h" #include "paddle/fluid/pir/transforms/build_cinn_pass.h" #include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h" @@ -156,6 +157,10 @@ void ApplyDivideGroupOpToFusionOpPass( pass_manager->AddPass( cinn::dialect::ir::CreateDivideGroupOpToFusionOpPass()); } + + pass_manager->AddPass(cinn::dialect::ir::CreateSingleOpFallbackToPhiPass()); + pass_manager->AddPass(cinn::dialect::ir::CreateShapeOpsFallbackToPhiPass()); + pass_manager->Run(program); } @@ -176,7 +181,6 @@ void ApplyCinnLowerPass( pass_manager->AddPass(std::move(pass.value())); } - pass_manager->AddPass(cinn::dialect::ir::CreateSingleOpFallbackToPhiPass()); if (FLAGS_enable_cinn_accuracy_check) { VLOG(0) << "Enable CINN Accuracy Check Pass"; pass_manager->AddPass(cinn::dialect::ir::CreateAccuarcyCheckPass()); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc b/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc index 387e69ff42d633..2275f737006eac 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc @@ -23,6 +23,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/include/core/builtin_dialect.h" +#include "paddle/pir/include/core/ir_mapping.h" namespace cinn::dialect::details { pir::Attribute ArrayAttributeToIntArrayAttribute( @@ -36,37 +37,147 @@ pir::Attribute ArrayAttributeToIntArrayAttribute( return attr_data; } +const auto& handler_reduce_sum_op = + [](::pir::Operation* op, + ::pir::IrMapping& ir_mapping, // NOLINT + ::pir::Builder& builder) -> ::pir::Operation* { // NOLINT + VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; + auto attrs = op->attributes(); + + pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute( + attrs.at("dim").dyn_cast<::pir::ArrayAttribute>()); + attrs.insert({"axis", attr_axis}); + attrs.insert({"dtype", attrs["dtype"]}); + attrs.insert({"keepdim", attrs["keep_dim"]}); + attrs.erase("dim"); + attrs.erase("keep_dim"); + + auto pd_op = builder.Build( + ir_mapping.Lookup(op->operand_source(0)), attrs); + for (uint32_t i = 0; i < op->num_results(); ++i) { + ir_mapping.Add(op->result(i), pd_op->result(i)); + } + return pd_op; +}; + const auto& handler_reduce_max_op = - [&](::pir::Operation* op, - const ::pir::Builder& builder) -> ::pir::Operation* { + [](::pir::Operation* op, + ::pir::IrMapping& ir_mapping, // NOLINT + ::pir::Builder& builder) -> ::pir::Operation* { // NOLINT VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; - auto cinn_op = op->dyn_cast(); - auto attr = cinn_op.attributes(); + auto attrs = op->attributes(); // TODO(chenxi67): 1. CINN op Dialect Normalization;2.AST Op compute // Normalization pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute( - attr.at("dim").dyn_cast<::pir::ArrayAttribute>()); - attr.insert({"axis", attr_axis}); - attr.insert({"keepdim", attr["keep_dim"]}); - attr.erase("dim"); - attr.erase("keep_dim"); - - auto pd_op = - const_cast<::pir::Builder*>(&builder)->Build( - cinn_op.operand_source(0), attr); + attrs.at("dim").dyn_cast<::pir::ArrayAttribute>()); + attrs.insert({"axis", attr_axis}); + attrs.insert({"keepdim", attrs["keep_dim"]}); + attrs.erase("dim"); + attrs.erase("keep_dim"); + + auto pd_op = builder.Build( + ir_mapping.Lookup(op->operand_source(0)), attrs); + for (uint32_t i = 0; i < op->num_results(); ++i) { + ir_mapping.Add(op->result(i), pd_op->result(i)); + } + return pd_op; +}; + +const auto& handler_reduce_min_op = + [](::pir::Operation* op, + ::pir::IrMapping& ir_mapping, // NOLINT + ::pir::Builder& builder) -> ::pir::Operation* { // NOLINT + VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; + auto attrs = op->attributes(); + + pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute( + attrs.at("dim").dyn_cast<::pir::ArrayAttribute>()); + attrs.insert({"axis", attr_axis}); + attrs.insert({"keepdim", attrs["keep_dim"]}); + attrs.erase("dim"); + attrs.erase("keep_dim"); + + auto pd_op = builder.Build( + ir_mapping.Lookup(op->operand_source(0)), attrs); + for (uint32_t i = 0; i < op->num_results(); ++i) { + ir_mapping.Add(op->result(i), pd_op->result(i)); + } + return pd_op; +}; + +const auto& handler_reduce_prod_op = + [](::pir::Operation* op, + ::pir::IrMapping& ir_mapping, // NOLINT + ::pir::Builder& builder) -> ::pir::Operation* { // NOLINT + VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; + auto attrs = op->attributes(); + + pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute( + attrs.at("dim").dyn_cast<::pir::ArrayAttribute>()); + attrs.insert({"dims", attr_axis}); + attrs.erase("dim"); + + auto pd_op = builder.Build( + ir_mapping.Lookup(op->operand_source(0)), attrs); + for (uint32_t i = 0; i < op->num_results(); ++i) { + ir_mapping.Add(op->result(i), pd_op->result(i)); + } return pd_op; }; +::pir::Operation* ConvertSliceOp(::pir::Operation* op, + ::pir::IrMapping& ir_mapping, // NOLINT + ::pir::Builder& builder) { // NOLINT + VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; + auto attrs = op->attributes(); + pir::Attribute starts = ArrayAttributeToIntArrayAttribute( + attrs.at("starts").dyn_cast<::pir::ArrayAttribute>()); + pir::Attribute ends = ArrayAttributeToIntArrayAttribute( + attrs.at("ends").dyn_cast<::pir::ArrayAttribute>()); + attrs["starts"] = starts; + attrs["ends"] = ends; + auto pd_op = builder.Build( + ir_mapping.Lookup(op->operand_source(0)), attrs); + for (uint32_t i = 0; i < op->num_results(); ++i) { + ir_mapping.Add(op->result(i), pd_op->result(i)); + } + return pd_op; +} + +::pir::Operation* ConvertConcatOp(::pir::Operation* op, + ::pir::IrMapping& ir_mapping, // NOLINT + ::pir::Builder& builder) { // NOLINT + VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; + auto attrs = op->attributes(); + for (auto item : attrs) { + VLOG(0) << item.first; + } + std::vector vec_inputs; + for (uint32_t i = 0; i < op->num_operands(); ++i) { + vec_inputs.push_back(ir_mapping.Lookup(op->operand_source(i))); + } + auto op_input = builder.Build(vec_inputs).result(0); + + int axis = attrs.at("axis").dyn_cast<::pir::Int32Attribute>().data(); + + auto pd_op = builder.Build(op_input, axis); + for (uint32_t i = 0; i < op->num_results(); ++i) { + ir_mapping.Add(op->result(i), pd_op->result(i)); + } + return pd_op; +} + bool CanApplyOn(::pir::Operation* op) { return op->dialect()->name() == "cinn_op"; } ::pir::Operation* RewriteCinnOpToPdOp(::pir::Operation* op, - const ::pir::Builder& builder) { + ::pir::IrMapping& ir_mapping, // NOLINT + ::pir::Builder& builder) { // NOLINT VLOG(8) << "Rewrite CinnOp to PdOp for op: " << op->name(); auto& op_transformers = TransformContext::Instance(); - return op_transformers[op->name()](op, builder); + return op_transformers[op->name()](op, ir_mapping, builder); } void RewriteCinnOpToPdOp(const ::pir::Block& src_block, @@ -91,20 +202,37 @@ void RewriteCinnOpToPdOp(const ::pir::Block& src_block, } ::pir::Operation* new_op; if (CanApplyOn(&op)) { - new_op = RewriteCinnOpToPdOp(&op, builder); + new_op = RewriteCinnOpToPdOp(&op, ir_mapping, builder); new_op->MoveTo(target_block, target_block->end()); } else { new_op = op.Clone(ir_mapping, clone_options); new_op->MoveTo(target_block, target_block->end()); } - for (uint32_t i = 0; i < op.num_results(); ++i) { - ir_mapping.Add(op.result(i), new_op->result(i)); - } } } } // namespace cinn::dialect::details +REGISTER_TRANSFORM_RULES(reduce_sum_op, + cinn::dialect::ReduceSumOp::name(), + cinn::dialect::details::handler_reduce_sum_op); + REGISTER_TRANSFORM_RULES(reduce_max_op, cinn::dialect::ReduceMaxOp::name(), cinn::dialect::details::handler_reduce_max_op); + +REGISTER_TRANSFORM_RULES(reduce_min_op, + cinn::dialect::ReduceMinOp::name(), + cinn::dialect::details::handler_reduce_min_op); + +REGISTER_TRANSFORM_RULES(reduce_prod_op, + cinn::dialect::ReduceProdOp::name(), + cinn::dialect::details::handler_reduce_prod_op); + +REGISTER_TRANSFORM_RULES(slice_op, + cinn::dialect::SliceOp::name(), + cinn::dialect::details::ConvertSliceOp); + +REGISTER_TRANSFORM_RULES(concat_op, + cinn::dialect::ConcatOp::name(), + cinn::dialect::details::ConvertConcatOp); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.h b/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.h index f6055f60473e4a..01ebe5056156a5 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.h @@ -20,15 +20,17 @@ #include "paddle/common/enforce.h" namespace pir { +class IrMapping; class Block; class Operation; class Builder; +class IrMapping; } // namespace pir namespace cinn::dialect::details { -using TRule = - std::function<::pir::Operation*(::pir::Operation*, const ::pir::Builder&)>; +using TRule = std::function<::pir::Operation*( + ::pir::Operation*, ::pir::IrMapping&, ::pir::Builder&)>; class TransformContext { private: @@ -86,6 +88,8 @@ class TransformRegistrar { void RewriteCinnOpToPdOp(const ::pir::Block& src_block, ::pir::Block* target_block); -::pir::Operation* RewriteCinnOpToPdOp(::pir::Operation*, const ::pir::Builder&); +::pir::Operation* RewriteCinnOpToPdOp(::pir::Operation*, + ::pir::IrMapping&, // NOLINT + ::pir::Builder&); // NOLINT } // namespace cinn::dialect::details diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 3271517ee94f8c..4c784779ec5db3 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -56,6 +56,7 @@ class SumOpPattern : public paddle::drr::DrrPatternBase { const auto &cinn_reduce_sum = res.Op(cinn::dialect::ReduceSumOp::name(), {{"dim", pattern.Attr("axis_info")}, + {"dtype", pattern.Attr("dtype")}, {"keep_dim", pattern.Attr("keep_dim")}}); res.Tensor("ret") = cinn_reduce_sum(res.Tensor("arg0")); } @@ -128,8 +129,10 @@ class ProdOpPattern : public paddle::drr::DrrPatternBase { {"dtype", pattern.Attr("dtype_2")}, {"place", pattern.Attr("place_2")}}); - const auto &pd_max = pattern.Op(paddle::dialect::ProdOp::name(), - {{"keep_dim", pattern.Attr("keep_dim")}}); + const auto &pd_max = + pattern.Op(paddle::dialect::ProdOp::name(), + {{"keep_dim", pattern.Attr("keep_dim")}, + {"reduce_all", pattern.Attr("reduce_all")}}); pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array()); // Result patterns @@ -137,7 +140,8 @@ class ProdOpPattern : public paddle::drr::DrrPatternBase { const auto &cinn_reduce_max = res.Op(cinn::dialect::ReduceProdOp::name(), {{"dim", pattern.Attr("axis_info")}, - {"keep_dim", pattern.Attr("keep_dim")}}); + {"keep_dim", pattern.Attr("keep_dim")}, + {"reduce_all", pattern.Attr("reduce_all")}}); res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0")); } }; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/shape_ops_fallback_to_phi_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/shape_ops_fallback_to_phi_pass.cc new file mode 100644 index 00000000000000..3da8780f3e61ab --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/shape_ops_fallback_to_phi_pass.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/dialect/operator/transforms/shape_ops_fallback_to_phi_pass.h" + +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/pir/include/core/dialect.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" + +namespace cinn { +namespace dialect { +namespace ir { + +namespace { + +class FusionShapeOpsPattern + : public pir::OpRewritePattern { + public: + explicit FusionShapeOpsPattern(::pir::IrContext* context) + : pir::OpRewritePattern(context) {} + + bool Match(cinn::dialect::FusionOp fusion_op) const override { + auto& shape_analysis = ::pir::ShapeAnalysisManager::Instance().Get( + fusion_op->GetParentProgram()); + if (fusion_op.num_results() == 1) { + const auto& shape = + shape_analysis.GetShapeOrDataForValue(fusion_op.result(0)); + if (shape.data() && shape.data()->size() <= 9) { + return true; + } + } + return false; + } + + void Rewrite(cinn::dialect::FusionOp fusion_op, + ::pir::PatternRewriter& rewriter) const override { + ::pir::IrMapping ir_mapping; + for (auto& op : *fusion_op.block()) { + if (op.isa<::pir::YieldOp>()) { + for (uint32_t i = 0; i < op.num_operands(); ++i) { + rewriter.ReplaceAllUsesWith( + fusion_op->result(i), + ir_mapping.Lookup<::pir::Value>(op.operand_source(i))); + } + continue; + } + for (size_t i = 0; i < op.num_operands(); ++i) { + if (!ir_mapping.GetMap<::pir::Value>().count(op.operand_source(i))) { + ir_mapping.Add(op.operand_source(i), op.operand_source(i)); + } + } + if (op.dialect()->name() == "cinn_op") { + auto new_pd_op = + details::RewriteCinnOpToPdOp(&op, ir_mapping, rewriter); + } else { + auto* new_op = op.Clone(ir_mapping, {true, true, true}); + rewriter.Insert(new_op); + } + } + + rewriter.EraseOp(fusion_op); + } +}; + +class ShapeOpsFallbackToPhiPass : public pir::PatternRewritePass { + public: + ShapeOpsFallbackToPhiPass() + : pir::PatternRewritePass("shape_ops_fallback_to_phi_pass", 1) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + ps.Add(context); + + return ps; + } + + bool CanApplyOn(pir::Operation* op) const override { + return op->num_regions() > 0; + } +}; + +} // namespace + +std::unique_ptr<::pir::Pass> CreateShapeOpsFallbackToPhiPass() { + return std::make_unique(); +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/shape_ops_fallback_to_phi_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/shape_ops_fallback_to_phi_pass.h new file mode 100644 index 00000000000000..8cc378dd555ac6 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/shape_ops_fallback_to_phi_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/include/pass/pass.h" + +namespace cinn { +namespace dialect { +namespace ir { +std::unique_ptr<::pir::Pass> CreateShapeOpsFallbackToPhiPass(); +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/test/ir/pir/cinn/inference/CMakeLists.txt b/test/ir/pir/cinn/inference/CMakeLists.txt index a5882bb6388c8c..497f0e3b474b66 100644 --- a/test/ir/pir/cinn/inference/CMakeLists.txt +++ b/test/ir/pir/cinn/inference/CMakeLists.txt @@ -24,17 +24,4 @@ if(WITH_GPU) set_tests_properties(test_llama_forward PROPERTIES TIMEOUT 300) set_tests_properties(test_llama_postprocess PROPERTIES TIMEOUT 300) - add_test( - NAME test_llama_postprocess_cinn - COMMAND - ${CMAKE_COMMAND} -E env - PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/:$ENV{PYTHONPATH} - FLAGS_prim_enable_dynamic=True FLAGS_prim_all=True FLAGS_enable_pir_api=1 - FLAGS_cinn_bucket_compile=True FLAGS_group_schedule_tiling_first=1 - FLAGS_pd_unittest_use_cinn=1 FLAGS_pir_apply_shape_optimization_pass=1 - ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_llama_postprocess.py - WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) - set_tests_properties(test_llama_postprocess_cinn - PROPERTIES LABELS "RUN_TYPE=CINN" TIMEOUT 300) - endif() diff --git a/test/ir/pir/cinn/inference/test_llama_postprocess.py b/test/ir/pir/cinn/inference/test_llama_postprocess.py index c3b235fde15ef4..cfff921719f955 100644 --- a/test/ir/pir/cinn/inference/test_llama_postprocess.py +++ b/test/ir/pir/cinn/inference/test_llama_postprocess.py @@ -90,8 +90,8 @@ def prepare_data(self): self.input_ids = paddle.randint(0, 512, [1, 32], dtype="int64") def check_jit_kernel_info(self, static_fn): - utils.check_jit_kernel_number(static_fn, 8) - utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 8}) + utils.check_jit_kernel_number(static_fn, 7) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 7}) def eval(self, use_cinn): paddle.seed(2024)