@@ -253,8 +253,8 @@ void SliceOp::Build(Builder &builder,
253253void SliceOp::PassStopGradients (OperationArgument &argument, int index) {
254254 std::vector<pir::Attribute> outs_stop_gradient (
255255 1 , pir::BoolAttribute::get (pir::IrContext::Instance (), true ));
256- if (auto input = argument.inputs [0 ]. dyn_cast <pir::OpResult>() ) {
257- auto *defining_op = input.owner ();
256+ if (auto input = argument.inputs [0 ]) {
257+ auto *defining_op = input.defining_op ();
258258 if (defining_op && defining_op->isa <CombineOp>()) {
259259 IR_ENFORCE (defining_op->HasAttribute (kStopGradientAttrName ),
260260 " Required CombineOp must have attribute %s" ,
@@ -274,8 +274,8 @@ void SliceOp::RefreshStopGradients() {
274274 std::vector<pir::Attribute> outs_stop_gradient (
275275 1 , pir::BoolAttribute::get (pir::IrContext::Instance (), true ));
276276 auto index = attribute (" index" ).dyn_cast <pir::Int32Attribute>().data ();
277- if (auto input = (*this )->operand_source (0 ). dyn_cast <pir::OpResult>() ) {
278- auto *defining_op = input.owner ();
277+ if (auto input = (*this )->operand_source (0 )) {
278+ auto *defining_op = input.defining_op ();
279279 if (defining_op && defining_op->isa <CombineOp>()) {
280280 IR_ENFORCE (defining_op->HasAttribute (kStopGradientAttrName ),
281281 " Required CombineOp must have attribute %s" ,
@@ -350,8 +350,8 @@ void SplitOp::Build(Builder &builder,
350350
351351void SplitOp::PassStopGradients (OperationArgument &argument) {
352352 std::vector<bool > defaut_stop_gradients (argument.output_types .size (), true );
353- if (auto input = argument.inputs [0 ]. dyn_cast <OpResult>() ) {
354- auto *defining_op = input.owner ();
353+ if (auto input = argument.inputs [0 ]) {
354+ auto *defining_op = input.defining_op ();
355355 if (defining_op && defining_op->isa <CombineOp>()) {
356356 IR_ENFORCE (argument.output_types .size (),
357357 defining_op->num_operands (),
@@ -391,8 +391,8 @@ void SplitOp::PassStopGradients(OperationArgument &argument) {
391391
392392void SplitOp::RefreshStopGradients () {
393393 std::vector<bool > default_stop_gradients ((*this )->num_results (), true );
394- if (auto input = (*this )->operand_source (0 ). dyn_cast <OpResult>() ) {
395- auto *defining_op = input.owner ();
394+ if (auto input = (*this )->operand_source (0 )) {
395+ auto *defining_op = input.defining_op ();
396396 if (defining_op && defining_op->isa <CombineOp>()) {
397397 IR_ENFORCE ((*this )->num_results (),
398398 defining_op->num_operands (),
@@ -403,7 +403,7 @@ void SplitOp::RefreshStopGradients() {
403403 for (uint32_t i = 0 ; i < defining_op->num_operands (); ++i) {
404404 auto value = defining_op->operand_source (i);
405405 if (!value) continue ;
406- auto *operand_defining_op = value.dyn_cast <OpResult>(). owner ();
406+ auto *operand_defining_op = value.defining_op ();
407407 if (operand_defining_op->HasAttribute (kStopGradientAttrName )) {
408408 auto attrs = operand_defining_op->attribute (kStopGradientAttrName )
409409 .dyn_cast <pir::ArrayAttribute>()
0 commit comments