Skip to content

Commit aef9d6d

Browse files
cholesky and cholesky_solve tests (#60726)
1 parent 0ac9c29 commit aef9d6d

File tree

3 files changed

+58
-20
lines changed

3 files changed

+58
-20
lines changed

test/legacy_test/gradient_checker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def _compute_analytical_jacobian_pir(
353353
def grad_check(
354354
x,
355355
y,
356-
x_init=None,
356+
fetch_list=None,
357357
feeds=None,
358358
place=None,
359359
program=None,
@@ -403,12 +403,12 @@ def fail_test(msg):
403403
for i in range(len(y)):
404404
analytical.append(
405405
_compute_analytical_jacobian_pir(
406-
program, x, i, y, x_init, feeds, place
406+
program, x, i, y, fetch_list, feeds, place
407407
)
408408
)
409409
numerical = [
410410
_compute_numerical_jacobian_pir(
411-
program, xi, y, x_init, feeds, place, eps
411+
program, xi, y, fetch_list, feeds, place, eps
412412
)
413413
for xi in x
414414
]

test/legacy_test/test_cholesky_op.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from paddle import base
2424
from paddle.base import core
2525
from paddle.base.backward import _as_list
26+
from paddle.pir_utils import test_with_pir_api
2627

2728

2829
@skip_check_grad_ci(
@@ -68,33 +69,38 @@ def test_check_grad(self):
6869
for p in places:
6970
self.func(p)
7071

72+
@test_with_pir_api
7173
@prog_scope()
7274
def func(self, place):
7375
# use small size since Jacobian gradients is time consuming
7476
root_data = self.root_data[..., :3, :3]
75-
prog = base.Program()
76-
with base.program_guard(prog):
77-
root = paddle.create_parameter(
78-
dtype=root_data.dtype, shape=root_data.shape
79-
)
77+
prog = paddle.static.Program()
78+
with paddle.static.program_guard(prog):
79+
if paddle.framework.in_pir_mode():
80+
root = paddle.static.data(
81+
dtype=root_data.dtype, shape=root_data.shape, name="root"
82+
)
83+
else:
84+
root = paddle.create_parameter(
85+
dtype=root_data.dtype, shape=root_data.shape
86+
)
87+
root.stop_gradient = False
88+
root.persistable = True
8089
root_t = paddle.transpose(root, self.trans_dims)
8190
x = paddle.matmul(x=root, y=root_t) + 1e-05
8291
out = paddle.cholesky(x, upper=self.attrs["upper"])
8392
# check input arguments
8493
root = _as_list(root)
8594
out = _as_list(out)
8695

87-
for v in root:
88-
v.stop_gradient = False
89-
v.persistable = True
9096
for u in out:
9197
u.stop_gradient = False
9298
u.persistable = True
9399

94100
# init variable in startup program
95101
scope = base.executor.global_scope()
96102
exe = base.Executor(place)
97-
exe.run(base.default_startup_program())
103+
exe.run(paddle.static.default_startup_program())
98104

99105
x_init = _as_list(root_data)
100106
# init inputs if x_init is not None
@@ -106,10 +112,33 @@ def func(self, place):
106112
)
107113
# init variable in main program
108114
for var, arr in zip(root, x_init):
109-
assert var.shape == arr.shape
115+
assert tuple(var.shape) == tuple(arr.shape)
110116
feeds = {k.name: v for k, v in zip(root, x_init)}
111117
exe.run(prog, feed=feeds, scope=scope)
112-
grad_check(x=root, y=out, x_init=x_init, place=place, program=prog)
118+
fetch_list = None
119+
if paddle.framework.in_pir_mode():
120+
dys = []
121+
for i in range(len(out)):
122+
yi = out[i]
123+
dy = paddle.static.data(
124+
name='dys_%s' % i,
125+
shape=yi.shape,
126+
dtype=root_data.dtype,
127+
)
128+
dy.stop_gradient = False
129+
dy.persistable = True
130+
value = np.zeros(yi.shape, dtype=root_data.dtype)
131+
feeds.update({'dys_%s' % i: value})
132+
dys.append(dy)
133+
fetch_list = base.gradients(out, root, dys)
134+
grad_check(
135+
x=root,
136+
y=out,
137+
fetch_list=fetch_list,
138+
feeds=feeds,
139+
place=place,
140+
program=prog,
141+
)
113142

114143
def init_config(self):
115144
self._upper = True
@@ -144,8 +173,11 @@ def setUp(self):
144173
if core.is_compiled_with_cuda() and (not core.is_compiled_with_rocm()):
145174
self.places.append(base.CUDAPlace(0))
146175

176+
@test_with_pir_api
147177
def check_static_result(self, place, with_out=False):
148-
with base.program_guard(base.Program(), base.Program()):
178+
with paddle.static.program_guard(
179+
paddle.static.Program(), paddle.static.Program()
180+
):
149181
input = paddle.static.data(
150182
name="input", shape=[4, 4], dtype="float64"
151183
)
@@ -156,7 +188,6 @@ def check_static_result(self, place, with_out=False):
156188
exe = base.Executor(place)
157189
try:
158190
fetches = exe.run(
159-
base.default_main_program(),
160191
feed={"input": input_np},
161192
fetch_list=[result],
162193
)

test/legacy_test/test_cholesky_solve_op.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import paddle
2626
from paddle import base
2727
from paddle.base import Program, core, program_guard
28+
from paddle.pir_utils import test_with_pir_api
2829

2930
paddle.enable_static()
3031

@@ -143,7 +144,7 @@ def test_check_output(self):
143144

144145
# check Op grad
145146
def test_check_grad_normal(self):
146-
self.check_grad(['Y'], 'Out', max_relative_error=0.01)
147+
self.check_grad(['Y'], 'Out', max_relative_error=0.01, check_pir=True)
147148

148149

149150
# test condition: 3D(broadcast) + 3D, upper=True
@@ -169,9 +170,12 @@ def setUp(self):
169170
if core.is_compiled_with_cuda():
170171
self.place.append(paddle.CUDAPlace(0))
171172

173+
@test_with_pir_api
172174
def check_static_result(self, place):
173175
paddle.enable_static()
174-
with base.program_guard(base.Program(), base.Program()):
176+
with paddle.static.program_guard(
177+
paddle.static.Program(), paddle.static.Program()
178+
):
175179
x = paddle.static.data(name="x", shape=[10, 2], dtype=self.dtype)
176180
y = paddle.static.data(name="y", shape=[10, 10], dtype=self.dtype)
177181
z = paddle.linalg.cholesky_solve(x, y, upper=self.upper)
@@ -187,7 +191,6 @@ def check_static_result(self, place):
187191

188192
exe = base.Executor(place)
189193
fetches = exe.run(
190-
base.default_main_program(),
191194
feed={"x": x_np, "y": umat},
192195
fetch_list=[z],
193196
)
@@ -239,7 +242,7 @@ def run(place):
239242

240243
# test condition out of bounds
241244
class TestCholeskySolveOpError(unittest.TestCase):
242-
def test_errors(self):
245+
def test_errors_1(self):
243246
paddle.enable_static()
244247
with program_guard(Program(), Program()):
245248
# The input type of solve_op must be Variable.
@@ -251,6 +254,10 @@ def test_errors(self):
251254
)
252255
self.assertRaises(TypeError, paddle.linalg.cholesky_solve, x1, y1)
253256

257+
@test_with_pir_api
258+
def test_errors_2(self):
259+
paddle.enable_static()
260+
with program_guard(Program(), Program()):
254261
# The data type of input must be float32 or float64.
255262
x2 = paddle.static.data(name="x2", shape=[30, 30], dtype="bool")
256263
y2 = paddle.static.data(name="y2", shape=[30, 10], dtype="bool")

0 commit comments

Comments
 (0)