Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3187,6 +3187,7 @@ def to_csv(
self,
path_or_buf: Union[PathLike, BinaryIO],
batch_size: Optional[int] = None,
num_proc: Optional[int] = None,
**to_csv_kwargs,
) -> int:
"""Exports the dataset to csv
Expand All @@ -3195,6 +3196,10 @@ def to_csv(
path_or_buf (``PathLike`` or ``FileOrBuffer``): Either a path to a file or a BinaryIO.
batch_size (Optional ``int``): Size of the batch to load in memory and write at once.
Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`.
num_proc (:obj:`int`, optional): Number of processes for multiprocessing. By default it doesn't
use multiprocessing. ``batch_size`` in this case defaults to
:obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE` but feel free to make it 5x or 10x of the default
value if you have sufficient compute power.
to_csv_kwargs: Parameters to pass to pandas's :func:`pandas.DataFrame.to_csv`

Returns:
Expand All @@ -3203,7 +3208,7 @@ def to_csv(
# Dynamic import to avoid circular dependency
from .io.csv import CsvDatasetWriter

return CsvDatasetWriter(self, path_or_buf, batch_size=batch_size, **to_csv_kwargs).write()
return CsvDatasetWriter(self, path_or_buf, batch_size=batch_size, num_proc=num_proc, **to_csv_kwargs).write()

def to_dict(self, batch_size: Optional[int] = None, batched: bool = False) -> Union[dict, Iterator[dict]]:
"""Returns the dataset as a Python dict. Can also return a generator for large datasets.
Expand Down
70 changes: 50 additions & 20 deletions src/datasets/io/csv.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import multiprocessing
import os
from typing import BinaryIO, Optional, Union

from .. import Dataset, Features, NamedSplit, config
from .. import Dataset, Features, NamedSplit, config, utils
from ..formatting import query_table
from ..packaged_modules.csv.csv import Csv
from ..utils import logging
from ..utils.typing import NestedDataStructureLike, PathLike
from .abc import AbstractDatasetReader

Expand Down Expand Up @@ -58,41 +60,69 @@ def __init__(
dataset: Dataset,
path_or_buf: Union[PathLike, BinaryIO],
batch_size: Optional[int] = None,
num_proc: Optional[int] = None,
**to_csv_kwargs,
):
assert num_proc is None or num_proc > 0, "num_proc must be an integer > 0."
self.dataset = dataset
self.path_or_buf = path_or_buf
self.batch_size = batch_size
self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE
self.num_proc = num_proc
self.encoding = "utf-8"
self.to_csv_kwargs = to_csv_kwargs

def write(self) -> int:
batch_size = self.batch_size if self.batch_size else config.DEFAULT_MAX_BATCH_SIZE
_ = self.to_csv_kwargs.pop("path_or_buf", None)

if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
with open(self.path_or_buf, "wb+") as buffer:
written = self._write(file_obj=buffer, batch_size=batch_size, **self.to_csv_kwargs)
written = self._write(file_obj=buffer, **self.to_csv_kwargs)
else:
written = self._write(file_obj=self.path_or_buf, batch_size=batch_size, **self.to_csv_kwargs)
written = self._write(file_obj=self.path_or_buf, **self.to_csv_kwargs)
return written

def _write(
self, file_obj: BinaryIO, batch_size: int, header: bool = True, encoding: str = "utf-8", **to_csv_kwargs
) -> int:
def _batch_csv(self, args):
offset, header, to_csv_kwargs = args

batch = query_table(
table=self.dataset.data,
key=slice(offset, offset + self.batch_size),
indices=self.dataset._indices,
)
csv_str = batch.to_pandas().to_csv(
path_or_buf=None, header=header if (offset == 0) else False, **to_csv_kwargs
)
return csv_str.encode(self.encoding)

def _write(self, file_obj: BinaryIO, header: bool = True, **to_csv_kwargs) -> int:
"""Writes the pyarrow table as CSV to a binary file handle.

Caller is responsible for opening and closing the handle.
"""
written = 0
_ = to_csv_kwargs.pop("path_or_buf", None)

for offset in range(0, len(self.dataset), batch_size):
batch = query_table(
table=self.dataset._data,
key=slice(offset, offset + batch_size),
indices=self.dataset._indices if self.dataset._indices is not None else None,
)
csv_str = batch.to_pandas().to_csv(
path_or_buf=None, header=header if (offset == 0) else False, encoding=encoding, **to_csv_kwargs
)
written += file_obj.write(csv_str.encode(encoding))

if self.num_proc is None or self.num_proc == 1:
for offset in utils.tqdm(
range(0, len(self.dataset), self.batch_size),
unit="ba",
disable=bool(logging.get_verbosity() == logging.NOTSET),
desc="Creating CSV from Arrow format",
):
csv_str = self._batch_csv((offset, header, to_csv_kwargs))
written += file_obj.write(csv_str)

else:
with multiprocessing.Pool(self.num_proc) as pool:
for csv_str in utils.tqdm(
pool.imap(
self._batch_csv,
[(offset, header, to_csv_kwargs) for offset in range(0, len(self.dataset), self.batch_size)],
),
total=(len(self.dataset) // self.batch_size) + 1,
unit="ba",
disable=bool(logging.get_verbosity() == logging.NOTSET),
desc="Creating CSV from Arrow format",
):
written += file_obj.write(csv_str)

return written
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def arrow_path(tmp_path_factory):
@pytest.fixture(scope="session")
def csv_path(tmp_path_factory):
path = str(tmp_path_factory.mktemp("data") / "dataset.csv")
with open(path, "w") as f:
with open(path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=["col_1", "col_2", "col_3"])
writer.writeheader()
for item in DATA:
Expand All @@ -235,7 +235,7 @@ def csv_path(tmp_path_factory):
@pytest.fixture(scope="session")
def csv2_path(tmp_path_factory):
path = str(tmp_path_factory.mktemp("data") / "dataset2.csv")
with open(path, "w") as f:
with open(path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=["col_1", "col_2", "col_3"])
writer.writeheader()
for item in DATA:
Expand Down
36 changes: 35 additions & 1 deletion tests/io/test_csv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import csv
import os

import pytest

from datasets import Dataset, DatasetDict, Features, NamedSplit, Value
from datasets.io.csv import CsvDatasetReader
from datasets.io.csv import CsvDatasetReader, CsvDatasetWriter

from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases

Expand Down Expand Up @@ -121,3 +124,34 @@ def test_csv_datasetdict_reader_split(split, csv_path, tmp_path):
dataset = CsvDatasetReader(path, cache_dir=cache_dir).read()
_check_csv_datasetdict(dataset, expected_features, splits=list(path.keys()))
assert all(dataset[split].split == split for split in path.keys())


def iter_csv_file(csv_path):
with open(csv_path, "r", encoding="utf-8") as csvfile:
yield from csv.reader(csvfile)


def test_dataset_to_csv(csv_path, tmp_path):
cache_dir = tmp_path / "cache"
output_csv = os.path.join(cache_dir, "tmp.csv")
dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read()
CsvDatasetWriter(dataset["train"], output_csv, index=False, num_proc=1).write()

original_csv = iter_csv_file(csv_path)
expected_csv = iter_csv_file(output_csv)

for row1, row2 in zip(original_csv, expected_csv):
assert row1 == row2


def test_dataset_to_csv_multiproc(csv_path, tmp_path):
cache_dir = tmp_path / "cache"
output_csv = os.path.join(cache_dir, "tmp.csv")
dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read()
CsvDatasetWriter(dataset["train"], output_csv, index=False, num_proc=2).write()

original_csv = iter_csv_file(csv_path)
expected_csv = iter_csv_file(output_csv)

for row1, row2 in zip(original_csv, expected_csv):
assert row1 == row2