Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
19d6b05
init functional jacobian api
Sep 22, 2021
f47b48f
merge upstream/develop
Sep 22, 2021
be2b30d
finish test with dtype float32
Sep 23, 2021
36b8c34
add float64 test case
Sep 23, 2021
35b1ce8
polish code
Sep 24, 2021
3a35a00
use atol=1e-5 with dtype float64
Sep 24, 2021
a3ea12e
fix for ci
Sep 24, 2021
8738cf8
set timeout for test_jacobian
Sep 24, 2021
c72565d
init hessian API
Sep 24, 2021
c2d12cc
save status
Sep 26, 2021
0bd8287
polish API docstring
Sep 26, 2021
a3e8585
Merge remote-tracking branch 'upstream/develop' into lml/jacobian
Sep 26, 2021
4d94e5a
modify docstring
Sep 26, 2021
546bcd1
merge jacobian
Sep 26, 2021
ae0f883
add utils.py
Sep 26, 2021
03d3feb
save status
Sep 26, 2021
9378769
fix dygraph double grad dtype error when calling for high differentia…
JiabinYang Sep 26, 2021
034011d
reinvoke ci
JiabinYang Sep 26, 2021
4aa5813
test_hessian.py is ok
Sep 27, 2021
94e1ed2
polish hessian API
Sep 27, 2021
cbd4d3b
init vhp
Sep 27, 2021
0758ad1
Revert "init vhp"
Sep 27, 2021
871d114
merge develop
Sep 27, 2021
d140eed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiabinYang Sep 27, 2021
f478f18
add test for partial_engine.cc
JiabinYang Sep 27, 2021
3ac1303
modify numerical_delta with dtype float32
Sep 27, 2021
50377ea
Merge remote-tracking branch 'jiabin/fix_dygraph_high_differential' i…
Sep 27, 2021
fd82d43
merge fix for dtype float64
Sep 27, 2021
73382ec
spell fix
Sep 27, 2021
6633eca
polish code
Sep 28, 2021
b7ea5e1
Merge remote-tracking branch 'upstream/develop' into lml/hessian
Sep 28, 2021
1a6e2b3
rm _stop_gradient_pre_process
Sep 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@
from .py_layer import PyLayer, PyLayerContext # noqa: F401
from ..framework import set_grad_enabled # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from .functional import jacobian # noqa: F401
from .functional import jacobian, hessian # noqa: F401

__all__ = ['backward', 'PyLayer', 'PyLayerContext']
152 changes: 127 additions & 25 deletions python/paddle/autograd/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,10 @@
# limitations under the License.

from paddle.fluid import framework
from .utils import _check_tensors, _stack_tensor_or_return_none, _replace_none_with_zero_tensor
import paddle


def _check_tensors(in_out_list, name):
assert in_out_list is not None, "{} should not be None".format(name)

if isinstance(in_out_list, (list, tuple)):
assert len(in_out_list) > 0, "{} connot be empyt".format(name)
for each_var in in_out_list:
assert isinstance(
each_var,
paddle.Tensor), "Elements of {} must be paddle.Tensor".format(
name)
return in_out_list
else:
assert isinstance(
in_out_list,
paddle.Tensor), "{} must be Tensor or list of Tensor".format(name)
return [in_out_list]


def _stack_tensor_or_return_none(origin_list):
assert len(origin_list) > 0, "Can't not stack an empty list"
return paddle.stack(
origin_list, axis=0) if isinstance(origin_list[0],
paddle.Tensor) else None


@framework.dygraph_only
def jacobian(func, inputs, create_graph=False, allow_unused=False):
'''
Expand Down Expand Up @@ -183,3 +159,129 @@ def func(x, y):
return jacobian[0]
else:
return jacobian


@framework.dygraph_only
def hessian(func, inputs, create_graph=False, allow_unused=False):
'''
.. note::
**This API is ONLY available in imperative mode.**

This API computes the Hessian matrix of `func` with respect to `inputs`.

Parameters:
func (function): a Python function that takes a Tensor or a Tensor
list/tuple as inputs and returns a Tensor with a single element.
inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or
Tensor list/tuple of the function ``func``.
create_graph (bool, optional): whether to create the gradient graphs
of the computing process. When it is True, higher order derivatives
are supported to compute; when it is False, the gradient graphs of
the computing process would be discarded. Defaults to ``False``.
allow_unused (bool, optional): whether to raise error or return None if
some Tensors of `inputs` are unreachable in the graph. Error would
be raised if allow_unused=False, and None would be returned as
their gradients if allow_unused=True. Default False.
Returns:
Hessian (Tensor or a tuple of tuple of Tensors): if function ``func``
takes a Tensor as ``inputs``, Hessian will be a single Tensor containing
the Hessian matrix for the linearized ``inputs`` Tensor. If function
``func`` takes a Tensor list/tuple as ``inputs``, then the Hessian will
be a tuple of tuple of Tensors where ``Hessian[i][j]`` will contain the
Hessian matrix of the ``i``th input and ``j``th input with size ``m * n``.
Here ``m`` and ``n`` denote the number of elements of the ``i`` th input
and the ``j`` th input respectively.

Examples 1:
.. code-block:: python

import paddle

def func(x):
return paddle.sum(paddle.matmul(x, x))

x = paddle.ones(shape=[2, 2], dtype='float32')
x.stop_gradient = False
hessian = paddle.autograd.hessian(func, x)
print(hessian)
# Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[2., 1., 1., 0.],
# [1., 0., 2., 1.],
# [1., 2., 0., 1.],
# [0., 1., 1., 2.]])

Examples 2:
.. code-block:: python

import paddle

def func(x, y):
return paddle.sum(paddle.matmul(x, y))

x = paddle.ones(shape=[2, 2], dtype='float32')
y = paddle.ones(shape=[2, 2], dtype='float32')
x.stop_gradient = False
y.stop_gradient = False
hessian = paddle.autograd.hessian(func, [x, y])
print(hessian)
# ((Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]]),
# Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[1., 1., 0., 0.],
# [0., 0., 1., 1.],
# [1., 1., 0., 0.],
# [0., 0., 1., 1.]])),
# (Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[1., 0., 1., 0.],
# [1., 0., 1., 0.],
# [0., 1., 0., 1.],
# [0., 1., 0., 1.]]),
# Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]])))

Examples 3:
.. code-block:: python

import paddle

def func(x, y):
return paddle.sum(paddle.matmul(x, x))

x = paddle.ones(shape=[2, 2], dtype='float32')
y = paddle.ones(shape=[2, 2], dtype='float32')
x.stop_gradient = False
y.stop_gradient = False
hessian = paddle.autograd.hessian(func, [x, y], allow_unused=True)
print(hessian)
# ((Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[2., 1., 1., 0.],
# [1., 0., 2., 1.],
# [1., 2., 0., 1.],
# [0., 1., 1., 2.]]), None), (None, None))

'''
inputs = _check_tensors(inputs, "inputs")
outputs = func(*inputs)
assert isinstance(outputs, paddle.Tensor) and outputs.shape == [
1
], "The function to compute Hessian matrix should return a Tensor with a single element"

def jac_func(*ins):
grad_inputs = paddle.grad(
outputs,
ins,
create_graph=True,
retain_graph=True,
allow_unused=allow_unused)
return tuple(
_replace_none_with_zero_tensor(grad_inputs[i], inputs[i])
for i in range(len(inputs)))

return jacobian(
jac_func, inputs, create_graph=create_graph, allow_unused=allow_unused)
49 changes: 49 additions & 0 deletions python/paddle/autograd/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle


def _check_tensors(in_out_list, name):
assert in_out_list is not None, "{} should not be None".format(name)

if isinstance(in_out_list, (list, tuple)):
assert len(in_out_list) > 0, "{} connot be empyt".format(name)
for each_var in in_out_list:
assert isinstance(
each_var,
paddle.Tensor), "Elements of {} must be paddle.Tensor".format(
name)
return list(in_out_list)
else:
assert isinstance(
in_out_list,
paddle.Tensor), "{} must be Tensor or list of Tensor".format(name)
return [in_out_list]


def _stack_tensor_or_return_none(origin_list):
assert len(origin_list) > 0, "Can't not stack an empty list"
return paddle.stack(
origin_list, axis=0) if isinstance(origin_list[0],
paddle.Tensor) else None


def _replace_none_with_zero_tensor(t, spec_t):
if t is None:
zero_t = paddle.zeros(shape=spec_t.shape, dtype=spec_t.dtype)
zero_t.stop_gradient = spec_t.stop_gradient
return zero_t
else:
return t
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ foreach(TEST_OP ${TEST_OPS})
endforeach(TEST_OP)

set_tests_properties(test_jacobian PROPERTIES TIMEOUT 20)
set_tests_properties(test_hessian PROPERTIES TIMEOUT 20)
140 changes: 140 additions & 0 deletions python/paddle/fluid/tests/unittests/autograd/test_hessian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
import paddle
import paddle.compat as cpt
from utils import _compute_numerical_hessian


class TestHessian(unittest.TestCase):
@classmethod
def setUpClass(self):
self.shape = (2, 2)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-2
self.rtol = 1e-2
self.atol = 1e-2
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)

def test_single_input(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))

numerical_hessian = _compute_numerical_hessian(
func, self.x, self.numerical_delta, self.np_dtype)

self.x.stop_gradient = False
hessian = paddle.autograd.hessian(func, self.x)
assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol,
self.atol)

def test_multi_input(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, y))

numerical_hessian = _compute_numerical_hessian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)

self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.hessian(func, [self.x, self.y])
for i in range(len(hessian)):
for j in range(len(hessian[0])):
assert np.allclose(hessian[i][j].numpy(),
numerical_hessian[i][j], self.rtol,
self.atol)

def test_allow_unused_false(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))

try:
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.hessian(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0

def test_allow_unused_true(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))

numerical_hessian = _compute_numerical_hessian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.hessian(
func, [self.x, self.y], allow_unused=True)
for i in range(len(hessian)):
for j in range(len(hessian[0])):
if i == j == 0:
assert np.allclose(hessian[i][j].numpy(),
numerical_hessian[i][j], self.rtol,
self.atol)
else:
assert hessian[i][j] is None

def test_create_graph_false(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))

numerical_hessian = _compute_numerical_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.hessian(func, self.x)
assert hessian.stop_gradient == True
assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol,
self.atol)
try:
paddle.grad(hessian, self.x)
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0

# TODO(levi): enable this test case when matmul_grad_grad_grad is ok
def _test_create_graph_true(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))

numerical_hessian = _compute_numerical_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.hessian(func, self.x, create_graph=True)
assert hessian.stop_gradient == False
assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol,
self.atol)
triple_grad = paddle.grad(hessian, self.x)
assert triple_grad is not None


class TestHessianFloat64(TestHessian):
@classmethod
def setUpClass(self):
self.shape = (2, 2)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-5
self.rtol = 1e-5
self.atol = 1e-5
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)


if __name__ == "__main__":
unittest.main()
Loading