Skip to content

Commit d466171

Browse files
authored
Merge pull request PaddlePaddle#33 from smallv0221/yxp0209
add seq2seq exprimental dataset and fix DureaderYesNo bug
2 parents 54a6f7c + bb96d3d commit d466171

File tree

3 files changed

+96
-3
lines changed

3 files changed

+96
-3
lines changed

examples/machine_reading_comprehension/DuReader-yesno/run_du.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _concat_seqs(seqs, separators, seq_mask=0, separator_mask=1):
9191
qas_id = example[-1]
9292
example = example[:-2]
9393
# tokenize raw text
94-
tokens_raw = [tokenizer(l) for l in example]
94+
tokens_raw = [tokenizer.tokenize(l) for l in example]
9595
# truncate to the truncate_length,
9696
tokens_trun = _truncate_seqs(tokens_raw, max_seq_length)
9797
# concate the sequences with special tokens

paddlenlp/datasets/experimental/dataset.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,6 @@ def read(self, root):
361361
In this case your implementation of `_read()` must also be lazy
362362
(that is, not load all examples into memory at once).
363363
"""
364-
if not isinstance(root, str):
365-
root = str(root)
366364

367365
if self.lazy:
368366
label_list = self.get_labels()
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import collections
16+
import os
17+
import warnings
18+
19+
from paddle.io import Dataset
20+
from paddle.dataset.common import md5file
21+
from paddle.utils.download import get_path_from_url
22+
from paddlenlp.utils.env import DATA_HOME
23+
from . import DatasetBuilder
24+
25+
__all__ = ['WMT14ende']
26+
27+
28+
class WMT14ende(DatasetBuilder):
29+
URL = "https://paddlenlp.bj.bcebos.com/datasets/WMT14.en-de.tar.gz"
30+
MD5 = None
31+
META_INFO = collections.namedtuple('META_INFO', ('src_file', 'tgt_file',
32+
'src_md5', 'tgt_md5'))
33+
SPLITS = {
34+
'train': META_INFO(
35+
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
36+
"train.tok.clean.bpe.33708.en"),
37+
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
38+
"train.tok.clean.bpe.33708.de"),
39+
"c7c0b77e672fc69f20be182ae37ff62c", None),
40+
'dev': META_INFO(
41+
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
42+
"newstest2013.tok.bpe.33708.en"),
43+
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
44+
"newstest2013.tok.bpe.33708.de"),
45+
"aa4228a4bedb6c45d67525fbfbcee75e",
46+
"9b1eeaff43a6d5e78a381a9b03170501"),
47+
'test': META_INFO(
48+
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
49+
"newstest2014.tok.bpe.33708.en"),
50+
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
51+
"newstest2014.tok.bpe.33708.de"),
52+
"c9403eacf623c6e2d9e5a1155bdff0b5",
53+
"0058855b55e37c4acfcb8cffecba1050"),
54+
'dev-eval': META_INFO(
55+
os.path.join("WMT14.en-de", "wmt14_ende_data",
56+
"newstest2013.tok.en"),
57+
os.path.join("WMT14.en-de", "wmt14_ende_data",
58+
"newstest2013.tok.de"),
59+
"d74712eb35578aec022265c439831b0e",
60+
"6ff76ced35b70e63a61ecec77a1c418f"),
61+
'test-eval': META_INFO(
62+
os.path.join("WMT14.en-de", "wmt14_ende_data",
63+
"newstest2014.tok.en"),
64+
os.path.join("WMT14.en-de", "wmt14_ende_data",
65+
"newstest2014.tok.de"),
66+
"8cce2028e4ca3d4cc039dfd33adbfb43",
67+
"a1b1f4c47f487253e1ac88947b68b3b8")
68+
}
69+
70+
def _get_data(self, mode, **kwargs):
71+
default_root = os.path.join(DATA_HOME, self.__class__.__name__)
72+
src_filename, tgt_filename, src_data_hash, tgt_data_hash = self.SPLITS[
73+
mode]
74+
src_fullname = os.path.join(default_root, src_filename)
75+
tgt_fullname = os.path.join(default_root, tgt_filename)
76+
77+
if (not os.path.exists(src_fullname) or
78+
(src_data_hash and not md5file(src_fullname) == src_data_hash)) or (
79+
not os.path.exists(tgt_fullname) or
80+
(tgt_data_hash and not md5file(tgt_fullname) == tgt_data_hash)):
81+
get_path_from_url(self.URL, default_root, self.MD5)
82+
83+
return src_fullname, tgt_fullname
84+
85+
def _read(self, filename):
86+
src_filename, tgt_filename = filename
87+
with open(src_filename, 'r', encoding='utf-8') as src_f:
88+
with open(tgt_filename, 'r', encoding='utf-8') as tgt_f:
89+
for src_line, tgt_line in zip(src_f, tgt_f):
90+
src_line = src_line.strip()
91+
tgt_line = tgt_line.strip()
92+
if not src_line or not tgt_line:
93+
break
94+
95+
yield {"source": src_line, "target": tgt_line}

0 commit comments

Comments
 (0)