Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2849,6 +2849,7 @@ def to_csv(
self,
path_or_buf: Union[PathLike, BinaryIO],
batch_size: Optional[int] = None,
num_proc: Optional[int] = None,
**to_csv_kwargs,
) -> int:
"""Exports the dataset to csv
Expand All @@ -2857,6 +2858,10 @@ def to_csv(
path_or_buf (``PathLike`` or ``FileOrBuffer``): Either a path to a file or a BinaryIO.
batch_size (Optional ``int``): Size of the batch to load in memory and write at once.
Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`.
num_proc (:obj:`int`, optional): Number of processes for multiprocessing. By default it doesn't
use multiprocessing. ``batch_size`` in this case defaults to
:obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE` but feel free to make it 5x or 10x of the default
value if you have sufficient compute power.
to_csv_kwargs: Parameters to pass to pandas's :func:`pandas.DataFrame.to_csv`

Returns:
Expand All @@ -2865,7 +2870,7 @@ def to_csv(
# Dynamic import to avoid circular dependency
from .io.csv import CsvDatasetWriter

return CsvDatasetWriter(self, path_or_buf, batch_size=batch_size, **to_csv_kwargs).write()
return CsvDatasetWriter(self, path_or_buf, batch_size=batch_size, num_proc=num_proc, **to_csv_kwargs).write()

def to_dict(self, batch_size: Optional[int] = None, batched: bool = False) -> Union[dict, Iterator[dict]]:
"""Returns the dataset as a Python dict. Can also return a generator for large datasets.
Expand Down
70 changes: 50 additions & 20 deletions src/datasets/io/csv.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import multiprocessing
import os
from typing import BinaryIO, Optional, Union

from .. import Dataset, Features, NamedSplit, config
from .. import Dataset, Features, NamedSplit, config, utils
from ..formatting import query_table
from ..packaged_modules.csv.csv import Csv
from ..utils import logging
from ..utils.typing import NestedDataStructureLike, PathLike
from .abc import AbstractDatasetReader

Expand Down Expand Up @@ -58,41 +60,69 @@ def __init__(
dataset: Dataset,
path_or_buf: Union[PathLike, BinaryIO],
batch_size: Optional[int] = None,
num_proc: Optional[int] = None,
**to_csv_kwargs,
):
assert num_proc is None or num_proc > 0, "num_proc must be an integer > 0."
self.dataset = dataset
self.path_or_buf = path_or_buf
self.batch_size = batch_size
self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE
self.num_proc = num_proc
self.encoding = "utf-8"
self.to_csv_kwargs = to_csv_kwargs

def write(self) -> int:
batch_size = self.batch_size if self.batch_size else config.DEFAULT_MAX_BATCH_SIZE
_ = self.to_csv_kwargs.pop("path_or_buf", None)

if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
with open(self.path_or_buf, "wb+") as buffer:
written = self._write(file_obj=buffer, batch_size=batch_size, **self.to_csv_kwargs)
written = self._write(file_obj=buffer, **self.to_csv_kwargs)
else:
written = self._write(file_obj=self.path_or_buf, batch_size=batch_size, **self.to_csv_kwargs)
written = self._write(file_obj=self.path_or_buf, **self.to_csv_kwargs)
return written

def _write(
self, file_obj: BinaryIO, batch_size: int, header: bool = True, encoding: str = "utf-8", **to_csv_kwargs
) -> int:
def _batch_csv(self, args):
offset, header, to_csv_kwargs = args

batch = query_table(
table=self.dataset.data,
key=slice(offset, offset + self.batch_size),
indices=self.dataset._indices,
)
csv_str = batch.to_pandas().to_csv(
path_or_buf=None, header=header if (offset == 0) else False, **to_csv_kwargs
)
return csv_str.encode(self.encoding)

def _write(self, file_obj: BinaryIO, header: bool = True, **to_csv_kwargs) -> int:
"""Writes the pyarrow table as CSV to a binary file handle.

Caller is responsible for opening and closing the handle.
"""
written = 0
_ = to_csv_kwargs.pop("path_or_buf", None)

for offset in range(0, len(self.dataset), batch_size):
batch = query_table(
table=self.dataset._data,
key=slice(offset, offset + batch_size),
indices=self.dataset._indices if self.dataset._indices is not None else None,
)
csv_str = batch.to_pandas().to_csv(
path_or_buf=None, header=header if (offset == 0) else False, encoding=encoding, **to_csv_kwargs
)
written += file_obj.write(csv_str.encode(encoding))

if self.num_proc is None or self.num_proc == 1:
for offset in utils.tqdm(
range(0, len(self.dataset), self.batch_size),
unit="ba",
disable=bool(logging.get_verbosity() == logging.NOTSET),
desc="Creating CSV from Arrow format",
):
csv_str = self._batch_csv((offset, header, to_csv_kwargs))
written += file_obj.write(csv_str)

else:
with multiprocessing.Pool(self.num_proc) as pool:
for csv_str in utils.tqdm(
pool.imap(
self._batch_csv,
[(offset, header, to_csv_kwargs) for offset in range(0, len(self.dataset), self.batch_size)],
),
total=(len(self.dataset) // self.batch_size) + 1,
unit="ba",
disable=bool(logging.get_verbosity() == logging.NOTSET),
desc="Creating CSV from Arrow format",
):
written += file_obj.write(csv_str)

return written