diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc index f68c852f4dcc65..b1cf65915282d1 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc @@ -108,8 +108,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { if (!IsSameDim(x_dims, output_shape)) { // add broadcast to input 0 if (auto full_op = op->operand_source(0) - .dyn_cast() - .owner() + .defining_op() ->dyn_cast()) { auto new_full = rewriter->Build( output_shape, @@ -133,8 +132,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { if (!IsSameDim(y_dims, output_shape)) { if (auto full_op = op->operand_source(1) - .dyn_cast() - .owner() + .defining_op() ->dyn_cast()) { auto new_full = rewriter->Build( output_shape, diff --git a/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc index 38bf51de75019b..c3e95ed2491512 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc @@ -33,8 +33,7 @@ class DynamicReshapeOpPattern bool MatchAndRewrite(paddle::dialect::ReshapeOp op, pir::PatternRewriter& rewriter) const override { - auto scale_factor_gen_op = - op->operand_source(1).dyn_cast().owner(); + auto scale_factor_gen_op = op->operand_source(1).defining_op(); auto output = op.result(0); // The value of shape attribute is fake, we only use the output shape info @@ -43,8 +42,8 @@ class DynamicReshapeOpPattern output.type().dyn_cast().GetRank(), 1); shape[0] = -1; - auto cinn_reshape = rewriter.Build( - op->operand_source(0).dyn_cast(), shape); + auto cinn_reshape = + rewriter.Build(op->operand_source(0), shape); auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc index 61bccc598fb826..a6d91355d67743 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc @@ -154,8 +154,7 @@ std::vector GetOutputOpList( auto yield_op = op_list.back(); for (size_t i = 0; i < yield_op->num_operands(); ++i) { - vec_res.push_back( - yield_op->operand(i).source().dyn_cast().owner()); + vec_res.push_back(yield_op->operand(i).source().defining_op()); } return vec_res; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h index d59f673d53f7ba..4381ca7234b013 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h @@ -360,8 +360,7 @@ inline bool horizontal_relation(const std::shared_ptr& first, // visit all producer node // Get all the input Op for (size_t i = 0; i < candidate->num_operands(); ++i) { - auto producer = - candidate->operand_source(i).dyn_cast().owner(); + auto producer = candidate->operand_source(i).defining_op(); // check dependency. if (first_set.count(producer)) { return true; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc index e982a2c2e7a40f..a1eeadb50accf7 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc @@ -43,7 +43,7 @@ std::vector GetProducerOpsReverseSort( if (!operand || !(operand.source())) { continue; } - auto* source_op = operand.source().dyn_cast().owner(); + auto* source_op = operand.source().defining_op(); if (!op2id.count(source_op)) { continue; @@ -73,7 +73,7 @@ std::unordered_set GetProducerOps(pir::Operation* op) { if (!operand || !(operand.source())) { continue; } - auto* source_op = operand.source().dyn_cast().owner(); + auto* source_op = operand.source().defining_op(); producers.insert(source_op); } return producers; @@ -109,7 +109,7 @@ std::vector TopologicalSort( continue; } - if (inner_set.count(operand.source().dyn_cast().owner())) { + if (inner_set.count(operand.source().defining_op())) { count++; } } @@ -281,7 +281,7 @@ class OpFusionPassHelper { // input op for (size_t i = 0; i < op->num_operands(); ++i) { - auto input = op->operand_source(i).dyn_cast().owner(); + auto input = op->operand_source(i).defining_op(); if (input && (local_ops_.count(input))) { group->input_ops[input] = 1; } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h index 10128578e93c19..b1b94ea276a643 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h @@ -186,8 +186,7 @@ inline bool is_horizontal_relation(::pir::Operation* producer, candidates.pop(); // visit all producer op for (size_t i = 0; i < candidate->num_operands(); ++i) { - auto tmp_op = - candidate->operand_source(i).dyn_cast().owner(); + auto tmp_op = candidate->operand_source(i).defining_op(); // check depency. if (producer == tmp_op) { return true; @@ -374,8 +373,7 @@ inline bool reduce_fuse_broadcast(::pir::Operation* producer, candidates.pop(); for (size_t i = 0; i < candidate->num_operands(); ++i) { - auto producer = - candidate->operand_source(i).dyn_cast().owner(); + auto producer = candidate->operand_source(i).defining_op(); if (producer == reducer) { return true; } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/tensor_node.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/tensor_node.cc index aefcdb1db18172..4e0b6ed89667e8 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/tensor_node.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/tensor_node.cc @@ -22,9 +22,7 @@ namespace cinn { namespace dialect { namespace ir { -OpNode TensorNode::producer() const { - return OpNode(node_data_.dyn_cast().owner()); -} +OpNode TensorNode::producer() const { return OpNode(node_data_.defining_op()); } OpNode TensorNode::ConsumerOpListView::Iterator::operator*() const { return OpNode(iter_.owner()); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/merge_reshape_with_broadcast_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/merge_reshape_with_broadcast_pass.cc index cdd2326fc1216c..f4df56ddc98f23 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/merge_reshape_with_broadcast_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/merge_reshape_with_broadcast_pass.cc @@ -102,8 +102,7 @@ class MergeReshapeWithBroadcastPattern bool MatchAndRewrite(cinn::dialect::BroadcastOp op, pir::PatternRewriter& rewriter) const override { auto reshape_op = op->operand_source(0) - .dyn_cast() - .owner() + .defining_op() ->dyn_cast(); if (reshape_op && CanMerge(reshape_op)) { diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index c9b85a517d4bee..31421d52971548 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -155,7 +155,7 @@ class ScaleOpPattern : public pir::OpRewritePattern { full_op.attribute("value").dyn_cast().data(); auto cinn_scale = rewriter.Build( - op->operand_source(0).dyn_cast(), + op->operand_source(0), scale_value, op->attributes().at("bias").dyn_cast().data(), op->attributes() @@ -220,7 +220,7 @@ class ReshapeOpPattern } auto cinn_reshape = rewriter.Build( - op->operand_source(0).dyn_cast(), vec_out_shape); + op->operand_source(0), vec_out_shape); rewriter.ReplaceAllUsesWith(op.result(0), cinn_reshape.result(0)); rewriter.EraseOp(op); @@ -264,8 +264,8 @@ class Pool2dOpPattern attrs.erase("paddings"); attrs.erase("pooling_type"); - auto cinn_reshape = rewriter.Build( - op->operand_source(0).dyn_cast(), attrs); + auto cinn_reshape = + rewriter.Build(op->operand_source(0), attrs); rewriter.ReplaceAllUsesWith(op.result(0), cinn_reshape.result(0)); rewriter.EraseOp(op); @@ -337,13 +337,13 @@ class SliceOpPattern : public pir::OpRewritePattern { cinn::dialect::ir::GetVectorAttr(op, "decrease_axis"); auto infer_flags = cinn::dialect::ir::GetVectorAttr(op, "infer_flags"); - auto cinn_slice = rewriter.Build( - op->operand_source(0).dyn_cast(), - axes, - start_vec, - end_vec, - infer_flags, - decrease_axis); + auto cinn_slice = + rewriter.Build(op->operand_source(0), + axes, + start_vec, + end_vec, + infer_flags, + decrease_axis); // NOTE(Aurelius84): In SliceRawInferMeta, it not always share_lod, so // we need to update it maually. cinn_slice.result(0).set_type(op.result(0).type()); diff --git a/paddle/cinn/hlir/framework/pir/group.h b/paddle/cinn/hlir/framework/pir/group.h index 94cfe542990b0f..29933c67020e80 100644 --- a/paddle/cinn/hlir/framework/pir/group.h +++ b/paddle/cinn/hlir/framework/pir/group.h @@ -193,7 +193,7 @@ struct Group { continue; } - if (!ops_set.count(value.dyn_cast<::pir::OpResult>().owner())) { + if (!ops_set.count(value.defining_op())) { // if the input value owner op is not in OpSet, it's the group's input group_inputs.insert(value); continue; diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_util.cc b/paddle/cinn/hlir/framework/pir/op_lowering_util.cc index da234602810bee..e176c164a6dcdb 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_util.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_util.cc @@ -61,7 +61,7 @@ std::vector<::pir::Operation*> GetConsumersInSet( std::vector<::pir::Operation*> GetProducers(::pir::Operation* op) { std::vector<::pir::Operation*> producers; for (auto& source : op->operands_source()) { - auto* producer_op = source.dyn_cast<::pir::OpResult>().owner(); + auto* producer_op = source.defining_op(); CHECK(producer_op); producers.push_back(producer_op); } @@ -1565,7 +1565,7 @@ void SyncThreadWithShared( continue; } auto op_data = op_out_set.find(block->name)->second; - auto* op = op_data.dyn_cast<::pir::OpResult>().owner(); + auto* op = op_data.defining_op(); auto op_shape = CompatibleInfo::ValueShape(op_data); auto masters = GetMasters(op, pretty_name, ops_inline, ops_set); diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc index d7342b7773c185..5b0eed964f5986 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc @@ -242,7 +242,7 @@ Variable* CreateVar(pir::Value value, const std::string& var_name_prefix, bool force_persisable, ValueExecutionInfo* value_exe_info) { - pir::Operation* def_op = value.dyn_cast().owner(); + pir::Operation* def_op = value.defining_op(); bool is_persisable = false; if (def_op->isa<::pir::ParameterOp>()) { is_persisable = true; diff --git a/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py index 18dc70f9fa7a7c..3dccbe512e7440 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py @@ -109,9 +109,7 @@ def gen_cpp_file_code(self, cpp_file_path): ir_op_name = self.dialect_name + "." + phi_op_name params_no_mutable_attr = [] for i in range(len(op_info_item.input_name_list)): - params_no_mutable_attr.append( - f"inputs[{i}].dyn_cast()" - ) + params_no_mutable_attr.append(f"inputs[{i}]") if len(op_info_item.attribute_name_list) > 0: params_no_mutable_attr.append("attrs") @@ -131,9 +129,7 @@ def gen_cpp_file_code(self, cpp_file_path): len(op_info_item.input_name_list) + len(op_info_item.mutable_attribute_name_list) ): - params_with_mutable_attr.append( - f"inputs[{i}].dyn_cast()" - ) + params_with_mutable_attr.append(f"inputs[{i}]") if len(op_info_item.attribute_name_list) > len( op_info_item.mutable_attribute_name_list ): diff --git a/paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py b/paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py index a900a378cfd772..6b50ca341ee60e 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py @@ -251,9 +251,9 @@ def GenBuildOutputsPart2( """ CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::IntArray {name}; - if ({name}_.dyn_cast() && {name}_.dyn_cast().owner()->isa()) {{ + if ({name}_.isa() && {name}_.defining_op()->isa()) {{ {name} = std::move(phi::IntArray(paddle::dialect::GetInt64Vector( - {name}_.dyn_cast().owner() + {name}_.defining_op() ->dyn_cast() .attribute("value")))); }} else if ({name}_.type().isa()) {{ @@ -281,9 +281,9 @@ def GenBuildOutputsPart2( }}\n""" CREATE_VECTOR_INT_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ std::vector {name}; - if ({name}_.dyn_cast() && {name}_.dyn_cast().owner()->isa()) {{ + if ({name}_.isa() && {name}_.defining_op()->isa()) {{ {name} = paddle::dialect::GetInt64Vector( - {name}_.dyn_cast().owner() + {name}_.defining_op() ->dyn_cast() .attribute("value")); }} else if ({name}_.type().isa()) {{ @@ -308,8 +308,8 @@ def GenBuildOutputsPart2( }}\n""" CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::Scalar {name}; - if ({name}_.dyn_cast() && {name}_.dyn_cast().owner()->isa()) {{ - {name} = std::move(phi::Scalar({name}_.dyn_cast().owner() + if ({name}_.isa() && {name}_.defining_op()->isa()) {{ + {name} = std::move(phi::Scalar({name}_.defining_op() ->dyn_cast() .attribute("value") .dyn_cast() diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 9e68fe2421bceb..49eba696fa0037 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -2810,13 +2810,10 @@ void SliceArrayOp::VerifySig() { phi::IntArray CalcSliceBoundsFromValue(pir::Value starts_or_ends) { phi::IntArray starts_or_ends_list; - if (starts_or_ends.dyn_cast() - .owner() - ->isa()) { + if (starts_or_ends.defining_op()->isa()) { starts_or_ends_list = std::move(phi::IntArray(paddle::dialect::GetInt64Vector( - starts_or_ends.dyn_cast() - .owner() + starts_or_ends.defining_op() ->dyn_cast() .attribute("value")))); } else if (starts_or_ends.type().isa()) { diff --git a/paddle/fluid/pir/drr/rewrite_pattern.cc b/paddle/fluid/pir/drr/rewrite_pattern.cc index 2cc6dbb6688f3c..86483abb9eb1ce 100644 --- a/paddle/fluid/pir/drr/rewrite_pattern.cc +++ b/paddle/fluid/pir/drr/rewrite_pattern.cc @@ -199,7 +199,7 @@ void DrrRewritePattern::DfsVisitor( ir_operand_value.use_count()) { return; } - auto* ir_producer_op = ir_operand_value.dyn_cast().owner(); + auto* ir_producer_op = ir_operand_value.defining_op(); drr_visited_ops->insert(drr_producer_op); DfsVisitor(drr_producer_op, ir_producer_op, @@ -428,7 +428,7 @@ MatchContextImpl DrrRewritePattern::CreateOperations( } auto ir_val = res_match_ctx.GetIrValue(input->name()); if (ir_val) { - pir::Operation* ir_input_op = ir_val.dyn_cast().owner(); + pir::Operation* ir_input_op = ir_val.defining_op(); if (op_2_temp_program_index.count(ir_input_op) == 0) { max_input_op_index = 0UL; } else if (max_input_op_index < diff --git a/paddle/fluid/pir/transforms/sub_graph_detector.cc b/paddle/fluid/pir/transforms/sub_graph_detector.cc index f2a7e82a840b12..38c53c4fadab13 100644 --- a/paddle/fluid/pir/transforms/sub_graph_detector.cc +++ b/paddle/fluid/pir/transforms/sub_graph_detector.cc @@ -50,7 +50,7 @@ std::vector InverselyTopologicalSort(pir::Block* block) { if (!operand || !(operand.source())) { continue; } - auto* defined_op = operand.source().dyn_cast().owner(); + auto* defined_op = operand.source().defining_op(); if (pending_count.find(defined_op) != pending_count.end()) { ++pending_count[defined_op]; } else { @@ -76,7 +76,7 @@ std::vector InverselyTopologicalSort(pir::Block* block) { if (!operand || !(operand.source())) { continue; } - auto* defined_op = operand.source().dyn_cast().owner(); + auto* defined_op = operand.source().defining_op(); --pending_count[defined_op]; if (pending_count[defined_op] == 0) { queue.push(defined_op); @@ -103,7 +103,7 @@ std::vector GetProducerOpsReverseSort( if (!operand || !(operand.source())) { continue; } - auto* source_op = operand.source().dyn_cast().owner(); + auto* source_op = operand.source().defining_op(); if (!producers.count(source_op)) { producers.insert(source_op); PADDLE_ENFORCE( @@ -129,7 +129,7 @@ std::unordered_set GetProducerOps(pir::Operation* op) { if (!operand || !(operand.source())) { continue; } - auto* source_op = operand.source().dyn_cast().owner(); + auto* source_op = operand.source().defining_op(); producers.insert(source_op); } return producers; diff --git a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 index fbf082c72fddd3..9d97be3ed34729 100644 --- a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 @@ -44,7 +44,7 @@ return vjp_res; {% macro get_mutable_attribute(attrs, api_name) %} {% for i in attrs %} {%- if i is mutable_attribute -%} -auto* {{i.name}}_define_op = std::static_pointer_cast({{i.name~'_'}}.impl())->value().dyn_cast().owner(); +auto* {{i.name}}_define_op = std::static_pointer_cast({{i.name~'_'}}.impl())->value().defining_op(); {% if i.typename is scalar %} if({{i.name}}_define_op->name() != "pd_op.full") { PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc index dc70d287e768e8..765eaee822ebcc 100644 --- a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc +++ b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc @@ -172,8 +172,7 @@ class GreedyPatternRewriteDriver : public pir::PatternRewriter { // that single use values often have more canonicalization opportunities. if (!operand || (!operand.use_empty() && !operand.HasOneUse())) return; - if (auto* def_op = operand.dyn_cast().owner()) - AddToWorklist(def_op); + if (auto* def_op = operand.defining_op()) AddToWorklist(def_op); } void AddOperandsToWorklist(const std::vector operands) {