Skip to content

Commit 7c0c2e8

Browse files
authored
[Dy2St][NO.2] pir dy2st unittest fix test_bert - Part 1 (#60164)
1 parent 338f2da commit 7c0c2e8

File tree

1 file changed

+78
-68
lines changed

1 file changed

+78
-68
lines changed

test/dygraph_to_static/test_bert.py

Lines changed: 78 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,21 @@
2323
from dygraph_to_static_utils import (
2424
Dy2StTestBase,
2525
enable_to_static_guard,
26+
test_legacy_and_pt_and_pir,
2627
test_sot_only,
2728
)
2829
from predictor_utils import PredictorTools
2930

3031
import paddle
3132
from paddle import base
3233
from paddle.base import core
34+
from paddle.base.framework import unique_name
35+
from paddle.framework import use_pir_api
3336
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
3437

35-
place = base.CUDAPlace(0) if base.is_compiled_with_cuda() else base.CPUPlace()
38+
place = (
39+
paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
40+
)
3641
SEED = 2020
3742
STEP_NUM = 10
3843
PRINT_STEP = 2
@@ -95,7 +100,7 @@ def tearDown(self):
95100
self.temp_dir.cleanup()
96101

97102
def train(self, bert_config, data_reader, to_static):
98-
with base.dygraph.guard(place):
103+
with unique_name.guard():
99104
base.default_main_program().random_seed = SEED
100105
base.default_startup_program().random_seed = SEED
101106

@@ -158,7 +163,9 @@ def train(self, bert_config, data_reader, to_static):
158163
step_idx += 1
159164
if step_idx == STEP_NUM:
160165
if to_static:
161-
paddle.jit.save(bert, self.model_save_prefix)
166+
# TODO(pir-save-load): Fix this after we support save/load in PIR
167+
if not use_pir_api():
168+
paddle.jit.save(bert, self.model_save_prefix)
162169
else:
163170
paddle.save(
164171
bert.state_dict(),
@@ -172,8 +179,7 @@ def train_dygraph(self, bert_config, data_reader):
172179
return self.train(bert_config, data_reader, False)
173180

174181
def train_static(self, bert_config, data_reader):
175-
with enable_to_static_guard(True):
176-
return self.train(bert_config, data_reader, True)
182+
return self.train(bert_config, data_reader, True)
177183

178184
def predict_static(self, data):
179185
paddle.enable_static()
@@ -195,11 +201,12 @@ def predict_static(self, data):
195201
fetch_list=fetch_targets,
196202
)
197203

204+
paddle.disable_static()
198205
return pred_res
199206

200207
def predict_dygraph(self, bert_config, data):
201208
with enable_to_static_guard(False):
202-
with base.dygraph.guard(place):
209+
with unique_name.guard():
203210
bert = PretrainModelLayer(
204211
config=bert_config, weight_sharing=False, use_fp16=False
205212
)
@@ -210,7 +217,7 @@ def predict_dygraph(self, bert_config, data):
210217
bert.set_dict(model_dict)
211218
bert.eval()
212219

213-
input_vars = [base.dygraph.to_variable(x) for x in data]
220+
input_vars = [paddle.to_tensor(x) for x in data]
214221
(
215222
src_ids,
216223
pos_ids,
@@ -234,31 +241,30 @@ def predict_dygraph(self, bert_config, data):
234241
return pred_res
235242

236243
def predict_dygraph_jit(self, data):
237-
with base.dygraph.guard(place):
238-
bert = paddle.jit.load(self.model_save_prefix)
239-
bert.eval()
240-
241-
(
242-
src_ids,
243-
pos_ids,
244-
sent_ids,
245-
input_mask,
246-
mask_label,
247-
mask_pos,
248-
labels,
249-
) = data
250-
pred_res = bert(
251-
src_ids,
252-
pos_ids,
253-
sent_ids,
254-
input_mask,
255-
mask_label,
256-
mask_pos,
257-
labels,
258-
)
259-
pred_res = [var.numpy() for var in pred_res]
244+
bert = paddle.jit.load(self.model_save_prefix)
245+
bert.eval()
246+
247+
(
248+
src_ids,
249+
pos_ids,
250+
sent_ids,
251+
input_mask,
252+
mask_label,
253+
mask_pos,
254+
labels,
255+
) = data
256+
pred_res = bert(
257+
src_ids,
258+
pos_ids,
259+
sent_ids,
260+
input_mask,
261+
mask_label,
262+
mask_pos,
263+
labels,
264+
)
265+
pred_res = [var.numpy() for var in pred_res]
260266

261-
return pred_res
267+
return pred_res
262268

263269
def predict_analysis_inference(self, data):
264270
output = PredictorTools(
@@ -267,6 +273,7 @@ def predict_analysis_inference(self, data):
267273
out = output()
268274
return out
269275

276+
@test_legacy_and_pt_and_pir
270277
def test_train(self):
271278
static_loss, static_ppl = self.train_static(
272279
self.bert_config, self.data_reader
@@ -280,6 +287,7 @@ def test_train(self):
280287
self.verify_predict()
281288

282289
@test_sot_only
290+
@test_legacy_and_pt_and_pir
283291
def test_train_composite(self):
284292
core._set_prim_backward_enabled(True)
285293
# core._add_skip_comp_ops("layer_norm")
@@ -297,43 +305,45 @@ def test_train_composite(self):
297305
def verify_predict(self):
298306
for data in self.data_reader.data_generator()():
299307
dygraph_pred_res = self.predict_dygraph(self.bert_config, data)
300-
static_pred_res = self.predict_static(data)
301-
dygraph_jit_pred_res = self.predict_dygraph_jit(data)
302-
predictor_pred_res = self.predict_analysis_inference(data)
303-
304-
for dy_res, st_res, dy_jit_res, predictor_res in zip(
305-
dygraph_pred_res,
306-
static_pred_res,
307-
dygraph_jit_pred_res,
308-
predictor_pred_res,
309-
):
310-
np.testing.assert_allclose(
311-
st_res,
312-
dy_res,
313-
rtol=1e-05,
314-
err_msg='dygraph_res: {},\n static_res: {}'.format(
315-
dy_res[~np.isclose(st_res, dy_res)],
316-
st_res[~np.isclose(st_res, dy_res)],
317-
),
318-
)
319-
np.testing.assert_allclose(
320-
st_res,
321-
dy_jit_res,
322-
rtol=1e-05,
323-
err_msg='dygraph_jit_res: {},\n static_res: {}'.format(
324-
dy_jit_res[~np.isclose(st_res, dy_jit_res)],
325-
st_res[~np.isclose(st_res, dy_jit_res)],
326-
),
327-
)
328-
np.testing.assert_allclose(
329-
st_res,
330-
predictor_res,
331-
rtol=1e-05,
332-
err_msg='dygraph_jit_res: {},\n static_res: {}'.format(
333-
predictor_res[~np.isclose(st_res, predictor_res)],
334-
st_res[~np.isclose(st_res, predictor_res)],
335-
),
336-
)
308+
# TODO(pir-save-load): Fix this after we support save/load in PIR
309+
if not use_pir_api():
310+
static_pred_res = self.predict_static(data)
311+
dygraph_jit_pred_res = self.predict_dygraph_jit(data)
312+
predictor_pred_res = self.predict_analysis_inference(data)
313+
314+
for dy_res, st_res, dy_jit_res, predictor_res in zip(
315+
dygraph_pred_res,
316+
static_pred_res,
317+
dygraph_jit_pred_res,
318+
predictor_pred_res,
319+
):
320+
np.testing.assert_allclose(
321+
st_res,
322+
dy_res,
323+
rtol=1e-05,
324+
err_msg='dygraph_res: {},\n static_res: {}'.format(
325+
dy_res[~np.isclose(st_res, dy_res)],
326+
st_res[~np.isclose(st_res, dy_res)],
327+
),
328+
)
329+
np.testing.assert_allclose(
330+
st_res,
331+
dy_jit_res,
332+
rtol=1e-05,
333+
err_msg='dygraph_jit_res: {},\n static_res: {}'.format(
334+
dy_jit_res[~np.isclose(st_res, dy_jit_res)],
335+
st_res[~np.isclose(st_res, dy_jit_res)],
336+
),
337+
)
338+
np.testing.assert_allclose(
339+
st_res,
340+
predictor_res,
341+
rtol=1e-05,
342+
err_msg='dygraph_jit_res: {},\n static_res: {}'.format(
343+
predictor_res[~np.isclose(st_res, predictor_res)],
344+
st_res[~np.isclose(st_res, predictor_res)],
345+
),
346+
)
337347
break
338348

339349

0 commit comments

Comments
 (0)