From c98dc7db53493785f0ee63d0fe16cb08bc27e4d4 Mon Sep 17 00:00:00 2001 From: Santiago Castro Date: Wed, 14 Feb 2024 20:46:16 -0500 Subject: [PATCH 1/4] Undo the changes in `arrow_writer.py` from #6636 See #6663. --- src/datasets/arrow_writer.py | 35 ++++++++++++----------------------- 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 711215a3bbc..b550d45d55f 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -428,18 +428,12 @@ def write_examples_on_file(self): return # order the columns properly - if self.schema: - schema_cols = set(self.schema.names) - common_cols, extra_cols = [], [] - for col in self.current_examples[0][0]: - if col in schema_cols: - common_cols.append(col) - else: - extra_cols.append(col) - cols = common_cols + extra_cols - else: - cols = list(self.current_examples[0][0]) - + cols = ( + [col for col in self.schema.names if col in self.current_examples[0][0]] + + [col for col in self.current_examples[0][0].keys() if col not in self.schema.names] + if self.schema + else self.current_examples[0][0].keys() + ) batch_examples = {} for col in cols: # We use row[0][col] since current_examples contains (example, key) tuples. @@ -549,17 +543,12 @@ def write_batch( try_features = self._features if self.pa_writer is None and self.update_features else None arrays = [] inferred_features = Features() - if self.schema: - schema_cols = set(self.schema.names) - common_cols, extra_cols = [], [] - for col in batch_examples: - if col in schema_cols: - common_cols.append(col) - else: - extra_cols.append(col) - cols = common_cols + extra_cols - else: - cols = list(batch_examples) + cols = ( + [col for col in self.schema.names if col in batch_examples] + + [col for col in batch_examples.keys() if col not in self.schema.names] + if self.schema + else batch_examples.keys() + ) for col in cols: col_values = batch_examples[col] col_type = features[col] if features else None From 4612879925ebb3214055fb3e0ddc0867f10d00c9 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Fri, 16 Feb 2024 02:14:49 +0100 Subject: [PATCH 2/4] Add test --- tests/test_arrow_writer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_arrow_writer.py b/tests/test_arrow_writer.py index b2483509b0c..a98b5a9c42f 100644 --- a/tests/test_arrow_writer.py +++ b/tests/test_arrow_writer.py @@ -87,7 +87,13 @@ def _check_output(output, expected_num_chunks: int): @pytest.mark.parametrize("writer_batch_size", [None, 1, 10]) @pytest.mark.parametrize( - "fields", [None, {"col_1": pa.string(), "col_2": pa.int64()}, {"col_1": pa.string(), "col_2": pa.int32()}] + "fields", + [ + None, + {"col_1": pa.string(), "col_2": pa.int64()}, + {"col_1": pa.string(), "col_2": pa.int32()}, + {"col_2": pa.int64(), "col_1": pa.string()}, + ], ) def test_write(fields, writer_batch_size): output = pa.BufferOutputStream() From 0160c134902a08de88cb9438d9428b4dc78e4dbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20=C5=A0a=C5=A1ko?= Date: Fri, 16 Feb 2024 02:15:17 +0100 Subject: [PATCH 3/4] Apply suggestions from code review --- src/datasets/arrow_writer.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index b550d45d55f..ec1693c127f 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -428,12 +428,14 @@ def write_examples_on_file(self): return # order the columns properly - cols = ( - [col for col in self.schema.names if col in self.current_examples[0][0]] - + [col for col in self.current_examples[0][0].keys() if col not in self.schema.names] - if self.schema - else self.current_examples[0][0].keys() - ) + if self.schema: + schema_cols = set(self.schema.names) + examples_cols = self.current_examples[0][0].keys() # .keys() preserves the order (unlike set) + common_cols = [col for col in self.schema.names if col in examples_cols] + extra_cols = [col for col in examples_cols if col not in schema_cols] + cols = common_cols + extra_cols + else: + cols = list(self.current_examples[0][0]) batch_examples = {} for col in cols: # We use row[0][col] since current_examples contains (example, key) tuples. @@ -543,12 +545,14 @@ def write_batch( try_features = self._features if self.pa_writer is None and self.update_features else None arrays = [] inferred_features = Features() - cols = ( - [col for col in self.schema.names if col in batch_examples] - + [col for col in batch_examples.keys() if col not in self.schema.names] - if self.schema - else batch_examples.keys() - ) + if self.schema: + schema_cols = set(self.schema.names) + batch_cols = batch_examples.keys() # .keys() preserves the order (unlike set) + common_cols = [col for col in self.schema.names if col in batch_cols] + extra_cols = [col for col in batch_cols if col not in schema_cols] + cols = common_cols + extra_cols + else: + cols = list(batch_examples) for col in cols: col_values = batch_examples[col] col_type = features[col] if features else None From 23a78809d66c6be600d5a3b66ec1d1a7367c2dab Mon Sep 17 00:00:00 2001 From: mariosasko Date: Fri, 16 Feb 2024 02:17:34 +0100 Subject: [PATCH 4/4] Nits --- src/datasets/arrow_writer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index ec1693c127f..30a6ed774c6 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -426,8 +426,7 @@ def write_examples_on_file(self): """Write stored examples from the write-pool of examples. It makes a table out of the examples and write it.""" if not self.current_examples: return - - # order the columns properly + # preserve the order the columns if self.schema: schema_cols = set(self.schema.names) examples_cols = self.current_examples[0][0].keys() # .keys() preserves the order (unlike set) @@ -545,7 +544,8 @@ def write_batch( try_features = self._features if self.pa_writer is None and self.update_features else None arrays = [] inferred_features = Features() - if self.schema: + # preserve the order the columns + if self.schema: schema_cols = set(self.schema.names) batch_cols = batch_examples.keys() # .keys() preserves the order (unlike set) common_cols = [col for col in self.schema.names if col in batch_cols]