diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 711215a3bbc..30a6ed774c6 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -426,20 +426,15 @@ def write_examples_on_file(self): """Write stored examples from the write-pool of examples. It makes a table out of the examples and write it.""" if not self.current_examples: return - - # order the columns properly + # preserve the order the columns if self.schema: schema_cols = set(self.schema.names) - common_cols, extra_cols = [], [] - for col in self.current_examples[0][0]: - if col in schema_cols: - common_cols.append(col) - else: - extra_cols.append(col) + examples_cols = self.current_examples[0][0].keys() # .keys() preserves the order (unlike set) + common_cols = [col for col in self.schema.names if col in examples_cols] + extra_cols = [col for col in examples_cols if col not in schema_cols] cols = common_cols + extra_cols else: cols = list(self.current_examples[0][0]) - batch_examples = {} for col in cols: # We use row[0][col] since current_examples contains (example, key) tuples. @@ -549,14 +544,12 @@ def write_batch( try_features = self._features if self.pa_writer is None and self.update_features else None arrays = [] inferred_features = Features() + # preserve the order the columns if self.schema: schema_cols = set(self.schema.names) - common_cols, extra_cols = [], [] - for col in batch_examples: - if col in schema_cols: - common_cols.append(col) - else: - extra_cols.append(col) + batch_cols = batch_examples.keys() # .keys() preserves the order (unlike set) + common_cols = [col for col in self.schema.names if col in batch_cols] + extra_cols = [col for col in batch_cols if col not in schema_cols] cols = common_cols + extra_cols else: cols = list(batch_examples) diff --git a/tests/test_arrow_writer.py b/tests/test_arrow_writer.py index b2483509b0c..a98b5a9c42f 100644 --- a/tests/test_arrow_writer.py +++ b/tests/test_arrow_writer.py @@ -87,7 +87,13 @@ def _check_output(output, expected_num_chunks: int): @pytest.mark.parametrize("writer_batch_size", [None, 1, 10]) @pytest.mark.parametrize( - "fields", [None, {"col_1": pa.string(), "col_2": pa.int64()}, {"col_1": pa.string(), "col_2": pa.int32()}] + "fields", + [ + None, + {"col_1": pa.string(), "col_2": pa.int64()}, + {"col_1": pa.string(), "col_2": pa.int32()}, + {"col_2": pa.int64(), "col_1": pa.string()}, + ], ) def test_write(fields, writer_batch_size): output = pa.BufferOutputStream()