Skip to content

Commit 389da7d

Browse files
authored
[Dy2St][tests][46-49, 16] paddle.jit.enable_to_static->enable_to_static_guard (#59730)
1 parent a38ab8b commit 389da7d

File tree

5 files changed

+199
-207
lines changed

5 files changed

+199
-207
lines changed

test/dygraph_to_static/test_lac.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
2424

25-
from dygraph_to_static_utils import Dy2StTestBase
25+
from dygraph_to_static_utils import Dy2StTestBase, enable_to_static_guard
2626

2727
import paddle
2828
from paddle import _legacy_C_ops, base
@@ -531,7 +531,6 @@ def setUp(self):
531531
self.dy_param_path = os.path.join(self.temp_dir.name, 'lac_dy_param')
532532

533533
def train(self, args, to_static):
534-
paddle.jit.enable_to_static(to_static)
535534
place = (
536535
base.CUDAPlace(0)
537536
if base.is_compiled_with_cuda()
@@ -580,7 +579,9 @@ def train(self, args, to_static):
580579
num_label_chunks,
581580
num_correct_chunks,
582581
) = chunk_eval(
583-
input=crf_decode, label=targets, seq_length=length
582+
input=crf_decode,
583+
label=targets,
584+
seq_length=length,
584585
)
585586
outputs = [avg_cost, precision, recall, f1_score]
586587
avg_cost, precision, recall, f1_score = (
@@ -619,9 +620,13 @@ def train(self, args, to_static):
619620

620621
return np.array(loss_data)
621622

623+
def _train(self, to_static: bool):
624+
with enable_to_static_guard(to_static):
625+
self.train(self.args, to_static)
626+
622627
def test_train(self):
623-
st_out = self.train(self.args, to_static=True)
624-
dy_out = self.train(self.args, to_static=False)
628+
st_out = self._train(to_static=True)
629+
dy_out = self._train(to_static=False)
625630
np.testing.assert_allclose(
626631
dy_out,
627632
st_out,
@@ -645,19 +650,21 @@ def verify_predict(self):
645650

646651
def predict_dygraph(self, batch):
647652
words, targets, length = batch
648-
paddle.jit.enable_to_static(False)
649-
with base.dygraph.guard(self.place):
650-
model = LexNet(self.args)
651-
# load dygraph trained parameters
652-
model_dict = paddle.load(self.dy_param_path + ".pdparams")
653-
model.set_dict(model_dict)
654-
model.eval()
655-
656-
_, pred_res = model(
657-
to_variable(words), to_variable(targets), to_variable(length)
658-
)
653+
with enable_to_static_guard(False):
654+
with base.dygraph.guard(self.place):
655+
model = LexNet(self.args)
656+
# load dygraph trained parameters
657+
model_dict = paddle.load(self.dy_param_path + ".pdparams")
658+
model.set_dict(model_dict)
659+
model.eval()
660+
661+
_, pred_res = model(
662+
to_variable(words),
663+
to_variable(targets),
664+
to_variable(length),
665+
)
659666

660-
return pred_res.numpy()
667+
return pred_res.numpy()
661668

662669
def predict_static(self, batch):
663670
"""

0 commit comments

Comments
 (0)