diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index ff6c42613d06bf..0c82ce1aeaf213 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -444,6 +444,14 @@ def all_stop_gradient_true(block): return True +def all_input_stop_gradient_true(list_of_list): + for list_ in list_of_list: + for stop_gradient in list_: + if stop_gradient is False: + return False + return True + + def all_output_grad_none(list_of_list): for list_ in list_of_list: for value in list_: diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 551e55a18b942a..8b72bb35a04cc5 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -22,6 +22,7 @@ ValueDict, ValueSet, _as_list, + all_input_stop_gradient_true, all_output_grad_none, all_stop_gradient_true, argument_to_value, @@ -649,6 +650,14 @@ def append_yield( ]: continue + if all_input_stop_gradient_true( + input_grad_stopgradients + ) and op.name() not in [ + "pd_op.array_read", + "pd_op.array_write_", + "pd_op.increment_", + ]: + continue if op.name() == "pd_op.if": origin_inputs = get_real_op_inputs(op) for sub_block in op.blocks(): diff --git a/test/auto_parallel/pir/test_to_static_pir_program.py b/test/auto_parallel/pir/test_to_static_pir_program.py index 3085e3a726de0b..486011ad0e77b5 100644 --- a/test/auto_parallel/pir/test_to_static_pir_program.py +++ b/test/auto_parallel/pir/test_to_static_pir_program.py @@ -139,8 +139,6 @@ def test_to_static_program(self): backward_op_list = [ "pd_op.sgd_", "pd_op.sgd_", - "pd_op.relu_grad", - "pd_op.c_allreduce_sum_", "pd_op.matmul_grad", "pd_op.relu_grad", "pd_op.matmul_grad", diff --git a/test/ir/pir/test_ir_backward.py b/test/ir/pir/test_ir_backward.py index 5e4f5386a1cdac..3f8a77eed354fe 100644 --- a/test/ir/pir/test_ir_backward.py +++ b/test/ir/pir/test_ir_backward.py @@ -292,6 +292,31 @@ def false_func(): self.assertEqual((grad_x == res).all(), True) +class TestBackward_5(unittest.TestCase): + def tearDown(self) -> None: + paddle.framework.set_flags({"FLAGS_enable_pir_api": False}) + + def test_skip_vjp(self): + if not paddle.framework.in_pir_mode(): + return + program = paddle.static.Program() + with paddle.static.program_guard(program): + x = paddle.static.data('x', [4, 4], 'float32') + x.stop_gradient = True + y = paddle.nn.functional.relu(x) + y.stop_gradient = False + z = paddle.nn.functional.relu(y) + loss = paddle.mean(z) + + paddle.autograd.ir_backward.append_backward(loss) + relu_grad_number = 0 + for op in program.global_block().ops: + if op.name() == "pd_op.relu_grad": + relu_grad_number += 1 + + self.assertEqual(relu_grad_number, 1) + + class TestValueSet(unittest.TestCase): def setUp(self) -> None: with paddle.pir_utils.IrGuard():