diff --git a/paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.cc index d5ec3042186e3a..651968c6434ea7 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.cc @@ -78,6 +78,7 @@ class BlockDimExprsAsserter { auto VisitEachInputAndDimExprs = [&](const auto& Visit) { for (int i = 0; i < op.num_operands(); ++i) { pir::Value input = op.operand_source(i); + if (!input || !input.type()) continue; const auto& value_dim_exprs = GraphDimExprs4Value(input); Visit(input, value_dim_exprs); } @@ -125,6 +126,7 @@ class BlockDimExprsAsserter { return std::visit(patterns, value_dim_exprs.variant()); }; VisitEachInputAndDimExprs([&](auto value, const auto& value_dim_exprs) { + if (!value || !value.type()) return; const auto& new_symbol_replaced = GetNewSymbolReplaced(value_dim_exprs); shape_analysis->SetShapeOrDataForValue(value, new_symbol_replaced); }); @@ -155,16 +157,19 @@ class BlockDimExprsAsserter { void AssertDimExprForOutput(pir::Operation* op) { // NOLINT VLOG(5) << "Add assert for result of [ " << op->name() << " ]"; + if (op->num_results() == 0) return; if (!op->HasInterface()) { LOG(INFO) << "skip the checking for [ " << op->name() << " ]"; return; } + auto OpDimExprs4Value = MakeOpDimExprs4Value(op); const auto& inputs = [&] { std::vector inputs; inputs.reserve(op->num_operands()); for (int i = 0; i < op->num_operands(); ++i) { const auto& input = op->operand_source(i); + if (!input || !input.type()) continue; if (input.type().isa()) { return std::vector{}; } @@ -176,18 +181,20 @@ class BlockDimExprsAsserter { builder_.SetInsertionPointAfter(op); for (std::size_t i = 0; i < op->num_results(); ++i) { pir::Value output = op->result(i); + if (!output || !output.type()) continue; const auto& shape_or_data_dim_expr = GraphDimExprs4Value(output); if (!shape_or_data_dim_expr.isa()) continue; if (shape_or_data_dim_expr.data().has_value()) { - TryAssertDimExprsForOutputData(inputs, output, OpDimExprs4Value); + TryAssertDimExprsForOutputData(op, inputs, output, OpDimExprs4Value); } else { - TryAssertDimExprsForOutputShape(inputs, output, OpDimExprs4Value); + TryAssertDimExprsForOutputShape(op, inputs, output, OpDimExprs4Value); } } } void TryAssertDimExprsForOutputShape( + const pir::Operation* op, const std::vector& inputs, pir::Value output, const DimExprs4ValueT& OpDimExprs4Value) { @@ -203,14 +210,15 @@ class BlockDimExprsAsserter { const auto& shape_tensor_from_dim_exprs = opt_shape_tensor_from_dim_exprs.value(); auto shape_tensor_from_infer_meta = BuildShapeTensorFromInferMeta(output); - AddAssertEqual(shape_tensor_from_dim_exprs, shape_tensor_from_infer_meta); + AddAssertEqual( + op, shape_tensor_from_dim_exprs, shape_tensor_from_infer_meta); } std::optional BuildShapeTensorFromShapeDimExprs( const std::vector& inputs, pir::Value output, const DimExprs4ValueT& OpDimExprs4Value) { - const auto& shape_or_data = GraphDimExprs4Value(output); + const auto& shape_or_data = OpDimExprs4Value(output); const auto& dim_exprs = shape_or_data.shape(); return BuildShapeTensorFromDimExprs(inputs, dim_exprs, OpDimExprs4Value); } @@ -219,7 +227,7 @@ class BlockDimExprsAsserter { const std::vector& inputs, pir::Value output, const DimExprs4ValueT& OpDimExprs4Value) { - const auto& shape_or_data = GraphDimExprs4Value(output); + const auto& shape_or_data = OpDimExprs4Value(output); const auto& dim_exprs = shape_or_data.data(); if (!dim_exprs.has_value()) return std::nullopt; return BuildShapeTensorFromDimExprs( @@ -260,13 +268,14 @@ class BlockDimExprsAsserter { return builder_.Build(output).out(); } - void TryAssertDimExprsForOutputData(const std::vector& inputs, + void TryAssertDimExprsForOutputData(const pir::Operation* op, + const std::vector& inputs, pir::Value output, const DimExprs4ValueT& OpDimExprs4Value) { auto opt_shape_tensor_from_dim_exprs = BuildShapeTensorFromDataDimExprs(inputs, output, OpDimExprs4Value); if (!opt_shape_tensor_from_dim_exprs.has_value()) return; - AddAssertEqual(opt_shape_tensor_from_dim_exprs.value(), output); + AddAssertEqual(op, opt_shape_tensor_from_dim_exprs.value(), output); } size_t GetNumel(pir::Value value) { @@ -281,7 +290,9 @@ class BlockDimExprsAsserter { return numel; } - void AddAssertEqual(pir::Value lhs, pir::Value rhs) { + void AddAssertEqual(const pir::Operation* op, + pir::Value lhs, + pir::Value rhs) { size_t lhs_numel = GetNumel(lhs); size_t rhs_numel = GetNumel(rhs); PADDLE_ENFORCE_EQ(lhs_numel, @@ -295,7 +306,16 @@ class BlockDimExprsAsserter { builder_.Build(lhs, rhs).out(); pir::Value all_eq = builder_.Build(lhs_eq_rhs).out(); - builder_.Build(all_eq, lhs_eq_rhs, lhs_numel); + pir::Value assert_data = + builder_.Build(std::vector{lhs, rhs}).out(); + auto assert_op = builder_.Build( + all_eq, assert_data, lhs_numel); + const std::string error_msg = "Check [" + op->name() + "_" + + std::to_string(op->id()) + + "] infer symbolic shape failed."; + assert_op->set_attribute( + paddle::dialect::AssertOp::ERROR_INFO_ATTR_NAME, + pir::StrAttribute::get(pir::IrContext::Instance(), error_msg)); } DimExprs4ValueT GraphDimExprs4Value; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc index 6ef8dd56edebc9..83d3cdce2173aa 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc @@ -60,21 +60,19 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { const auto& y_shape = shape_analysis.GetShapeOrDataForValue(y); const auto& out_shape = shape_analysis.GetShapeOrDataForValue(op->result(0)); - if (x_shape == y_shape) { + if (x_shape.shape() == y_shape.shape()) { return false; } pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y, &shape_analysis); - if (x_shape.shape() != out_shape.shape() || - x_shape.data() != out_shape.data()) { + if (x_shape.shape() != out_shape.shape()) { pir::Value broadcasted_x = rewriter->Build(x, output_dim_tensor).out(); op->operand(0).set_source(broadcasted_x); shape_analysis.SetShapeOrDataForValue(broadcasted_x, out_shape); } - if (y_shape.shape() != out_shape.shape() || - y_shape.data() != out_shape.data()) { + if (y_shape.shape() != out_shape.shape()) { pir::Value broadcasted_y = rewriter->Build(y, output_dim_tensor).out(); op->operand(1).set_source(broadcasted_y); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc index 0f15edcd0b8d6e..edb57fa8e15eaa 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc @@ -193,8 +193,10 @@ struct CachedDimExprToValueConverter { pir::Value prod = ConvertToValue(operands->at(0)); for (int i = 1; i < operands->size(); ++i) { if (operands->at(i).isa>()) { - const auto& [operand] = - *operands->at(i).dyn_cast>(); + const auto& operand = + operands->at(i) + .dyn_cast>() + ->data; pir::Value operand_value = ConvertToValue(operand); prod = rewriter->Build(prod, operand_value) .out(); @@ -218,7 +220,8 @@ struct CachedDimExprToValueConverter { pir::Value max = ConvertToValue(operands->at(0)); for (int i = 1; i < operands->size(); ++i) { pir::Value operand_value = ConvertToValue(operands->at(i)); - max = rewriter->Build(max, operand_value).out(); + max = + rewriter->Build(max, operand_value).out(); } return max; } @@ -234,7 +237,8 @@ struct CachedDimExprToValueConverter { pir::Value min = ConvertToValue(operands->at(0)); for (int i = 1; i < operands->size(); ++i) { pir::Value operand_value = ConvertToValue(operands->at(i)); - min = rewriter->Build(min, operand_value).out(); + min = + rewriter->Build(min, operand_value).out(); } return min; } diff --git a/paddle/fluid/framework/new_executor/instruction/control_flow/assert_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/assert_instruction.cc index d2835dd65ccad1..e25afc34212cf1 100644 --- a/paddle/fluid/framework/new_executor/instruction/control_flow/assert_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/assert_instruction.cc @@ -82,11 +82,20 @@ void AssertInstruction::Run() { value_exe_info_->GetVarByValue(val)->Get(); formatter.Print(tensor, name); } - + const std::string& error_msg = [&]() -> std::string { + if (op_->HasAttribute(paddle::dialect::AssertOp::ERROR_INFO_ATTR_NAME)) { + return op_ + ->attribute( + paddle::dialect::AssertOp::ERROR_INFO_ATTR_NAME) + .AsString(); + } + return {}; + }(); PADDLE_THROW(platform::errors::InvalidArgument( "The condition variable '%s' of AssertOp must be " - "true, but received false", - value_exe_info_->GetVarName(cond_var_))); + "true, but received false. %s", + value_exe_info_->GetVarName(cond_var_), + error_msg)); } } // namespace framework diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc index 0e294991449c16..069c646fc60edb 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc @@ -26,37 +26,17 @@ bool ArangeOpInferSymbolicShape( const auto &step_shape_or_data = shape_analysis->GetShapeOrDataForValue(op->operand_source(2)); - const auto start = [&] { - symbol::DimExpr expr; - if (start_shape_or_data.data().has_value()) { - expr = start_shape_or_data.data().value()[0]; - } else { - expr = start_shape_or_data.shape()[0]; - } - return expr; - }(); - - const auto end = [&] { - symbol::DimExpr expr; - if (end_shape_or_data.data().has_value()) { - expr = end_shape_or_data.data().value()[0]; - } else { - expr = end_shape_or_data.shape()[0]; - } - return expr; - }(); - - const auto step = [&] { - symbol::DimExpr expr; - if (step_shape_or_data.data().has_value()) { - expr = step_shape_or_data.data().value()[0]; - } else { - expr = step_shape_or_data.shape()[0]; - } - return expr; - }(); - const symbol::ShapeOrDataDimExprs &shape_data = [&] { + if (!start_shape_or_data.data().has_value() || + !end_shape_or_data.data().has_value() || + !step_shape_or_data.data().has_value()) { + return symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(std::vector{ + symbol::DimExpr(shape_analysis->GetNextSymName())})}; + } + const auto &start = start_shape_or_data.data()->at(0); + const auto &end = end_shape_or_data.data()->at(0); + const auto &step = step_shape_or_data.data()->at(0); std::vector out_dims; // TODO(lanxianghit, jiahy0825): here should be ceil((end - start) / step), // but DimExpr doesn't support ceil and float now @@ -135,10 +115,32 @@ bool DataOpInferSymbolicShape(pir::Operation *op, return sym_dims; }(); - symbol::ShapeOrDataDimExprs shape_data{ - symbol::TensorShapeOrDataDimExprs(sym_dims)}; + auto IsOneNumel = [&](pir::Value value) { + const auto &dims = value.type().dyn_cast().dims(); + if (dims.size() == 1 && dims[0] == 1) { + return true; + } + return false; + }; + + auto IsIntType = [&](pir::Value value) { + const auto &dtype = value.type().dyn_cast().dtype(); + return dtype.isa() || dtype.isa(); + }; + + const auto &shape_or_data = [&]() { + if (IsOneNumel(op->result(0)) && IsIntType(op->result(0))) { + std::vector data{ + symbol::DimExpr(shape_analysis->GetNextSymName())}; + return symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(sym_dims, data)}; + } else { + return symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(sym_dims)}; + } + }(); - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_or_data); return true; } @@ -164,9 +166,17 @@ bool EmptyOpInferSymbolicShape(pir::Operation *op, pir::Value operand_source = op->operand_source(0); const symbol::ShapeOrDataDimExprs &operand_shape_or_data = shape_analysis->GetShapeOrDataForValue(operand_source); - - shape_analysis->SetShapeOrDataForValue(op->result(0), - operand_shape_or_data); + PADDLE_ENFORCE_EQ( + operand_shape_or_data.data().has_value(), + true, + common::errors::InvalidArgument( + "The data of input dim_expr shape is null. When input of empty op " + "is a tensor, the data of input dim_expr shape must have value.")); + + shape_analysis->SetShapeOrDataForValue( + op->result(0), + symbol::TensorShapeOrDataDimExprs{ + operand_shape_or_data.data().value()}); return true; } } diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 476f97304530a2..d109ced69babd6 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -852,6 +852,7 @@ void HasElementsOp::VerifySig() { } const char *AssertOp::attributes_name[1] = {"summarize"}; +const char AssertOp::ERROR_INFO_ATTR_NAME[] = "error_info"; void AssertOp::Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h index 9b9bcd97b78fe5..9f32413743ce96 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h @@ -177,6 +177,7 @@ class AssertOp : public pir::Op { public: using Op::Op; + static const char ERROR_INFO_ATTR_NAME[]; static const char *name() { return "pd_op.assert"; } static constexpr uint32_t attributes_num = 1; static const char *attributes_name[1]; diff --git a/test/ir/pir/cinn/symbolic/CMakeLists.txt b/test/ir/pir/cinn/symbolic/CMakeLists.txt index e90301a149bfb4..4851fdb22151f7 100644 --- a/test/ir/pir/cinn/symbolic/CMakeLists.txt +++ b/test/ir/pir/cinn/symbolic/CMakeLists.txt @@ -34,7 +34,7 @@ if(WITH_GPU) PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/:$ENV{PYTHONPATH} FLAGS_check_infer_symbolic=1 FLAGS_enable_pir_api=1 FLAGS_cinn_bucket_compile=True FLAGS_prim_enable_dynamic=true - FLAGS_pir_apply_shape_optimization_pass=1 + FLAGS_prim_all=True FLAGS_pir_apply_shape_optimization_pass=1 FLAGS_group_schedule_tiling_first=1 FLAGS_cinn_new_group_scheduler=1 ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${cinn_pir_test_name}.py diff --git a/test/ir/pir/cinn/symbolic/test_infer_sym_shape_nullary_op.py b/test/ir/pir/cinn/symbolic/test_infer_sym_shape_nullary_op.py index 1f5704eef2f083..6275ba7c833c02 100644 --- a/test/ir/pir/cinn/symbolic/test_infer_sym_shape_nullary_op.py +++ b/test/ir/pir/cinn/symbolic/test_infer_sym_shape_nullary_op.py @@ -46,17 +46,17 @@ def forward(self, in_0, in_1, in_2): class ArangeOpInferSymbolicShapeTest(TestBase): def prepare_data(self): - self.start = paddle.full([1], 0) - self.end = paddle.full([1], 5) - self.step = paddle.full([1], 1) + self.start = paddle.full([1], 0, dtype='int32') + self.end = paddle.full([1], 5, dtype='int32') + self.step = paddle.full([1], 1, dtype='int32') self.expected = ['shape[Mul(Add(S1, -S0), 1 / (S2))], data[NULL]'] def test_eval_symbolic(self): net = ArangeNet() input_spec = [ - InputSpec(shape=[None], dtype='float32'), - InputSpec(shape=[None], dtype='float32'), - InputSpec(shape=[None], dtype='float32'), + InputSpec(shape=[1], dtype='int32'), + InputSpec(shape=[1], dtype='int32'), + InputSpec(shape=[1], dtype='int32'), ] net = apply_to_static(net, False, input_spec) net.eval() @@ -100,7 +100,7 @@ def __init__(self): def forward(self, x): out = paddle.empty(shape=[128, 32]) - out = paddle.empty(shape=x) + out = paddle.empty(shape=x.shape) return out diff --git a/test/ir/pir/cinn/symbolic/test_infer_sym_shape_unary_op.py b/test/ir/pir/cinn/symbolic/test_infer_sym_shape_unary_op.py index 7a3507d44bc203..f103350cbb380d 100644 --- a/test/ir/pir/cinn/symbolic/test_infer_sym_shape_unary_op.py +++ b/test/ir/pir/cinn/symbolic/test_infer_sym_shape_unary_op.py @@ -635,9 +635,9 @@ class SplitOpInferSymbolicShapeTest(TestBase): def prepare_data(self): self.cases = [np.random.rand(4, 6, 5)] self.expected = [ - 'shape[S0, S1, S2], data[NULL]', - 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, Add(S1, -3), S2], data[NULL]', - 'shape[S0, 1, S2], data[NULL], shape[S0, Add(S1, -1), S2], data[NULL]', + 'shape[S0, 6, S2], data[NULL]', + 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]', + 'shape[S0, 1, S2], data[NULL], shape[S0, 5, S2], data[NULL]', 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]', 'shape[S0, 6, S2], data[NULL]', 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]', diff --git a/test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_4_st.py b/test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_4_st.py index 2e266168892cfd..23fcc791e5bda9 100644 --- a/test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_4_st.py +++ b/test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_4_st.py @@ -15,10 +15,13 @@ # repo: llm_sub_graphs # model: chatglm2 # api:paddle.nn.functional.input.embedding||method:transpose||api:paddle.tensor.creation.ones||api:paddle.tensor.creation.tril||method:astype||api:paddle.tensor.creation.ones||method:astype||method:__and__||api:paddle.tensor.creation.arange||method:__truediv__||method:__rpow__||method:__rtruediv__||api:paddle.tensor.creation.arange||api:paddle.tensor.math.outer||method:astype||api:paddle.tensor.ops.cos||api:paddle.tensor.ops.sin||api:paddle.tensor.manipulation.stack||method:__getitem__||method:transpose +import os import unittest import numpy as np +os.environ["FLAGS_prim_all"] = "False" + import paddle