Skip to content

Commit f58cb01

Browse files
authored
【Paddle.Fleet】fix dataset zip py3 bug (#31441)
* fix zip py3 bug
1 parent bf09dcb commit f58cb01

File tree

2 files changed

+56
-10
lines changed

2 files changed

+56
-10
lines changed

python/paddle/distributed/fleet/data_generator/data_generator.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ def set_batch(self, batch_size):
3232
'''
3333
Set batch size of current DataGenerator
3434
This is necessary only if a user wants to define generator_batch
35-
35+
3636
Example:
3737
3838
.. code-block:: python
39-
39+
4040
import paddle.distributed.fleet.data_generator as dg
4141
class MyData(dg.DataGenerator):
4242
@@ -52,7 +52,7 @@ def local_iter():
5252
yield ("words", s[1].extend([s[1][0]]))
5353
mydata = MyData()
5454
mydata.set_batch(128)
55-
55+
5656
'''
5757
self.batch_size_ = batch_size
5858

@@ -63,7 +63,7 @@ def run_from_memory(self):
6363
6464
Example:
6565
.. code-block:: python
66-
66+
6767
import paddle.distributed.fleet.data_generator as dg
6868
class MyData(dg.DataGenerator):
6969
@@ -100,9 +100,9 @@ def run_from_stdin(self):
100100
generated.
101101
102102
Example:
103-
103+
104104
.. code-block:: python
105-
105+
106106
import paddle.distributed.fleet.data_generator as dg
107107
class MyData(dg.DataGenerator):
108108
@@ -161,7 +161,7 @@ def generate_sample(self, line):
161161
The data format is list or tuple:
162162
[(name, [feasign, ...]), ...]
163163
or ((name, [feasign, ...]), ...)
164-
164+
165165
For example:
166166
[("words", [1926, 08, 17]), ("label", [1])]
167167
or (("words", [1926, 08, 17]), ("label", [1]))
@@ -174,7 +174,7 @@ def generate_sample(self, line):
174174
Example:
175175
176176
.. code-block:: python
177-
177+
178178
import paddle.distributed.fleet.data_generator as dg
179179
class MyData(dg.DataGenerator):
180180
@@ -206,7 +206,7 @@ def generate_batch(self, samples):
206206
Example:
207207
208208
.. code-block:: python
209-
209+
210210
import paddle.distributed.fleet.data_generator as dg
211211
class MyData(dg.DataGenerator):
212212
@@ -259,6 +259,9 @@ def _gen_str(self, line):
259259
Returns:
260260
Return a string data that can be read directly by the MultiSlotDataFeed.
261261
'''
262+
if sys.version > '3' and isinstance(line, zip):
263+
line = list(line)
264+
262265
if not isinstance(line, list) and not isinstance(line, tuple):
263266
raise ValueError(
264267
"the output of process() must be in list or tuple type"
@@ -289,7 +292,7 @@ def _gen_str(self, line):
289292
>>> [ids_num id1 id2 ...] ...
290293
The proto_info will be in this format:
291294
>>> [(name, type), ...]
292-
295+
293296
For example, if the input is like this:
294297
>>> [("words", [1926, 08, 17]), ("label", [1])]
295298
>>> or (("words", [1926, 08, 17]), ("label", [1]))
@@ -304,6 +307,9 @@ def _gen_str(self, line):
304307
Returns:
305308
Return a string data that can be read directly by the MultiSlotDataFeed.
306309
'''
310+
if sys.version > '3' and isinstance(line, zip):
311+
line = list(line)
312+
307313
if not isinstance(line, list) and not isinstance(line, tuple):
308314
raise ValueError(
309315
"the output of process() must be in list or tuple type"

python/paddle/fluid/tests/unittests/test_data_generator.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,32 @@ def data_iter():
9595
return data_iter
9696

9797

98+
class MyMultiSlotStringDataGenerator_zip(fleet.MultiSlotStringDataGenerator):
99+
def generate_sample(self, line):
100+
def data_iter():
101+
for i in range(40):
102+
if i == 1:
103+
yield None
104+
feature_name = ["words", "label"]
105+
data = [["1", "2", "3", "4"], ["0"]]
106+
yield zip(feature_name, data)
107+
108+
return data_iter
109+
110+
111+
class MyMultiSlotDataGenerator_zip(fleet.MultiSlotDataGenerator):
112+
def generate_sample(self, line):
113+
def data_iter():
114+
for i in range(40):
115+
if i == 1:
116+
yield None
117+
feature_name = ["words", "label"]
118+
data = [[1, 2, 3, 4], [0]]
119+
yield zip(feature_name, data)
120+
121+
return data_iter
122+
123+
98124
class TestMultiSlotDataGenerator(unittest.TestCase):
99125
def test_MultiSlotDataGenerator_basic(self):
100126
my_ms_dg = MyMultiSlotDataGenerator()
@@ -149,5 +175,19 @@ def test_MultiSlotDataGenerator_error(self):
149175
my_ms_dg.run_from_memory()
150176

151177

178+
class TestMultiSlotStringDataGeneratorZip(unittest.TestCase):
179+
def test_MultiSlotStringDataGenerator_zip(self):
180+
my_ms_dg = MyMultiSlotStringDataGenerator_zip()
181+
my_ms_dg.set_batch(1)
182+
my_ms_dg.run_from_memory()
183+
184+
185+
class TestMultiSlotDataGeneratorZip(unittest.TestCase):
186+
def test_MultiSlotDataGenerator_zip(self):
187+
my_ms_dg = MyMultiSlotDataGenerator_zip()
188+
my_ms_dg.set_batch(1)
189+
my_ms_dg.run_from_memory()
190+
191+
152192
if __name__ == '__main__':
153193
unittest.main()

0 commit comments

Comments
 (0)