Skip to content

Commit e8e31dd

Browse files
Fix to_json ValueError and remove pandas pin (#6201)
* Unpin pandas * Fix JsonDatasetWriter * Fix typo in docstring * Leave default index for orient different from split or table * Pass index within to_json_kwargs when relevant
1 parent 5cf9bbe commit e8e31dd

File tree

3 files changed

+10
-14
lines changed

3 files changed

+10
-14
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@
115115
# For smart caching dataset processing
116116
"dill>=0.3.0,<0.3.8", # tmp pin until dill has official support for determinism see https://github.com/uqfoundation/dill/issues/19
117117
# For performance gains with apache arrow
118-
"pandas<2.1.0", # temporary pin
118+
"pandas",
119119
# for downloading datasets over HTTPS
120120
"requests>=2.19.0",
121121
# progress bars in download and scripts

src/datasets/arrow_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4844,7 +4844,7 @@ def to_json(
48444844
48454845
<Changed version="2.11.0">
48464846
4847-
Now, `index` defaults to `False` if `orint` is `"split"` or `"table"` is specified.
4847+
Now, `index` defaults to `False` if `orient` is `"split"` or `"table"`.
48484848
48494849
If you would like to write the index, pass `index=True`.
48504850

src/datasets/io/json.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,37 +93,34 @@ def write(self) -> int:
9393
_ = self.to_json_kwargs.pop("path_or_buf", None)
9494
orient = self.to_json_kwargs.pop("orient", "records")
9595
lines = self.to_json_kwargs.pop("lines", True if orient == "records" else False)
96-
index = self.to_json_kwargs.pop("index", False if orient in ["split", "table"] else True)
96+
if "index" not in self.to_json_kwargs and orient in ["split", "table"]:
97+
self.to_json_kwargs["index"] = False
9798
compression = self.to_json_kwargs.pop("compression", None)
9899

99100
if compression not in [None, "infer", "gzip", "bz2", "xz"]:
100101
raise NotImplementedError(f"`datasets` currently does not support {compression} compression")
101102

102103
if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
103104
with fsspec.open(self.path_or_buf, "wb", compression=compression) as buffer:
104-
written = self._write(file_obj=buffer, orient=orient, lines=lines, index=index, **self.to_json_kwargs)
105+
written = self._write(file_obj=buffer, orient=orient, lines=lines, **self.to_json_kwargs)
105106
else:
106107
if compression:
107108
raise NotImplementedError(
108109
f"The compression parameter is not supported when writing to a buffer, but compression={compression}"
109110
" was passed. Please provide a local path instead."
110111
)
111-
written = self._write(
112-
file_obj=self.path_or_buf, orient=orient, lines=lines, index=index, **self.to_json_kwargs
113-
)
112+
written = self._write(file_obj=self.path_or_buf, orient=orient, lines=lines, **self.to_json_kwargs)
114113
return written
115114

116115
def _batch_json(self, args):
117-
offset, orient, lines, index, to_json_kwargs = args
116+
offset, orient, lines, to_json_kwargs = args
118117

119118
batch = query_table(
120119
table=self.dataset.data,
121120
key=slice(offset, offset + self.batch_size),
122121
indices=self.dataset._indices,
123122
)
124-
json_str = batch.to_pandas().to_json(
125-
path_or_buf=None, orient=orient, lines=lines, index=index, **to_json_kwargs
126-
)
123+
json_str = batch.to_pandas().to_json(path_or_buf=None, orient=orient, lines=lines, **to_json_kwargs)
127124
if not json_str.endswith("\n"):
128125
json_str += "\n"
129126
return json_str.encode(self.encoding)
@@ -133,7 +130,6 @@ def _write(
133130
file_obj: BinaryIO,
134131
orient,
135132
lines,
136-
index,
137133
**to_json_kwargs,
138134
) -> int:
139135
"""Writes the pyarrow table as JSON lines to a binary file handle.
@@ -149,15 +145,15 @@ def _write(
149145
disable=not logging.is_progress_bar_enabled(),
150146
desc="Creating json from Arrow format",
151147
):
152-
json_str = self._batch_json((offset, orient, lines, index, to_json_kwargs))
148+
json_str = self._batch_json((offset, orient, lines, to_json_kwargs))
153149
written += file_obj.write(json_str)
154150
else:
155151
num_rows, batch_size = len(self.dataset), self.batch_size
156152
with multiprocessing.Pool(self.num_proc) as pool:
157153
for json_str in logging.tqdm(
158154
pool.imap(
159155
self._batch_json,
160-
[(offset, orient, lines, index, to_json_kwargs) for offset in range(0, num_rows, batch_size)],
156+
[(offset, orient, lines, to_json_kwargs) for offset in range(0, num_rows, batch_size)],
161157
),
162158
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
163159
unit="ba",

0 commit comments

Comments
 (0)