From 4f2e0e1e235924488ff8c1e9286251e0bab95a66 Mon Sep 17 00:00:00 2001 From: Aditya2755 Date: Sat, 29 Nov 2025 11:47:11 +0530 Subject: [PATCH] Add type overloads to load_dataset for better static type inference Fixes #7883 This PR adds @overload decorators to load_dataset() to help type checkers like Pylance and mypy correctly infer the return type based on the split and streaming parameters. Changes: - Added typing imports (Literal, overload) to load.py - Added 4 @overload signatures that map argument combinations to specific return types: * split=None, streaming=False -> DatasetDict * split specified, streaming=False -> Dataset * split=None, streaming=True -> IterableDatasetDict * split specified, streaming=True -> IterableDataset This resolves the Pylance error where to_csv() was not recognized on Dataset objects returned by load_dataset(..., split='train'), since the type checker previously saw the return type as a Union that included types without to_csv(). No runtime behavior changes - this is purely a static typing improvement. --- src/datasets/load.py | 98 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 96 insertions(+), 2 deletions(-) diff --git a/src/datasets/load.py b/src/datasets/load.py index ae3b9825970..0398c4cad8a 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -25,8 +25,7 @@ from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Optional, Union - +from typing import Any, Optional, Union, Literal, overload import fsspec import httpx import requests @@ -1187,6 +1186,101 @@ def load_dataset_builder( return builder_instance +@overload +def load_dataset( + path: str, + name: Optional[str] = None, + data_dir: Optional[str] = None, + data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None, + split: None = None, + cache_dir: Optional[str] = None, + features: Optional[Features] = None, + download_config: Optional[DownloadConfig] = None, + download_mode: Optional[Union[DownloadMode, str]] = None, + verification_mode: Optional[Union[VerificationMode, str]] = None, + keep_in_memory: Optional[bool] = None, + save_infos: bool = False, + revision: Optional[Union[str, Version]] = None, + token: Optional[Union[bool, str]] = None, + streaming: Literal[False] = False, + num_proc: Optional[int] = None, + storage_options: Optional[dict] = None, + **config_kwargs: Any, +) -> DatasetDict: ... + + +@overload +def load_dataset( + path: str, + name: Optional[str] = None, + data_dir: Optional[str] = None, + data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None, + *, + split: Union[str, Split, list[str], list[Split]], + cache_dir: Optional[str] = None, + features: Optional[Features] = None, + download_config: Optional[DownloadConfig] = None, + download_mode: Optional[Union[DownloadMode, str]] = None, + verification_mode: Optional[Union[VerificationMode, str]] = None, + keep_in_memory: Optional[bool] = None, + save_infos: bool = False, + revision: Optional[Union[Version, str]] = None, + token: Optional[Union[bool, str]] = None, + streaming: Literal[False] = False, + num_proc: Optional[int] = None, + storage_options: Optional[dict] = None, + **config_kwargs: Any, +) -> Dataset: ... + + +@overload +def load_dataset( + path: str, + name: Optional[str] = None, + data_dir: Optional[str] = None, + data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None, + split: None = None, + cache_dir: Optional[str] = None, + features: Optional[Features] = None, + download_config: Optional[DownloadConfig] = None, + download_mode: Optional[Union[DownloadMode, str]] = None, + verification_mode: Optional[Union[VerificationMode, str]] = None, + keep_in_memory: Optional[bool] = None, + save_infos: bool = False, + revision: Optional[Union[Version, str]] = None, + token: Optional[Union[bool, str]] = None, + *, + streaming: Literal[True], + num_proc: Optional[int] = None, + storage_options: Optional[dict] = None, + **config_kwargs: Any, +) -> IterableDatasetDict: ... + + +@overload +def load_dataset( + path: str, + name: Optional[str] = None, + data_dir: Optional[str] = None, + data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None, + *, + split: Union[str, Split, list[str], list[Split]], + cache_dir: Optional[str] = None, + features: Optional[Features] = None, + download_config: Optional[DownloadConfig] = None, + download_mode: Optional[Union[DownloadMode, str]] = None, + verification_mode: Optional[Union[VerificationMode, str]] = None, + keep_in_memory: Optional[bool] = None, + save_infos: bool = False, + revision: Optional[Union[Version, str]] = None, + token: Optional[Union[bool, str]] = None, + streaming: Literal[True], + num_proc: Optional[int] = None, + storage_options: Optional[dict] = None, + **config_kwargs: Any, +) -> IterableDataset: ... + + def load_dataset( path: str, name: Optional[str] = None,