Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 8 additions & 15 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,20 +426,15 @@ def write_examples_on_file(self):
"""Write stored examples from the write-pool of examples. It makes a table out of the examples and write it."""
if not self.current_examples:
return

# order the columns properly
# preserve the order the columns
if self.schema:
schema_cols = set(self.schema.names)
common_cols, extra_cols = [], []
for col in self.current_examples[0][0]:
if col in schema_cols:
common_cols.append(col)
else:
extra_cols.append(col)
examples_cols = self.current_examples[0][0].keys() # .keys() preserves the order (unlike set)
common_cols = [col for col in self.schema.names if col in examples_cols]
extra_cols = [col for col in examples_cols if col not in schema_cols]
cols = common_cols + extra_cols
else:
cols = list(self.current_examples[0][0])

batch_examples = {}
for col in cols:
# We use row[0][col] since current_examples contains (example, key) tuples.
Expand Down Expand Up @@ -549,14 +544,12 @@ def write_batch(
try_features = self._features if self.pa_writer is None and self.update_features else None
arrays = []
inferred_features = Features()
# preserve the order the columns
if self.schema:
schema_cols = set(self.schema.names)
common_cols, extra_cols = [], []
for col in batch_examples:
if col in schema_cols:
common_cols.append(col)
else:
extra_cols.append(col)
batch_cols = batch_examples.keys() # .keys() preserves the order (unlike set)
common_cols = [col for col in self.schema.names if col in batch_cols]
extra_cols = [col for col in batch_cols if col not in schema_cols]
cols = common_cols + extra_cols
else:
cols = list(batch_examples)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we should really avoid this extra copy, especially if the inner iterable is large.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This negligible optimization (write_examples_on_file and write_batch are called every writer_batch_size) goes against good code practices.

We wouldn't use Python for this project if we wanted to optimize every aspect of the API.

Copy link
Contributor Author

@bryant1410 bryant1410 Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's negligible. #6636's OP stated:

I work with bioinformatics data and often these tables have thousands and even tens of thousands of features.

We'd create a list of tens of thousands of strings for every batch, for every processing step (e.g., a map).

And it's easy to remove (just cols = batch_samples, instead of copying it into a list).

Among other things, this library is about large data processing efficiency, so I think it'd be nice to consider it.

Expand Down
8 changes: 7 additions & 1 deletion tests/test_arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ def _check_output(output, expected_num_chunks: int):

@pytest.mark.parametrize("writer_batch_size", [None, 1, 10])
@pytest.mark.parametrize(
"fields", [None, {"col_1": pa.string(), "col_2": pa.int64()}, {"col_1": pa.string(), "col_2": pa.int32()}]
"fields",
[
None,
{"col_1": pa.string(), "col_2": pa.int64()},
{"col_1": pa.string(), "col_2": pa.int32()},
{"col_2": pa.int64(), "col_1": pa.string()},
],
)
def test_write(fields, writer_batch_size):
output = pa.BufferOutputStream()
Expand Down