Skip to content

Commit 1a9385d

Browse files
authored
Add support for parsing JSON files in array form (#4997)
* Support parsing JSON lists * Add error handling * Minor improvements * Add tests * Comment
1 parent ace149f commit 1a9385d

File tree

3 files changed

+109
-8
lines changed

3 files changed

+109
-8
lines changed

src/datasets/packaged_modules/json/json.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,25 @@
1515
logger = datasets.utils.logging.get_logger(__name__)
1616

1717

18+
if datasets.config.PYARROW_VERSION.major >= 7:
19+
20+
def pa_table_from_pylist(mapping):
21+
return pa.Table.from_pylist(mapping)
22+
23+
else:
24+
25+
def pa_table_from_pylist(mapping):
26+
# Copied from: https://github.com/apache/arrow/blob/master/python/pyarrow/table.pxi#L5193
27+
arrays = []
28+
names = []
29+
if mapping:
30+
names = list(mapping[0].keys())
31+
for n in names:
32+
v = [row[n] if n in row else None for row in mapping]
33+
arrays.append(v)
34+
return pa.Table.from_arrays(arrays, names)
35+
36+
1837
@dataclass
1938
class JsonConfig(datasets.BuilderConfig):
2039
"""BuilderConfig for JSON."""
@@ -125,18 +144,29 @@ def _generate_tables(self, files):
125144
)
126145
block_size *= 2
127146
except pa.ArrowInvalid as e:
128-
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
129147
try:
130148
with open(file, encoding="utf-8") as f:
131149
dataset = json.load(f)
132150
except json.JSONDecodeError:
151+
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
133152
raise e
134-
raise ValueError(
135-
f"Not able to read records in the JSON file at {file}. "
136-
f"You should probably indicate the field of the JSON file containing your records. "
137-
f"This JSON file contain the following fields: {str(list(dataset.keys()))}. "
138-
f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. "
139-
) from None
153+
# If possible, parse the file as a list of json objects and exit the loop
154+
if isinstance(dataset, list): # list is the only sequence type supported in JSON
155+
try:
156+
pa_table = pa_table_from_pylist(dataset)
157+
except (pa.ArrowInvalid, AttributeError) as e:
158+
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
159+
raise ValueError(f"Not able to read records in the JSON file at {file}.") from None
160+
yield file_idx, self._cast_table(pa_table)
161+
break
162+
else:
163+
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
164+
raise ValueError(
165+
f"Not able to read records in the JSON file at {file}. "
166+
f"You should probably indicate the field of the JSON file containing your records. "
167+
f"This JSON file contain the following fields: {str(list(dataset.keys()))}. "
168+
f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. "
169+
) from None
140170
# Uncomment for debugging (will print the Arrow table size and elements)
141171
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
142172
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))

tests/packaged_modules/test_csv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
@pytest.fixture
1414
def csv_file(tmp_path):
15-
filename = tmp_path / "malformed_file.csv"
15+
filename = tmp_path / "file.csv"
1616
data = textwrap.dedent(
1717
"""\
1818
header1,header2
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import textwrap
2+
3+
import pyarrow as pa
4+
import pytest
5+
6+
from datasets.packaged_modules.json.json import Json
7+
8+
9+
@pytest.fixture
10+
def jsonl_file(tmp_path):
11+
filename = tmp_path / "file.jsonl"
12+
data = textwrap.dedent(
13+
"""\
14+
{"col_1": 1, "col_2": 2}
15+
{"col_1": 10, "col_2": 20}
16+
"""
17+
)
18+
with open(filename, "w") as f:
19+
f.write(data)
20+
return str(filename)
21+
22+
23+
@pytest.fixture
24+
def json_file_with_list_of_dicts(tmp_path):
25+
filename = tmp_path / "file_with_list_of_dicts.json"
26+
data = textwrap.dedent(
27+
"""\
28+
[
29+
{"col_1": 1, "col_2": 2},
30+
{"col_1": 10, "col_2": 20}
31+
]
32+
"""
33+
)
34+
with open(filename, "w") as f:
35+
f.write(data)
36+
return str(filename)
37+
38+
39+
@pytest.fixture
40+
def json_file_with_list_of_dicts_field(tmp_path):
41+
filename = tmp_path / "file_with_list_of_dicts_field.json"
42+
data = textwrap.dedent(
43+
"""\
44+
{
45+
"field1": 1,
46+
"field2": "aabb",
47+
"field3": [
48+
{"col_1": 1, "col_2": 2},
49+
{"col_1": 10, "col_2": 20}
50+
]
51+
}
52+
"""
53+
)
54+
with open(filename, "w") as f:
55+
f.write(data)
56+
return str(filename)
57+
58+
59+
@pytest.mark.parametrize(
60+
"file_fixture, config_kwargs",
61+
[
62+
("jsonl_file", {}),
63+
("json_file_with_list_of_dicts", {}),
64+
("json_file_with_list_of_dicts_field", {"field": "field3"}),
65+
],
66+
)
67+
def test_json_generate_tables(file_fixture, config_kwargs, request):
68+
json = Json(**config_kwargs)
69+
generator = json._generate_tables([[request.getfixturevalue(file_fixture)]])
70+
pa_table = pa.concat_tables([table for _, table in generator])
71+
assert pa_table.to_pydict() == {"col_1": [1, 10], "col_2": [2, 20]}

0 commit comments

Comments
 (0)