Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions src/datasets/utils/streaming_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

logger = get_logger(__name__)
BASE_KNOWN_EXTENSIONS = ["txt", "csv", "json", "jsonl", "tsv", "conll", "conllu", "parquet", "pkl", "pickle", "xml"]
COMPRESSION_KNOWN_EXTENSIONS = ["bz2", "lz4", "xz", "zst"]


def xjoin(a, *p):
Expand Down Expand Up @@ -63,6 +64,19 @@ def read_with_retries(*args, **kwargs):
return out

file_obj.read = read_with_retries
return file_obj


def _add_retries_to_fsspec_open_file(fsspec_open_file):
open_ = fsspec_open_file.open

def open_with_retries():
file_obj = open_()
_add_retries_to_file_obj_read_method(file_obj)
return file_obj

fsspec_open_file.open = open_with_retries
return fsspec_open_file


def xopen(file, mode="r", *args, **kwargs):
Expand All @@ -74,8 +88,13 @@ def xopen(file, mode="r", *args, **kwargs):
"""
if fsspec.get_fs_token_paths(file)[0].protocol == "https":
kwargs["headers"] = get_authentication_headers_for_url(file, use_auth_token=kwargs.pop("use_auth_token", None))
file_obj = fsspec.open(file, mode=mode, *args, **kwargs).open()
_add_retries_to_file_obj_read_method(file_obj)
compression = fsspec.core.get_compression(file, "infer")
if not compression or compression in ["gzip", "zip"]:
file_obj = fsspec.open(file, mode=mode, *args, **kwargs).open()
file_obj = _add_retries_to_file_obj_read_method(file_obj)
else:
file_obj = fsspec.open(file, mode=mode, compression=compression, *args, **kwargs)
file_obj = _add_retries_to_fsspec_open_file(file_obj)
Comment on lines +91 to +97
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xopen is an extension of open to make it work with remote files.

Here you change its behavior for compressed files: you automatically uncompress them. Therefore if you try to open a compressed file and then use gzip (or any other compressing tool) to uncompress it, it won't work since it's already uncompressed.

I think we should revert this change, and explicitly use some tool in the dataset scripts to uncompress the files as we do in standard python. Otherwise we may end up with code that works in streaming mode but not in standard mode and vice-versa.

Let me know what you think @albertvillanova

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, fsspec.open (even if passed the compression parameter) does not uncompress the file immediately: it returns an OpenFile instance, which will return a file-object wrapped with a decompressor instance (when called within a context manager), which will decompress on the fly... ;)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And yes, at the end, the result (after having called dl_manager.download_and_extract will be an uncompressed file, either streaming or not. That is the objective! 😉

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is: how do you make that StreamingDownloadManager.extract() passes the parameter compression=compression to fsspec.open(urlpath, compression=compression) if they can communicate only through the parameter urlpath?

Because of this, I always pass compression="infer", which assumes that all datasets scripts have called .extract (or .download_and_extract) before calling fsspec.open. This assumption is sensible and will work for all dataset scripts, except for oscar (as you told me yesterday), because you changed oscar with a call: gzip.open(open()).

return file_obj


Expand Down Expand Up @@ -130,7 +149,7 @@ def _extract(self, urlpath):

def _get_extraction_protocol(self, urlpath) -> Optional[str]:
path = urlpath.split("::")[0]
if path.split(".")[-1] in BASE_KNOWN_EXTENSIONS:
if path.split(".")[-1] in BASE_KNOWN_EXTENSIONS + COMPRESSION_KNOWN_EXTENSIONS:
return None
elif path.endswith(".gz") and not path.endswith(".tar.gz"):
return "gzip"
Expand Down
18 changes: 18 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,24 @@ def test_load_dataset_streaming_gz_json(jsonl_gz_path):
assert ds_item == {"col_1": "0", "col_2": 0, "col_3": 0.0}


@require_streaming
@pytest.mark.parametrize(
"path", ["sample.jsonl", "sample.jsonl.gz", "sample.tar", "sample.jsonl.xz", "sample.zip", "sample.jsonl.zst"]
)
def test_load_dataset_streaming_compressed_files(path):
repo_id = "albertvillanova/datasets-tests-compression"
data_files = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{path}"
Copy link
Member

@lhoestq lhoestq Aug 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really nice feature @albertvillanova !

I think the glob logic has to be moved in a data files resolution module, as done in #2662

def _resolve_data_files_locally_or_by_urls(

The current implementation may not be robust enough to work with path manipulations by users in compressed files

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not touched the glob logic in this PR though... 🤔

ds = load_dataset("json", split="train", data_files=data_files, streaming=True)
assert isinstance(ds, IterableDataset)
ds_item = next(iter(ds))
assert ds_item == {
"tokens": ["Ministeri", "de", "Justícia", "d'Espanya"],
"ner_tags": [1, 2, 2, 2],
"langs": ["ca", "ca", "ca", "ca"],
"spans": ["PER: Ministeri de Justícia d'Espanya"],
}


def test_loading_from_the_datasets_hub():
with tempfile.TemporaryDirectory() as tmp_dir:
dataset = load_dataset(SAMPLE_DATASET_IDENTIFIER, cache_dir=tmp_dir)
Expand Down