diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index c0a085b0731..71dbad71647 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -11,24 +11,20 @@ import datasets +logger = datasets.utils.logging.get_logger(__name__) + + @dataclass class JsonConfig(datasets.BuilderConfig): """BuilderConfig for JSON.""" features: Optional[datasets.Features] = None field: Optional[str] = None - use_threads: bool = True - block_size: Optional[int] = None + use_threads: bool = True # deprecated + block_size: Optional[int] = None # deprecated + chunksize: int = 10 << 20 # 10MB newlines_in_values: Optional[bool] = None - @property - def pa_read_options(self): - return paj.ReadOptions(use_threads=self.use_threads, block_size=self.block_size) - - @property - def pa_parse_options(self): - return paj.ParseOptions(newlines_in_values=self.newlines_in_values) - @property def schema(self): return pa.schema(self.features.type) if self.features is not None else None @@ -38,6 +34,15 @@ class Json(datasets.ArrowBasedBuilder): BUILDER_CONFIG_CLASS = JsonConfig def _info(self): + if self.config.block_size is not None: + logger.warning("The JSON loader parameter `block_size` is deprecated. Please use `chunksize` instead") + self.config.chunksize = self.config.block_size + if self.config.use_threads is not True: + logger.warning( + "The JSON loader parameter `use_threads` is deprecated and doesn't have any effect anymore." + ) + if self.config.newlines_in_values is not None: + raise ValueError("The JSON loader parameter `newlines_in_values` is no longer supported") return datasets.DatasetInfo(features=self.config.features) def _split_generators(self, dl_manager): @@ -57,8 +62,25 @@ def _split_generators(self, dl_manager): splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) return splits + def _cast_classlabels(self, pa_table: pa.Table) -> pa.Table: + if self.config.features: + # Encode column if ClassLabel + for i, col in enumerate(self.config.features.keys()): + if isinstance(self.config.features[col], datasets.ClassLabel): + pa_table = pa_table.set_column( + i, self.config.schema.field(col), [self.config.features[col].str2int(pa_table[col])] + ) + # Cast allows str <-> int/float + # Before casting, rearrange JSON field names to match passed features schema field names order + pa_table = pa.Table.from_arrays( + [pa_table[name] for name in self.config.features], schema=self.config.schema + ) + return pa_table + def _generate_tables(self, files): - for i, file in enumerate(files): + for file_idx, file in enumerate(files): + + # If the file is one json object and if we need to look at the list of items in one specific field if self.config.field is not None: with open(file, encoding="utf-8") as f: dataset = json.load(f) @@ -68,38 +90,38 @@ def _generate_tables(self, files): # We accept two format: a list of dicts or a dict of lists if isinstance(dataset, (list, tuple)): - pa_table = paj.read_json( - BytesIO("\n".join(json.dumps(row) for row in dataset).encode("utf-8")), - read_options=self.config.pa_read_options, - parse_options=self.config.pa_parse_options, - ) + mapping = {col: [dataset[i][col] for i in range(len(dataset))] for col in dataset[0].keys()} else: - pa_table = pa.Table.from_pydict(mapping=dataset) + mapping = dataset + pa_table = pa.Table.from_pydict(mapping=mapping) + yield file_idx, self._cast_classlabels(pa_table) + + # If the file has one json object per line else: - try: - with open(file, "rb") as f: - pa_table = paj.read_json( - f, read_options=self.config.pa_read_options, parse_options=self.config.pa_parse_options - ) - except pa.ArrowInvalid: - with open(file, encoding="utf-8") as f: - dataset = json.load(f) - raise ValueError( - f"Not able to read records in the JSON file at {file}. " - f"You should probably indicate the field of the JSON file containing your records. " - f"This JSON file contain the following fields: {str(list(dataset.keys()))}. " - f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. " - ) - if self.config.features: - # Encode column if ClassLabel - for i, col in enumerate(self.config.features.keys()): - if isinstance(self.config.features[col], datasets.ClassLabel): - pa_table = pa_table.set_column( - i, self.config.schema.field(col), [self.config.features[col].str2int(pa_table[col])] - ) - # Cast allows str <-> int/float, while parse_option explicit_schema does NOT - # Before casting, rearrange JSON field names to match passed features schema field names order - pa_table = pa.Table.from_arrays( - [pa_table[name] for name in self.config.features], schema=self.config.schema - ) - yield i, pa_table + with open(file, "rb") as f: + batch_idx = 0 + while True: + batch = f.read(self.config.chunksize) + if not batch: + break + batch += f.readline() # finish current line + try: + pa_table = paj.read_json(BytesIO(batch)) + except json.JSONDecodeError as e: + logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") + try: + with open(file, encoding="utf-8") as f: + dataset = json.load(f) + except json.JSONDecodeError: + raise e + raise ValueError( + f"Not able to read records in the JSON file at {file}. " + f"You should probably indicate the field of the JSON file containing your records. " + f"This JSON file contain the following fields: {str(list(dataset.keys()))}. " + f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. " + ) + # Uncomment for debugging (will print the Arrow table size and elements) + # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") + # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) + yield (file_idx, batch_idx), self._cast_classlabels(pa_table) + batch_idx += 1 diff --git a/tests/conftest.py b/tests/conftest.py index 152606a30fd..052d039b91c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -234,7 +234,7 @@ def jsonl_312_path(tmp_path_factory): path = str(tmp_path_factory.mktemp("data") / "dataset_312.jsonl") with open(path, "w") as f: for item in DATA_312: - f.write(json.dumps(item)) + f.write(json.dumps(item) + "\n") return path @@ -243,7 +243,7 @@ def jsonl_str_path(tmp_path_factory): path = str(tmp_path_factory.mktemp("data") / "dataset-str.jsonl") with open(path, "w") as f: for item in DATA_STR: - f.write(json.dumps(item)) + f.write(json.dumps(item) + "\n") return path diff --git a/tests/test_dataset_common.py b/tests/test_dataset_common.py index 88a777b47fc..c24a8f94b0b 100644 --- a/tests/test_dataset_common.py +++ b/tests/test_dataset_common.py @@ -84,6 +84,16 @@ def get_packaged_dataset_dummy_data_files(dataset_name, path_to_dummy_data): } +def get_pachakged_dataset_config_attributes(dataset_name): + if dataset_name == "json": + # The json dummy data are formatted as the squad format + # which has the list of examples in the field named "data". + # Therefore we have to tell the json loader to load this field. + return {"field": "data"} + else: + return {} + + class DatasetTester: def __init__(self, parent): self.parent = parent if parent is not None else TestCase() @@ -141,12 +151,15 @@ def check_if_url_is_valid(url): ) # packaged datasets like csv, text, json or pandas require some data files - if dataset_builder.__class__.__name__.lower() in _PACKAGED_DATASETS_MODULES: + builder_name = dataset_builder.__class__.__name__.lower() + if builder_name in _PACKAGED_DATASETS_MODULES: mock_dl_manager.download_dummy_data() path_to_dummy_data = mock_dl_manager.dummy_file dataset_builder.config.data_files = get_packaged_dataset_dummy_data_files( - dataset_builder.__class__.__name__.lower(), path_to_dummy_data + builder_name, path_to_dummy_data ) + for config_attr, value in get_pachakged_dataset_config_attributes(builder_name).items(): + setattr(dataset_builder.config, config_attr, value) # mock size needed for dummy data instead of actual dataset if dataset_builder.info is not None: