diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index bdd2756e09cd66..3eb4c01406704c 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -25,6 +25,73 @@ ) from paddle.base.wrapped_decorator import signature_safe_contextmanager +# TODO: Consider a better way to mark these ops has no grad op. +# Such as use a new trait to mark these ops. +ALLOW_NO_GRAD_OPS = [ + # Compare ops + "pd_op.equal", + "pd_op.equal_", + "pd_op.not_equal", + "pd_op.not_equal_", + "pd_op.less_than", + "pd_op.less_than_", + "pd_op.less_equal", + "pd_op.less_equal_", + "pd_op.greater_than", + "pd_op.greater_than_", + "pd_op.greater_equal", + "pd_op.greater_equal_", + # Logical ops + "pd_op.logical_and", + "pd_op.logical_and_", + "pd_op.logical_not", + "pd_op.logical_not_", + "pd_op.logical_or", + "pd_op.logical_or_", + "pd_op.logical_xor", + "pd_op.logical_xor_", + # Bitwise ops + "pd_op.bitwise_and", + "pd_op.bitwise_and_", + "pd_op.bitwise_left_shift", + "pd_op.bitwise_left_shift_", + "pd_op.bitwise_not", + "pd_op.bitwise_not_", + "pd_op.bitwise_or", + "pd_op.bitwise_or_", + "pd_op.bitwise_right_shift", + "pd_op.bitwise_right_shift_", + "pd_op.bitwise_xor", + "pd_op.bitwise_xor_", + # Array ops + "pd_op.assign_array", + "pd_op.array_length", + "pd_op.slice_array", + "pd_op.slice_array_dense", + "pd_op.assign_array", + "pd_op.assign_array_", + "pd_op.create_array", + "pd_op.create_array_like", + "pd_op.array_read", + "pd_op.array_write_", + "pd_op.array_pop", + # Others + "pd_op.remainder", + "pd_op.argmax", + "pd_op.print", + "pd_op.accuracy", + "pd_op.uniform", + "pd_op.gaussian", + "pd_op.bernoulli", + "pd_op.full_like", + "pd_op.assign_value_", + "pd_op.nextafter", + "pd_op.isnan", + "pd_op.isinf", + "pd_op.all", + "pd_op.any", +] + class ValueWrapper: def __init__(self, value) -> None: @@ -281,6 +348,11 @@ def is_control_flow(op): return op.name() == "pd_op.if" or op.name() == "pd_op.while" +def is_builtin_op(op): + dialect_name, opname = op.name().split(".") + return dialect_name == "builtin" + + def update_no_grad_set_by_stopgradient(block, no_grad_set): for op in block.ops: if is_control_flow(op): diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 4614856ed86ae9..07e283b7617f79 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -18,6 +18,7 @@ import paddle.pir from paddle.autograd.backward_utils import ( + ALLOW_NO_GRAD_OPS, State, ValueDict, ValueSet, @@ -32,6 +33,7 @@ get_real_op_inputs, get_split_op, inverse_sort_op, + is_builtin_op, is_control_flow, is_inplace_net, parent_total_ops, @@ -834,7 +836,13 @@ def append_yield( else: state.op_to_opgrad[op] = [] else: - logging.warning("%s op has no grad op", op.name()) + if ( + not is_builtin_op(op) + and op.name() not in ALLOW_NO_GRAD_OPS + ): + raise ValueError( + f"op '{op.name()}' has no grad op, consider enable prim to decompose it." + ) state.op_to_opgrad[op] = [] if fwd_block != bwd_block: @@ -1202,9 +1210,11 @@ def append_backward(loss, parameter_list=None, no_grad_set=None): input_inputs_grad.append( ( input, - input_to_inputgrad_map[input][0][0] - if input_to_inputgrad_map[input] != [] - else None, + ( + input_to_inputgrad_map[input][0][0] + if input_to_inputgrad_map[input] != [] + else None + ), ) ) diff --git a/test/dygraph_to_static/test_high_order_net.py b/test/dygraph_to_static/test_high_order_net.py new file mode 100644 index 00000000000000..9d116528ea649d --- /dev/null +++ b/test/dygraph_to_static/test_high_order_net.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from dygraph_to_static_utils import ( + Dy2StTestBase, + test_ast_only, + test_pir_only, +) + +import paddle + + +class HighOrderNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.linear = paddle.nn.Linear(3, 4, bias_attr=False) + + def forward(self, x): + y = self.linear(x) + z = paddle.pow(y, 2) + x_grad = paddle.grad(z, x, create_graph=True)[0] + x_grad_grad = paddle.grad(x_grad, x, create_graph=True)[0] + return x_grad_grad.mean() + + +class TestBackwardHasNoGradError(Dy2StTestBase): + @test_ast_only + @test_pir_only + def test_backward_has_no_grad_error(self): + net = HighOrderNet() + static_net = paddle.jit.to_static(net, full_graph=True) + + x = paddle.to_tensor([[1, 1, 1], [1, 1, 1]], 'float32') + x.stop_gradient = False + + with self.assertRaisesRegex( + ValueError, + "op 'pd_op.matmul_double_grad' has no grad op, consider enable prim to decompose it.", + ): + x_grad_grad = static_net(x) + x_grad_grad.backward() + + +if __name__ == "__main__": + unittest.main()