Skip to content
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 98 additions & 1 deletion src/huggingface_hub/hf_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from urllib.parse import quote, unquote

import fsspec
from requests import Response

from ._commit_api import CommitOperationCopy, CommitOperationDelete
from .constants import DEFAULT_REVISION, ENDPOINT, REPO_TYPE_MODEL, REPO_TYPES_MAPPING, REPO_TYPES_URL_PREFIXES
Expand Down Expand Up @@ -216,11 +217,15 @@ def _open(
path: str,
mode: str = "rb",
revision: Optional[str] = None,
block_size: Optional[int] = None,
**kwargs,
) -> "HfFileSystemFile":
if "a" in mode:
raise NotImplementedError("Appending to remote files is not yet supported.")
return HfFileSystemFile(self, path, mode=mode, revision=revision, **kwargs)
if block_size == 0:
return HfFileSystemStreamFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs)
else:
return HfFileSystemFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs)

def _rm(self, path: str, revision: Optional[str] = None, **kwargs) -> None:
resolved_path = self.resolve_path(path, revision=revision)
Expand Down Expand Up @@ -649,6 +654,94 @@ def _upload_chunk(self, final: bool = False) -> None:
)


class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
def __init__(
self,
fs: HfFileSystem,
path: str,
mode: str = "rb",
revision: Optional[str] = None,
block_size: int = 0,
cache_type: str = "none",
**kwargs,
):
if block_size != 0:
raise ValueError(f"HfFileSystemStreamFile only supports block_size=0 but got {block_size}")
if cache_type != "none":
raise ValueError(f"HfFileSystemStreamFile only supports cache_type='none' but got {cache_type}")
if "w" in mode:
raise ValueError(f"HfFileSystemStreamFile only supports reading but got mode='{mode}'")
try:
self.resolved_path = fs.resolve_path(path, revision=revision)
except FileNotFoundError as e:
if "w" in kwargs.get("mode", ""):
raise FileNotFoundError(
f"{e}.\nMake sure the repository and revision exist before writing data."
) from e
# avoid an unecessary .info() call to instantiate .details
self.details = {"name": self.resolved_path.unresolve(), "size": None}
super().__init__(
fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs
)
self.response: Optional[Response] = None
self.fs: HfFileSystem

def seek(self, loc: int, whence: int = 0):
if loc == 0 and whence == 1:
return
if loc == self.loc and whence == 0:
return
raise ValueError("Cannot seek streaming HF file")

def read(self, length: int = -1):
if self.response is None or self.response.raw.isclosed():
url = hf_hub_url(
repo_id=self.resolved_path.repo_id,
revision=self.resolved_path.revision,
filename=self.resolved_path.path_in_repo,
repo_type=self.resolved_path.repo_type,
endpoint=self.fs.endpoint,
)
self.response = http_backoff(
"GET",
url,
headers=self.fs._api._build_hf_headers(),
retry_on_status_codes=(502, 503, 504),
stream=True,
)
hf_raise_for_status(self.response)
try:
out = self.response.raw.read(length)
except Exception:
self.response.close()

# Retry by recreating the connection
self.response = http_backoff(
"GET",
url,
headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()},
retry_on_status_codes=(502, 503, 504),
stream=True,
)
hf_raise_for_status(self.response)
try:
out = self.response.raw.read(length)
except Exception:
self.response.close()
raise
self.loc += len(out)
return out

def __del__(self):
if not hasattr(self, "resolved_path"):
# Means that the constructor failed. Nothing to do.
return
return super().__del__()

def __reduce__(self):
return reopen, (self.fs, self.path, self.mode, self.blocksize, self.cache.name)


def safe_revision(revision: str) -> str:
return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision)

Expand All @@ -666,3 +759,7 @@ def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn:
elif isinstance(err, HFValidationError):
msg = f"{path} (invalid repository id)"
raise FileNotFoundError(msg) from err


def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type: str):
return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type)