diff --git a/python/paddle/optimizer/lbfgs.py b/python/paddle/optimizer/lbfgs.py index 215473ff3a7406..936b71b232d4d9 100644 --- a/python/paddle/optimizer/lbfgs.py +++ b/python/paddle/optimizer/lbfgs.py @@ -155,12 +155,7 @@ def _strong_wolfe( gtd_new = paddle.dot(grad_new, d) # bracket an interval containing a point satisfying the Wolfe criteria - t_prev, f_prev, g_prev, gtd_prev = ( - paddle.to_tensor(0, dtype=grad.dtype), - loss, - grad, - gtd, - ) + t_prev, f_prev, g_prev, gtd_prev = (0, loss, grad, gtd) done = False ls_iter = 0 while ls_iter < max_ls: @@ -227,7 +222,10 @@ def _strong_wolfe( low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) while not done and ls_iter < max_ls: # line-search bracket is so small - if paddle.abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: + bracket_ls = bracket[1] - bracket[0] + if not isinstance(bracket_ls, paddle.Tensor): + bracket_ls = paddle.to_tensor(bracket_ls, dtype=gtd_new.dtype) + if paddle.abs(bracket_ls) * d_norm < tolerance_change: break # compute new trial value diff --git a/test/legacy_test/test_lbfgs_class.py b/test/legacy_test/test_lbfgs_class.py index 47c0d36b9ecddc..631d21962e398b 100644 --- a/test/legacy_test/test_lbfgs_class.py +++ b/test/legacy_test/test_lbfgs_class.py @@ -498,6 +498,16 @@ def func3(x, alpha, d): paddle.to_tensor([1.0]), max_ls=1, ) + lbfgs._strong_wolfe( + func2, + paddle.to_tensor([1.0]), + -0.001, + paddle.to_tensor([1.0]), + paddle.to_tensor([1.0]), + paddle.to_tensor([1.0]), + paddle.to_tensor([1.0]), + max_ls=1, + ) lbfgs._strong_wolfe( func3,