diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index 0327996f6da..e16d7942c42 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -1,14 +1,14 @@ # coding=utf-8 - +import io import json from dataclasses import dataclass -from io import BytesIO from typing import Optional import pyarrow as pa import pyarrow.json as paj import datasets +from datasets.utils.file_utils import readline logger = datasets.utils.logging.get_logger(__name__) @@ -107,12 +107,16 @@ def _generate_tables(self, files): batch = f.read(self.config.chunksize) if not batch: break - batch += f.readline() # finish current line + # Finish current line + try: + batch += f.readline() + except (AttributeError, io.UnsupportedOperation): + batch += readline(f) try: while True: try: pa_table = paj.read_json( - BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size) + io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size) ) break except pa.ArrowInvalid as e: diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 6035bc69ff5..0468c8241f8 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -5,6 +5,7 @@ """ import copy +import io import json import os import re @@ -682,3 +683,16 @@ def docstring_decorator(fn): def estimate_dataset_size(paths): return sum(path.stat().st_size for path in paths) + + +def readline(f: io.RawIOBase): + # From: https://github.com/python/cpython/blob/d27e2f4d118e7a9909b6a3e5da06c5ff95806a85/Lib/_pyio.py#L525 + res = bytearray() + while True: + b = f.read(1) + if not b: + break + res += b + if res.endswith(b"\n"): + break + return bytes(res) diff --git a/src/datasets/utils/streaming_download_manager.py b/src/datasets/utils/streaming_download_manager.py index 5085d08c979..6bbcdd71523 100644 --- a/src/datasets/utils/streaming_download_manager.py +++ b/src/datasets/utils/streaming_download_manager.py @@ -134,6 +134,8 @@ def _get_extraction_protocol(self, urlpath) -> Optional[str]: return None elif path.endswith(".gz") and not path.endswith(".tar.gz"): return "gzip" + elif path.endswith(".tar"): + return "tar" elif path.endswith(".zip"): return "zip" raise NotImplementedError(f"Extraction protocol for file at {urlpath} is not implemented yet")