From 2e3f9bc2046e9787eb815bf762b527e8c6cc905e Mon Sep 17 00:00:00 2001 From: MrChengmo Date: Fri, 5 Mar 2021 12:06:04 +0800 Subject: [PATCH 1/2] fix zip py3 bug --- .../fleet/data_generator/data_generator.py | 26 ++++++++++++------- .../tests/unittests/test_data_generator.py | 20 ++++++++++++++ 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/fleet/data_generator/data_generator.py b/python/paddle/distributed/fleet/data_generator/data_generator.py index 669d2ea24a0c78..9d743fc38bf398 100644 --- a/python/paddle/distributed/fleet/data_generator/data_generator.py +++ b/python/paddle/distributed/fleet/data_generator/data_generator.py @@ -32,11 +32,11 @@ def set_batch(self, batch_size): ''' Set batch size of current DataGenerator This is necessary only if a user wants to define generator_batch - + Example: .. code-block:: python - + import paddle.distributed.fleet.data_generator as dg class MyData(dg.DataGenerator): @@ -52,7 +52,7 @@ def local_iter(): yield ("words", s[1].extend([s[1][0]])) mydata = MyData() mydata.set_batch(128) - + ''' self.batch_size_ = batch_size @@ -63,7 +63,7 @@ def run_from_memory(self): Example: .. code-block:: python - + import paddle.distributed.fleet.data_generator as dg class MyData(dg.DataGenerator): @@ -100,9 +100,9 @@ def run_from_stdin(self): generated. Example: - + .. code-block:: python - + import paddle.distributed.fleet.data_generator as dg class MyData(dg.DataGenerator): @@ -161,7 +161,7 @@ def generate_sample(self, line): The data format is list or tuple: [(name, [feasign, ...]), ...] or ((name, [feasign, ...]), ...) - + For example: [("words", [1926, 08, 17]), ("label", [1])] or (("words", [1926, 08, 17]), ("label", [1])) @@ -174,7 +174,7 @@ def generate_sample(self, line): Example: .. code-block:: python - + import paddle.distributed.fleet.data_generator as dg class MyData(dg.DataGenerator): @@ -206,7 +206,7 @@ def generate_batch(self, samples): Example: .. code-block:: python - + import paddle.distributed.fleet.data_generator as dg class MyData(dg.DataGenerator): @@ -259,6 +259,9 @@ def _gen_str(self, line): Returns: Return a string data that can be read directly by the MultiSlotDataFeed. ''' + if sys.version > '3' and isinstance(line, zip): + line = list(line) + if not isinstance(line, list) and not isinstance(line, tuple): raise ValueError( "the output of process() must be in list or tuple type" @@ -289,7 +292,7 @@ def _gen_str(self, line): >>> [ids_num id1 id2 ...] ... The proto_info will be in this format: >>> [(name, type), ...] - + For example, if the input is like this: >>> [("words", [1926, 08, 17]), ("label", [1])] >>> or (("words", [1926, 08, 17]), ("label", [1])) @@ -304,6 +307,9 @@ def _gen_str(self, line): Returns: Return a string data that can be read directly by the MultiSlotDataFeed. ''' + if sys.version > '3' and isinstance(line, zip): + line = list(line) + if not isinstance(line, list) and not isinstance(line, tuple): raise ValueError( "the output of process() must be in list or tuple type" diff --git a/python/paddle/fluid/tests/unittests/test_data_generator.py b/python/paddle/fluid/tests/unittests/test_data_generator.py index 6381cb36402636..7cf7439ddc86bf 100644 --- a/python/paddle/fluid/tests/unittests/test_data_generator.py +++ b/python/paddle/fluid/tests/unittests/test_data_generator.py @@ -95,6 +95,19 @@ def data_iter(): return data_iter +class MyMultiSlotDataGenerator_zip(fleet.MultiSlotDataGenerator): + def generate_sample(self, line): + def data_iter(): + for i in range(40): + if i == 1: + yield None + feature_name = ["words", "label"] + data = [[1, 2, 3, 4], [0]] + yield zip(feature_name, data) + + return data_iter + + class TestMultiSlotDataGenerator(unittest.TestCase): def test_MultiSlotDataGenerator_basic(self): my_ms_dg = MyMultiSlotDataGenerator() @@ -149,5 +162,12 @@ def test_MultiSlotDataGenerator_error(self): my_ms_dg.run_from_memory() +class TestMultiSlotDataGeneratorZip(unittest.TestCase): + def test_MultiSlotDataGenerator_zip(self): + my_ms_dg = MyMultiSlotDataGenerator_zip() + my_ms_dg.set_batch(1) + my_ms_dg.run_from_memory() + + if __name__ == '__main__': unittest.main() From 03c2458a93b84f46b74faed1864c072a2e1da411 Mon Sep 17 00:00:00 2001 From: MrChengmo Date: Thu, 18 Mar 2021 19:03:08 +0800 Subject: [PATCH 2/2] add unittest --- .../tests/unittests/test_data_generator.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_data_generator.py b/python/paddle/fluid/tests/unittests/test_data_generator.py index 7cf7439ddc86bf..69d8e01fd464af 100644 --- a/python/paddle/fluid/tests/unittests/test_data_generator.py +++ b/python/paddle/fluid/tests/unittests/test_data_generator.py @@ -95,6 +95,19 @@ def data_iter(): return data_iter +class MyMultiSlotStringDataGenerator_zip(fleet.MultiSlotStringDataGenerator): + def generate_sample(self, line): + def data_iter(): + for i in range(40): + if i == 1: + yield None + feature_name = ["words", "label"] + data = [["1", "2", "3", "4"], ["0"]] + yield zip(feature_name, data) + + return data_iter + + class MyMultiSlotDataGenerator_zip(fleet.MultiSlotDataGenerator): def generate_sample(self, line): def data_iter(): @@ -162,6 +175,13 @@ def test_MultiSlotDataGenerator_error(self): my_ms_dg.run_from_memory() +class TestMultiSlotStringDataGeneratorZip(unittest.TestCase): + def test_MultiSlotStringDataGenerator_zip(self): + my_ms_dg = MyMultiSlotStringDataGenerator_zip() + my_ms_dg.set_batch(1) + my_ms_dg.run_from_memory() + + class TestMultiSlotDataGeneratorZip(unittest.TestCase): def test_MultiSlotDataGenerator_zip(self): my_ms_dg = MyMultiSlotDataGenerator_zip()