Skip to content

Commit 98159c4

Browse files
Support streaming compressed files (#2786)
* Pass compression to stream zstd file * Implement custom readline for io.RawIOBase like * Fix readline in json module for io.RawIOBase * Simplify custom readline * Test load dataset streaming compressed files * Test xz files * Support streaming xz compressed files * Support streaming bz2 and lz4 compressed files * Fix style in test * Fix test * Test zip files * Test tar files * Test gzip files * Implement _add_retries_to_fsspec_open_file * Add retries to fsspec OpenFile * Make _add_retries_to_file_obj_read_method return * Refactor _add_retries_to_fsspec_open_file
1 parent c9fca18 commit 98159c4

2 files changed

Lines changed: 40 additions & 3 deletions

File tree

src/datasets/utils/streaming_download_manager.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
logger = get_logger(__name__)
1717
BASE_KNOWN_EXTENSIONS = ["txt", "csv", "json", "jsonl", "tsv", "conll", "conllu", "parquet", "pkl", "pickle", "xml"]
18+
COMPRESSION_KNOWN_EXTENSIONS = ["bz2", "lz4", "xz", "zst"]
1819

1920

2021
def xjoin(a, *p):
@@ -63,6 +64,19 @@ def read_with_retries(*args, **kwargs):
6364
return out
6465

6566
file_obj.read = read_with_retries
67+
return file_obj
68+
69+
70+
def _add_retries_to_fsspec_open_file(fsspec_open_file):
71+
open_ = fsspec_open_file.open
72+
73+
def open_with_retries():
74+
file_obj = open_()
75+
_add_retries_to_file_obj_read_method(file_obj)
76+
return file_obj
77+
78+
fsspec_open_file.open = open_with_retries
79+
return fsspec_open_file
6680

6781

6882
def xopen(file, mode="r", *args, **kwargs):
@@ -74,8 +88,13 @@ def xopen(file, mode="r", *args, **kwargs):
7488
"""
7589
if fsspec.get_fs_token_paths(file)[0].protocol == "https":
7690
kwargs["headers"] = get_authentication_headers_for_url(file, use_auth_token=kwargs.pop("use_auth_token", None))
77-
file_obj = fsspec.open(file, mode=mode, *args, **kwargs).open()
78-
_add_retries_to_file_obj_read_method(file_obj)
91+
compression = fsspec.core.get_compression(file, "infer")
92+
if not compression or compression in ["gzip", "zip"]:
93+
file_obj = fsspec.open(file, mode=mode, *args, **kwargs).open()
94+
file_obj = _add_retries_to_file_obj_read_method(file_obj)
95+
else:
96+
file_obj = fsspec.open(file, mode=mode, compression=compression, *args, **kwargs)
97+
file_obj = _add_retries_to_fsspec_open_file(file_obj)
7998
return file_obj
8099

81100

@@ -130,7 +149,7 @@ def _extract(self, urlpath):
130149

131150
def _get_extraction_protocol(self, urlpath) -> Optional[str]:
132151
path = urlpath.split("::")[0]
133-
if path.split(".")[-1] in BASE_KNOWN_EXTENSIONS:
152+
if path.split(".")[-1] in BASE_KNOWN_EXTENSIONS + COMPRESSION_KNOWN_EXTENSIONS:
134153
return None
135154
elif path.endswith(".gz") and not path.endswith(".tar.gz"):
136155
return "gzip"

tests/test_load.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,24 @@ def test_load_dataset_streaming_gz_json(jsonl_gz_path):
247247
assert ds_item == {"col_1": "0", "col_2": 0, "col_3": 0.0}
248248

249249

250+
@require_streaming
251+
@pytest.mark.parametrize(
252+
"path", ["sample.jsonl", "sample.jsonl.gz", "sample.tar", "sample.jsonl.xz", "sample.zip", "sample.jsonl.zst"]
253+
)
254+
def test_load_dataset_streaming_compressed_files(path):
255+
repo_id = "albertvillanova/datasets-tests-compression"
256+
data_files = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{path}"
257+
ds = load_dataset("json", split="train", data_files=data_files, streaming=True)
258+
assert isinstance(ds, IterableDataset)
259+
ds_item = next(iter(ds))
260+
assert ds_item == {
261+
"tokens": ["Ministeri", "de", "Justícia", "d'Espanya"],
262+
"ner_tags": [1, 2, 2, 2],
263+
"langs": ["ca", "ca", "ca", "ca"],
264+
"spans": ["PER: Ministeri de Justícia d'Espanya"],
265+
}
266+
267+
250268
def test_loading_from_the_datasets_hub():
251269
with tempfile.TemporaryDirectory() as tmp_dir:
252270
dataset = load_dataset(SAMPLE_DATASET_IDENTIFIER, cache_dir=tmp_dir)

0 commit comments

Comments
 (0)