Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 66 additions & 44 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,20 @@
import datasets


logger = datasets.utils.logging.get_logger(__name__)


@dataclass
class JsonConfig(datasets.BuilderConfig):
"""BuilderConfig for JSON."""

features: Optional[datasets.Features] = None
field: Optional[str] = None
use_threads: bool = True
block_size: Optional[int] = None
use_threads: bool = True # deprecated
block_size: Optional[int] = None # deprecated
chunksize: int = 10 << 20 # 10MB
newlines_in_values: Optional[bool] = None

@property
def pa_read_options(self):
return paj.ReadOptions(use_threads=self.use_threads, block_size=self.block_size)

@property
def pa_parse_options(self):
return paj.ParseOptions(newlines_in_values=self.newlines_in_values)

@property
def schema(self):
return pa.schema(self.features.type) if self.features is not None else None
Expand All @@ -38,6 +34,15 @@ class Json(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = JsonConfig

def _info(self):
if self.config.block_size is not None:
logger.warning("The JSON loader parameter `block_size` is deprecated. Please use `chunksize` instead")
self.config.chunksize = self.config.block_size
if self.config.use_threads is not True:
logger.warning(
"The JSON loader parameter `use_threads` is deprecated and doesn't have any effect anymore."
)
if self.config.newlines_in_values is not None:
raise ValueError("The JSON loader parameter `newlines_in_values` is no longer supported")
return datasets.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
Expand All @@ -57,8 +62,25 @@ def _split_generators(self, dl_manager):
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits

def _cast_classlabels(self, pa_table: pa.Table) -> pa.Table:
if self.config.features:
# Encode column if ClassLabel
for i, col in enumerate(self.config.features.keys()):
if isinstance(self.config.features[col], datasets.ClassLabel):
pa_table = pa_table.set_column(
i, self.config.schema.field(col), [self.config.features[col].str2int(pa_table[col])]
)
# Cast allows str <-> int/float
# Before casting, rearrange JSON field names to match passed features schema field names order
pa_table = pa.Table.from_arrays(
[pa_table[name] for name in self.config.features], schema=self.config.schema
)
return pa_table

def _generate_tables(self, files):
for i, file in enumerate(files):
for file_idx, file in enumerate(files):

# If the file is one json object and if we need to look at the list of items in one specific field
if self.config.field is not None:
with open(file, encoding="utf-8") as f:
dataset = json.load(f)
Expand All @@ -68,38 +90,38 @@ def _generate_tables(self, files):

# We accept two format: a list of dicts or a dict of lists
if isinstance(dataset, (list, tuple)):
pa_table = paj.read_json(
BytesIO("\n".join(json.dumps(row) for row in dataset).encode("utf-8")),
read_options=self.config.pa_read_options,
parse_options=self.config.pa_parse_options,
)
mapping = {col: [dataset[i][col] for i in range(len(dataset))] for col in dataset[0].keys()}
else:
pa_table = pa.Table.from_pydict(mapping=dataset)
mapping = dataset
pa_table = pa.Table.from_pydict(mapping=mapping)
yield file_idx, self._cast_classlabels(pa_table)

# If the file has one json object per line
else:
try:
with open(file, "rb") as f:
pa_table = paj.read_json(
f, read_options=self.config.pa_read_options, parse_options=self.config.pa_parse_options
)
except pa.ArrowInvalid:
with open(file, encoding="utf-8") as f:
dataset = json.load(f)
raise ValueError(
f"Not able to read records in the JSON file at {file}. "
f"You should probably indicate the field of the JSON file containing your records. "
f"This JSON file contain the following fields: {str(list(dataset.keys()))}. "
f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. "
)
if self.config.features:
# Encode column if ClassLabel
for i, col in enumerate(self.config.features.keys()):
if isinstance(self.config.features[col], datasets.ClassLabel):
pa_table = pa_table.set_column(
i, self.config.schema.field(col), [self.config.features[col].str2int(pa_table[col])]
)
# Cast allows str <-> int/float, while parse_option explicit_schema does NOT
# Before casting, rearrange JSON field names to match passed features schema field names order
pa_table = pa.Table.from_arrays(
[pa_table[name] for name in self.config.features], schema=self.config.schema
)
yield i, pa_table
with open(file, "rb") as f:
batch_idx = 0
while True:
batch = f.read(self.config.chunksize)
if not batch:
break
batch += f.readline() # finish current line
try:
pa_table = paj.read_json(BytesIO(batch))
except json.JSONDecodeError as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
try:
with open(file, encoding="utf-8") as f:
dataset = json.load(f)
except json.JSONDecodeError:
raise e
raise ValueError(
f"Not able to read records in the JSON file at {file}. "
f"You should probably indicate the field of the JSON file containing your records. "
f"This JSON file contain the following fields: {str(list(dataset.keys()))}. "
f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. "
)
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield (file_idx, batch_idx), self._cast_classlabels(pa_table)
batch_idx += 1
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def jsonl_312_path(tmp_path_factory):
path = str(tmp_path_factory.mktemp("data") / "dataset_312.jsonl")
with open(path, "w") as f:
for item in DATA_312:
f.write(json.dumps(item))
f.write(json.dumps(item) + "\n")
return path


Expand All @@ -243,7 +243,7 @@ def jsonl_str_path(tmp_path_factory):
path = str(tmp_path_factory.mktemp("data") / "dataset-str.jsonl")
with open(path, "w") as f:
for item in DATA_STR:
f.write(json.dumps(item))
f.write(json.dumps(item) + "\n")
return path


Expand Down
17 changes: 15 additions & 2 deletions tests/test_dataset_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ def get_packaged_dataset_dummy_data_files(dataset_name, path_to_dummy_data):
}


def get_pachakged_dataset_config_attributes(dataset_name):
if dataset_name == "json":
# The json dummy data are formatted as the squad format
# which has the list of examples in the field named "data".
# Therefore we have to tell the json loader to load this field.
return {"field": "data"}
else:
return {}


class DatasetTester:
def __init__(self, parent):
self.parent = parent if parent is not None else TestCase()
Expand Down Expand Up @@ -141,12 +151,15 @@ def check_if_url_is_valid(url):
)

# packaged datasets like csv, text, json or pandas require some data files
if dataset_builder.__class__.__name__.lower() in _PACKAGED_DATASETS_MODULES:
builder_name = dataset_builder.__class__.__name__.lower()
if builder_name in _PACKAGED_DATASETS_MODULES:
mock_dl_manager.download_dummy_data()
path_to_dummy_data = mock_dl_manager.dummy_file
dataset_builder.config.data_files = get_packaged_dataset_dummy_data_files(
dataset_builder.__class__.__name__.lower(), path_to_dummy_data
builder_name, path_to_dummy_data
)
for config_attr, value in get_pachakged_dataset_config_attributes(builder_name).items():
setattr(dataset_builder.config, config_attr, value)

# mock size needed for dummy data instead of actual dataset
if dataset_builder.info is not None:
Expand Down