2323from paddle import base
2424from paddle .base import core
2525from paddle .base .backward import _as_list
26+ from paddle .pir_utils import test_with_pir_api
2627
2728
2829@skip_check_grad_ci (
@@ -68,33 +69,38 @@ def test_check_grad(self):
6869 for p in places :
6970 self .func (p )
7071
72+ @test_with_pir_api
7173 @prog_scope ()
7274 def func (self , place ):
7375 # use small size since Jacobian gradients is time consuming
7476 root_data = self .root_data [..., :3 , :3 ]
75- prog = base .Program ()
76- with base .program_guard (prog ):
77- root = paddle .create_parameter (
78- dtype = root_data .dtype , shape = root_data .shape
79- )
77+ prog = paddle .static .Program ()
78+ with paddle .static .program_guard (prog ):
79+ if paddle .framework .in_pir_mode ():
80+ root = paddle .static .data (
81+ dtype = root_data .dtype , shape = root_data .shape , name = "root"
82+ )
83+ else :
84+ root = paddle .create_parameter (
85+ dtype = root_data .dtype , shape = root_data .shape
86+ )
87+ root .stop_gradient = False
88+ root .persistable = True
8089 root_t = paddle .transpose (root , self .trans_dims )
8190 x = paddle .matmul (x = root , y = root_t ) + 1e-05
8291 out = paddle .cholesky (x , upper = self .attrs ["upper" ])
8392 # check input arguments
8493 root = _as_list (root )
8594 out = _as_list (out )
8695
87- for v in root :
88- v .stop_gradient = False
89- v .persistable = True
9096 for u in out :
9197 u .stop_gradient = False
9298 u .persistable = True
9399
94100 # init variable in startup program
95101 scope = base .executor .global_scope ()
96102 exe = base .Executor (place )
97- exe .run (base .default_startup_program ())
103+ exe .run (paddle . static .default_startup_program ())
98104
99105 x_init = _as_list (root_data )
100106 # init inputs if x_init is not None
@@ -106,10 +112,33 @@ def func(self, place):
106112 )
107113 # init variable in main program
108114 for var , arr in zip (root , x_init ):
109- assert var .shape == arr .shape
115+ assert tuple ( var .shape ) == tuple ( arr .shape )
110116 feeds = {k .name : v for k , v in zip (root , x_init )}
111117 exe .run (prog , feed = feeds , scope = scope )
112- grad_check (x = root , y = out , x_init = x_init , place = place , program = prog )
118+ fetch_list = None
119+ if paddle .framework .in_pir_mode ():
120+ dys = []
121+ for i in range (len (out )):
122+ yi = out [i ]
123+ dy = paddle .static .data (
124+ name = 'dys_%s' % i ,
125+ shape = yi .shape ,
126+ dtype = root_data .dtype ,
127+ )
128+ dy .stop_gradient = False
129+ dy .persistable = True
130+ value = np .zeros (yi .shape , dtype = root_data .dtype )
131+ feeds .update ({'dys_%s' % i : value })
132+ dys .append (dy )
133+ fetch_list = base .gradients (out , root , dys )
134+ grad_check (
135+ x = root ,
136+ y = out ,
137+ fetch_list = fetch_list ,
138+ feeds = feeds ,
139+ place = place ,
140+ program = prog ,
141+ )
113142
114143 def init_config (self ):
115144 self ._upper = True
@@ -144,8 +173,11 @@ def setUp(self):
144173 if core .is_compiled_with_cuda () and (not core .is_compiled_with_rocm ()):
145174 self .places .append (base .CUDAPlace (0 ))
146175
176+ @test_with_pir_api
147177 def check_static_result (self , place , with_out = False ):
148- with base .program_guard (base .Program (), base .Program ()):
178+ with paddle .static .program_guard (
179+ paddle .static .Program (), paddle .static .Program ()
180+ ):
149181 input = paddle .static .data (
150182 name = "input" , shape = [4 , 4 ], dtype = "float64"
151183 )
@@ -156,7 +188,6 @@ def check_static_result(self, place, with_out=False):
156188 exe = base .Executor (place )
157189 try :
158190 fetches = exe .run (
159- base .default_main_program (),
160191 feed = {"input" : input_np },
161192 fetch_list = [result ],
162193 )
0 commit comments