Skip to content

Commit 6a1c6b1

Browse files
Make Extractor accept Path as input (#4718)
* Make Extractor accept Path as input * Remove unnecessary casting of Path to str * Remove other unnecessary casting of Path to str * Add type hints * Fix type hints with TYPE_CHECKING
1 parent 5088e95 commit 6a1c6b1

File tree

6 files changed

+40
-32
lines changed

6 files changed

+40
-32
lines changed

src/datasets/utils/extract.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,37 @@
77
import warnings
88
import zipfile
99
from abc import ABC, abstractmethod
10+
from typing import TYPE_CHECKING, Optional, Union
1011

1112
from .. import config
1213
from .filelock import FileLock
1314

1415

16+
if TYPE_CHECKING:
17+
import pathlib
18+
19+
1520
class ExtractManager:
16-
def __init__(self, cache_dir=None):
21+
def __init__(self, cache_dir: Optional[str] = None):
1722
self.extract_dir = (
1823
os.path.join(cache_dir, config.EXTRACTED_DATASETS_DIR) if cache_dir else config.EXTRACTED_DATASETS_PATH
1924
)
2025
self.extractor = Extractor
2126

22-
def _get_output_path(self, path):
27+
def _get_output_path(self, path: str) -> str:
2328
from .file_utils import hash_url_to_filename
2429

2530
# Path where we extract compressed archives
2631
# We extract in the cache dir, and get the extracted path name by hashing the original path"
2732
abs_path = os.path.abspath(path)
2833
return os.path.join(self.extract_dir, hash_url_to_filename(abs_path))
2934

30-
def _do_extract(self, output_path, force_extract):
35+
def _do_extract(self, output_path: str, force_extract: bool) -> bool:
3136
return force_extract or (
3237
not os.path.isfile(output_path) and not (os.path.isdir(output_path) and os.listdir(output_path))
3338
)
3439

35-
def extract(self, input_path, force_extract=False):
40+
def extract(self, input_path: str, force_extract: bool = False) -> str:
3641
extractor_format = self.extractor.infer_extractor_format(input_path)
3742
if not extractor_format:
3843
return input_path
@@ -45,25 +50,25 @@ def extract(self, input_path, force_extract=False):
4550
class BaseExtractor(ABC):
4651
@classmethod
4752
@abstractmethod
48-
def is_extractable(cls, path: str, **kwargs) -> bool:
53+
def is_extractable(cls, path: Union["pathlib.Path", str], **kwargs) -> bool:
4954
...
5055

5156
@staticmethod
5257
@abstractmethod
53-
def extract(input_path: str, output_path: str) -> None:
58+
def extract(input_path: Union["pathlib.Path", str], output_path: Union["pathlib.Path", str]) -> None:
5459
...
5560

5661

5762
class MagicNumberBaseExtractor(BaseExtractor, ABC):
5863
magic_number = b""
5964

6065
@staticmethod
61-
def read_magic_number(path: str, magic_number_length: int):
66+
def read_magic_number(path: Union["pathlib.Path", str], magic_number_length: int):
6267
with open(path, "rb") as f:
6368
return f.read(magic_number_length)
6469

6570
@classmethod
66-
def is_extractable(cls, path: str, magic_number: bytes = b"") -> bool:
71+
def is_extractable(cls, path: Union["pathlib.Path", str], magic_number: bytes = b"") -> bool:
6772
if not magic_number:
6873
try:
6974
magic_number = cls.read_magic_number(path, len(cls.magic_number))
@@ -74,11 +79,11 @@ def is_extractable(cls, path: str, magic_number: bytes = b"") -> bool:
7479

7580
class TarExtractor(BaseExtractor):
7681
@classmethod
77-
def is_extractable(cls, path: str, **kwargs) -> bool:
82+
def is_extractable(cls, path: Union["pathlib.Path", str], **kwargs) -> bool:
7883
return tarfile.is_tarfile(path)
7984

8085
@staticmethod
81-
def extract(input_path: str, output_path: str) -> None:
86+
def extract(input_path: Union["pathlib.Path", str], output_path: Union["pathlib.Path", str]) -> None:
8287
os.makedirs(output_path, exist_ok=True)
8388
tar_file = tarfile.open(input_path)
8489
tar_file.extractall(output_path)
@@ -89,19 +94,19 @@ class GzipExtractor(MagicNumberBaseExtractor):
8994
magic_number = b"\x1F\x8B"
9095

9196
@staticmethod
92-
def extract(input_path: str, output_path: str) -> None:
97+
def extract(input_path: Union["pathlib.Path", str], output_path: Union["pathlib.Path", str]) -> None:
9398
with gzip.open(input_path, "rb") as gzip_file:
9499
with open(output_path, "wb") as extracted_file:
95100
shutil.copyfileobj(gzip_file, extracted_file)
96101

97102

98103
class ZipExtractor(BaseExtractor):
99104
@classmethod
100-
def is_extractable(cls, path: str, **kwargs) -> bool:
105+
def is_extractable(cls, path: Union["pathlib.Path", str], **kwargs) -> bool:
101106
return zipfile.is_zipfile(path)
102107

103108
@staticmethod
104-
def extract(input_path: str, output_path: str) -> None:
109+
def extract(input_path: Union["pathlib.Path", str], output_path: Union["pathlib.Path", str]) -> None:
105110
os.makedirs(output_path, exist_ok=True)
106111
with zipfile.ZipFile(input_path, "r") as zip_file:
107112
zip_file.extractall(output_path)
@@ -112,7 +117,7 @@ class XzExtractor(MagicNumberBaseExtractor):
112117
magic_number = b"\xFD\x37\x7A\x58\x5A\x00"
113118

114119
@staticmethod
115-
def extract(input_path: str, output_path: str) -> None:
120+
def extract(input_path: Union["pathlib.Path", str], output_path: Union["pathlib.Path", str]) -> None:
116121
with lzma.open(input_path) as compressed_file:
117122
with open(output_path, "wb") as extracted_file:
118123
shutil.copyfileobj(compressed_file, extracted_file)
@@ -123,14 +128,14 @@ class RarExtractor(BaseExtractor):
123128
RAR5_ID = b"Rar!\x1a\x07\x01\x00"
124129

125130
@classmethod
126-
def is_extractable(cls, path: str, **kwargs) -> bool:
131+
def is_extractable(cls, path: Union["pathlib.Path", str], **kwargs) -> bool:
127132
"""https://github.com/markokr/rarfile/blob/master/rarfile.py"""
128133
with open(path, "rb") as f:
129134
magic_number = f.read(len(cls.RAR5_ID))
130135
return magic_number == cls.RAR5_ID or magic_number.startswith(cls.RAR_ID)
131136

132137
@staticmethod
133-
def extract(input_path: str, output_path: str) -> None:
138+
def extract(input_path: Union["pathlib.Path", str], output_path: Union["pathlib.Path", str]) -> None:
134139
if not config.RARFILE_AVAILABLE:
135140
raise OSError("Please pip install rarfile")
136141
import rarfile
@@ -145,7 +150,7 @@ class ZstdExtractor(MagicNumberBaseExtractor):
145150
magic_number = b"\x28\xb5\x2F\xFD"
146151

147152
@staticmethod
148-
def extract(input_path: str, output_path: str) -> None:
153+
def extract(input_path: Union["pathlib.Path", str], output_path: Union["pathlib.Path", str]) -> None:
149154
if not config.ZSTANDARD_AVAILABLE:
150155
raise OSError("Please pip install zstandard")
151156
import zstandard as zstd
@@ -159,7 +164,7 @@ class Bzip2Extractor(MagicNumberBaseExtractor):
159164
magic_number = b"\x42\x5A\x68"
160165

161166
@staticmethod
162-
def extract(input_path: str, output_path: str) -> None:
167+
def extract(input_path: Union["pathlib.Path", str], output_path: Union["pathlib.Path", str]) -> None:
163168
with bz2.open(input_path, "rb") as compressed_file:
164169
with open(output_path, "wb") as extracted_file:
165170
shutil.copyfileobj(compressed_file, extracted_file)
@@ -169,7 +174,7 @@ class SevenZipExtractor(MagicNumberBaseExtractor):
169174
magic_number = b"\x37\x7A\xBC\xAF\x27\x1C"
170175

171176
@staticmethod
172-
def extract(input_path: str, output_path: str) -> None:
177+
def extract(input_path: Union["pathlib.Path", str], output_path: Union["pathlib.Path", str]) -> None:
173178
if not config.PY7ZR_AVAILABLE:
174179
raise OSError("Please pip install py7zr")
175180
import py7zr
@@ -183,7 +188,7 @@ class Lz4Extractor(MagicNumberBaseExtractor):
183188
magic_number = b"\x04\x22\x4D\x18"
184189

185190
@staticmethod
186-
def extract(input_path: str, output_path: str) -> None:
191+
def extract(input_path: Union["pathlib.Path", str], output_path: Union["pathlib.Path", str]) -> None:
187192
if not config.LZ4_AVAILABLE:
188193
raise OSError("Please pip install lz4")
189194
import lz4.frame
@@ -219,14 +224,14 @@ def _get_magic_number_max_length(cls):
219224
return magic_number_max_length
220225

221226
@staticmethod
222-
def _read_magic_number(path: str, magic_number_length: int):
227+
def _read_magic_number(path: Union["pathlib.Path", str], magic_number_length: int):
223228
try:
224229
return MagicNumberBaseExtractor.read_magic_number(path, magic_number_length=magic_number_length)
225230
except OSError:
226231
return b""
227232

228233
@classmethod
229-
def is_extractable(cls, path, return_extractor=False):
234+
def is_extractable(cls, path: Union["pathlib.Path", str], return_extractor: bool = False) -> bool:
230235
warnings.warn(
231236
"Method 'is_extractable' was deprecated in version 2.4.0 and will be removed in 3.0.0. "
232237
"Use 'infer_extractor_format' instead.",
@@ -238,17 +243,23 @@ def is_extractable(cls, path, return_extractor=False):
238243
return False if not return_extractor else (False, None)
239244

240245
@classmethod
241-
def infer_extractor_format(cls, path):
246+
def infer_extractor_format(cls, path: Union["pathlib.Path", str]) -> str:
242247
magic_number_max_length = cls._get_magic_number_max_length()
243248
magic_number = cls._read_magic_number(path, magic_number_max_length)
244249
for extractor_format, extractor in cls.extractors.items():
245250
if extractor.is_extractable(path, magic_number=magic_number):
246251
return extractor_format
247252

248253
@classmethod
249-
def extract(cls, input_path, output_path, extractor_format=None, extractor="deprecated"):
254+
def extract(
255+
cls,
256+
input_path: Union["pathlib.Path", str],
257+
output_path: Union["pathlib.Path", str],
258+
extractor_format: Optional[str] = None,
259+
extractor: Optional[BaseExtractor] = "deprecated",
260+
) -> None:
250261
# Prevent parallel extractions
251-
lock_path = input_path + ".lock"
262+
lock_path = str(input_path) + ".lock"
252263
with FileLock(lock_path):
253264
shutil.rmtree(output_path, ignore_errors=True)
254265
os.makedirs(os.path.dirname(output_path), exist_ok=True)

tests/test_download_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_download_manager_download(urls_type, tmp_path, monkeypatch):
4242
urls = {"train": url}
4343
dataset_name = "dummy"
4444
cache_subdir = "downloads"
45-
cache_dir_root = str(tmp_path)
45+
cache_dir_root = tmp_path
4646
download_config = DownloadConfig(
4747
cache_dir=os.path.join(cache_dir_root, cache_subdir),
4848
use_etag=False,

tests/test_extract.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ def test_extractor(
123123
elif compression_format == "zstd":
124124
reason += require_zstandard.kwargs["reason"]
125125
pytest.skip(reason)
126-
input_path = str(input_path)
127126
extractor_format = Extractor.infer_extractor_format(input_path)
128127
assert extractor_format is not None
129128
output_path = tmp_path / ("extracted" if is_archive else "extracted.txt")

tests/test_file_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def zstd_path(tmp_path_factory):
2626
@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])
2727
def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file):
2828
input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path}
29-
input_path = str(input_paths[compression_format])
29+
input_path = input_paths[compression_format]
3030
cache_dir = tmp_path / "cache"
3131
download_config = DownloadConfig(cache_dir=cache_dir, extract_compressed_file=True)
3232
extracted_path = cached_path(input_path, download_config=download_config)

tests/test_filesystem.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def test_compression_filesystems(compression_fs_class, gz_file, bz2_file, lz4_fi
7171
elif compression_fs_class.protocol == "zstd":
7272
reason += require_zstandard.kwargs["reason"]
7373
pytest.skip(reason)
74-
input_path = str(input_path)
7574
fs = fsspec.filesystem(compression_fs_class.protocol, fo=input_path)
7675
assert isinstance(fs, compression_fs_class)
7776
expected_filename = os.path.basename(input_path)

tests/test_streaming_download_manager.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,6 @@ def test_streaming_dl_manager_extract_all_supported_single_file_compression_type
699699
elif compression_fs_class.protocol == "zstd":
700700
reason += require_zstandard.kwargs["reason"]
701701
pytest.skip(reason)
702-
input_path = str(input_path)
703702
dl_manager = StreamingDownloadManager()
704703
output_path = dl_manager.extract(input_path)
705704
path = os.path.basename(input_path)
@@ -791,7 +790,7 @@ def _test_jsonl(path, file):
791790

792791
def test_iter_archive_path(tar_jsonl_path):
793792
dl_manager = StreamingDownloadManager()
794-
archive_iterable = dl_manager.iter_archive(str(tar_jsonl_path))
793+
archive_iterable = dl_manager.iter_archive(tar_jsonl_path)
795794
num_jsonl = 0
796795
for num_jsonl, (path, file) in enumerate(archive_iterable, start=1):
797796
_test_jsonl(path, file)
@@ -805,7 +804,7 @@ def test_iter_archive_path(tar_jsonl_path):
805804

806805
def test_iter_archive_file(tar_nested_jsonl_path):
807806
dl_manager = StreamingDownloadManager()
808-
files_iterable = dl_manager.iter_archive(str(tar_nested_jsonl_path))
807+
files_iterable = dl_manager.iter_archive(tar_nested_jsonl_path)
809808
num_tar, num_jsonl = 0, 0
810809
for num_tar, (path, file) in enumerate(files_iterable, start=1):
811810
for num_jsonl, (subpath, subfile) in enumerate(dl_manager.iter_archive(file), start=1):

0 commit comments

Comments
 (0)