diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 6b31e9be9ac..fe673952e4a 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -2856,6 +2856,7 @@ def to_json( self, path_or_buf: Union[PathLike, BinaryIO], batch_size: Optional[int] = None, + num_proc: Optional[int] = None, **to_json_kwargs, ) -> int: """Export the dataset to JSON Lines or JSON. @@ -2864,6 +2865,10 @@ def to_json( path_or_buf (``PathLike`` or ``FileOrBuffer``): Either a path to a file or a BinaryIO. batch_size (:obj:`int`, optional): Size of the batch to load in memory and write at once. Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`. + num_proc (:obj:`int`, optional): Number of processes for multiprocessing. By default it doesn't + use multiprocessing. ``batch_size`` in this case defaults to + :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE` but feel free to make it 5x or 10x of the default + value if you have sufficient compute power. lines (:obj:`bool`, default ``True``): Whether output JSON lines format. Only possible if ``orient="records"`. It will throw ValueError with ``orient`` different from ``"records"``, since the others are not list-like. @@ -2884,7 +2889,7 @@ def to_json( # Dynamic import to avoid circular dependency from .io.json import JsonDatasetWriter - return JsonDatasetWriter(self, path_or_buf, batch_size=batch_size, **to_json_kwargs).write() + return JsonDatasetWriter(self, path_or_buf, batch_size=batch_size, num_proc=num_proc, **to_json_kwargs).write() def to_pandas( self, batch_size: Optional[int] = None, batched: bool = False diff --git a/src/datasets/io/json.py b/src/datasets/io/json.py index 377ff1d9c44..7cb5ea0f75e 100644 --- a/src/datasets/io/json.py +++ b/src/datasets/io/json.py @@ -1,3 +1,4 @@ +import multiprocessing import os from typing import BinaryIO, Optional, Union @@ -63,30 +64,47 @@ def __init__( dataset: Dataset, path_or_buf: Union[PathLike, BinaryIO], batch_size: Optional[int] = None, + num_proc: Optional[int] = None, **to_json_kwargs, ): + assert num_proc is None or num_proc > 0, "num_proc must be an integer > 0." self.dataset = dataset self.path_or_buf = path_or_buf - self.batch_size = batch_size + self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE + self.num_proc = num_proc + self.encoding = "utf-8" self.to_json_kwargs = to_json_kwargs def write(self) -> int: - batch_size = self.batch_size if self.batch_size else config.DEFAULT_MAX_BATCH_SIZE + _ = self.to_json_kwargs.pop("path_or_buf", None) + orient = self.to_json_kwargs.pop("orient", "records") + lines = self.to_json_kwargs.pop("lines", True) if isinstance(self.path_or_buf, (str, bytes, os.PathLike)): with open(self.path_or_buf, "wb+") as buffer: - written = self._write(file_obj=buffer, batch_size=batch_size, **self.to_json_kwargs) + written = self._write(file_obj=buffer, orient=orient, lines=lines, **self.to_json_kwargs) else: - written = self._write(file_obj=self.path_or_buf, batch_size=batch_size, **self.to_json_kwargs) + written = self._write(file_obj=self.path_or_buf, orient=orient, lines=lines, **self.to_json_kwargs) return written + def _batch_json(self, args): + offset, orient, lines, to_json_kwargs = args + + batch = query_table( + table=self.dataset.data, + key=slice(offset, offset + self.batch_size), + indices=self.dataset._indices, + ) + json_str = batch.to_pandas().to_json(path_or_buf=None, orient=orient, lines=lines, **to_json_kwargs) + if not json_str.endswith("\n"): + json_str += "\n" + return json_str.encode(self.encoding) + def _write( self, file_obj: BinaryIO, - batch_size: int, - encoding: str = "utf-8", - orient="records", - lines=True, + orient, + lines, **to_json_kwargs, ) -> int: """Writes the pyarrow table as JSON lines to a binary file handle. @@ -94,18 +112,31 @@ def _write( Caller is responsible for opening and closing the handle. """ written = 0 - _ = to_json_kwargs.pop("path_or_buf", None) - - for offset in utils.tqdm( - range(0, len(self.dataset), batch_size), unit="ba", disable=bool(logging.get_verbosity() == logging.NOTSET) - ): - batch = query_table( - table=self.dataset.data, - key=slice(offset, offset + batch_size), - indices=self.dataset._indices if self.dataset._indices is not None else None, - ) - json_str = batch.to_pandas().to_json(path_or_buf=None, orient=orient, lines=lines, **to_json_kwargs) - if not json_str.endswith("\n"): - json_str += "\n" - written += file_obj.write(json_str.encode(encoding)) + + if self.num_proc is None or self.num_proc == 1: + for offset in utils.tqdm( + range(0, len(self.dataset), self.batch_size), + unit="ba", + disable=bool(logging.get_verbosity() == logging.NOTSET), + desc="Creating json from Arrow format", + ): + json_str = self._batch_json((offset, orient, lines, to_json_kwargs)) + written += file_obj.write(json_str) + else: + with multiprocessing.Pool(self.num_proc) as pool: + for json_str in utils.tqdm( + pool.imap( + self._batch_json, + [ + (offset, orient, lines, to_json_kwargs) + for offset in range(0, len(self.dataset), self.batch_size) + ], + ), + total=(len(self.dataset) // self.batch_size) + 1, + unit="ba", + disable=bool(logging.get_verbosity() == logging.NOTSET), + desc="Creating json from Arrow format", + ): + written += file_obj.write(json_str) + return written diff --git a/tests/io/test_json.py b/tests/io/test_json.py index 9c7473b0abd..51b32898693 100644 --- a/tests/io/test_json.py +++ b/tests/io/test_json.py @@ -211,3 +211,42 @@ def test_dataset_to_json_orient(self, orient, container, keys, len_at, dataset): assert len(exported_content[len_at]) == 10 else: assert len(exported_content) == 10 + + @pytest.mark.parametrize("lines, load_json_function", [(True, load_json_lines), (False, load_json)]) + def test_dataset_to_json_lines_multiproc(self, lines, load_json_function, dataset): + with io.BytesIO() as buffer: + JsonDatasetWriter(dataset, buffer, lines=lines, num_proc=2).write() + buffer.seek(0) + exported_content = load_json_function(buffer) + assert isinstance(exported_content, list) + assert isinstance(exported_content[0], dict) + assert len(exported_content) == 10 + + @pytest.mark.parametrize( + "orient, container, keys, len_at", + [ + ("records", list, {"tokens", "labels", "answers", "id"}, None), + ("split", dict, {"index", "columns", "data"}, "data"), + ("index", dict, set("0123456789"), None), + ("columns", dict, {"tokens", "labels", "answers", "id"}, "tokens"), + ("values", list, None, None), + ("table", dict, {"schema", "data"}, "data"), + ], + ) + def test_dataset_to_json_orient_multiproc(self, orient, container, keys, len_at, dataset): + with io.BytesIO() as buffer: + JsonDatasetWriter(dataset, buffer, lines=False, orient=orient, num_proc=2).write() + buffer.seek(0) + exported_content = load_json(buffer) + assert isinstance(exported_content, container) + if keys: + if container is dict: + assert exported_content.keys() == keys + else: + assert exported_content[0].keys() == keys + else: + assert not hasattr(exported_content, "keys") and not hasattr(exported_content[0], "keys") + if len_at: + assert len(exported_content[len_at]) == 10 + else: + assert len(exported_content) == 10