Skip to content

Commit 8048114

Browse files
huangjiyieee4017
authored andcommitted
1 parent bd482a8 commit 8048114

17 files changed

Lines changed: 37 additions & 53 deletions

paddle/fluid/pir/dialect/op_generator/api_gen.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,13 @@
137137
optional_{name} = {name};
138138
}}"""
139139

140-
OPTIONAL_OPRESULT_OUTPUT_TEMPLATE = """
140+
OPTIONAL_VALUE_OUTPUT_TEMPLATE = """
141141
paddle::optional<pir::Value> optional_{name};
142142
if (!IsEmptyValue({op_name}_op.result({index}))) {{
143143
optional_{name} = paddle::make_optional<pir::Value>({op_name}_op.result({index}));
144144
}}"""
145145

146-
OPTIONAL_VECTOR_OPRESULT_OUTPUT_TEMPLATE = """
146+
OPTIONAL_VECTOR_VALUE_OUTPUT_TEMPLATE = """
147147
paddle::optional<std::vector<pir::Value>> optional_{name};
148148
if (!IsEmptyValue({op_name}_op.result({index}))) {{
149149
auto optional_{name}_slice_op = ApiBuilder::Instance().GetBuilder()->Build<pir::SplitOp>({op_name}_op.result({index}));
@@ -423,13 +423,13 @@ def _gen_handle_optional_outputs(self, op_info, op_name):
423423
continue
424424
if self._is_optional_output(op_info, name):
425425
if VECTOR_TYPE in type:
426-
ret += OPTIONAL_VECTOR_OPRESULT_OUTPUT_TEMPLATE.format(
426+
ret += OPTIONAL_VECTOR_VALUE_OUTPUT_TEMPLATE.format(
427427
name=name,
428428
op_name=op_name,
429429
index=i,
430430
)
431431
else:
432-
ret += OPTIONAL_OPRESULT_OUTPUT_TEMPLATE.format(
432+
ret += OPTIONAL_VALUE_OUTPUT_TEMPLATE.format(
433433
name=name,
434434
op_name=op_name,
435435
index=i,

paddle/fluid/pir/drr/ir_operation_factory.cc

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,15 @@ void OperationFactory::RegisterManualOpCreator() {
3535
const pir::AttributeMap& attrs,
3636
pir::PatternRewriter& rewriter) {
3737
return rewriter.Build<paddle::dialect::FusedGemmEpilogueOp>(
38-
inputs[0].dyn_cast<pir::OpResult>(),
39-
inputs[1].dyn_cast<pir::OpResult>(),
40-
inputs[2].dyn_cast<pir::OpResult>(),
41-
attrs);
38+
inputs[0], inputs[1], inputs[2], attrs);
4239
});
4340
RegisterOperationCreator(
4441
"pd_op.fused_gemm_epilogue_grad",
4542
[](const std::vector<pir::Value>& inputs,
4643
const pir::AttributeMap& attrs,
4744
pir::PatternRewriter& rewriter) {
4845
return rewriter.Build<paddle::dialect::FusedGemmEpilogueGradOp>(
49-
inputs[0].dyn_cast<pir::OpResult>(),
50-
inputs[1].dyn_cast<pir::OpResult>(),
51-
inputs[2].dyn_cast<pir::OpResult>(),
52-
inputs[3].dyn_cast<pir::OpResult>(),
53-
attrs);
46+
inputs[0], inputs[1], inputs[2], inputs[3], attrs);
5447
});
5548
RegisterOperationCreator("builtin.combine",
5649
[](const std::vector<pir::Value>& inputs,
@@ -64,8 +57,8 @@ void OperationFactory::RegisterManualOpCreator() {
6457
const pir::AttributeMap& attrs,
6558
pir::PatternRewriter& rewriter) {
6659
return rewriter.Build<paddle::dialect::ScaleOp>(
67-
inputs[0].dyn_cast<pir::OpResult>(),
68-
inputs[1].dyn_cast<pir::OpResult>(),
60+
inputs[0],
61+
inputs[1],
6962
attrs.at("bias").dyn_cast<pir::FloatAttribute>().data(),
7063
attrs.at("bias_after_scale").dyn_cast<pir::BoolAttribute>().data());
7164
});

paddle/fluid/pir/transforms/constant_folding_pass.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,7 @@ class ConstantFoldingPattern : public pir::RewritePattern {
346346
prev_op->name()));
347347
}
348348
} else {
349-
op_inputs.push_back(
350-
op->operand_source(i).dyn_cast<pir::OpResult>() /*nullptr*/);
349+
op_inputs.push_back(nullptr);
351350
}
352351
}
353352

paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class ReplaceFetchWithShadowOutputPattern
2929
paddle::dialect::FetchOp op,
3030
pir::PatternRewriter& rewriter) const override { // NOLINT
3131
rewriter.Build<pir::ShadowOutputOp>(
32-
op->operand_source(0).dyn_cast<pir::OpResult>(),
32+
op->operand_source(0),
3333
op->attributes().at("name").dyn_cast<pir::StrAttribute>().AsString());
3434
rewriter.EraseOp(op);
3535
return true;

paddle/fluid/pir/transforms/transform_general_functions.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ void GetUsedExternalValueImpl(
6161
namespace pir {
6262

6363
std::string GetParameterNameFromValue(pir::Value value) {
64-
pir::Operation* owner = value.dyn_cast<OpResult>().owner();
64+
pir::Operation* owner = value.defining_op();
6565
std::string name;
6666
if (owner->isa<ParameterOp>()) {
6767
pir::ParameterOp op = owner->dyn_cast<pir::ParameterOp>();
@@ -104,7 +104,7 @@ Operation* GetDefiningOpForInput(const Operation* op, uint32_t index) {
104104
index < op->num_operands() && op->operand_source(index),
105105
true,
106106
phi::errors::InvalidArgument("Intput operand's index must be valid."));
107-
return op->operand_source(index).dyn_cast<OpResult>().owner();
107+
return op->operand_source(index).defining_op();
108108
}
109109

110110
std::vector<std::pair<Operation*, int32_t>> GetUseOpsForOutput(

paddle/fluid/pybind/manual_static_op_function.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ static PyObject *static_api_set_parameter(PyObject *self,
5959
VLOG(6) << "Add set_parameter op into program";
6060
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
6161

62-
// Get OpResult from args
62+
// Get Value from args
6363
PyObject *parameter_obj = PyTuple_GET_ITEM(args, 0);
6464
auto parameter = CastPyArg2Value(parameter_obj, "parameter", 0);
6565

paddle/fluid/pybind/op_function_common.cc

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -865,18 +865,14 @@ void CastPyArg2AttrValues(PyObject* obj,
865865
Py_ssize_t len = PyList_Size(obj);
866866
PyObject* item = nullptr;
867867
for (Py_ssize_t i = 0; i < len; i++) {
868-
// TODO(xiongkun): judge OpResult or Value;
868+
// TODO(xiongkun): judge Value;
869869
item = PyList_GetItem(obj, i);
870870
::pybind11::detail::instance* inst =
871871
(::pybind11::detail::instance*)item; // NOLINT
872872
void** vh = inst->simple_layout ? inst->simple_value_holder
873873
: &inst->nonsimple.values_and_holders[0];
874-
::pir::OpResult* opresult = reinterpret_cast<::pir::OpResult*>(vh[0]);
875-
if (opresult->impl() == nullptr) {
876-
results.emplace_back(pir::Value(nullptr));
877-
} else {
878-
results.emplace_back(pir::Value(opresult->Value::impl()));
879-
}
874+
::pir::Value* value = reinterpret_cast<::pir::Value*>(vh[0]);
875+
results.emplace_back(pir::Value(value->impl()));
880876
}
881877
} else {
882878
PADDLE_THROW(platform::errors::InvalidType(

paddle/pir/core/builtin_op.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,8 @@ void SliceOp::Build(Builder &builder,
253253
void SliceOp::PassStopGradients(OperationArgument &argument, int index) {
254254
std::vector<pir::Attribute> outs_stop_gradient(
255255
1, pir::BoolAttribute::get(pir::IrContext::Instance(), true));
256-
if (auto input = argument.inputs[0].dyn_cast<pir::OpResult>()) {
257-
auto *defining_op = input.owner();
256+
if (auto input = argument.inputs[0]) {
257+
auto *defining_op = input.defining_op();
258258
if (defining_op && defining_op->isa<CombineOp>()) {
259259
IR_ENFORCE(defining_op->HasAttribute(kStopGradientAttrName),
260260
"Required CombineOp must have attribute %s",
@@ -274,8 +274,8 @@ void SliceOp::RefreshStopGradients() {
274274
std::vector<pir::Attribute> outs_stop_gradient(
275275
1, pir::BoolAttribute::get(pir::IrContext::Instance(), true));
276276
auto index = attribute("index").dyn_cast<pir::Int32Attribute>().data();
277-
if (auto input = (*this)->operand_source(0).dyn_cast<pir::OpResult>()) {
278-
auto *defining_op = input.owner();
277+
if (auto input = (*this)->operand_source(0)) {
278+
auto *defining_op = input.defining_op();
279279
if (defining_op && defining_op->isa<CombineOp>()) {
280280
IR_ENFORCE(defining_op->HasAttribute(kStopGradientAttrName),
281281
"Required CombineOp must have attribute %s",
@@ -350,8 +350,8 @@ void SplitOp::Build(Builder &builder,
350350

351351
void SplitOp::PassStopGradients(OperationArgument &argument) {
352352
std::vector<bool> defaut_stop_gradients(argument.output_types.size(), true);
353-
if (auto input = argument.inputs[0].dyn_cast<OpResult>()) {
354-
auto *defining_op = input.owner();
353+
if (auto input = argument.inputs[0]) {
354+
auto *defining_op = input.defining_op();
355355
if (defining_op && defining_op->isa<CombineOp>()) {
356356
IR_ENFORCE(argument.output_types.size(),
357357
defining_op->num_operands(),
@@ -391,8 +391,8 @@ void SplitOp::PassStopGradients(OperationArgument &argument) {
391391

392392
void SplitOp::RefreshStopGradients() {
393393
std::vector<bool> default_stop_gradients((*this)->num_results(), true);
394-
if (auto input = (*this)->operand_source(0).dyn_cast<OpResult>()) {
395-
auto *defining_op = input.owner();
394+
if (auto input = (*this)->operand_source(0)) {
395+
auto *defining_op = input.defining_op();
396396
if (defining_op && defining_op->isa<CombineOp>()) {
397397
IR_ENFORCE((*this)->num_results(),
398398
defining_op->num_operands(),
@@ -403,7 +403,7 @@ void SplitOp::RefreshStopGradients() {
403403
for (uint32_t i = 0; i < defining_op->num_operands(); ++i) {
404404
auto value = defining_op->operand_source(i);
405405
if (!value) continue;
406-
auto *operand_defining_op = value.dyn_cast<OpResult>().owner();
406+
auto *operand_defining_op = value.defining_op();
407407
if (operand_defining_op->HasAttribute(kStopGradientAttrName)) {
408408
auto attrs = operand_defining_op->attribute(kStopGradientAttrName)
409409
.dyn_cast<pir::ArrayAttribute>()

paddle/pir/core/ir_context.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class TypeId;
3030
class Dialect;
3131
class OpInfo;
3232
class Type;
33-
class OpResult;
3433
class Attribute;
3534
class Operation;
3635
class InterfaceValue;

paddle/pir/core/op_info.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
namespace pir {
2222
class OpInfoImpl;
2323
class IrContext;
24-
class OpResult;
2524
class Type;
2625
class Attribute;
2726
class Dialect;

0 commit comments

Comments
 (0)