Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,6 @@
# Size of the preloaded record batch in `Dataset.__iter__`
ARROW_READER_BATCH_SIZE_IN_DATASET_ITER = 10

# Pickling tables works only for small tables (<4GiB)
# For big tables, we write them on disk instead
MAX_TABLE_NBYTES_FOR_PICKLING = 4 << 30

# Max shard size in bytes (e.g. to shard parquet datasets in push_to_hub or download_and_prepare)
MAX_SHARD_SIZE = "500MB"

Expand Down Expand Up @@ -237,3 +233,6 @@

# Maximum number of uploaded files per commit
UPLOADS_MAX_NUMBER_PER_COMMIT = 50

# Backward compatibiliy
MAX_TABLE_NBYTES_FOR_PICKLING = 4 << 30
37 changes: 0 additions & 37 deletions src/datasets/table.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import os
import tempfile
import warnings
from functools import partial
from itertools import groupby
Expand Down Expand Up @@ -67,16 +66,6 @@ def _memory_mapped_arrow_table_from_file(filename: str) -> pa.Table:
return pa_table


def _write_table_to_file(table: pa.Table, filename: str) -> int:
with open(filename, "wb") as sink:
writer = pa.RecordBatchStreamWriter(sink=sink, schema=table.schema)
batches: List[pa.RecordBatch] = table.to_batches()
for batch in batches:
writer.write_batch(batch)
writer.close()
return sum(batch.nbytes for batch in batches)


def _deepcopy(x, memo: dict):
"""deepcopy a regular class instance"""
cls = x.__class__
Expand Down Expand Up @@ -187,32 +176,6 @@ def __deepcopy__(self, memo: dict):
memo[id(self._batches)] = list(self._batches)
return _deepcopy(self, memo)

def __getstate__(self):
# We can't pickle objects that are bigger than 4GiB, or it causes OverflowError
# So we write the table on disk instead
if self.table.nbytes >= config.MAX_TABLE_NBYTES_FOR_PICKLING:
table = self.table
with tempfile.NamedTemporaryFile("wb", delete=False, suffix=".arrow") as tmp_file:
filename = tmp_file.name
logger.debug(
f"Attempting to pickle a table bigger than 4GiB. Writing it on the disk instead at {filename}"
)
_write_table_to_file(table=table, filename=filename)
return {"path": filename}
else:
return {"table": self.table}

def __setstate__(self, state):
if "path" in state:
filename = state["path"]
logger.debug(f"Unpickling a big table from the disk at {filename}")
table = _in_memory_arrow_table_from_file(filename)
logger.debug(f"Removing temporary table file at {filename}")
os.remove(filename)
else:
table = state["table"]
Table.__init__(self, table)

def validate(self, *args, **kwargs):
"""
Perform validation checks. An exception is raised if validation fails.
Expand Down