Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
if (!IsSameDim(x_dims, output_shape)) {
// add broadcast to input 0
if (auto full_op = op->operand_source(0)
.dyn_cast<pir::OpResult>()
.owner()
.defining_op()
->dyn_cast<paddle::dialect::FullOp>()) {
auto new_full = rewriter->Build<paddle::dialect::FullOp>(
output_shape,
Expand All @@ -133,8 +132,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {

if (!IsSameDim(y_dims, output_shape)) {
if (auto full_op = op->operand_source(1)
.dyn_cast<pir::OpResult>()
.owner()
.defining_op()
->dyn_cast<paddle::dialect::FullOp>()) {
auto new_full = rewriter->Build<paddle::dialect::FullOp>(
output_shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ class DynamicReshapeOpPattern

bool MatchAndRewrite(paddle::dialect::ReshapeOp op,
pir::PatternRewriter& rewriter) const override {
auto scale_factor_gen_op =
op->operand_source(1).dyn_cast<pir::OpResult>().owner();
auto scale_factor_gen_op = op->operand_source(1).defining_op();
auto output = op.result(0);

// The value of shape attribute is fake, we only use the output shape info
Expand All @@ -43,8 +42,8 @@ class DynamicReshapeOpPattern
output.type().dyn_cast<pir::ShapedTypeInterface>().GetRank(), 1);
shape[0] = -1;

auto cinn_reshape = rewriter.Build<cinn::dialect::ReshapeOp>(
op->operand_source(0).dyn_cast<pir::OpResult>(), shape);
auto cinn_reshape =
rewriter.Build<cinn::dialect::ReshapeOp>(op->operand_source(0), shape);

auto& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ std::vector<pir::Operation*> GetOutputOpList(
auto yield_op = op_list.back();

for (size_t i = 0; i < yield_op->num_operands(); ++i) {
vec_res.push_back(
yield_op->operand(i).source().dyn_cast<pir::OpResult>().owner());
vec_res.push_back(yield_op->operand(i).source().defining_op());
}

return vec_res;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,7 @@ inline bool horizontal_relation(const std::shared_ptr<ir::Group>& first,
// visit all producer node
// Get all the input Op
for (size_t i = 0; i < candidate->num_operands(); ++i) {
auto producer =
candidate->operand_source(i).dyn_cast<pir::OpResult>().owner();
auto producer = candidate->operand_source(i).defining_op();
// check dependency.
if (first_set.count(producer)) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ std::vector<pir::Operation*> GetProducerOpsReverseSort(
if (!operand || !(operand.source())) {
continue;
}
auto* source_op = operand.source().dyn_cast<pir::OpResult>().owner();
auto* source_op = operand.source().defining_op();

if (!op2id.count(source_op)) {
continue;
Expand Down Expand Up @@ -73,7 +73,7 @@ std::unordered_set<pir::Operation*> GetProducerOps(pir::Operation* op) {
if (!operand || !(operand.source())) {
continue;
}
auto* source_op = operand.source().dyn_cast<pir::OpResult>().owner();
auto* source_op = operand.source().defining_op();
producers.insert(source_op);
}
return producers;
Expand Down Expand Up @@ -109,7 +109,7 @@ std::vector<pir::Operation*> TopologicalSort(
continue;
}

if (inner_set.count(operand.source().dyn_cast<pir::OpResult>().owner())) {
if (inner_set.count(operand.source().defining_op())) {
count++;
}
}
Expand Down Expand Up @@ -281,7 +281,7 @@ class OpFusionPassHelper {
// input op

for (size_t i = 0; i < op->num_operands(); ++i) {
auto input = op->operand_source(i).dyn_cast<pir::OpResult>().owner();
auto input = op->operand_source(i).defining_op();
if (input && (local_ops_.count(input))) {
group->input_ops[input] = 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,7 @@ inline bool is_horizontal_relation(::pir::Operation* producer,
candidates.pop();
// visit all producer op
for (size_t i = 0; i < candidate->num_operands(); ++i) {
auto tmp_op =
candidate->operand_source(i).dyn_cast<pir::OpResult>().owner();
auto tmp_op = candidate->operand_source(i).defining_op();
// check depency.
if (producer == tmp_op) {
return true;
Expand Down Expand Up @@ -374,8 +373,7 @@ inline bool reduce_fuse_broadcast(::pir::Operation* producer,
candidates.pop();

for (size_t i = 0; i < candidate->num_operands(); ++i) {
auto producer =
candidate->operand_source(i).dyn_cast<pir::OpResult>().owner();
auto producer = candidate->operand_source(i).defining_op();
if (producer == reducer) {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ namespace cinn {
namespace dialect {
namespace ir {

OpNode TensorNode::producer() const {
return OpNode(node_data_.dyn_cast<pir::OpResult>().owner());
}
OpNode TensorNode::producer() const { return OpNode(node_data_.defining_op()); }

OpNode TensorNode::ConsumerOpListView::Iterator::operator*() const {
return OpNode(iter_.owner());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ class MergeReshapeWithBroadcastPattern
bool MatchAndRewrite(cinn::dialect::BroadcastOp op,
pir::PatternRewriter& rewriter) const override {
auto reshape_op = op->operand_source(0)
.dyn_cast<pir::OpResult>()
.owner()
.defining_op()
->dyn_cast<cinn::dialect::ReshapeOp>();

if (reshape_op && CanMerge(reshape_op)) {
Expand Down
22 changes: 11 additions & 11 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class ScaleOpPattern : public pir::OpRewritePattern<paddle::dialect::ScaleOp> {
full_op.attribute("value").dyn_cast<pir::FloatAttribute>().data();

auto cinn_scale = rewriter.Build<cinn::dialect::ScaleOp>(
op->operand_source(0).dyn_cast<pir::OpResult>(),
op->operand_source(0),
scale_value,
op->attributes().at("bias").dyn_cast<pir::FloatAttribute>().data(),
op->attributes()
Expand Down Expand Up @@ -220,7 +220,7 @@ class ReshapeOpPattern
}

auto cinn_reshape = rewriter.Build<cinn::dialect::ReshapeOp>(
op->operand_source(0).dyn_cast<pir::OpResult>(), vec_out_shape);
op->operand_source(0), vec_out_shape);
rewriter.ReplaceAllUsesWith(op.result(0), cinn_reshape.result(0));
rewriter.EraseOp(op);

Expand Down Expand Up @@ -264,8 +264,8 @@ class Pool2dOpPattern
attrs.erase("paddings");
attrs.erase("pooling_type");

auto cinn_reshape = rewriter.Build<cinn::dialect::Pool2dOp>(
op->operand_source(0).dyn_cast<pir::OpResult>(), attrs);
auto cinn_reshape =
rewriter.Build<cinn::dialect::Pool2dOp>(op->operand_source(0), attrs);
rewriter.ReplaceAllUsesWith(op.result(0), cinn_reshape.result(0));
rewriter.EraseOp(op);

Expand Down Expand Up @@ -337,13 +337,13 @@ class SliceOpPattern : public pir::OpRewritePattern<paddle::dialect::SliceOp> {
cinn::dialect::ir::GetVectorAttr(op, "decrease_axis");
auto infer_flags = cinn::dialect::ir::GetVectorAttr(op, "infer_flags");

auto cinn_slice = rewriter.Build<cinn::dialect::SliceOp>(
op->operand_source(0).dyn_cast<pir::OpResult>(),
axes,
start_vec,
end_vec,
infer_flags,
decrease_axis);
auto cinn_slice =
rewriter.Build<cinn::dialect::SliceOp>(op->operand_source(0),
axes,
start_vec,
end_vec,
infer_flags,
decrease_axis);
// NOTE(Aurelius84): In SliceRawInferMeta, it not always share_lod, so
// we need to update it maually.
cinn_slice.result(0).set_type(op.result(0).type());
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/pir/group.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ struct Group {
continue;
}

if (!ops_set.count(value.dyn_cast<::pir::OpResult>().owner())) {
if (!ops_set.count(value.defining_op())) {
// if the input value owner op is not in OpSet, it's the group's input
group_inputs.insert(value);
continue;
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/framework/pir/op_lowering_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ std::vector<::pir::Operation*> GetConsumersInSet(
std::vector<::pir::Operation*> GetProducers(::pir::Operation* op) {
std::vector<::pir::Operation*> producers;
for (auto& source : op->operands_source()) {
auto* producer_op = source.dyn_cast<::pir::OpResult>().owner();
auto* producer_op = source.defining_op();
CHECK(producer_op);
producers.push_back(producer_op);
}
Expand Down Expand Up @@ -1565,7 +1565,7 @@ void SyncThreadWithShared(
continue;
}
auto op_data = op_out_set.find(block->name)->second;
auto* op = op_data.dyn_cast<::pir::OpResult>().owner();
auto* op = op_data.defining_op();
auto op_shape = CompatibleInfo::ValueShape(op_data);

auto masters = GetMasters(op, pretty_name, ops_inline, ops_set);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ Variable* CreateVar(pir::Value value,
const std::string& var_name_prefix,
bool force_persisable,
ValueExecutionInfo* value_exe_info) {
pir::Operation* def_op = value.dyn_cast<pir::OpResult>().owner();
pir::Operation* def_op = value.defining_op();
bool is_persisable = false;
if (def_op->isa<::pir::ParameterOp>()) {
is_persisable = true;
Expand Down
8 changes: 2 additions & 6 deletions paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ def gen_cpp_file_code(self, cpp_file_path):
ir_op_name = self.dialect_name + "." + phi_op_name
params_no_mutable_attr = []
for i in range(len(op_info_item.input_name_list)):
params_no_mutable_attr.append(
f"inputs[{i}].dyn_cast<pir::OpResult>()"
)
params_no_mutable_attr.append(f"inputs[{i}]")
if len(op_info_item.attribute_name_list) > 0:
params_no_mutable_attr.append("attrs")

Expand All @@ -131,9 +129,7 @@ def gen_cpp_file_code(self, cpp_file_path):
len(op_info_item.input_name_list)
+ len(op_info_item.mutable_attribute_name_list)
):
params_with_mutable_attr.append(
f"inputs[{i}].dyn_cast<pir::OpResult>()"
)
params_with_mutable_attr.append(f"inputs[{i}]")
if len(op_info_item.attribute_name_list) > len(
op_info_item.mutable_attribute_name_list
):
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ def GenBuildOutputsPart2(
"""

CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::IntArray {name};
if ({name}_.dyn_cast<pir::OpResult>() && {name}_.dyn_cast<pir::OpResult>().owner()->isa<paddle::dialect::FullIntArrayOp>()) {{
if ({name}_.isa<pir::OpResult>() && {name}_.defining_op()->isa<paddle::dialect::FullIntArrayOp>()) {{
{name} = std::move(phi::IntArray(paddle::dialect::GetInt64Vector(
{name}_.dyn_cast<pir::OpResult>().owner()
{name}_.defining_op()
->dyn_cast<paddle::dialect::FullIntArrayOp>()
.attribute("value"))));
}} else if ({name}_.type().isa<pir::VectorType>()) {{
Expand Down Expand Up @@ -281,9 +281,9 @@ def GenBuildOutputsPart2(
}}\n"""

CREATE_VECTOR_INT_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ std::vector<int64_t> {name};
if ({name}_.dyn_cast<pir::OpResult>() && {name}_.dyn_cast<pir::OpResult>().owner()->isa<paddle::dialect::FullIntArrayOp>()) {{
if ({name}_.isa<pir::OpResult>() && {name}_.defining_op()->isa<paddle::dialect::FullIntArrayOp>()) {{
{name} = paddle::dialect::GetInt64Vector(
{name}_.dyn_cast<pir::OpResult>().owner()
{name}_.defining_op()
->dyn_cast<paddle::dialect::FullIntArrayOp>()
.attribute("value"));
}} else if ({name}_.type().isa<pir::VectorType>()) {{
Expand All @@ -308,8 +308,8 @@ def GenBuildOutputsPart2(
}}\n"""

CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::Scalar {name};
if ({name}_.dyn_cast<pir::OpResult>() && {name}_.dyn_cast<pir::OpResult>().owner()->isa<paddle::dialect::FullOp>()) {{
{name} = std::move(phi::Scalar({name}_.dyn_cast<pir::OpResult>().owner()
if ({name}_.isa<pir::OpResult>() && {name}_.defining_op()->isa<paddle::dialect::FullOp>()) {{
{name} = std::move(phi::Scalar({name}_.defining_op()
->dyn_cast<paddle::dialect::FullOp>()
.attribute("value")
.dyn_cast<paddle::dialect::ScalarAttribute>()
Expand Down
7 changes: 2 additions & 5 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2810,13 +2810,10 @@ void SliceArrayOp::VerifySig() {

phi::IntArray CalcSliceBoundsFromValue(pir::Value starts_or_ends) {
phi::IntArray starts_or_ends_list;
if (starts_or_ends.dyn_cast<pir::OpResult>()
.owner()
->isa<paddle::dialect::FullIntArrayOp>()) {
if (starts_or_ends.defining_op()->isa<paddle::dialect::FullIntArrayOp>()) {
starts_or_ends_list =
std::move(phi::IntArray(paddle::dialect::GetInt64Vector(
starts_or_ends.dyn_cast<pir::OpResult>()
.owner()
starts_or_ends.defining_op()
->dyn_cast<paddle::dialect::FullIntArrayOp>()
.attribute("value"))));
} else if (starts_or_ends.type().isa<pir::VectorType>()) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/drr/rewrite_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void DrrRewritePattern::DfsVisitor(
ir_operand_value.use_count()) {
return;
}
auto* ir_producer_op = ir_operand_value.dyn_cast<pir::OpResult>().owner();
auto* ir_producer_op = ir_operand_value.defining_op();
drr_visited_ops->insert(drr_producer_op);
DfsVisitor(drr_producer_op,
ir_producer_op,
Expand Down Expand Up @@ -428,7 +428,7 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
}
auto ir_val = res_match_ctx.GetIrValue(input->name());
if (ir_val) {
pir::Operation* ir_input_op = ir_val.dyn_cast<pir::OpResult>().owner();
pir::Operation* ir_input_op = ir_val.defining_op();
if (op_2_temp_program_index.count(ir_input_op) == 0) {
max_input_op_index = 0UL;
} else if (max_input_op_index <
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/pir/transforms/sub_graph_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ std::vector<pir::Operation*> InverselyTopologicalSort(pir::Block* block) {
if (!operand || !(operand.source())) {
continue;
}
auto* defined_op = operand.source().dyn_cast<pir::OpResult>().owner();
auto* defined_op = operand.source().defining_op();
if (pending_count.find(defined_op) != pending_count.end()) {
++pending_count[defined_op];
} else {
Expand All @@ -76,7 +76,7 @@ std::vector<pir::Operation*> InverselyTopologicalSort(pir::Block* block) {
if (!operand || !(operand.source())) {
continue;
}
auto* defined_op = operand.source().dyn_cast<pir::OpResult>().owner();
auto* defined_op = operand.source().defining_op();
--pending_count[defined_op];
if (pending_count[defined_op] == 0) {
queue.push(defined_op);
Expand All @@ -103,7 +103,7 @@ std::vector<pir::Operation*> GetProducerOpsReverseSort(
if (!operand || !(operand.source())) {
continue;
}
auto* source_op = operand.source().dyn_cast<pir::OpResult>().owner();
auto* source_op = operand.source().defining_op();
if (!producers.count(source_op)) {
producers.insert(source_op);
PADDLE_ENFORCE(
Expand All @@ -129,7 +129,7 @@ std::unordered_set<pir::Operation*> GetProducerOps(pir::Operation* op) {
if (!operand || !(operand.source())) {
continue;
}
auto* source_op = operand.source().dyn_cast<pir::OpResult>().owner();
auto* source_op = operand.source().defining_op();
producers.insert(source_op);
}
return producers;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ return vjp_res;
{% macro get_mutable_attribute(attrs, api_name) %}
{% for i in attrs %}
{%- if i is mutable_attribute -%}
auto* {{i.name}}_define_op = std::static_pointer_cast<primitive::LazyTensor>({{i.name~'_'}}.impl())->value().dyn_cast<pir::OpResult>().owner();
auto* {{i.name}}_define_op = std::static_pointer_cast<primitive::LazyTensor>({{i.name~'_'}}.impl())->value().defining_op();
{% if i.typename is scalar %}
if({{i.name}}_define_op->name() != "pd_op.full") {
PADDLE_THROW(platform::errors::Unimplemented(
Expand Down
3 changes: 1 addition & 2 deletions paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ class GreedyPatternRewriteDriver : public pir::PatternRewriter {
// that single use values often have more canonicalization opportunities.
if (!operand || (!operand.use_empty() && !operand.HasOneUse())) return;

if (auto* def_op = operand.dyn_cast<pir::OpResult>().owner())
AddToWorklist(def_op);
if (auto* def_op = operand.defining_op()) AddToWorklist(def_op);
}

void AddOperandsToWorklist(const std::vector<pir::Value> operands) {
Expand Down