Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .utils.file_utils import DownloadConfig, is_remote_url
from .utils.filelock import FileLock
from .utils.info_utils import get_size_checksum_dict, verify_checksums, verify_splits
from .utils.patching import extend_module_for_deleting_file_on_close


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -1124,10 +1125,12 @@ def _prepare_split(self, split_generator):

generator = self._generate_tables(**split_generator.gen_kwargs)
with ArrowWriter(features=self.info.features, path=fpath) as writer:
for key, table in utils.tqdm(
generator, unit=" tables", leave=False, disable=bool(logging.get_verbosity() == logging.NOTSET)
):
writer.write_table(table)
with extend_module_for_deleting_file_on_close(self.__module__):
for key, table in utils.tqdm(
generator, unit=" tables", leave=False, disable=bool(logging.get_verbosity() == logging.NOTSET)
):

writer.write_table(table)
num_examples, num_bytes = writer.finalize()

split_generator.split_info.num_examples = num_examples
Expand Down
64 changes: 1 addition & 63 deletions src/datasets/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,75 +3,13 @@
from typing import Optional, Union

from .utils.logging import get_logger
from .utils.patching import patch_submodule
from .utils.streaming_download_manager import xjoin, xopen


logger = get_logger(__name__)


class _PatchedModuleObj:
"""Set all the modules components as attributes of the _PatchedModuleObj object"""

def __init__(self, module):
if module is not None:
for key in getattr(module, "__all__", module.__dict__):
if not key.startswith("__"):
setattr(self, key, getattr(module, key))


class patch_submodule:
"""
Patch a submodule attribute of an object, by keeping all other submodules intact at all levels.

Example::

>>> import importlib
>>> from datasets.load import prepare_module
>>> from datasets.streaming import patch_submodule, xjoin
>>>
>>> snli_module_path, _ = prepare_module("snli")
>>> snli_module = importlib.import_module(snli_module_path)
>>> patcher = patch_submodule(snli_module, "os.path.join", xjoin)
>>> patcher.start()
>>> assert snli_module.os.path.join is xjoin
"""

_active_patches = []

def __init__(self, obj, target: str, new):
self.obj = obj
self.target = target
self.new = new
self.key = target.split(".")[0]
self.original = getattr(obj, self.key, None)

def __enter__(self):
*submodules, attr = self.target.split(".")
current = self.obj
for key in submodules:
setattr(current, key, _PatchedModuleObj(getattr(current, key, None)))
current = getattr(current, key)
setattr(current, attr, self.new)

def __exit__(self, *exc_info):
setattr(self.obj, self.key, self.original)

def start(self):
"""Activate a patch."""
self.__enter__()
self._active_patches.append(self)

def stop(self):
"""Stop an active patch."""
try:
self._active_patches.remove(self)
except ValueError:
# If the patch hasn't been started this will fail
return None

return self.__exit__()


def extend_module_for_streaming(module_path, use_auth_token: Optional[Union[str, bool]] = None):
"""
Extend the `open` and `os.path.join` functions of the module to support data streaming.
Expand Down
91 changes: 91 additions & 0 deletions src/datasets/utils/patching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import importlib
import os
from contextlib import contextmanager

from .logging import get_logger


logger = get_logger(__name__)


class _PatchedModuleObj:
"""Set all the modules components as attributes of the _PatchedModuleObj object"""

def __init__(self, module):
if module is not None:
for key in getattr(module, "__all__", module.__dict__):
if not key.startswith("__"):
setattr(self, key, getattr(module, key))


class patch_submodule:
"""
Patch a submodule attribute of an object, by keeping all other submodules intact at all levels.

Example::

>>> import importlib
>>> from datasets.load import prepare_module
>>> from datasets.streaming import patch_submodule, xjoin
>>>
>>> snli_module_path, _ = prepare_module("snli")
>>> snli_module = importlib.import_module(snli_module_path)
>>> patcher = patch_submodule(snli_module, "os.path.join", xjoin)
>>> patcher.start()
>>> assert snli_module.os.path.join is xjoin
"""

_active_patches = []

def __init__(self, obj, target: str, new):
self.obj = obj
self.target = target
self.new = new
self.key = target.split(".")[0]
self.original = getattr(obj, self.key, None)

def __enter__(self):
*submodules, attr = self.target.split(".")
current = self.obj
for key in submodules:
setattr(current, key, _PatchedModuleObj(getattr(current, key, None)))
current = getattr(current, key)
setattr(current, attr, self.new)

def __exit__(self, *exc_info):
setattr(self.obj, self.key, self.original)

def start(self):
"""Activate a patch."""
self.__enter__()
self._active_patches.append(self)

def stop(self):
"""Stop an active patch."""
try:
self._active_patches.remove(self)
except ValueError:
# If the patch hasn't been started this will fail
return None

return self.__exit__()


@contextmanager
def extend_module_for_deleting_file_on_close(module_path):
@contextmanager
def open_and_delete_on_close(file, *args, **kwargs):
f = open(file, *args, **kwargs)
try:
yield f
finally:
f.close()
os.remove(file)

module = importlib.import_module(module_path)
patch = patch_submodule(module, "open", open_and_delete_on_close)
patch.start()
try:
yield patch
finally:
patch.stop()
8 changes: 8 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,11 @@ def test_load_from_disk_with_default_in_memory(

with assert_arrow_memory_increases() if expected_in_memory else assert_arrow_memory_doesnt_increase():
_ = load_from_disk(dataset_path)


def test_load_dataset_deletes_extracted_files(jsonl_gz_path, tmp_path):
data_files = jsonl_gz_path
cache_dir = tmp_path / "cache"
ds = load_dataset("json", split="train", data_files=data_files, cache_dir=cache_dir)
assert ds[0] == {"col_1": "0", "col_2": 0, "col_3": 0.0}
assert sorted((cache_dir / "downloads" / "extracted").iterdir()) == []