Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2856,6 +2856,7 @@ def to_json(
self,
path_or_buf: Union[PathLike, BinaryIO],
batch_size: Optional[int] = None,
num_proc: Optional[int] = None,
**to_json_kwargs,
) -> int:
"""Export the dataset to JSON Lines or JSON.
Expand All @@ -2864,6 +2865,10 @@ def to_json(
path_or_buf (``PathLike`` or ``FileOrBuffer``): Either a path to a file or a BinaryIO.
batch_size (:obj:`int`, optional): Size of the batch to load in memory and write at once.
Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`.
num_proc (:obj:`int`, optional): Number of processes for multiprocessing. By default it doesn't
use multiprocessing. ``batch_size`` in this case defaults to
:obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE` but feel free to make it 5x or 10x of the default
value if you have sufficient compute power.
lines (:obj:`bool`, default ``True``): Whether output JSON lines format.
Only possible if ``orient="records"`. It will throw ValueError with ``orient`` different from
``"records"``, since the others are not list-like.
Expand All @@ -2884,7 +2889,7 @@ def to_json(
# Dynamic import to avoid circular dependency
from .io.json import JsonDatasetWriter

return JsonDatasetWriter(self, path_or_buf, batch_size=batch_size, **to_json_kwargs).write()
return JsonDatasetWriter(self, path_or_buf, batch_size=batch_size, num_proc=num_proc, **to_json_kwargs).write()

def to_pandas(
self, batch_size: Optional[int] = None, batched: bool = False
Expand Down
75 changes: 53 additions & 22 deletions src/datasets/io/json.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import multiprocessing
import os
from typing import BinaryIO, Optional, Union

Expand Down Expand Up @@ -63,49 +64,79 @@ def __init__(
dataset: Dataset,
path_or_buf: Union[PathLike, BinaryIO],
batch_size: Optional[int] = None,
num_proc: Optional[int] = None,
**to_json_kwargs,
):
assert num_proc is None or num_proc > 0, "num_proc must be an integer > 0."
self.dataset = dataset
self.path_or_buf = path_or_buf
self.batch_size = batch_size
self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE
self.num_proc = num_proc
self.encoding = "utf-8"
self.to_json_kwargs = to_json_kwargs

def write(self) -> int:
batch_size = self.batch_size if self.batch_size else config.DEFAULT_MAX_BATCH_SIZE
_ = self.to_json_kwargs.pop("path_or_buf", None)
orient = self.to_json_kwargs.pop("orient", "records")
lines = self.to_json_kwargs.pop("lines", True)

if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
with open(self.path_or_buf, "wb+") as buffer:
written = self._write(file_obj=buffer, batch_size=batch_size, **self.to_json_kwargs)
written = self._write(file_obj=buffer, orient=orient, lines=lines, **self.to_json_kwargs)
else:
written = self._write(file_obj=self.path_or_buf, batch_size=batch_size, **self.to_json_kwargs)
written = self._write(file_obj=self.path_or_buf, orient=orient, lines=lines, **self.to_json_kwargs)
return written

def _batch_json(self, args):
offset, orient, lines, to_json_kwargs = args

batch = query_table(
table=self.dataset.data,
key=slice(offset, offset + self.batch_size),
indices=self.dataset._indices,
)
json_str = batch.to_pandas().to_json(path_or_buf=None, orient=orient, lines=lines, **to_json_kwargs)
if not json_str.endswith("\n"):
json_str += "\n"
return json_str.encode(self.encoding)

def _write(
self,
file_obj: BinaryIO,
batch_size: int,
encoding: str = "utf-8",
orient="records",
lines=True,
orient,
lines,
**to_json_kwargs,
) -> int:
"""Writes the pyarrow table as JSON lines to a binary file handle.

Caller is responsible for opening and closing the handle.
"""
written = 0
_ = to_json_kwargs.pop("path_or_buf", None)

for offset in utils.tqdm(
range(0, len(self.dataset), batch_size), unit="ba", disable=bool(logging.get_verbosity() == logging.NOTSET)
):
batch = query_table(
table=self.dataset.data,
key=slice(offset, offset + batch_size),
indices=self.dataset._indices if self.dataset._indices is not None else None,
)
json_str = batch.to_pandas().to_json(path_or_buf=None, orient=orient, lines=lines, **to_json_kwargs)
if not json_str.endswith("\n"):
json_str += "\n"
written += file_obj.write(json_str.encode(encoding))

if self.num_proc is None or self.num_proc == 1:
for offset in utils.tqdm(
range(0, len(self.dataset), self.batch_size),
unit="ba",
disable=bool(logging.get_verbosity() == logging.NOTSET),
desc="Creating json from Arrow format",
):
json_str = self._batch_json((offset, orient, lines, to_json_kwargs))
written += file_obj.write(json_str)
else:
with multiprocessing.Pool(self.num_proc) as pool:
for json_str in utils.tqdm(
pool.imap(
self._batch_json,
[
(offset, orient, lines, to_json_kwargs)
for offset in range(0, len(self.dataset), self.batch_size)
],
),
total=(len(self.dataset) // self.batch_size) + 1,
unit="ba",
disable=bool(logging.get_verbosity() == logging.NOTSET),
desc="Creating json from Arrow format",
):
written += file_obj.write(json_str)

return written
39 changes: 39 additions & 0 deletions tests/io/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,42 @@ def test_dataset_to_json_orient(self, orient, container, keys, len_at, dataset):
assert len(exported_content[len_at]) == 10
else:
assert len(exported_content) == 10

@pytest.mark.parametrize("lines, load_json_function", [(True, load_json_lines), (False, load_json)])
def test_dataset_to_json_lines_multiproc(self, lines, load_json_function, dataset):
with io.BytesIO() as buffer:
JsonDatasetWriter(dataset, buffer, lines=lines, num_proc=2).write()
buffer.seek(0)
exported_content = load_json_function(buffer)
assert isinstance(exported_content, list)
assert isinstance(exported_content[0], dict)
assert len(exported_content) == 10

@pytest.mark.parametrize(
"orient, container, keys, len_at",
[
("records", list, {"tokens", "labels", "answers", "id"}, None),
("split", dict, {"index", "columns", "data"}, "data"),
("index", dict, set("0123456789"), None),
("columns", dict, {"tokens", "labels", "answers", "id"}, "tokens"),
("values", list, None, None),
("table", dict, {"schema", "data"}, "data"),
],
)
def test_dataset_to_json_orient_multiproc(self, orient, container, keys, len_at, dataset):
with io.BytesIO() as buffer:
JsonDatasetWriter(dataset, buffer, lines=False, orient=orient, num_proc=2).write()
buffer.seek(0)
exported_content = load_json(buffer)
assert isinstance(exported_content, container)
if keys:
if container is dict:
assert exported_content.keys() == keys
else:
assert exported_content[0].keys() == keys
else:
assert not hasattr(exported_content, "keys") and not hasattr(exported_content[0], "keys")
if len_at:
assert len(exported_content[len_at]) == 10
else:
assert len(exported_content) == 10