Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions scripts/machine_translation/train_gnmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,9 @@ def train():
valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out)
logging.info('[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
.format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
test_loss, test_translation_out = evaluate(test_data_loader)
test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out)
logging.info('[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'
.format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100))
dataprocessor.write_sentences(valid_translation_out,
os.path.join(args.save_dir,
'epoch{:d}_valid_out.txt').format(epoch_id))
dataprocessor.write_sentences(test_translation_out,
os.path.join(args.save_dir,
'epoch{:d}_test_out.txt').format(epoch_id))
if valid_bleu_score > best_valid_bleu:
best_valid_bleu = valid_bleu_score
save_path = os.path.join(args.save_dir, 'valid_best.params')
Expand Down
10 changes: 0 additions & 10 deletions scripts/machine_translation/train_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,19 +350,9 @@ def train():
bpe=bpe)
logging.info('[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
.format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
test_loss, test_translation_out = evaluate(test_data_loader, ctx[0])
test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out,
tokenized=tokenized, tokenizer=args.bleu,
split_compound_word=split_compound_word,
bpe=bpe)
logging.info('[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'
.format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100))
dataprocessor.write_sentences(valid_translation_out,
os.path.join(args.save_dir,
'epoch{:d}_valid_out.txt').format(epoch_id))
dataprocessor.write_sentences(test_translation_out,
os.path.join(args.save_dir,
'epoch{:d}_test_out.txt').format(epoch_id))
if valid_bleu_score > best_valid_bleu:
best_valid_bleu = valid_bleu_score
save_path = os.path.join(args.save_dir, 'valid_best.params')
Expand Down