diff --git a/src/datasets/filesystems/compression.py b/src/datasets/filesystems/compression.py index 416d5b15f82..237b06d7500 100644 --- a/src/datasets/filesystems/compression.py +++ b/src/datasets/filesystems/compression.py @@ -14,7 +14,7 @@ class BaseCompressedFileFileSystem(AbstractArchiveFileSystem): None # protocol passed in prefix to the url. ex: "gzip", for gzip://file.txt::http://foo.bar/file.txt.gz ) compression: str = None # compression type in fsspec. ex: "gzip" - extension: str = None # extension of the filename to strip. ex: "".gz" to get file.txt from file.txt.gz + extensions: list[str] = None # extensions of the filename to strip. ex: ".gz" to get file.txt from file.txt.gz def __init__( self, fo: str = "", target_protocol: Optional[str] = None, target_options: Optional[dict] = None, **kwargs @@ -90,7 +90,7 @@ class Bz2FileSystem(BaseCompressedFileFileSystem): protocol = "bz2" compression = "bz2" - extension = ".bz2" + extensions = [".bz2"] class GzipFileSystem(BaseCompressedFileFileSystem): @@ -98,7 +98,7 @@ class GzipFileSystem(BaseCompressedFileFileSystem): protocol = "gzip" compression = "gzip" - extension = ".gz" + extensions = [".gz", ".gzip"] class Lz4FileSystem(BaseCompressedFileFileSystem): @@ -106,7 +106,7 @@ class Lz4FileSystem(BaseCompressedFileFileSystem): protocol = "lz4" compression = "lz4" - extension = ".lz4" + extensions = [".lz4"] class XzFileSystem(BaseCompressedFileFileSystem): @@ -114,7 +114,7 @@ class XzFileSystem(BaseCompressedFileFileSystem): protocol = "xz" compression = "xz" - extension = ".xz" + extensions = [".xz"] class ZstdFileSystem(BaseCompressedFileFileSystem): @@ -124,4 +124,4 @@ class ZstdFileSystem(BaseCompressedFileFileSystem): protocol = "zstd" compression = "zstd" - extension = ".zst" + extensions = [".zst", ".zstd"] diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index 426083fc718..c5d8bcd03fc 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -90,6 +90,20 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table: for column_name in set(self.config.features) - set(pa_table.column_names): type = self.config.features.arrow_schema.field(column_name).type pa_table = pa_table.append_column(column_name, pa.array([None] * len(pa_table), type=type)) + # convert to string when needed + for i, column_name in enumerate(pa_table.column_names): + if pa.types.is_struct(pa_table[column_name].type) and self.config.features.get( + column_name, None + ) == datasets.Value("string"): + jsonl = ( + pa_table[column_name] + .to_pandas(types_mapper=pd.ArrowDtype) + .to_json(orient="records", lines=True) + ) + string_array = pa.array( + ("{" + x.rstrip() for x in ("\n" + jsonl).split("\n{") if x), type=pa.string() + ) + pa_table = pa_table.set_column(i, column_name, string_array) # more expensive cast to support nested structures with keys in a different order # allows str <-> int/float or str to Audio for example pa_table = table_cast(pa_table, self.config.features.arrow_schema) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index dff7cc3e754..81be4f295c4 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -461,12 +461,18 @@ def readline(f: io.RawIOBase): ] COMPRESSION_EXTENSION_TO_PROTOCOL = { # single file compression - **{fs_class.extension.lstrip("."): fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS}, + **{ + extension.lstrip("."): fs_class.protocol + for fs_class in COMPRESSION_FILESYSTEMS + for extension in fs_class.extensions + }, # archive compression "zip": "zip", } SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL = { - fs_class.extension.lstrip("."): fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS + extension.lstrip("."): fs_class.protocol + for fs_class in COMPRESSION_FILESYSTEMS + for extension in fs_class.extensions } SINGLE_FILE_COMPRESSION_PROTOCOLS = {fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS} SINGLE_SLASH_AFTER_PROTOCOL_PATTERN = re.compile(r"(?