diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 55fa375baa8..299a7cfb8ee 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -2633,6 +2633,28 @@ def to_dict(self, batch_size: Optional[int] = None, batched: bool = False) -> Un for offset in range(0, len(self), batch_size) ) + def to_json( + self, + path_or_buf: Union[PathLike, BinaryIO], + batch_size: Optional[int] = None, + **to_json_kwargs, + ) -> int: + """Exports the dataset to JSON. + + Args: + path_or_buf (``PathLike`` or ``FileOrBuffer``): Either a path to a file or a BinaryIO. + batch_size (Optional ``int``): Size of the batch to load in memory and write at once. + Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`. + to_json_kwargs: Parameters to pass to pandas's :func:`pandas.DataFrame.to_json` + + Returns: + int: The number of characters or bytes written + """ + # Dynamic import to avoid circular dependency + from .io.json import JsonDatasetWriter + + return JsonDatasetWriter(self, path_or_buf, batch_size=batch_size, **to_json_kwargs).write() + def to_pandas( self, batch_size: Optional[int] = None, batched: bool = False ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: diff --git a/src/datasets/io/json.py b/src/datasets/io/json.py index fad2ca44d54..2175d17ed2b 100644 --- a/src/datasets/io/json.py +++ b/src/datasets/io/json.py @@ -1,6 +1,8 @@ -from typing import Optional +import os +from typing import BinaryIO, Optional, Union -from .. import Features, NamedSplit +from .. import Dataset, Features, NamedSplit, config +from ..formatting import query_table from ..packaged_modules.json.json import Json from ..utils.typing import NestedDataStructureLike, PathLike from .abc import AbstractDatasetReader @@ -52,3 +54,45 @@ def read(self): split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory ) return dataset + + +class JsonDatasetWriter: + def __init__( + self, + dataset: Dataset, + path_or_buf: Union[PathLike, BinaryIO], + batch_size: Optional[int] = None, + **to_json_kwargs, + ): + self.dataset = dataset + self.path_or_buf = path_or_buf + self.batch_size = batch_size + self.to_json_kwargs = to_json_kwargs + + def write(self) -> int: + batch_size = self.batch_size if self.batch_size else config.DEFAULT_MAX_BATCH_SIZE + + if isinstance(self.path_or_buf, (str, bytes, os.PathLike)): + with open(self.path_or_buf, "wb+") as buffer: + written = self._write(file_obj=buffer, batch_size=batch_size, **self.to_json_kwargs) + else: + written = self._write(file_obj=self.path_or_buf, batch_size=batch_size, **self.to_json_kwargs) + return written + + def _write(self, file_obj: BinaryIO, batch_size: int, encoding: str = "utf-8", **to_json_kwargs) -> int: + """Writes the pyarrow table as JSON to a binary file handle. + + Caller is responsible for opening and closing the handle. + """ + written = 0 + _ = to_json_kwargs.pop("path_or_buf", None) + + for offset in range(0, len(self.dataset), batch_size): + batch = query_table( + table=self.dataset.data, + key=slice(offset, offset + batch_size), + indices=self.dataset._indices if self.dataset._indices is not None else None, + ) + json_str = batch.to_pandas().to_json(path_or_buf=None, **to_json_kwargs) + written += file_obj.write(json_str.encode(encoding)) + return written diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 75b88933c26..53a50cd64f8 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -2077,6 +2077,16 @@ def test_dataset_from_text(path_type, split, features, keep_in_memory, text_path assert dataset.features[feature].dtype == expected_dtype +def test_dataset_to_json(dataset, tmp_path): + file_path = tmp_path / "test_path.jsonl" + bytes_written = dataset.to_json(path_or_buf=file_path) + assert file_path.is_file() + assert bytes_written == file_path.stat().st_size + df = pd.read_json(file_path) + assert df.shape == dataset.shape + assert list(df.columns) == list(dataset.column_names) + + @pytest.mark.parametrize("in_memory", [False, True]) @pytest.mark.parametrize( "method_and_params",