Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2633,6 +2633,28 @@ def to_dict(self, batch_size: Optional[int] = None, batched: bool = False) -> Un
for offset in range(0, len(self), batch_size)
)

def to_json(
self,
path_or_buf: Union[PathLike, BinaryIO],
batch_size: Optional[int] = None,
**to_json_kwargs,
) -> int:
"""Exports the dataset to JSON.

Args:
path_or_buf (``PathLike`` or ``FileOrBuffer``): Either a path to a file or a BinaryIO.
batch_size (Optional ``int``): Size of the batch to load in memory and write at once.
Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`.
to_json_kwargs: Parameters to pass to pandas's :func:`pandas.DataFrame.to_json`

Returns:
int: The number of characters or bytes written
"""
# Dynamic import to avoid circular dependency
from .io.json import JsonDatasetWriter

return JsonDatasetWriter(self, path_or_buf, batch_size=batch_size, **to_json_kwargs).write()

def to_pandas(
self, batch_size: Optional[int] = None, batched: bool = False
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
Expand Down
48 changes: 46 additions & 2 deletions src/datasets/io/json.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional
import os
from typing import BinaryIO, Optional, Union

from .. import Features, NamedSplit
from .. import Dataset, Features, NamedSplit, config
from ..formatting import query_table
from ..packaged_modules.json.json import Json
from ..utils.typing import NestedDataStructureLike, PathLike
from .abc import AbstractDatasetReader
Expand Down Expand Up @@ -52,3 +54,45 @@ def read(self):
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
return dataset


class JsonDatasetWriter:
def __init__(
self,
dataset: Dataset,
path_or_buf: Union[PathLike, BinaryIO],
batch_size: Optional[int] = None,
**to_json_kwargs,
):
self.dataset = dataset
self.path_or_buf = path_or_buf
self.batch_size = batch_size
self.to_json_kwargs = to_json_kwargs

def write(self) -> int:
batch_size = self.batch_size if self.batch_size else config.DEFAULT_MAX_BATCH_SIZE

if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
with open(self.path_or_buf, "wb+") as buffer:
written = self._write(file_obj=buffer, batch_size=batch_size, **self.to_json_kwargs)
else:
written = self._write(file_obj=self.path_or_buf, batch_size=batch_size, **self.to_json_kwargs)
return written

def _write(self, file_obj: BinaryIO, batch_size: int, encoding: str = "utf-8", **to_json_kwargs) -> int:
"""Writes the pyarrow table as JSON to a binary file handle.

Caller is responsible for opening and closing the handle.
"""
written = 0
_ = to_json_kwargs.pop("path_or_buf", None)

for offset in range(0, len(self.dataset), batch_size):
batch = query_table(
table=self.dataset.data,
key=slice(offset, offset + batch_size),
indices=self.dataset._indices if self.dataset._indices is not None else None,
)
json_str = batch.to_pandas().to_json(path_or_buf=None, **to_json_kwargs)
written += file_obj.write(json_str.encode(encoding))
return written
10 changes: 10 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2077,6 +2077,16 @@ def test_dataset_from_text(path_type, split, features, keep_in_memory, text_path
assert dataset.features[feature].dtype == expected_dtype


def test_dataset_to_json(dataset, tmp_path):
file_path = tmp_path / "test_path.jsonl"
bytes_written = dataset.to_json(path_or_buf=file_path)
assert file_path.is_file()
assert bytes_written == file_path.stat().st_size
df = pd.read_json(file_path)
assert df.shape == dataset.shape
assert list(df.columns) == list(dataset.column_names)


@pytest.mark.parametrize("in_memory", [False, True])
@pytest.mark.parametrize(
"method_and_params",
Expand Down