Skip to content

Commit ed3486b

Browse files
authored
Support n-order differential testing (#62074)
* init * fix some typro * opt * add full jacbian test mode * remove dyn numerical jvp * msg fix * msg fix * fix unused * add TODO * fix * fix * rm ano
1 parent 2ca34a7 commit ed3486b

1 file changed

Lines changed: 358 additions & 0 deletions

File tree

Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections.abc import Sequence
16+
from logging import warning
17+
18+
import numpy as np
19+
20+
import paddle
21+
from paddle import base
22+
from paddle.autograd.backward_utils import ValueDict
23+
from paddle.base import core
24+
from paddle.base.backward import _as_list
25+
26+
__all__ = ['check_vjp']
27+
28+
EPS = 1e-4
29+
30+
default_gradient_tolerance = {
31+
np.float16: 1e-2,
32+
np.float32: 2e-3,
33+
np.float64: 1e-5,
34+
np.complex64: 1e-3,
35+
np.complex128: 1e-5,
36+
}
37+
38+
39+
def _product(t):
40+
return int(np.prod(t))
41+
42+
43+
def make_jacobian(x, y_size, np_dtype):
44+
if isinstance(x, (base.framework.Variable, paddle.pir.Value)):
45+
return np.zeros((_product(x.shape), y_size), dtype=np_dtype)
46+
elif isinstance(x, Sequence):
47+
jacobians = list(
48+
filter(
49+
lambda t: t is not None,
50+
(make_jacobian(item, y_size, np_dtype) for item in x),
51+
)
52+
)
53+
return jacobians
54+
else:
55+
pass
56+
57+
58+
def compute_numerical_jacobian(program, inputs, outputs, feeds, eps):
59+
paddle.enable_static()
60+
numerical = []
61+
for input in inputs:
62+
numerical.append(
63+
_compute_numerical_jacobian(program, input, outputs, feeds, eps)
64+
)
65+
paddle.disable_static()
66+
return numerical
67+
68+
69+
def _compute_numerical_jacobian(program, x, y, feeds, eps):
70+
if not isinstance(x, paddle.pir.Value):
71+
raise TypeError('x is not Value')
72+
73+
# To compute the jacobian, treat x and y as one-dimensional vectors.
74+
y = _as_list(y)
75+
exe = paddle.static.Executor()
76+
77+
def run():
78+
res = exe.run(program, feeds, fetch_list=[y])
79+
y_res = res[: len(y)]
80+
return [yi.flatten() for yi in y_res]
81+
82+
x_name = x.get_defining_op().attrs()['name']
83+
x_shape = x.shape
84+
x_size = _product(x_shape)
85+
np_type = dtype_to_np_dtype(x.dtype)
86+
np_t = np.array(feeds[x_name]).astype(np_type)
87+
np_t = np_t.flatten()
88+
jacobian = [make_jacobian(x, _product(yi.shape), np_type) for yi in y]
89+
90+
for i in range(x_size):
91+
orig = np_t[i]
92+
x_pos = orig + eps
93+
np_t[i] = x_pos
94+
np_f = np_t.reshape(x_shape)
95+
feeds[x_name] = np_f
96+
y_pos = run()
97+
98+
x_neg = orig - eps
99+
np_t[i] = x_neg
100+
np_f = np_t.reshape(x_shape)
101+
feeds[x_name] = np_f
102+
y_neg = run()
103+
104+
np_t[i] = orig
105+
for j in range(len(y)):
106+
ret = (y_pos[j] - y_neg[j]) / eps / 2.0
107+
jacobian[j][i, :] = ret
108+
109+
return jacobian
110+
111+
112+
def compute_analytical_jacobian(
113+
program, inputs, outputs, last_grads_in, feeds, fetch_list
114+
):
115+
paddle.enable_static()
116+
analytical = []
117+
for i in range(len(outputs)):
118+
name = last_grads_in[i].name
119+
feeds.update(
120+
{
121+
name: np.zeros(
122+
outputs[i].shape, dtype=dtype_to_np_dtype(outputs[i].dtype)
123+
)
124+
}
125+
)
126+
for i in range(len(outputs)):
127+
analytical.append(
128+
_compute_analytical_jacobian(
129+
program,
130+
inputs,
131+
i,
132+
outputs,
133+
fetch_list,
134+
feeds,
135+
last_grads_in[i].name,
136+
)
137+
)
138+
paddle.disable_static()
139+
return analytical
140+
141+
142+
def _compute_analytical_jacobian(program, x, i, y, grads, feeds, name):
143+
if not isinstance(x, (list, paddle.pir.Value)):
144+
raise TypeError('x is not Value or list of Value')
145+
np_type = dtype_to_np_dtype(y[i].dtype)
146+
exe = paddle.static.Executor()
147+
y_size = _product(y[i].shape)
148+
x = _as_list(x)
149+
jacobian = make_jacobian(x, y_size, np_type)
150+
151+
# get the name in feeds of dyi
152+
np_t = np.array(feeds[name]).astype(np_type)
153+
shape = np_t.shape
154+
np_t = np_t.flatten()
155+
for i in range(y_size):
156+
np_t[i] = 1
157+
np_f = np_t.reshape(shape)
158+
feeds[name] = np_f
159+
res = exe.run(program, feed=feeds, fetch_list=[grads])
160+
dx_res = res[: len(grads)]
161+
for j in range(len(grads)):
162+
if dx_res[j] is not None:
163+
jacobian[j][:, i] = dx_res[j].flatten()
164+
else:
165+
jacobian[j][:, i] = np.zeros(
166+
grads[j].shape, dtype=np_type
167+
).flatten()
168+
169+
np_t[i] = 0
170+
np_f = np_t.reshape(shape)
171+
feeds[name] = np_f
172+
173+
return jacobian
174+
175+
176+
def dtype_to_np_dtype(dtype):
177+
if dtype == core.VarDesc.VarType.FP32 or dtype == core.DataType.FLOAT32:
178+
return np.float32
179+
elif dtype == core.VarDesc.VarType.FP64 or dtype == core.DataType.FLOAT64:
180+
return np.float64
181+
elif dtype == core.VarDesc.VarType.FP16 or dtype == core.DataType.FLOAT16:
182+
return np.float16
183+
else:
184+
raise ValueError("Not supported data type " + str(dtype))
185+
186+
187+
def get_eager_vjp(func, inputs, cotangents=None, order=1):
188+
for x in inputs:
189+
x.stop_gradient = False
190+
outputs = func(inputs)
191+
return _get_eager_vjp(inputs, outputs, cotangents, order)
192+
193+
194+
def _get_eager_vjp(inputs, outputs, tangents, order):
195+
if order > 1:
196+
create_graph = True
197+
else:
198+
create_graph = False
199+
200+
d_inputs = paddle.grad(
201+
outputs=outputs,
202+
inputs=inputs,
203+
grad_outputs=tangents,
204+
create_graph=create_graph,
205+
allow_unused=True,
206+
)
207+
d_inputs = [d_input for d_input in d_inputs if d_input is not None]
208+
if order > 1:
209+
ddys = []
210+
for d_input in d_inputs:
211+
d_input.stop_gradient = False
212+
ddy = paddle.ones(shape=d_input.shape, dtype=d_input.dtype)
213+
ddy.stop_gradient = False
214+
ddys.append(ddy)
215+
return _get_eager_vjp(inputs, d_inputs, ddys, order - 1)
216+
217+
return d_inputs
218+
219+
220+
def get_static_vjp(program, feeds, fetch):
221+
paddle.enable_static()
222+
exe = paddle.static.Executor()
223+
res = exe.run(program, feed=feeds, fetch_list=[fetch])
224+
paddle.disable_static()
225+
return res
226+
227+
228+
def get_static_vjp_program(func, inputs, order):
229+
cotangents = []
230+
paddle.enable_static()
231+
input_vars = []
232+
feeds = {}
233+
for idx, input in enumerate(inputs):
234+
np_type = dtype_to_np_dtype(input.dtype)
235+
input_var = paddle.static.data(
236+
'input_' + str(idx), input.shape, dtype=np_type
237+
)
238+
input_vars.append(input_var)
239+
feeds.update({'input_' + str(idx): input.numpy()})
240+
outputs = func(input_vars)
241+
outputs = _as_list(outputs)
242+
# TODO(GGBond8488): Need to be fixed when paddle uses pir by default.
243+
program, (keys, values) = paddle.base.libpaddle.pir.clone_program(
244+
paddle.static.default_main_program()
245+
)
246+
op_map = ValueDict()
247+
for key, value in zip(keys, values):
248+
op_map[key] = value
249+
pir_inputs = []
250+
for input in input_vars:
251+
pir_inputs.append(op_map[input])
252+
pir_outputs = []
253+
grads_in_init = []
254+
with paddle.static.program_guard(program):
255+
# Make sure the grad_in_var is in the program
256+
for idx, output in enumerate(outputs):
257+
pir_outputs.append(op_map[output])
258+
np_type = dtype_to_np_dtype(input.dtype)
259+
grad_in_var = paddle.static.data(
260+
'grad_in_' + str(idx), output.shape, dtype=np_type
261+
)
262+
grads_in_init.append(grad_in_var)
263+
grad_in_np = np.random.random(size=output.shape).astype(np_type)
264+
feeds.update({'grad_in_' + str(idx): grad_in_np})
265+
cotangents.append(grad_in_np)
266+
feeds, pre_outputs, d_inputs, last_grads_in = _get_static_vjp_program(
267+
pir_inputs, pir_outputs, feeds, grads_in_init, order
268+
)
269+
if not d_inputs:
270+
warning(f"{func.__name__} {order}s grad will return None")
271+
paddle.disable_static()
272+
return program, pir_inputs, d_inputs, pre_outputs, feeds, cotangents
273+
274+
275+
def _get_static_vjp_program(inputs, outputs, feeds, grads_in, order):
276+
def _require_grads(vars):
277+
for var in vars:
278+
var.stop_gradient = False
279+
var.persistable = True
280+
281+
inputs = _as_list(inputs)
282+
outputs = _as_list(outputs)
283+
_require_grads(inputs)
284+
_require_grads(outputs)
285+
_require_grads(grads_in)
286+
d_inputs = paddle.base.gradients(outputs, inputs, grads_in)
287+
d_inputs = [d_input for d_input in d_inputs if d_input is not None]
288+
_require_grads(d_inputs)
289+
290+
if order > 1:
291+
ddys = []
292+
for idx, d_input in enumerate(d_inputs):
293+
np_type = dtype_to_np_dtype(d_input.dtype)
294+
ddy = paddle.static.data(
295+
name=f'dy_{idx}_{order}',
296+
shape=d_input.shape,
297+
dtype=np_type,
298+
)
299+
ones = np.ones(d_input.shape, dtype=np_type)
300+
feeds.update({f'dy_{idx}_{order}': ones})
301+
ddys.append(ddy)
302+
_require_grads(ddys)
303+
return _get_static_vjp_program(inputs, d_inputs, feeds, ddys, order - 1)
304+
return feeds, outputs, d_inputs, grads_in
305+
306+
307+
def check_vjp(func, args, order=2, atol=None, rtol=None, eps=EPS):
308+
args = _as_list(args)
309+
np_type = dtype_to_np_dtype(args[0].dtype)
310+
atol = atol if atol else default_gradient_tolerance[np_type]
311+
rtol = rtol if rtol else default_gradient_tolerance[np_type]
312+
313+
(
314+
program,
315+
inputs,
316+
fetch_list,
317+
outputs,
318+
feeds,
319+
cotangents,
320+
) = get_static_vjp_program(func, args, order)
321+
numeric_jacobian = compute_numerical_jacobian(
322+
program, inputs, outputs, feeds, eps
323+
)
324+
cotangents = list(map(paddle.to_tensor, cotangents))
325+
eager_vjps = get_eager_vjp(func, args, cotangents, order)
326+
static_vjps_np = get_static_vjp(program, feeds, fetch_list)
327+
eager_vjps_np = []
328+
for eager_vjp in eager_vjps:
329+
eager_vjps_np.append(eager_vjp.numpy())
330+
inputs_length = len(numeric_jacobian)
331+
numeric_vjps = []
332+
for x_idx in range(inputs_length):
333+
jacobians = _as_list(numeric_jacobian[x_idx])
334+
dx_idx = None
335+
v = np.ones(static_vjps_np[x_idx].shape).astype(np_type).flatten()
336+
for y_idx in range(len(jacobians)):
337+
if dx_idx is None:
338+
dx_idx = np.dot(v, jacobians[y_idx])
339+
else:
340+
dx_idx += np.dot(v, jacobians[y_idx])
341+
numeric_vjps.append(dx_idx)
342+
eager_vjps_np = list(map(np.ndarray.flatten, eager_vjps_np))
343+
static_vjps_np = list(map(np.ndarray.flatten, static_vjps_np))
344+
345+
np.testing.assert_allclose(
346+
numeric_vjps,
347+
eager_vjps_np,
348+
atol=atol,
349+
rtol=rtol,
350+
err_msg="eager vjps is not close to numeric vjps",
351+
)
352+
np.testing.assert_allclose(
353+
numeric_vjps,
354+
static_vjps_np,
355+
atol=atol,
356+
rtol=rtol,
357+
err_msg="static vjps is not close to numeric vjps",
358+
)

0 commit comments

Comments
 (0)