From f1c70e8f16e2c3974ab150de5e2e90858313619b Mon Sep 17 00:00:00 2001 From: Bhavitvya Malik Date: Sat, 11 Sep 2021 02:36:17 +0530 Subject: [PATCH 1/7] add multi-proc in `to_csv` --- src/datasets/arrow_dataset.py | 7 +++- src/datasets/io/csv.py | 70 +++++++++++++++++++++++++---------- 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 57e3113695e..ac35193e7f0 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -2849,6 +2849,7 @@ def to_csv( self, path_or_buf: Union[PathLike, BinaryIO], batch_size: Optional[int] = None, + num_proc: Optional[int] = None, **to_csv_kwargs, ) -> int: """Exports the dataset to csv @@ -2857,6 +2858,10 @@ def to_csv( path_or_buf (``PathLike`` or ``FileOrBuffer``): Either a path to a file or a BinaryIO. batch_size (Optional ``int``): Size of the batch to load in memory and write at once. Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`. + num_proc (:obj:`int`, optional): Number of processes for multiprocessing. By default it doesn't + use multiprocessing. ``batch_size`` in this case defaults to + :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE` but feel free to make it 5x or 10x of the default + value if you have sufficient compute power. to_csv_kwargs: Parameters to pass to pandas's :func:`pandas.DataFrame.to_csv` Returns: @@ -2865,7 +2870,7 @@ def to_csv( # Dynamic import to avoid circular dependency from .io.csv import CsvDatasetWriter - return CsvDatasetWriter(self, path_or_buf, batch_size=batch_size, **to_csv_kwargs).write() + return CsvDatasetWriter(self, path_or_buf, batch_size=batch_size, num_proc=num_proc, **to_csv_kwargs).write() def to_dict(self, batch_size: Optional[int] = None, batched: bool = False) -> Union[dict, Iterator[dict]]: """Returns the dataset as a Python dict. Can also return a generator for large datasets. diff --git a/src/datasets/io/csv.py b/src/datasets/io/csv.py index a63e5bd61cf..97402e0538b 100644 --- a/src/datasets/io/csv.py +++ b/src/datasets/io/csv.py @@ -1,9 +1,11 @@ +import multiprocessing import os from typing import BinaryIO, Optional, Union -from .. import Dataset, Features, NamedSplit, config +from .. import Dataset, Features, NamedSplit, config, utils from ..formatting import query_table from ..packaged_modules.csv.csv import Csv +from ..utils import logging from ..utils.typing import NestedDataStructureLike, PathLike from .abc import AbstractDatasetReader @@ -58,41 +60,69 @@ def __init__( dataset: Dataset, path_or_buf: Union[PathLike, BinaryIO], batch_size: Optional[int] = None, + num_proc: Optional[int] = None, **to_csv_kwargs, ): + assert num_proc is None or num_proc > 0, "num_proc must be an integer > 0." self.dataset = dataset self.path_or_buf = path_or_buf - self.batch_size = batch_size + self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE + self.num_proc = num_proc + self.encoding = "utf-8" self.to_csv_kwargs = to_csv_kwargs def write(self) -> int: - batch_size = self.batch_size if self.batch_size else config.DEFAULT_MAX_BATCH_SIZE + _ = self.to_csv_kwargs.pop("path_or_buf", None) if isinstance(self.path_or_buf, (str, bytes, os.PathLike)): with open(self.path_or_buf, "wb+") as buffer: - written = self._write(file_obj=buffer, batch_size=batch_size, **self.to_csv_kwargs) + written = self._write(file_obj=buffer, **self.to_csv_kwargs) else: - written = self._write(file_obj=self.path_or_buf, batch_size=batch_size, **self.to_csv_kwargs) + written = self._write(file_obj=self.path_or_buf, **self.to_csv_kwargs) return written - def _write( - self, file_obj: BinaryIO, batch_size: int, header: bool = True, encoding: str = "utf-8", **to_csv_kwargs - ) -> int: + def _batch_csv(self, args): + offset, header, to_csv_kwargs = args + + batch = query_table( + table=self.dataset.data, + key=slice(offset, offset + self.batch_size), + indices=self.dataset._indices, + ) + csv_str = batch.to_pandas().to_csv( + path_or_buf=None, header=header if (offset == 0) else False, **to_csv_kwargs + ) + return csv_str.encode(self.encoding) + + def _write(self, file_obj: BinaryIO, header: bool = True, **to_csv_kwargs) -> int: """Writes the pyarrow table as CSV to a binary file handle. Caller is responsible for opening and closing the handle. """ written = 0 - _ = to_csv_kwargs.pop("path_or_buf", None) - - for offset in range(0, len(self.dataset), batch_size): - batch = query_table( - table=self.dataset._data, - key=slice(offset, offset + batch_size), - indices=self.dataset._indices if self.dataset._indices is not None else None, - ) - csv_str = batch.to_pandas().to_csv( - path_or_buf=None, header=header if (offset == 0) else False, encoding=encoding, **to_csv_kwargs - ) - written += file_obj.write(csv_str.encode(encoding)) + + if self.num_proc is None or self.num_proc == 1: + for offset in utils.tqdm( + range(0, len(self.dataset), self.batch_size), + unit="ba", + disable=bool(logging.get_verbosity() == logging.NOTSET), + desc="Creating CSV from Arrow format", + ): + csv_str = self._batch_csv((offset, header, to_csv_kwargs)) + written += file_obj.write(csv_str) + + else: + with multiprocessing.Pool(self.num_proc) as pool: + for csv_str in utils.tqdm( + pool.imap( + self._batch_csv, + [(offset, header, to_csv_kwargs) for offset in range(0, len(self.dataset), self.batch_size)], + ), + total=(len(self.dataset) // self.batch_size) + 1, + unit="ba", + disable=bool(logging.get_verbosity() == logging.NOTSET), + desc="Creating CSV from Arrow format", + ): + written += file_obj.write(csv_str) + return written From c2eef0463a5f01ec65de3835e4f31e2a90f4d677 Mon Sep 17 00:00:00 2001 From: Bhavitvya Malik Date: Thu, 14 Oct 2021 19:31:59 +0530 Subject: [PATCH 2/7] add tests for dataset to csv --- tests/io/test_csv.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/io/test_csv.py b/tests/io/test_csv.py index 7053ae09910..5dd190fb19c 100644 --- a/tests/io/test_csv.py +++ b/tests/io/test_csv.py @@ -1,7 +1,9 @@ +import csv +import io import pytest from datasets import Dataset, DatasetDict, Features, NamedSplit, Value -from datasets.io.csv import CsvDatasetReader +from datasets.io.csv import CsvDatasetReader, CsvDatasetWriter from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases @@ -121,3 +123,32 @@ def test_csv_datasetdict_reader_split(split, csv_path, tmp_path): dataset = CsvDatasetReader(path, cache_dir=cache_dir).read() _check_csv_datasetdict(dataset, expected_features, splits=list(path.keys())) assert all(dataset[split].split == split for split in path.keys()) + +def load_csv(csv_path): + csvfile = open(csv_path, 'r', encoding='utf-8') + csvreader = csv.reader(csvfile) + return csvreader + +def test_dataset_to_csv(csv_path, tmp_path): + cache_dir = tmp_path / "cache" + output_csv = cache_dir / "tmp.csv" + dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read() + CsvDatasetWriter(dataset['train'], output_csv, index=False, num_proc=1).write() + + original_csv = load_csv(csv_path) + expected_csv = load_csv(output_csv) + + for row1,row2 in zip(original_csv, expected_csv): + assert row1==row2 + +def test_dataset_to_csv_multiproc(csv_path, tmp_path): + cache_dir = tmp_path / "cache" + output_csv = cache_dir / "tmp.csv" + dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read() + CsvDatasetWriter(dataset['train'], output_csv, index=False, num_proc=2).write() + + original_csv = load_csv(csv_path) + expected_csv = load_csv(output_csv) + + for row1,row2 in zip(original_csv, expected_csv): + assert row1==row2 \ No newline at end of file From 254d73e50d4bad297291b5d91d71500cade2cecf Mon Sep 17 00:00:00 2001 From: Bhavitvya Malik Date: Thu, 14 Oct 2021 19:35:54 +0530 Subject: [PATCH 3/7] make style --- tests/io/test_csv.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/io/test_csv.py b/tests/io/test_csv.py index 5dd190fb19c..c494a085c5e 100644 --- a/tests/io/test_csv.py +++ b/tests/io/test_csv.py @@ -1,5 +1,6 @@ import csv import io + import pytest from datasets import Dataset, DatasetDict, Features, NamedSplit, Value @@ -124,31 +125,34 @@ def test_csv_datasetdict_reader_split(split, csv_path, tmp_path): _check_csv_datasetdict(dataset, expected_features, splits=list(path.keys())) assert all(dataset[split].split == split for split in path.keys()) + def load_csv(csv_path): - csvfile = open(csv_path, 'r', encoding='utf-8') + csvfile = open(csv_path, "r", encoding="utf-8") csvreader = csv.reader(csvfile) return csvreader + def test_dataset_to_csv(csv_path, tmp_path): cache_dir = tmp_path / "cache" output_csv = cache_dir / "tmp.csv" dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read() - CsvDatasetWriter(dataset['train'], output_csv, index=False, num_proc=1).write() + CsvDatasetWriter(dataset["train"], output_csv, index=False, num_proc=1).write() original_csv = load_csv(csv_path) expected_csv = load_csv(output_csv) - for row1,row2 in zip(original_csv, expected_csv): - assert row1==row2 + for row1, row2 in zip(original_csv, expected_csv): + assert row1 == row2 + def test_dataset_to_csv_multiproc(csv_path, tmp_path): cache_dir = tmp_path / "cache" output_csv = cache_dir / "tmp.csv" dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read() - CsvDatasetWriter(dataset['train'], output_csv, index=False, num_proc=2).write() + CsvDatasetWriter(dataset["train"], output_csv, index=False, num_proc=2).write() original_csv = load_csv(csv_path) expected_csv = load_csv(output_csv) - for row1,row2 in zip(original_csv, expected_csv): - assert row1==row2 \ No newline at end of file + for row1, row2 in zip(original_csv, expected_csv): + assert row1 == row2 From 09a04a8dfbad24e4c21f9736c98a3cdf9c034c50 Mon Sep 17 00:00:00 2001 From: Bhavitvya Malik Date: Thu, 14 Oct 2021 19:38:43 +0530 Subject: [PATCH 4/7] fix imports --- tests/io/test_csv.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/io/test_csv.py b/tests/io/test_csv.py index c494a085c5e..3442274fb23 100644 --- a/tests/io/test_csv.py +++ b/tests/io/test_csv.py @@ -1,5 +1,4 @@ import csv -import io import pytest From b9deb188ee827a6c9d377b2e345d73522c25f249 Mon Sep 17 00:00:00 2001 From: Bhavitvya Malik Date: Thu, 14 Oct 2021 20:06:47 +0530 Subject: [PATCH 5/7] fix path for windows --- tests/io/test_csv.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/io/test_csv.py b/tests/io/test_csv.py index 3442274fb23..1d0c55cae4b 100644 --- a/tests/io/test_csv.py +++ b/tests/io/test_csv.py @@ -1,4 +1,5 @@ import csv +import os import pytest @@ -133,7 +134,7 @@ def load_csv(csv_path): def test_dataset_to_csv(csv_path, tmp_path): cache_dir = tmp_path / "cache" - output_csv = cache_dir / "tmp.csv" + output_csv = os.path.join(cache_dir,"tmp.csv") dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read() CsvDatasetWriter(dataset["train"], output_csv, index=False, num_proc=1).write() @@ -146,7 +147,7 @@ def test_dataset_to_csv(csv_path, tmp_path): def test_dataset_to_csv_multiproc(csv_path, tmp_path): cache_dir = tmp_path / "cache" - output_csv = cache_dir / "tmp.csv" + output_csv = os.path.join(cache_dir,"tmp.csv") dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read() CsvDatasetWriter(dataset["train"], output_csv, index=False, num_proc=2).write() From 9fe735f452dc41fea9eafb6ef3fac29797b68d21 Mon Sep 17 00:00:00 2001 From: Bhavitvya Malik Date: Thu, 14 Oct 2021 20:07:34 +0530 Subject: [PATCH 6/7] make style --- tests/io/test_csv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/io/test_csv.py b/tests/io/test_csv.py index 1d0c55cae4b..a70b1012a00 100644 --- a/tests/io/test_csv.py +++ b/tests/io/test_csv.py @@ -134,7 +134,7 @@ def load_csv(csv_path): def test_dataset_to_csv(csv_path, tmp_path): cache_dir = tmp_path / "cache" - output_csv = os.path.join(cache_dir,"tmp.csv") + output_csv = os.path.join(cache_dir, "tmp.csv") dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read() CsvDatasetWriter(dataset["train"], output_csv, index=False, num_proc=1).write() @@ -147,7 +147,7 @@ def test_dataset_to_csv(csv_path, tmp_path): def test_dataset_to_csv_multiproc(csv_path, tmp_path): cache_dir = tmp_path / "cache" - output_csv = os.path.join(cache_dir,"tmp.csv") + output_csv = os.path.join(cache_dir, "tmp.csv") dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read() CsvDatasetWriter(dataset["train"], output_csv, index=False, num_proc=2).write() From 0927a0cc60326d6f7043f80a9218f9cee4fcbfb9 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Fri, 22 Oct 2021 19:20:59 +0200 Subject: [PATCH 7/7] Fix tests --- tests/conftest.py | 4 ++-- tests/io/test_csv.py | 15 +++++++-------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f929e00bb4e..de238deda2e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -224,7 +224,7 @@ def arrow_path(tmp_path_factory): @pytest.fixture(scope="session") def csv_path(tmp_path_factory): path = str(tmp_path_factory.mktemp("data") / "dataset.csv") - with open(path, "w") as f: + with open(path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=["col_1", "col_2", "col_3"]) writer.writeheader() for item in DATA: @@ -235,7 +235,7 @@ def csv_path(tmp_path_factory): @pytest.fixture(scope="session") def csv2_path(tmp_path_factory): path = str(tmp_path_factory.mktemp("data") / "dataset2.csv") - with open(path, "w") as f: + with open(path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=["col_1", "col_2", "col_3"]) writer.writeheader() for item in DATA: diff --git a/tests/io/test_csv.py b/tests/io/test_csv.py index a70b1012a00..69c1e1d5092 100644 --- a/tests/io/test_csv.py +++ b/tests/io/test_csv.py @@ -126,10 +126,9 @@ def test_csv_datasetdict_reader_split(split, csv_path, tmp_path): assert all(dataset[split].split == split for split in path.keys()) -def load_csv(csv_path): - csvfile = open(csv_path, "r", encoding="utf-8") - csvreader = csv.reader(csvfile) - return csvreader +def iter_csv_file(csv_path): + with open(csv_path, "r", encoding="utf-8") as csvfile: + yield from csv.reader(csvfile) def test_dataset_to_csv(csv_path, tmp_path): @@ -138,8 +137,8 @@ def test_dataset_to_csv(csv_path, tmp_path): dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read() CsvDatasetWriter(dataset["train"], output_csv, index=False, num_proc=1).write() - original_csv = load_csv(csv_path) - expected_csv = load_csv(output_csv) + original_csv = iter_csv_file(csv_path) + expected_csv = iter_csv_file(output_csv) for row1, row2 in zip(original_csv, expected_csv): assert row1 == row2 @@ -151,8 +150,8 @@ def test_dataset_to_csv_multiproc(csv_path, tmp_path): dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read() CsvDatasetWriter(dataset["train"], output_csv, index=False, num_proc=2).write() - original_csv = load_csv(csv_path) - expected_csv = load_csv(output_csv) + original_csv = iter_csv_file(csv_path) + expected_csv = iter_csv_file(output_csv) for row1, row2 in zip(original_csv, expected_csv): assert row1 == row2