Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2869,6 +2869,7 @@ def to_json(
self,
path_or_buf: Union[PathLike, BinaryIO],
batch_size: Optional[int] = None,
num_proc: Optional[int] = None,
**to_json_kwargs,
) -> int:
"""Export the dataset to JSON Lines or JSON.
Expand All @@ -2877,6 +2878,8 @@ def to_json(
path_or_buf (``PathLike`` or ``FileOrBuffer``): Either a path to a file or a BinaryIO.
batch_size (:obj:`int`, optional): Size of the batch to load in memory and write at once.
Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`.
num_proc (:obj:`int`, optional): Number of processes for multiprocessing. By default it doesn't
use multiprocessing.
lines (:obj:`bool`, default ``True``): Whether output JSON lines format.
Only possible if ``orient="records"`. It will throw ValueError with ``orient`` different from
``"records"``, since the others are not list-like.
Expand All @@ -2897,7 +2900,7 @@ def to_json(
# Dynamic import to avoid circular dependency
from .io.json import JsonDatasetWriter

return JsonDatasetWriter(self, path_or_buf, batch_size=batch_size, **to_json_kwargs).write()
return JsonDatasetWriter(self, path_or_buf, batch_size=batch_size, num_proc=num_proc, **to_json_kwargs).write()

def to_pandas(
self, batch_size: Optional[int] = None, batched: bool = False
Expand Down
85 changes: 63 additions & 22 deletions src/datasets/io/json.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import multiprocessing
import os
from typing import BinaryIO, Optional, Union

Expand Down Expand Up @@ -63,49 +64,89 @@ def __init__(
dataset: Dataset,
path_or_buf: Union[PathLike, BinaryIO],
batch_size: Optional[int] = None,
num_proc: Optional[int] = None,
**to_json_kwargs,
):
self.dataset = dataset
self.path_or_buf = path_or_buf
self.batch_size = batch_size
if batch_size:
self.batch_size = batch_size
elif num_proc is not None:
self.batch_size = 100_000
else:
self.batch_size = config.DEFAULT_MAX_BATCH_SIZE
self.num_proc = num_proc
self.encoding = "utf-8"
self.to_json_kwargs = to_json_kwargs

def write(self) -> int:
batch_size = self.batch_size if self.batch_size else config.DEFAULT_MAX_BATCH_SIZE
_ = self.to_json_kwargs.pop("path_or_buf", None)
orient = self.to_json_kwargs.pop("orient", "records")
lines = self.to_json_kwargs.pop("lines", True)

if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
with open(self.path_or_buf, "wb+") as buffer:
written = self._write(file_obj=buffer, batch_size=batch_size, **self.to_json_kwargs)
written = self._write(file_obj=buffer, orient=orient, lines=lines, **self.to_json_kwargs)
else:
written = self._write(file_obj=self.path_or_buf, batch_size=batch_size, **self.to_json_kwargs)
written = self._write(file_obj=self.path_or_buf, orient=orient, lines=lines, **self.to_json_kwargs)
return written

def _batch_json(self, args):
offset, orient, lines = args

batch = query_table(
table=self.dataset.data,
key=slice(offset, offset + self.batch_size),
indices=self.dataset._indices,
)
var = batch.to_pandas().to_json(path_or_buf=None, orient=orient, lines=lines, **self.to_json_kwargs)
return var.encode(self.encoding)

def _write(
self,
file_obj: BinaryIO,
batch_size: int,
encoding: str = "utf-8",
orient="records",
lines=True,
orient,
lines,
**to_json_kwargs,
) -> int:
"""Writes the pyarrow table as JSON lines to a binary file handle.

Caller is responsible for opening and closing the handle.
"""
written = 0
_ = to_json_kwargs.pop("path_or_buf", None)

for offset in utils.tqdm(
range(0, len(self.dataset), batch_size), unit="ba", disable=bool(logging.get_verbosity() == logging.NOTSET)
):
batch = query_table(
table=self.dataset.data,
key=slice(offset, offset + batch_size),
indices=self.dataset._indices if self.dataset._indices is not None else None,
)
json_str = batch.to_pandas().to_json(path_or_buf=None, orient=orient, lines=lines, **to_json_kwargs)
if not json_str.endswith("\n"):
json_str += "\n"
written += file_obj.write(json_str.encode(encoding))

if self.num_proc is None or self.num_proc == 1:
for offset in utils.tqdm(
range(0, len(self.dataset), self.batch_size),
unit="ba",
disable=bool(logging.get_verbosity() == logging.NOTSET),
desc="Creating json from Arrow format",
):
batch = query_table(
table=self.dataset.data,
key=slice(offset, offset + self.batch_size),
indices=self.dataset._indices if self.dataset._indices is not None else None,
)
json_str = batch.to_pandas().to_json(path_or_buf=None, orient=orient, lines=lines, **to_json_kwargs)
if not json_str.endswith("\n"):
json_str += "\n"
written += file_obj.write(json_str.encode(self.encoding))
else:
pool = multiprocessing.Pool(processes=self.num_proc)

for json_str in utils.tqdm(
pool.imap(
self._batch_json,
[(offset, orient, lines) for offset in range(0, len(self.dataset), self.batch_size)],
),
total=(len(self.dataset) // self.batch_size) + 1,
unit="ba",
disable=bool(logging.get_verbosity() == logging.NOTSET),
desc="Creating json from Arrow format",
):
written += file_obj.write((json_str))

pool.close()
pool.join()

return written