diff --git a/docs/source/_redirects.yml b/docs/source/_redirects.yml index be8e3196c28..50373bfea2f 100644 --- a/docs/source/_redirects.yml +++ b/docs/source/_redirects.yml @@ -10,4 +10,5 @@ faiss_and_ea: faiss_es features: about_dataset_features using_metrics: how_to_metrics exploring: access +package_reference/logging_methods: package_reference/utilities # end of first_section diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index fa137faa9c6..e3155286c20 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -121,8 +121,8 @@ title: Loading methods - local: package_reference/table_classes title: Table Classes - - local: package_reference/logging_methods - title: Logging methods + - local: package_reference/utilities + title: Utilities - local: package_reference/task_templates title: Task templates title: "Reference" diff --git a/docs/source/package_reference/logging_methods.mdx b/docs/source/package_reference/utilities.mdx similarity index 69% rename from docs/source/package_reference/logging_methods.mdx rename to docs/source/package_reference/utilities.mdx index 28fc533afb6..59668dbc120 100644 --- a/docs/source/package_reference/logging_methods.mdx +++ b/docs/source/package_reference/utilities.mdx @@ -1,4 +1,6 @@ -# Logging methods +# Utilities + +## Configure logging 🤗 Datasets strives to be transparent and explicit about how it works, but this can be quite verbose at times. We have included a series of logging methods which allow you to easily adjust the level of verbosity of the entire library. Currently the default verbosity of the library is set to `WARNING`. @@ -28,10 +30,6 @@ In order from the least to the most verbose (with their corresponding `int` valu 4. `logging.INFO` (int value, 20): reports error, warnings and basic information. 5. `logging.DEBUG` (int value, 10): report all information. -By default, `tqdm` progress bars will be displayed during dataset download and preprocessing. [`logging.disable_progress_bar`] and [`logging.enable_progress_bar`] can be used to suppress or unsuppress this behavior. - -## Functions - [[autodoc]] datasets.logging.get_verbosity [[autodoc]] datasets.logging.set_verbosity @@ -48,44 +46,13 @@ By default, `tqdm` progress bars will be displayed during dataset download and p [[autodoc]] datasets.logging.enable_propagation -[[autodoc]] datasets.logging.get_logger - -[[autodoc]] datasets.logging.enable_progress_bar - -[[autodoc]] datasets.logging.disable_progress_bar - -[[autodoc]] datasets.is_progress_bar_enabled - -## Levels - -### datasets.logging.CRITICAL - -datasets.logging.CRITICAL = 50 - -### datasets.logging.DEBUG - -datasets.logging.DEBUG = 10 - -### datasets.logging.ERROR - -datasets.logging.ERROR = 40 - -### datasets.logging.FATAL - -datasets.logging.FATAL = 50 - -### datasets.logging.INFO - -datasets.logging.INFO = 20 - -### datasets.logging.NOTSET - -datasets.logging.NOTSET = 0 +## Configure progress bars -### datasets.logging.WARN +By default, `tqdm` progress bars will be displayed during dataset download and preprocessing. You can disable them globally by setting `HF_DATASETS_DISABLE_PROGRESS_BARS` +environment variable. You can also enable/disable them using [`~utils.enable_progress_bars`] and [`~utils.disable_progress_bars`]. If set, the environment variable has priority on the helpers. -datasets.logging.WARN = 30 +[[autodoc]] datasets.utils.enable_progress_bars -### datasets.logging.WARNING +[[autodoc]] datasets.utils.disable_progress_bars -datasets.logging.WARNING = 30 +[[autodoc]] datasets.utils.are_progress_bars_disabled \ No newline at end of file diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index ffd57c3a8ca..472a09c9a6e 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -111,6 +111,7 @@ ) from .tasks import TaskTemplate from .utils import logging +from .utils import tqdm as hf_tqdm from .utils.deprecation_utils import deprecated from .utils.file_utils import _retry, estimate_dataset_size from .utils.info_utils import is_small_dataset @@ -1494,8 +1495,7 @@ def save_to_disk( dataset_info = asdict(self._info) shards_done = 0 - pbar = logging.tqdm( - disable=not logging.is_progress_bar_enabled(), + pbar = hf_tqdm( unit=" examples", total=len(self), desc=f"Saving the dataset ({shards_done}/{num_shards} shards)", @@ -3080,8 +3080,7 @@ def load_processed_shard_from_cache(shard_kwargs): except NonExistentDatasetError: pass if transformed_dataset is None: - with logging.tqdm( - disable=not logging.is_progress_bar_enabled(), + with hf_tqdm( unit=" examples", total=pbar_total, desc=desc or "Map", @@ -3173,8 +3172,7 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str: with Pool(len(kwargs_per_job)) as pool: os.environ = prev_env logger.info(f"Spawning {num_proc} processes") - with logging.tqdm( - disable=not logging.is_progress_bar_enabled(), + with hf_tqdm( unit=" examples", total=pbar_total, desc=(desc or "Map") + f" (num_proc={num_proc})", @@ -5195,11 +5193,10 @@ def shards_with_embedded_external_files(shards): uploaded_size = 0 additions = [] - for index, shard in logging.tqdm( + for index, shard in hf_tqdm( enumerate(shards), desc="Uploading the dataset shards", total=num_shards, - disable=not logging.is_progress_bar_enabled(), ): shard_path_in_repo = f"{data_dir}/{split}-{index:05d}-of-{num_shards:05d}.parquet" buffer = BytesIO() diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index bc6c869bd10..5d6d8141f6d 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -41,6 +41,7 @@ from .keyhash import DuplicatedKeysError, KeyHasher from .table import array_cast, array_concat, cast_array_to_feature, embed_table_storage, table_cast from .utils import logging +from .utils import tqdm as hf_tqdm from .utils.file_utils import hash_url_to_filename from .utils.py_utils import asdict, first_non_null_value @@ -689,9 +690,8 @@ def finalize(self, metrics_query_result: dict): for metadata in beam.io.filesystems.FileSystems.match([parquet_path + "*.parquet"])[0].metadata_list ] try: # stream conversion - disable = not logging.is_progress_bar_enabled() num_bytes = 0 - for shard in logging.tqdm(shards, unit="shards", disable=disable): + for shard in hf_tqdm(shards, unit="shards"): with beam.io.filesystems.FileSystems.open(shard) as source: with beam.io.filesystems.FileSystems.create( shard.replace(".parquet", ".arrow") @@ -706,9 +706,8 @@ def finalize(self, metrics_query_result: dict): ) local_convert_dir = os.path.join(self._cache_dir, "beam_convert") os.makedirs(local_convert_dir, exist_ok=True) - disable = not logging.is_progress_bar_enabled() num_bytes = 0 - for shard in logging.tqdm(shards, unit="shards", disable=disable): + for shard in hf_tqdm(shards, unit="shards"): local_parquet_path = os.path.join(local_convert_dir, hash_url_to_filename(shard) + ".parquet") beam_utils.download_remote_to_local(shard, local_parquet_path) local_arrow_path = local_parquet_path.replace(".parquet", ".arrow") @@ -727,8 +726,7 @@ def finalize(self, metrics_query_result: dict): def get_parquet_lengths(sources) -> List[int]: shard_lengths = [] - disable = not logging.is_progress_bar_enabled() - for source in logging.tqdm(sources, unit="parquet files", disable=disable): + for source in hf_tqdm(sources, unit="parquet files"): parquet_file = pa.parquet.ParquetFile(source) shard_lengths.append(parquet_file.metadata.num_rows) return shard_lengths diff --git a/src/datasets/builder.py b/src/datasets/builder.py index b4df8f44ade..b58036d80e0 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -65,6 +65,7 @@ from .splits import Split, SplitDict, SplitGenerator, SplitInfo from .streaming import extend_dataset_builder_for_streaming from .utils import logging +from .utils import tqdm as hf_tqdm from .utils.file_utils import cached_path, is_remote_url from .utils.filelock import FileLock from .utils.info_utils import VerificationMode, get_size_checksum_dict, verify_checksums, verify_splits @@ -1526,8 +1527,7 @@ def _prepare_split( ) num_proc = num_input_shards - pbar = logging.tqdm( - disable=not logging.is_progress_bar_enabled(), + pbar = hf_tqdm( unit=" examples", total=split_info.num_examples, desc=f"Generating {split_info.name} split", @@ -1784,8 +1784,7 @@ def _prepare_split( ) num_proc = num_input_shards - pbar = logging.tqdm( - disable=not logging.is_progress_bar_enabled(), + pbar = hf_tqdm( unit=" examples", total=split_info.num_examples, desc=f"Generating {split_info.name} split", diff --git a/src/datasets/config.py b/src/datasets/config.py index 3e1f20475ac..4b6cca1e696 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -1,15 +1,15 @@ import importlib import importlib.metadata +import logging import os import platform from pathlib import Path +from typing import Optional from packaging import version -from .utils.logging import get_logger - -logger = get_logger(__name__) +logger = logging.getLogger(__name__.split(".", 1)[0]) # to avoid circular import from .utils.logging # Datasets S3_DATASETS_BUCKET_PREFIX = "https://s3.amazonaws.com/datasets.huggingface.co/datasets/datasets" @@ -192,6 +192,18 @@ # Offline mode HF_DATASETS_OFFLINE = os.environ.get("HF_DATASETS_OFFLINE", "AUTO").upper() in ENV_VARS_TRUE_VALUES +# Here, `True` will disable progress bars globally without possibility of enabling it +# programmatically. `False` will enable them without possibility of disabling them. +# If environment variable is not set (None), then the user is free to enable/disable +# them programmatically. +# TL;DR: env variable has priority over code +__HF_DATASETS_DISABLE_PROGRESS_BARS = os.environ.get("HF_DATASETS_DISABLE_PROGRESS_BARS") +HF_DATASETS_DISABLE_PROGRESS_BARS: Optional[bool] = ( + __HF_DATASETS_DISABLE_PROGRESS_BARS.upper() in ENV_VARS_TRUE_VALUES + if __HF_DATASETS_DISABLE_PROGRESS_BARS is not None + else None +) + # In-memory DEFAULT_IN_MEMORY_MAX_SIZE = 0 # Disabled IN_MEMORY_MAX_SIZE = float(os.environ.get("HF_DATASETS_IN_MEMORY_MAX_SIZE", DEFAULT_IN_MEMORY_MAX_SIZE)) diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index 9eb1c499299..f318e893738 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -17,6 +17,7 @@ from .download.streaming_download_manager import _prepare_path_and_storage_options, xbasename, xjoin from .splits import Split from .utils import logging +from .utils import tqdm as hf_tqdm from .utils.file_utils import is_local_path, is_relative_path from .utils.py_utils import glob_pattern_to_regex, string_to_dict @@ -515,9 +516,9 @@ def _get_origin_metadata( partial(_get_single_origin_metadata, download_config=download_config), data_files, max_workers=max_workers, - tqdm_class=logging.tqdm, + tqdm_class=hf_tqdm, desc="Resolving data files", - disable=len(data_files) <= 16 or not logging.is_progress_bar_enabled(), + disable=len(data_files) <= 16, ) diff --git a/src/datasets/download/download_manager.py b/src/datasets/download/download_manager.py index 889f3fc5d47..d4fe3c9fafa 100644 --- a/src/datasets/download/download_manager.py +++ b/src/datasets/download/download_manager.py @@ -28,10 +28,11 @@ from typing import Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union from .. import config +from ..utils import tqdm as hf_tqdm from ..utils.deprecation_utils import DeprecatedEnum, deprecated from ..utils.file_utils import cached_path, get_from_cache, hash_url_to_filename, is_relative_path, url_or_path_join from ..utils.info_utils import get_size_checksum_dict -from ..utils.logging import get_logger, is_progress_bar_enabled, tqdm +from ..utils.logging import get_logger from ..utils.py_utils import NestedDataStructure, map_nested, size_str from .download_config import DownloadConfig @@ -327,18 +328,16 @@ def upload(local_file_path): uploaded_path_or_paths = map_nested( lambda local_file_path: upload(local_file_path), downloaded_path_or_paths, - disable_tqdm=not is_progress_bar_enabled(), ) return uploaded_path_or_paths def _record_sizes_checksums(self, url_or_urls: NestedDataStructure, downloaded_path_or_paths: NestedDataStructure): """Record size/checksum of downloaded files.""" delay = 5 - for url, path in tqdm( + for url, path in hf_tqdm( list(zip(url_or_urls.flatten(), downloaded_path_or_paths.flatten())), delay=delay, desc="Computing checksums", - disable=not is_progress_bar_enabled(), ): # call str to support PathLike objects self._recorded_sizes_checksums[str(url)] = get_size_checksum_dict( @@ -373,9 +372,7 @@ def download_custom(self, url_or_urls, custom_download): def url_to_downloaded_path(url): return os.path.join(cache_dir, hash_url_to_filename(url)) - downloaded_path_or_paths = map_nested( - url_to_downloaded_path, url_or_urls, disable_tqdm=not is_progress_bar_enabled() - ) + downloaded_path_or_paths = map_nested(url_to_downloaded_path, url_or_urls) url_or_urls = NestedDataStructure(url_or_urls) downloaded_path_or_paths = NestedDataStructure(downloaded_path_or_paths) for url, path in zip(url_or_urls.flatten(), downloaded_path_or_paths.flatten()): @@ -426,7 +423,6 @@ def download(self, url_or_urls): url_or_urls, map_tuple=True, num_proc=download_config.num_proc, - disable_tqdm=not is_progress_bar_enabled(), desc="Downloading data files", ) duration = datetime.now() - start_time @@ -534,7 +530,6 @@ def extract(self, path_or_paths, num_proc="deprecated"): partial(cached_path, download_config=download_config), path_or_paths, num_proc=download_config.num_proc, - disable_tqdm=not is_progress_bar_enabled(), desc="Extracting data files", ) path_or_paths = NestedDataStructure(path_or_paths) diff --git a/src/datasets/io/csv.py b/src/datasets/io/csv.py index e052ee101e4..f5091e1352e 100644 --- a/src/datasets/io/csv.py +++ b/src/datasets/io/csv.py @@ -5,7 +5,7 @@ from .. import Dataset, Features, NamedSplit, config from ..formatting import query_table from ..packaged_modules.csv.csv import Csv -from ..utils import logging +from ..utils import tqdm as hf_tqdm from ..utils.typing import NestedDataStructureLike, PathLike from .abc import AbstractDatasetReader @@ -117,10 +117,9 @@ def _write(self, file_obj: BinaryIO, header, index, **to_csv_kwargs) -> int: written = 0 if self.num_proc is None or self.num_proc == 1: - for offset in logging.tqdm( + for offset in hf_tqdm( range(0, len(self.dataset), self.batch_size), unit="ba", - disable=not logging.is_progress_bar_enabled(), desc="Creating CSV from Arrow format", ): csv_str = self._batch_csv((offset, header, index, to_csv_kwargs)) @@ -129,14 +128,13 @@ def _write(self, file_obj: BinaryIO, header, index, **to_csv_kwargs) -> int: else: num_rows, batch_size = len(self.dataset), self.batch_size with multiprocessing.Pool(self.num_proc) as pool: - for csv_str in logging.tqdm( + for csv_str in hf_tqdm( pool.imap( self._batch_csv, [(offset, header, index, to_csv_kwargs) for offset in range(0, num_rows, batch_size)], ), total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size, unit="ba", - disable=not logging.is_progress_bar_enabled(), desc="Creating CSV from Arrow format", ): written += file_obj.write(csv_str) diff --git a/src/datasets/io/json.py b/src/datasets/io/json.py index 5f43efec542..ae4710e0726 100644 --- a/src/datasets/io/json.py +++ b/src/datasets/io/json.py @@ -7,7 +7,7 @@ from .. import Dataset, Features, NamedSplit, config from ..formatting import query_table from ..packaged_modules.json.json import Json -from ..utils import logging +from ..utils import tqdm as hf_tqdm from ..utils.typing import NestedDataStructureLike, PathLike from .abc import AbstractDatasetReader @@ -139,10 +139,9 @@ def _write( written = 0 if self.num_proc is None or self.num_proc == 1: - for offset in logging.tqdm( + for offset in hf_tqdm( range(0, len(self.dataset), self.batch_size), unit="ba", - disable=not logging.is_progress_bar_enabled(), desc="Creating json from Arrow format", ): json_str = self._batch_json((offset, orient, lines, to_json_kwargs)) @@ -150,14 +149,13 @@ def _write( else: num_rows, batch_size = len(self.dataset), self.batch_size with multiprocessing.Pool(self.num_proc) as pool: - for json_str in logging.tqdm( + for json_str in hf_tqdm( pool.imap( self._batch_json, [(offset, orient, lines, to_json_kwargs) for offset in range(0, num_rows, batch_size)], ), total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size, unit="ba", - disable=not logging.is_progress_bar_enabled(), desc="Creating json from Arrow format", ): written += file_obj.write(json_str) diff --git a/src/datasets/io/parquet.py b/src/datasets/io/parquet.py index 39ff70836bb..118525fd81a 100644 --- a/src/datasets/io/parquet.py +++ b/src/datasets/io/parquet.py @@ -9,7 +9,7 @@ from ..formatting import query_table from ..packaged_modules import _PACKAGED_DATASETS_MODULES from ..packaged_modules.parquet.parquet import Parquet -from ..utils import logging +from ..utils import tqdm as hf_tqdm from ..utils.typing import NestedDataStructureLike, PathLike from .abc import AbstractDatasetReader @@ -140,10 +140,9 @@ def _write(self, file_obj: BinaryIO, batch_size: int, **parquet_writer_kwargs) - writer = pq.ParquetWriter(file_obj, schema=schema, **parquet_writer_kwargs) - for offset in logging.tqdm( + for offset in hf_tqdm( range(0, len(self.dataset), batch_size), unit="ba", - disable=not logging.is_progress_bar_enabled(), desc="Creating parquet from Arrow format", ): batch = query_table( diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py index 0553ac2cf1e..ceb425447c2 100644 --- a/src/datasets/io/sql.py +++ b/src/datasets/io/sql.py @@ -4,7 +4,7 @@ from .. import Dataset, Features, config from ..formatting import query_table from ..packaged_modules.sql.sql import Sql -from ..utils import logging +from ..utils import tqdm as hf_tqdm from .abc import AbstractDatasetInputStream @@ -102,24 +102,22 @@ def _write(self, index, **to_sql_kwargs) -> int: written = 0 if self.num_proc is None or self.num_proc == 1: - for offset in logging.tqdm( + for offset in hf_tqdm( range(0, len(self.dataset), self.batch_size), unit="ba", - disable=not logging.is_progress_bar_enabled(), desc="Creating SQL from Arrow format", ): written += self._batch_sql((offset, index, to_sql_kwargs)) else: num_rows, batch_size = len(self.dataset), self.batch_size with multiprocessing.Pool(self.num_proc) as pool: - for num_rows in logging.tqdm( + for num_rows in hf_tqdm( pool.imap( self._batch_sql, [(offset, index, to_sql_kwargs) for offset in range(0, num_rows, batch_size)], ), total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size, unit="ba", - disable=not logging.is_progress_bar_enabled(), desc="Creating SQL from Arrow format", ): written += num_rows diff --git a/src/datasets/search.py b/src/datasets/search.py index 890043a1261..5ec41bbc3e0 100644 --- a/src/datasets/search.py +++ b/src/datasets/search.py @@ -8,6 +8,7 @@ import numpy as np from .utils import logging +from .utils import tqdm as hf_tqdm if TYPE_CHECKING: @@ -150,7 +151,7 @@ def add_documents(self, documents: Union[List[str], "Dataset"], column: Optional index_config = self.es_index_config self.es_client.indices.create(index=index_name, body=index_config) number_of_docs = len(documents) - progress = logging.tqdm(unit="docs", total=number_of_docs, disable=not logging.is_progress_bar_enabled()) + progress = hf_tqdm(unit="docs", total=number_of_docs) successes = 0 def passage_generator(): @@ -301,7 +302,7 @@ def add_vectors( # Add vectors logger.info(f"Adding {len(vectors)} vectors to the faiss index") - for i in logging.tqdm(range(0, len(vectors), batch_size), disable=not logging.is_progress_bar_enabled()): + for i in hf_tqdm(range(0, len(vectors), batch_size)): vecs = vectors[i : i + batch_size] if column is None else vectors[i : i + batch_size][column] self.faiss_index.add(vecs) diff --git a/src/datasets/utils/__init__.py b/src/datasets/utils/__init__.py index f6ad0e6fb2f..001fca727b3 100644 --- a/src/datasets/utils/__init__.py +++ b/src/datasets/utils/__init__.py @@ -14,18 +14,15 @@ # flake8: noqa # Lint as: python3 -"""Util import.""" - -__all__ = [ - "VerificationMode", - "Version", - "disable_progress_bar", - "enable_progress_bar", - "is_progress_bar_enabled", - "experimental", -] +from . import tqdm as _tqdm # _tqdm is the module from .info_utils import VerificationMode from .logging import disable_progress_bar, enable_progress_bar, is_progress_bar_enabled from .version import Version from .experimental import experimental +from .tqdm import ( + disable_progress_bars, + enable_progress_bars, + are_progress_bars_disabled, + tqdm, +) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index d33cf4647f6..e6aec445035 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -24,13 +24,16 @@ import fsspec import huggingface_hub import requests +from fsspec.core import strip_protocol +from fsspec.utils import can_be_local from huggingface_hub import HfFolder from huggingface_hub.utils import insecure_hashlib from packaging import version from .. import __version__, config from ..download.download_config import DownloadConfig -from . import logging +from . import _tqdm, logging +from . import tqdm as hf_tqdm from .extract import ExtractManager from .filelock import FileLock @@ -177,6 +180,10 @@ def cached_path( if isinstance(url_or_filename, Path): url_or_filename = str(url_or_filename) + # Convert fsspec URL in the format "file://local/path" to "local/path" + if can_be_local(url_or_filename): + url_or_filename = strip_protocol(url_or_filename) + if is_remote_url(url_or_filename): # URL, so get it from the cache (downloading if necessary) output_path = get_from_cache( @@ -348,7 +355,7 @@ def fsspec_head(url, storage_options=None): class TqdmCallback(fsspec.callbacks.TqdmCallback): def __init__(self, tqdm_kwargs=None, *args, **kwargs): super().__init__(tqdm_kwargs, *args, **kwargs) - self._tqdm = logging # replace tqdm.tqdm by datasets.logging.tqdm + self._tqdm = _tqdm # replace tqdm.tqdm by datasets.tqdm.tqdm def fsspec_get(url, temp_file, storage_options=None, desc=None): @@ -359,7 +366,6 @@ def fsspec_get(url, temp_file, storage_options=None, desc=None): callback = TqdmCallback( tqdm_kwargs={ "desc": desc or "Downloading", - "disable": not logging.is_progress_bar_enabled(), "unit": "B", "unit_scale": True, } @@ -408,13 +414,12 @@ def http_get( return content_length = response.headers.get("Content-Length") total = resume_size + int(content_length) if content_length is not None else None - with logging.tqdm( + with hf_tqdm( unit="B", unit_scale=True, total=total, initial=resume_size, desc=desc or "Downloading", - disable=not logging.is_progress_bar_enabled(), ) as progress: for chunk in response.iter_content(chunk_size=1024): progress.update(len(chunk)) diff --git a/src/datasets/utils/logging.py b/src/datasets/utils/logging.py index c379820143f..b3ea17d6cad 100644 --- a/src/datasets/utils/logging.py +++ b/src/datasets/utils/logging.py @@ -27,7 +27,12 @@ ) from typing import Optional -from tqdm import auto as tqdm_lib +from .tqdm import ( # noqa: F401 # imported for backward compatibility + disable_progress_bar, + enable_progress_bar, + is_progress_bar_enabled, + tqdm, +) log_levels = { @@ -172,76 +177,3 @@ def enable_propagation() -> None: # Configure the library root logger at the module level (singleton-like) _configure_library_root_logger() - - -class EmptyTqdm: - """Dummy tqdm which doesn't do anything.""" - - def __init__(self, *args, **kwargs): # pylint: disable=unused-argument - self._iterator = args[0] if args else None - - def __iter__(self): - return iter(self._iterator) - - def __getattr__(self, _): - """Return empty function.""" - - def empty_fn(*args, **kwargs): # pylint: disable=unused-argument - return - - return empty_fn - - def __enter__(self): - return self - - def __exit__(self, type_, value, traceback): - return - - -_tqdm_active = True - - -class _tqdm_cls: - def __call__(self, *args, disable=False, **kwargs): - if _tqdm_active and not disable: - return tqdm_lib.tqdm(*args, **kwargs) - else: - return EmptyTqdm(*args, **kwargs) - - def set_lock(self, *args, **kwargs): - self._lock = None - if _tqdm_active: - return tqdm_lib.tqdm.set_lock(*args, **kwargs) - - def get_lock(self): - if _tqdm_active: - return tqdm_lib.tqdm.get_lock() - - def __delattr__(self, attr): - """fix for https://github.com/huggingface/datasets/issues/6066""" - try: - del self.__dict__[attr] - except KeyError: - if attr != "_lock": - raise AttributeError(attr) - - -tqdm = _tqdm_cls() - - -def is_progress_bar_enabled() -> bool: - """Return a boolean indicating whether tqdm progress bars are enabled.""" - global _tqdm_active - return bool(_tqdm_active) - - -def enable_progress_bar(): - """Enable tqdm progress bar.""" - global _tqdm_active - _tqdm_active = True - - -def disable_progress_bar(): - """Disable tqdm progress bar.""" - global _tqdm_active - _tqdm_active = False diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index 243bc0f99c9..4d49c3b5865 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -46,6 +46,7 @@ from .. import config from ..parallel import parallel_map from . import logging +from . import tqdm as hf_tqdm try: # pragma: no branch @@ -377,7 +378,7 @@ def _single_map_nested(args): # Loop over single examples or batches and write to buffer/file if examples are to be updated pbar_iterable = data_struct.items() if isinstance(data_struct, dict) else data_struct pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc - with logging.tqdm(pbar_iterable, disable=disable_tqdm, position=rank, unit="obj", desc=pbar_desc) as pbar: + with hf_tqdm(pbar_iterable, disable=disable_tqdm, position=rank, unit="obj", desc=pbar_desc) as pbar: if isinstance(data_struct, dict): return {k: _single_map_nested((function, v, types, None, True, None)) for k, v in pbar} else: @@ -455,7 +456,6 @@ def map_nested( if not isinstance(data_struct, dict) and not isinstance(data_struct, types): return function(data_struct) - disable_tqdm = disable_tqdm or not logging.is_progress_bar_enabled() iterable = list(data_struct.values()) if isinstance(data_struct, dict) else data_struct if num_proc is None: @@ -463,7 +463,7 @@ def map_nested( if num_proc != -1 and num_proc <= 1 or len(iterable) < parallel_min_length: mapped = [ _single_map_nested((function, obj, types, None, True, None)) - for obj in logging.tqdm(iterable, disable=disable_tqdm, desc=desc) + for obj in hf_tqdm(iterable, disable=disable_tqdm, desc=desc) ] else: with warnings.catch_warnings(): diff --git a/src/datasets/utils/tqdm.py b/src/datasets/utils/tqdm.py new file mode 100644 index 00000000000..0ea73a1c4a6 --- /dev/null +++ b/src/datasets/utils/tqdm.py @@ -0,0 +1,130 @@ +"""Utility helpers to handle progress bars in `datasets`. + +Example: + 1. Use `datasets.utils.tqdm` as you would use `tqdm.tqdm` or `tqdm.auto.tqdm`. + 2. To disable progress bars, either use `disable_progress_bars()` helper or set the + environment variable `HF_DATASETS_DISABLE_PROGRESS_BARS` to 1. + 3. To re-enable progress bars, use `enable_progress_bars()`. + 4. To check whether progress bars are disabled, use `are_progress_bars_disabled()`. + +NOTE: Environment variable `HF_DATASETS_DISABLE_PROGRESS_BARS` has the priority. + +Example: + ```py + from datasets.utils import ( + are_progress_bars_disabled, + disable_progress_bars, + enable_progress_bars, + tqdm, + ) + + # Disable progress bars globally + disable_progress_bars() + + # Use as normal `tqdm` + for _ in tqdm(range(5)): + do_something() + + # Still not showing progress bars, as `disable=False` is overwritten to `True`. + for _ in tqdm(range(5), disable=False): + do_something() + + are_progress_bars_disabled() # True + + # Re-enable progress bars globally + enable_progress_bars() + + # Progress bar will be shown ! + for _ in tqdm(range(5)): + do_something() + ``` +""" +import warnings + +from tqdm.auto import tqdm as old_tqdm + +from ..config import HF_DATASETS_DISABLE_PROGRESS_BARS + + +# `HF_DATASETS_DISABLE_PROGRESS_BARS` is `Optional[bool]` while `_hf_datasets_progress_bars_disabled` +# is a `bool`. If `HF_DATASETS_DISABLE_PROGRESS_BARS` is set to True or False, it has priority. +# If `HF_DATASETS_DISABLE_PROGRESS_BARS` is None, it means the user have not set the +# environment variable and is free to enable/disable progress bars programmatically. +# TL;DR: env variable has priority over code. +# +# By default, progress bars are enabled. +_hf_datasets_progress_bars_disabled: bool = HF_DATASETS_DISABLE_PROGRESS_BARS or False + + +def disable_progress_bars() -> None: + """ + Disable globally progress bars used in `datasets` except if `HF_DATASETS_DISABLE_PROGRESS_BAR` environment + variable has been set. + + Use [`~utils.enable_progress_bars`] to re-enable them. + """ + if HF_DATASETS_DISABLE_PROGRESS_BARS is False: + warnings.warn( + "Cannot disable progress bars: environment variable `HF_DATASETS_DISABLE_PROGRESS_BAR=0` is set and has" + " priority." + ) + return + global _hf_datasets_progress_bars_disabled + _hf_datasets_progress_bars_disabled = True + + +def enable_progress_bars() -> None: + """ + Enable globally progress bars used in `datasets` except if `HF_DATASETS_DISABLE_PROGRESS_BAR` environment + variable has been set. + + Use [`~utils.disable_progress_bars`] to disable them. + """ + if HF_DATASETS_DISABLE_PROGRESS_BARS is True: + warnings.warn( + "Cannot enable progress bars: environment variable `HF_DATASETS_DISABLE_PROGRESS_BAR=1` is set and has" + " priority." + ) + return + global _hf_datasets_progress_bars_disabled + _hf_datasets_progress_bars_disabled = False + + +def are_progress_bars_disabled() -> bool: + """Return whether progress bars are globally disabled or not. + + Progress bars used in `datasets` can be enable or disabled globally using [`~utils.enable_progress_bars`] + and [`~utils.disable_progress_bars`] or by setting `HF_DATASETS_DISABLE_PROGRESS_BAR` as environment variable. + """ + global _hf_datasets_progress_bars_disabled + return _hf_datasets_progress_bars_disabled + + +class tqdm(old_tqdm): + """ + Class to override `disable` argument in case progress bars are globally disabled. + + Taken from https://github.com/tqdm/tqdm/issues/619#issuecomment-619639324. + """ + + def __init__(self, *args, **kwargs): + if are_progress_bars_disabled(): + kwargs["disable"] = True + super().__init__(*args, **kwargs) + + def __delattr__(self, attr: str) -> None: + """Fix for https://github.com/huggingface/datasets/issues/6066""" + try: + super().__delattr__(attr) + except AttributeError: + if attr != "_lock": + raise + + +# backward compatibility +enable_progress_bar = enable_progress_bars +disable_progress_bar = disable_progress_bars + + +def is_progress_bar_enabled(): + return not are_progress_bars_disabled() diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index a6175c3dd17..54f11bccce4 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -80,12 +80,14 @@ def test_extracted_datasets_path(default_extracted, default_cache_dir, xz_file, def test_cached_path_local(text_file): - # absolute path - text_file = str(Path(text_file).resolve()) - assert cached_path(text_file) == text_file - # relative path - text_file = str(Path(__file__).resolve().relative_to(Path(os.getcwd()))) - assert cached_path(text_file) == text_file + # input absolute path -> output absolute path + text_file_abs = str(Path(text_file).resolve()) + assert os.path.samefile(cached_path(text_file_abs), text_file_abs) + # input relative path -> output absolute path + text_file = __file__ + text_file_abs = str(Path(text_file).resolve()) + text_file_rel = str(Path(text_file).resolve().relative_to(Path(os.getcwd()))) + assert os.path.samefile(cached_path(text_file_rel), text_file_abs) def test_cached_path_missing_local(tmp_path): diff --git a/tests/test_logging.py b/tests/test_logging.py deleted file mode 100644 index a7669a3fff6..00000000000 --- a/tests/test_logging.py +++ /dev/null @@ -1,19 +0,0 @@ -from unittest.mock import patch - -import datasets -from datasets import Dataset - - -def test_enable_disable_progress_bar(): - dset = Dataset.from_dict({"col_1": [3, 2, 0, 1]}) - - with patch("tqdm.auto.tqdm") as mock_tqdm: - datasets.disable_progress_bar() - dset.map(lambda x: {"col_2": x["col_1"] + 1}) - mock_tqdm.assert_not_called() - - mock_tqdm.reset_mock() - - datasets.enable_progress_bar() - dset.map(lambda x: {"col_2": x["col_1"] + 1}) - mock_tqdm.assert_called() diff --git a/tests/test_tqdm.py b/tests/test_tqdm.py new file mode 100644 index 00000000000..e6ddb86de1d --- /dev/null +++ b/tests/test_tqdm.py @@ -0,0 +1,116 @@ +import unittest +from unittest.mock import patch + +import pytest +from pytest import CaptureFixture + +from datasets.utils import ( + are_progress_bars_disabled, + disable_progress_bars, + enable_progress_bars, + tqdm, +) + + +class TestTqdmUtils(unittest.TestCase): + @pytest.fixture(autouse=True) + def capsys(self, capsys: CaptureFixture) -> None: + """Workaround to make capsys work in unittest framework. + + Capsys is a convenient pytest fixture to capture stdout. + See https://waylonwalker.com/pytest-capsys/. + + Taken from https://github.com/pytest-dev/pytest/issues/2504#issuecomment-309475790. + """ + self.capsys = capsys + + def setUp(self) -> None: + """Get verbosity to set it back after the tests.""" + self._previous_are_progress_bars_disabled = are_progress_bars_disabled() + return super().setUp() + + def tearDown(self) -> None: + """Set back progress bars verbosity as before testing.""" + if self._previous_are_progress_bars_disabled: + disable_progress_bars() + else: + enable_progress_bars() + + @patch("datasets.utils._tqdm.HF_DATASETS_DISABLE_PROGRESS_BARS", None) + def test_tqdm_helpers(self) -> None: + """Test helpers to enable/disable progress bars.""" + disable_progress_bars() + self.assertTrue(are_progress_bars_disabled()) + + enable_progress_bars() + self.assertFalse(are_progress_bars_disabled()) + + @patch("datasets.utils._tqdm.HF_DATASETS_DISABLE_PROGRESS_BARS", True) + def test_cannot_enable_tqdm_when_env_variable_is_set(self) -> None: + """ + Test helpers cannot enable/disable progress bars when + `HF_DATASETS_DISABLE_PROGRESS_BARS` is set. + """ + disable_progress_bars() + self.assertTrue(are_progress_bars_disabled()) + + with self.assertWarns(UserWarning): + enable_progress_bars() + self.assertTrue(are_progress_bars_disabled()) # Still disabled ! + + @patch("datasets.utils._tqdm.HF_DATASETS_DISABLE_PROGRESS_BARS", False) + def test_cannot_disable_tqdm_when_env_variable_is_set(self) -> None: + """ + Test helpers cannot enable/disable progress bars when + `HF_DATASETS_DISABLE_PROGRESS_BARS` is set. + """ + enable_progress_bars() + self.assertFalse(are_progress_bars_disabled()) + + with self.assertWarns(UserWarning): + disable_progress_bars() + self.assertFalse(are_progress_bars_disabled()) # Still enabled ! + + @patch("datasets.utils._tqdm.HF_DATASETS_DISABLE_PROGRESS_BARS", None) + def test_tqdm_disabled(self) -> None: + """Test TQDM not outputting anything when globally disabled.""" + disable_progress_bars() + for _ in tqdm(range(10)): + pass + + captured = self.capsys.readouterr() + self.assertEqual(captured.out, "") + self.assertEqual(captured.err, "") + + @patch("datasets.utils._tqdm.HF_DATASETS_DISABLE_PROGRESS_BARS", None) + def test_tqdm_disabled_cannot_be_forced(self) -> None: + """Test TQDM cannot be forced when globally disabled.""" + disable_progress_bars() + for _ in tqdm(range(10), disable=False): + pass + + captured = self.capsys.readouterr() + self.assertEqual(captured.out, "") + self.assertEqual(captured.err, "") + + @patch("datasets.utils._tqdm.HF_DATASETS_DISABLE_PROGRESS_BARS", None) + def test_tqdm_can_be_disabled_when_globally_enabled(self) -> None: + """Test TQDM can still be locally disabled even when globally enabled.""" + enable_progress_bars() + for _ in tqdm(range(10), disable=True): + pass + + captured = self.capsys.readouterr() + self.assertEqual(captured.out, "") + self.assertEqual(captured.err, "") + + @patch("datasets.utils._tqdm.HF_DATASETS_DISABLE_PROGRESS_BARS", None) + def test_tqdm_enabled(self) -> None: + """Test TQDM work normally when globally enabled.""" + enable_progress_bars() + for _ in tqdm(range(10)): + pass + + captured = self.capsys.readouterr() + self.assertEqual(captured.out, "") + self.assertIn("10/10", captured.err) # tqdm log