2222
2323os .environ ["CUDA_VISIBLE_DEVICES" ] = "2"
2424
25- from dygraph_to_static_utils import Dy2StTestBase
25+ from dygraph_to_static_utils import Dy2StTestBase , enable_to_static_guard
2626
2727import paddle
2828from paddle import _legacy_C_ops , base
@@ -531,7 +531,6 @@ def setUp(self):
531531 self .dy_param_path = os .path .join (self .temp_dir .name , 'lac_dy_param' )
532532
533533 def train (self , args , to_static ):
534- paddle .jit .enable_to_static (to_static )
535534 place = (
536535 base .CUDAPlace (0 )
537536 if base .is_compiled_with_cuda ()
@@ -580,7 +579,9 @@ def train(self, args, to_static):
580579 num_label_chunks ,
581580 num_correct_chunks ,
582581 ) = chunk_eval (
583- input = crf_decode , label = targets , seq_length = length
582+ input = crf_decode ,
583+ label = targets ,
584+ seq_length = length ,
584585 )
585586 outputs = [avg_cost , precision , recall , f1_score ]
586587 avg_cost , precision , recall , f1_score = (
@@ -619,9 +620,13 @@ def train(self, args, to_static):
619620
620621 return np .array (loss_data )
621622
623+ def _train (self , to_static : bool ):
624+ with enable_to_static_guard (to_static ):
625+ self .train (self .args , to_static )
626+
622627 def test_train (self ):
623- st_out = self .train ( self . args , to_static = True )
624- dy_out = self .train ( self . args , to_static = False )
628+ st_out = self ._train ( to_static = True )
629+ dy_out = self ._train ( to_static = False )
625630 np .testing .assert_allclose (
626631 dy_out ,
627632 st_out ,
@@ -645,19 +650,21 @@ def verify_predict(self):
645650
646651 def predict_dygraph (self , batch ):
647652 words , targets , length = batch
648- paddle .jit .enable_to_static (False )
649- with base .dygraph .guard (self .place ):
650- model = LexNet (self .args )
651- # load dygraph trained parameters
652- model_dict = paddle .load (self .dy_param_path + ".pdparams" )
653- model .set_dict (model_dict )
654- model .eval ()
655-
656- _ , pred_res = model (
657- to_variable (words ), to_variable (targets ), to_variable (length )
658- )
653+ with enable_to_static_guard (False ):
654+ with base .dygraph .guard (self .place ):
655+ model = LexNet (self .args )
656+ # load dygraph trained parameters
657+ model_dict = paddle .load (self .dy_param_path + ".pdparams" )
658+ model .set_dict (model_dict )
659+ model .eval ()
660+
661+ _ , pred_res = model (
662+ to_variable (words ),
663+ to_variable (targets ),
664+ to_variable (length ),
665+ )
659666
660- return pred_res .numpy ()
667+ return pred_res .numpy ()
661668
662669 def predict_static (self , batch ):
663670 """
0 commit comments