Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/datasets/filesystems/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class BaseCompressedFileFileSystem(AbstractArchiveFileSystem):
None # protocol passed in prefix to the url. ex: "gzip", for gzip://file.txt::http://foo.bar/file.txt.gz
)
compression: str = None # compression type in fsspec. ex: "gzip"
extension: str = None # extension of the filename to strip. ex: "".gz" to get file.txt from file.txt.gz
extensions: list[str] = None # extensions of the filename to strip. ex: ".gz" to get file.txt from file.txt.gz

def __init__(
self, fo: str = "", target_protocol: Optional[str] = None, target_options: Optional[dict] = None, **kwargs
Expand Down Expand Up @@ -90,31 +90,31 @@ class Bz2FileSystem(BaseCompressedFileFileSystem):

protocol = "bz2"
compression = "bz2"
extension = ".bz2"
extensions = [".bz2"]


class GzipFileSystem(BaseCompressedFileFileSystem):
"""Read contents of GZIP file as a filesystem with one file inside."""

protocol = "gzip"
compression = "gzip"
extension = ".gz"
extensions = [".gz", ".gzip"]


class Lz4FileSystem(BaseCompressedFileFileSystem):
"""Read contents of LZ4 file as a filesystem with one file inside."""

protocol = "lz4"
compression = "lz4"
extension = ".lz4"
extensions = [".lz4"]


class XzFileSystem(BaseCompressedFileFileSystem):
"""Read contents of .xz (LZMA) file as a filesystem with one file inside."""

protocol = "xz"
compression = "xz"
extension = ".xz"
extensions = [".xz"]


class ZstdFileSystem(BaseCompressedFileFileSystem):
Expand All @@ -124,4 +124,4 @@ class ZstdFileSystem(BaseCompressedFileFileSystem):

protocol = "zstd"
compression = "zstd"
extension = ".zst"
extensions = [".zst", ".zstd"]
14 changes: 14 additions & 0 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
for column_name in set(self.config.features) - set(pa_table.column_names):
type = self.config.features.arrow_schema.field(column_name).type
pa_table = pa_table.append_column(column_name, pa.array([None] * len(pa_table), type=type))
# convert to string when needed
for i, column_name in enumerate(pa_table.column_names):
if pa.types.is_struct(pa_table[column_name].type) and self.config.features.get(
column_name, None
) == datasets.Value("string"):
jsonl = (
pa_table[column_name]
.to_pandas(types_mapper=pd.ArrowDtype)
.to_json(orient="records", lines=True)
)
string_array = pa.array(
("{" + x.rstrip() for x in ("\n" + jsonl).split("\n{") if x), type=pa.string()
)
pa_table = pa_table.set_column(i, column_name, string_array)
# more expensive cast to support nested structures with keys in a different order
# allows str <-> int/float or str to Audio for example
pa_table = table_cast(pa_table, self.config.features.arrow_schema)
Expand Down
10 changes: 8 additions & 2 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,12 +461,18 @@ def readline(f: io.RawIOBase):
]
COMPRESSION_EXTENSION_TO_PROTOCOL = {
# single file compression
**{fs_class.extension.lstrip("."): fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS},
**{
extension.lstrip("."): fs_class.protocol
for fs_class in COMPRESSION_FILESYSTEMS
for extension in fs_class.extensions
},
# archive compression
"zip": "zip",
}
SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL = {
fs_class.extension.lstrip("."): fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS
extension.lstrip("."): fs_class.protocol
for fs_class in COMPRESSION_FILESYSTEMS
for extension in fs_class.extensions
}
SINGLE_FILE_COMPRESSION_PROTOCOLS = {fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS}
SINGLE_SLASH_AFTER_PROTOCOL_PATTERN = re.compile(r"(?<!:):/")
Expand Down
Loading