2323from dygraph_to_static_utils import (
2424 Dy2StTestBase ,
2525 enable_to_static_guard ,
26+ test_legacy_and_pt_and_pir ,
2627 test_sot_only ,
2728)
2829from predictor_utils import PredictorTools
2930
3031import paddle
3132from paddle import base
3233from paddle .base import core
34+ from paddle .base .framework import unique_name
35+ from paddle .framework import use_pir_api
3336from paddle .jit .translated_layer import INFER_MODEL_SUFFIX , INFER_PARAMS_SUFFIX
3437
35- place = base .CUDAPlace (0 ) if base .is_compiled_with_cuda () else base .CPUPlace ()
38+ place = (
39+ paddle .CUDAPlace (0 ) if paddle .is_compiled_with_cuda () else paddle .CPUPlace ()
40+ )
3641SEED = 2020
3742STEP_NUM = 10
3843PRINT_STEP = 2
@@ -95,7 +100,7 @@ def tearDown(self):
95100 self .temp_dir .cleanup ()
96101
97102 def train (self , bert_config , data_reader , to_static ):
98- with base . dygraph . guard (place ):
103+ with unique_name . guard ():
99104 base .default_main_program ().random_seed = SEED
100105 base .default_startup_program ().random_seed = SEED
101106
@@ -158,7 +163,9 @@ def train(self, bert_config, data_reader, to_static):
158163 step_idx += 1
159164 if step_idx == STEP_NUM :
160165 if to_static :
161- paddle .jit .save (bert , self .model_save_prefix )
166+ # TODO(pir-save-load): Fix this after we support save/load in PIR
167+ if not use_pir_api ():
168+ paddle .jit .save (bert , self .model_save_prefix )
162169 else :
163170 paddle .save (
164171 bert .state_dict (),
@@ -172,8 +179,7 @@ def train_dygraph(self, bert_config, data_reader):
172179 return self .train (bert_config , data_reader , False )
173180
174181 def train_static (self , bert_config , data_reader ):
175- with enable_to_static_guard (True ):
176- return self .train (bert_config , data_reader , True )
182+ return self .train (bert_config , data_reader , True )
177183
178184 def predict_static (self , data ):
179185 paddle .enable_static ()
@@ -195,11 +201,12 @@ def predict_static(self, data):
195201 fetch_list = fetch_targets ,
196202 )
197203
204+ paddle .disable_static ()
198205 return pred_res
199206
200207 def predict_dygraph (self , bert_config , data ):
201208 with enable_to_static_guard (False ):
202- with base . dygraph . guard (place ):
209+ with unique_name . guard ():
203210 bert = PretrainModelLayer (
204211 config = bert_config , weight_sharing = False , use_fp16 = False
205212 )
@@ -210,7 +217,7 @@ def predict_dygraph(self, bert_config, data):
210217 bert .set_dict (model_dict )
211218 bert .eval ()
212219
213- input_vars = [base . dygraph . to_variable (x ) for x in data ]
220+ input_vars = [paddle . to_tensor (x ) for x in data ]
214221 (
215222 src_ids ,
216223 pos_ids ,
@@ -234,31 +241,30 @@ def predict_dygraph(self, bert_config, data):
234241 return pred_res
235242
236243 def predict_dygraph_jit (self , data ):
237- with base .dygraph .guard (place ):
238- bert = paddle .jit .load (self .model_save_prefix )
239- bert .eval ()
240-
241- (
242- src_ids ,
243- pos_ids ,
244- sent_ids ,
245- input_mask ,
246- mask_label ,
247- mask_pos ,
248- labels ,
249- ) = data
250- pred_res = bert (
251- src_ids ,
252- pos_ids ,
253- sent_ids ,
254- input_mask ,
255- mask_label ,
256- mask_pos ,
257- labels ,
258- )
259- pred_res = [var .numpy () for var in pred_res ]
244+ bert = paddle .jit .load (self .model_save_prefix )
245+ bert .eval ()
246+
247+ (
248+ src_ids ,
249+ pos_ids ,
250+ sent_ids ,
251+ input_mask ,
252+ mask_label ,
253+ mask_pos ,
254+ labels ,
255+ ) = data
256+ pred_res = bert (
257+ src_ids ,
258+ pos_ids ,
259+ sent_ids ,
260+ input_mask ,
261+ mask_label ,
262+ mask_pos ,
263+ labels ,
264+ )
265+ pred_res = [var .numpy () for var in pred_res ]
260266
261- return pred_res
267+ return pred_res
262268
263269 def predict_analysis_inference (self , data ):
264270 output = PredictorTools (
@@ -267,6 +273,7 @@ def predict_analysis_inference(self, data):
267273 out = output ()
268274 return out
269275
276+ @test_legacy_and_pt_and_pir
270277 def test_train (self ):
271278 static_loss , static_ppl = self .train_static (
272279 self .bert_config , self .data_reader
@@ -280,6 +287,7 @@ def test_train(self):
280287 self .verify_predict ()
281288
282289 @test_sot_only
290+ @test_legacy_and_pt_and_pir
283291 def test_train_composite (self ):
284292 core ._set_prim_backward_enabled (True )
285293 # core._add_skip_comp_ops("layer_norm")
@@ -297,43 +305,45 @@ def test_train_composite(self):
297305 def verify_predict (self ):
298306 for data in self .data_reader .data_generator ()():
299307 dygraph_pred_res = self .predict_dygraph (self .bert_config , data )
300- static_pred_res = self .predict_static (data )
301- dygraph_jit_pred_res = self .predict_dygraph_jit (data )
302- predictor_pred_res = self .predict_analysis_inference (data )
303-
304- for dy_res , st_res , dy_jit_res , predictor_res in zip (
305- dygraph_pred_res ,
306- static_pred_res ,
307- dygraph_jit_pred_res ,
308- predictor_pred_res ,
309- ):
310- np .testing .assert_allclose (
311- st_res ,
312- dy_res ,
313- rtol = 1e-05 ,
314- err_msg = 'dygraph_res: {},\n static_res: {}' .format (
315- dy_res [~ np .isclose (st_res , dy_res )],
316- st_res [~ np .isclose (st_res , dy_res )],
317- ),
318- )
319- np .testing .assert_allclose (
320- st_res ,
321- dy_jit_res ,
322- rtol = 1e-05 ,
323- err_msg = 'dygraph_jit_res: {},\n static_res: {}' .format (
324- dy_jit_res [~ np .isclose (st_res , dy_jit_res )],
325- st_res [~ np .isclose (st_res , dy_jit_res )],
326- ),
327- )
328- np .testing .assert_allclose (
329- st_res ,
330- predictor_res ,
331- rtol = 1e-05 ,
332- err_msg = 'dygraph_jit_res: {},\n static_res: {}' .format (
333- predictor_res [~ np .isclose (st_res , predictor_res )],
334- st_res [~ np .isclose (st_res , predictor_res )],
335- ),
336- )
308+ # TODO(pir-save-load): Fix this after we support save/load in PIR
309+ if not use_pir_api ():
310+ static_pred_res = self .predict_static (data )
311+ dygraph_jit_pred_res = self .predict_dygraph_jit (data )
312+ predictor_pred_res = self .predict_analysis_inference (data )
313+
314+ for dy_res , st_res , dy_jit_res , predictor_res in zip (
315+ dygraph_pred_res ,
316+ static_pred_res ,
317+ dygraph_jit_pred_res ,
318+ predictor_pred_res ,
319+ ):
320+ np .testing .assert_allclose (
321+ st_res ,
322+ dy_res ,
323+ rtol = 1e-05 ,
324+ err_msg = 'dygraph_res: {},\n static_res: {}' .format (
325+ dy_res [~ np .isclose (st_res , dy_res )],
326+ st_res [~ np .isclose (st_res , dy_res )],
327+ ),
328+ )
329+ np .testing .assert_allclose (
330+ st_res ,
331+ dy_jit_res ,
332+ rtol = 1e-05 ,
333+ err_msg = 'dygraph_jit_res: {},\n static_res: {}' .format (
334+ dy_jit_res [~ np .isclose (st_res , dy_jit_res )],
335+ st_res [~ np .isclose (st_res , dy_jit_res )],
336+ ),
337+ )
338+ np .testing .assert_allclose (
339+ st_res ,
340+ predictor_res ,
341+ rtol = 1e-05 ,
342+ err_msg = 'dygraph_jit_res: {},\n static_res: {}' .format (
343+ predictor_res [~ np .isclose (st_res , predictor_res )],
344+ st_res [~ np .isclose (st_res , predictor_res )],
345+ ),
346+ )
337347 break
338348
339349
0 commit comments