Skip to content

[Optimization] Prevent per-thread instantiation of Cloud Storage FileSystem during Data loading initialization #8149

@ankitaluthra1

Description

@ankitaluthra1

Feature request

Modify the dataset loading initialization, so that fsspec filesystem instances (like GCSFileSystem or S3FileSystem) are instantiated once in the main thread and explicitly passed down to the background threads.

By pre-instantiating the fs object in the main thread (where the directory cache from glob is still hot) and passing it to the worker threads, we can share a single cached filesystem instance across the entire thread pool.

Motivation

Currently, when resolving data files with wildcards, the library spins up a ThreadPoolExecutor to fetch metadata concurrently (code ref). This calls _get_single_origin_metadata, which instantiates a new file system instance using url_to_fs(data_file, **storage_options). This can be optimised to use single filesystem by creating the file system instance before hand and passing it to _get_single_origin_metadata

Your contribution

Create shared file system instance

def _get_origin_metadata(
    data_files: list[str],
    download_config: Optional[DownloadConfig] = None,
    max_workers: Optional[int] = None,
) -> list[SingleOriginMetadata]:
    max_workers = max_workers if max_workers is not None else config.HF_DATASETS_MULTITHREADING_MAX_WORKERS
    
    if all("hf://" in data_file for data_file in data_files):
        # No need for multithreading here since the origin metadata of HF files
        # is (repo_id, revision) and is cached after first .info() call.
        return [
            _get_single_origin_metadata(data_file, download_config=download_config)
            for data_file in hf_tqdm(
                data_files,
                desc="Resolving data files",
                disable=len(data_files) <= 16 or None,
            )
        ]
        
    # --- NEW: Pre-instantiate the filesystem in the main thread ---
    # Assuming the batch of data_files shares the same storage backend
    shared_fs = None
    if data_files:
        sample_file, storage_options = _prepare_path_and_storage_options(data_files[0], download_config=download_config)
        shared_fs, _ = url_to_fs(sample_file, **storage_options)
    # --------------------------------------------------------------

    return thread_map(
        # --- NEW: Pass the shared_fs down to the worker threads ---
        partial(_get_single_origin_metadata, download_config=download_config, fs=shared_fs),
        data_files,
        max_workers=max_workers,
        tqdm_class=hf_tqdm,
        desc="Resolving data files",
        disable=len(data_files) <= 16 or None,
    )

Use shared filesystem instance

def _get_single_origin_metadata(
    data_file: str,
    download_config: Optional[DownloadConfig] = None,
    fs=None,  # <-- NEW: Accept a pre-instantiated fs
) -> SingleOriginMetadata:
    
    if data_file.startswith(config.HF_ENDPOINT):
        if fs is None:
            fs = HfFileSystem(endpoint=config.HF_ENDPOINT, token=download_config.token)
        data_file = "hf://" + data_file[len(config.HF_ENDPOINT) + 1 :]
        data_file = data_file.replace("/resolve/", "/" if data_file.startswith("hf://buckets/") else "@", 1)
        fs_path = data_file
    else:
        data_file, storage_options = _prepare_path_and_storage_options(data_file, download_config=download_config)
        if fs is None:
            fs, fs_path = url_to_fs(data_file, **storage_options)
        else:
            # If fs is already provided, we only need to extract the path
            _, fs_path = url_to_fs(data_file, **storage_options)
            
    if isinstance(fs, HfFileSystem):
        resolved_path = fs.resolve_path(fs_path)
        if hasattr(resolved_path, "revision"):  # no revision for buckets
            return resolved_path.repo_id, resolved_path.revision
            
    info = fs.info(fs_path)
    # This correctly handles metadata across different filesystems natively:
    # s3fs uses "ETag", gcsfs uses "etag", and for local we simply check mtime
    for key in ["ETag", "etag", "mtime"]:
        if key in info:
            return (str(info[key]),)
    return ()

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions