Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
# For smart caching dataset processing
"dill>=0.3.0,<0.3.8", # tmp pin until dill has official support for determinism see https://github.com/uqfoundation/dill/issues/19
# For performance gains with apache arrow
"pandas<2.1.0", # temporary pin
"pandas",
# for downloading datasets over HTTPS
"requests>=2.19.0",
# progress bars in download and scripts
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4844,7 +4844,7 @@ def to_json(

<Changed version="2.11.0">

Now, `index` defaults to `False` if `orint` is `"split"` or `"table"` is specified.
Now, `index` defaults to `False` if `orient` is `"split"` or `"table"`.

If you would like to write the index, pass `index=True`.

Expand Down
20 changes: 8 additions & 12 deletions src/datasets/io/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,37 +93,34 @@ def write(self) -> int:
_ = self.to_json_kwargs.pop("path_or_buf", None)
orient = self.to_json_kwargs.pop("orient", "records")
lines = self.to_json_kwargs.pop("lines", True if orient == "records" else False)
index = self.to_json_kwargs.pop("index", False if orient in ["split", "table"] else True)
if "index" not in self.to_json_kwargs and orient in ["split", "table"]:
self.to_json_kwargs["index"] = False
compression = self.to_json_kwargs.pop("compression", None)

if compression not in [None, "infer", "gzip", "bz2", "xz"]:
raise NotImplementedError(f"`datasets` currently does not support {compression} compression")

if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
with fsspec.open(self.path_or_buf, "wb", compression=compression) as buffer:
written = self._write(file_obj=buffer, orient=orient, lines=lines, index=index, **self.to_json_kwargs)
written = self._write(file_obj=buffer, orient=orient, lines=lines, **self.to_json_kwargs)
else:
if compression:
raise NotImplementedError(
f"The compression parameter is not supported when writing to a buffer, but compression={compression}"
" was passed. Please provide a local path instead."
)
written = self._write(
file_obj=self.path_or_buf, orient=orient, lines=lines, index=index, **self.to_json_kwargs
)
written = self._write(file_obj=self.path_or_buf, orient=orient, lines=lines, **self.to_json_kwargs)
return written

def _batch_json(self, args):
offset, orient, lines, index, to_json_kwargs = args
offset, orient, lines, to_json_kwargs = args

batch = query_table(
table=self.dataset.data,
key=slice(offset, offset + self.batch_size),
indices=self.dataset._indices,
)
json_str = batch.to_pandas().to_json(
path_or_buf=None, orient=orient, lines=lines, index=index, **to_json_kwargs
)
json_str = batch.to_pandas().to_json(path_or_buf=None, orient=orient, lines=lines, **to_json_kwargs)
if not json_str.endswith("\n"):
json_str += "\n"
return json_str.encode(self.encoding)
Expand All @@ -133,7 +130,6 @@ def _write(
file_obj: BinaryIO,
orient,
lines,
index,
**to_json_kwargs,
) -> int:
"""Writes the pyarrow table as JSON lines to a binary file handle.
Expand All @@ -149,15 +145,15 @@ def _write(
disable=not logging.is_progress_bar_enabled(),
desc="Creating json from Arrow format",
):
json_str = self._batch_json((offset, orient, lines, index, to_json_kwargs))
json_str = self._batch_json((offset, orient, lines, to_json_kwargs))
written += file_obj.write(json_str)
else:
num_rows, batch_size = len(self.dataset), self.batch_size
with multiprocessing.Pool(self.num_proc) as pool:
for json_str in logging.tqdm(
pool.imap(
self._batch_json,
[(offset, orient, lines, index, to_json_kwargs) for offset in range(0, num_rows, batch_size)],
[(offset, orient, lines, to_json_kwargs) for offset in range(0, num_rows, batch_size)],
),
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
unit="ba",
Expand Down