Skip to content

Commit b319382

Browse files
mariosaskolhoestq
andauthored
Better tqdm wrapper (#6433)
* Better TQDM wrapper * Improve docs * Minor doc fix * More fixes * Fix name error * More fixes * Fix fsspec import * Nit * Final fix? * Apply suggestions from code review Co-authored-by: Quentin Lhoest <[email protected]> * Use `fsspec` to handle local URIs in `cached_path` --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent bc44d21 commit b319382

File tree

22 files changed

+340
-213
lines changed

22 files changed

+340
-213
lines changed

docs/source/_redirects.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ faiss_and_ea: faiss_es
1010
features: about_dataset_features
1111
using_metrics: how_to_metrics
1212
exploring: access
13+
package_reference/logging_methods: package_reference/utilities
1314
# end of first_section

docs/source/_toctree.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@
121121
title: Loading methods
122122
- local: package_reference/table_classes
123123
title: Table Classes
124-
- local: package_reference/logging_methods
125-
title: Logging methods
124+
- local: package_reference/utilities
125+
title: Utilities
126126
- local: package_reference/task_templates
127127
title: Task templates
128128
title: "Reference"
Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
# Logging methods
1+
# Utilities
2+
3+
## Configure logging
24

35
🤗 Datasets strives to be transparent and explicit about how it works, but this can be quite verbose at times. We have included a series of logging methods which allow you to easily adjust the level of verbosity of the entire library. Currently the default verbosity of the library is set to `WARNING`.
46

@@ -28,10 +30,6 @@ In order from the least to the most verbose (with their corresponding `int` valu
2830
4. `logging.INFO` (int value, 20): reports error, warnings and basic information.
2931
5. `logging.DEBUG` (int value, 10): report all information.
3032

31-
By default, `tqdm` progress bars will be displayed during dataset download and preprocessing. [`logging.disable_progress_bar`] and [`logging.enable_progress_bar`] can be used to suppress or unsuppress this behavior.
32-
33-
## Functions
34-
3533
[[autodoc]] datasets.logging.get_verbosity
3634

3735
[[autodoc]] datasets.logging.set_verbosity
@@ -48,44 +46,13 @@ By default, `tqdm` progress bars will be displayed during dataset download and p
4846

4947
[[autodoc]] datasets.logging.enable_propagation
5048

51-
[[autodoc]] datasets.logging.get_logger
52-
53-
[[autodoc]] datasets.logging.enable_progress_bar
54-
55-
[[autodoc]] datasets.logging.disable_progress_bar
56-
57-
[[autodoc]] datasets.is_progress_bar_enabled
58-
59-
## Levels
60-
61-
### datasets.logging.CRITICAL
62-
63-
datasets.logging.CRITICAL = 50
64-
65-
### datasets.logging.DEBUG
66-
67-
datasets.logging.DEBUG = 10
68-
69-
### datasets.logging.ERROR
70-
71-
datasets.logging.ERROR = 40
72-
73-
### datasets.logging.FATAL
74-
75-
datasets.logging.FATAL = 50
76-
77-
### datasets.logging.INFO
78-
79-
datasets.logging.INFO = 20
80-
81-
### datasets.logging.NOTSET
82-
83-
datasets.logging.NOTSET = 0
49+
## Configure progress bars
8450

85-
### datasets.logging.WARN
51+
By default, `tqdm` progress bars will be displayed during dataset download and preprocessing. You can disable them globally by setting `HF_DATASETS_DISABLE_PROGRESS_BARS`
52+
environment variable. You can also enable/disable them using [`~utils.enable_progress_bars`] and [`~utils.disable_progress_bars`]. If set, the environment variable has priority on the helpers.
8653

87-
datasets.logging.WARN = 30
54+
[[autodoc]] datasets.utils.enable_progress_bars
8855

89-
### datasets.logging.WARNING
56+
[[autodoc]] datasets.utils.disable_progress_bars
9057

91-
datasets.logging.WARNING = 30
58+
[[autodoc]] datasets.utils.are_progress_bars_disabled

src/datasets/arrow_dataset.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
)
112112
from .tasks import TaskTemplate
113113
from .utils import logging
114+
from .utils import tqdm as hf_tqdm
114115
from .utils.deprecation_utils import deprecated
115116
from .utils.file_utils import _retry, estimate_dataset_size
116117
from .utils.info_utils import is_small_dataset
@@ -1494,8 +1495,7 @@ def save_to_disk(
14941495
dataset_info = asdict(self._info)
14951496

14961497
shards_done = 0
1497-
pbar = logging.tqdm(
1498-
disable=not logging.is_progress_bar_enabled(),
1498+
pbar = hf_tqdm(
14991499
unit=" examples",
15001500
total=len(self),
15011501
desc=f"Saving the dataset ({shards_done}/{num_shards} shards)",
@@ -3080,8 +3080,7 @@ def load_processed_shard_from_cache(shard_kwargs):
30803080
except NonExistentDatasetError:
30813081
pass
30823082
if transformed_dataset is None:
3083-
with logging.tqdm(
3084-
disable=not logging.is_progress_bar_enabled(),
3083+
with hf_tqdm(
30853084
unit=" examples",
30863085
total=pbar_total,
30873086
desc=desc or "Map",
@@ -3173,8 +3172,7 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
31733172
with Pool(len(kwargs_per_job)) as pool:
31743173
os.environ = prev_env
31753174
logger.info(f"Spawning {num_proc} processes")
3176-
with logging.tqdm(
3177-
disable=not logging.is_progress_bar_enabled(),
3175+
with hf_tqdm(
31783176
unit=" examples",
31793177
total=pbar_total,
31803178
desc=(desc or "Map") + f" (num_proc={num_proc})",
@@ -5195,11 +5193,10 @@ def shards_with_embedded_external_files(shards):
51955193

51965194
uploaded_size = 0
51975195
additions = []
5198-
for index, shard in logging.tqdm(
5196+
for index, shard in hf_tqdm(
51995197
enumerate(shards),
52005198
desc="Uploading the dataset shards",
52015199
total=num_shards,
5202-
disable=not logging.is_progress_bar_enabled(),
52035200
):
52045201
shard_path_in_repo = f"{data_dir}/{split}-{index:05d}-of-{num_shards:05d}.parquet"
52055202
buffer = BytesIO()

src/datasets/arrow_writer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from .keyhash import DuplicatedKeysError, KeyHasher
4242
from .table import array_cast, array_concat, cast_array_to_feature, embed_table_storage, table_cast
4343
from .utils import logging
44+
from .utils import tqdm as hf_tqdm
4445
from .utils.file_utils import hash_url_to_filename
4546
from .utils.py_utils import asdict, first_non_null_value
4647

@@ -689,9 +690,8 @@ def finalize(self, metrics_query_result: dict):
689690
for metadata in beam.io.filesystems.FileSystems.match([parquet_path + "*.parquet"])[0].metadata_list
690691
]
691692
try: # stream conversion
692-
disable = not logging.is_progress_bar_enabled()
693693
num_bytes = 0
694-
for shard in logging.tqdm(shards, unit="shards", disable=disable):
694+
for shard in hf_tqdm(shards, unit="shards"):
695695
with beam.io.filesystems.FileSystems.open(shard) as source:
696696
with beam.io.filesystems.FileSystems.create(
697697
shard.replace(".parquet", ".arrow")
@@ -706,9 +706,8 @@ def finalize(self, metrics_query_result: dict):
706706
)
707707
local_convert_dir = os.path.join(self._cache_dir, "beam_convert")
708708
os.makedirs(local_convert_dir, exist_ok=True)
709-
disable = not logging.is_progress_bar_enabled()
710709
num_bytes = 0
711-
for shard in logging.tqdm(shards, unit="shards", disable=disable):
710+
for shard in hf_tqdm(shards, unit="shards"):
712711
local_parquet_path = os.path.join(local_convert_dir, hash_url_to_filename(shard) + ".parquet")
713712
beam_utils.download_remote_to_local(shard, local_parquet_path)
714713
local_arrow_path = local_parquet_path.replace(".parquet", ".arrow")
@@ -727,8 +726,7 @@ def finalize(self, metrics_query_result: dict):
727726

728727
def get_parquet_lengths(sources) -> List[int]:
729728
shard_lengths = []
730-
disable = not logging.is_progress_bar_enabled()
731-
for source in logging.tqdm(sources, unit="parquet files", disable=disable):
729+
for source in hf_tqdm(sources, unit="parquet files"):
732730
parquet_file = pa.parquet.ParquetFile(source)
733731
shard_lengths.append(parquet_file.metadata.num_rows)
734732
return shard_lengths

src/datasets/builder.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from .splits import Split, SplitDict, SplitGenerator, SplitInfo
6666
from .streaming import extend_dataset_builder_for_streaming
6767
from .utils import logging
68+
from .utils import tqdm as hf_tqdm
6869
from .utils.file_utils import cached_path, is_remote_url
6970
from .utils.filelock import FileLock
7071
from .utils.info_utils import VerificationMode, get_size_checksum_dict, verify_checksums, verify_splits
@@ -1526,8 +1527,7 @@ def _prepare_split(
15261527
)
15271528
num_proc = num_input_shards
15281529

1529-
pbar = logging.tqdm(
1530-
disable=not logging.is_progress_bar_enabled(),
1530+
pbar = hf_tqdm(
15311531
unit=" examples",
15321532
total=split_info.num_examples,
15331533
desc=f"Generating {split_info.name} split",
@@ -1784,8 +1784,7 @@ def _prepare_split(
17841784
)
17851785
num_proc = num_input_shards
17861786

1787-
pbar = logging.tqdm(
1788-
disable=not logging.is_progress_bar_enabled(),
1787+
pbar = hf_tqdm(
17891788
unit=" examples",
17901789
total=split_info.num_examples,
17911790
desc=f"Generating {split_info.name} split",

src/datasets/config.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import importlib
22
import importlib.metadata
3+
import logging
34
import os
45
import platform
56
from pathlib import Path
7+
from typing import Optional
68

79
from packaging import version
810

9-
from .utils.logging import get_logger
1011

11-
12-
logger = get_logger(__name__)
12+
logger = logging.getLogger(__name__.split(".", 1)[0]) # to avoid circular import from .utils.logging
1313

1414
# Datasets
1515
S3_DATASETS_BUCKET_PREFIX = "https://s3.amazonaws.com/datasets.huggingface.co/datasets/datasets"
@@ -192,6 +192,18 @@
192192
# Offline mode
193193
HF_DATASETS_OFFLINE = os.environ.get("HF_DATASETS_OFFLINE", "AUTO").upper() in ENV_VARS_TRUE_VALUES
194194

195+
# Here, `True` will disable progress bars globally without possibility of enabling it
196+
# programmatically. `False` will enable them without possibility of disabling them.
197+
# If environment variable is not set (None), then the user is free to enable/disable
198+
# them programmatically.
199+
# TL;DR: env variable has priority over code
200+
__HF_DATASETS_DISABLE_PROGRESS_BARS = os.environ.get("HF_DATASETS_DISABLE_PROGRESS_BARS")
201+
HF_DATASETS_DISABLE_PROGRESS_BARS: Optional[bool] = (
202+
__HF_DATASETS_DISABLE_PROGRESS_BARS.upper() in ENV_VARS_TRUE_VALUES
203+
if __HF_DATASETS_DISABLE_PROGRESS_BARS is not None
204+
else None
205+
)
206+
195207
# In-memory
196208
DEFAULT_IN_MEMORY_MAX_SIZE = 0 # Disabled
197209
IN_MEMORY_MAX_SIZE = float(os.environ.get("HF_DATASETS_IN_MEMORY_MAX_SIZE", DEFAULT_IN_MEMORY_MAX_SIZE))

src/datasets/data_files.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .download.streaming_download_manager import _prepare_path_and_storage_options, xbasename, xjoin
1818
from .splits import Split
1919
from .utils import logging
20+
from .utils import tqdm as hf_tqdm
2021
from .utils.file_utils import is_local_path, is_relative_path
2122
from .utils.py_utils import glob_pattern_to_regex, string_to_dict
2223

@@ -515,9 +516,9 @@ def _get_origin_metadata(
515516
partial(_get_single_origin_metadata, download_config=download_config),
516517
data_files,
517518
max_workers=max_workers,
518-
tqdm_class=logging.tqdm,
519+
tqdm_class=hf_tqdm,
519520
desc="Resolving data files",
520-
disable=len(data_files) <= 16 or not logging.is_progress_bar_enabled(),
521+
disable=len(data_files) <= 16,
521522
)
522523

523524

src/datasets/download/download_manager.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@
2828
from typing import Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union
2929

3030
from .. import config
31+
from ..utils import tqdm as hf_tqdm
3132
from ..utils.deprecation_utils import DeprecatedEnum, deprecated
3233
from ..utils.file_utils import cached_path, get_from_cache, hash_url_to_filename, is_relative_path, url_or_path_join
3334
from ..utils.info_utils import get_size_checksum_dict
34-
from ..utils.logging import get_logger, is_progress_bar_enabled, tqdm
35+
from ..utils.logging import get_logger
3536
from ..utils.py_utils import NestedDataStructure, map_nested, size_str
3637
from .download_config import DownloadConfig
3738

@@ -327,18 +328,16 @@ def upload(local_file_path):
327328
uploaded_path_or_paths = map_nested(
328329
lambda local_file_path: upload(local_file_path),
329330
downloaded_path_or_paths,
330-
disable_tqdm=not is_progress_bar_enabled(),
331331
)
332332
return uploaded_path_or_paths
333333

334334
def _record_sizes_checksums(self, url_or_urls: NestedDataStructure, downloaded_path_or_paths: NestedDataStructure):
335335
"""Record size/checksum of downloaded files."""
336336
delay = 5
337-
for url, path in tqdm(
337+
for url, path in hf_tqdm(
338338
list(zip(url_or_urls.flatten(), downloaded_path_or_paths.flatten())),
339339
delay=delay,
340340
desc="Computing checksums",
341-
disable=not is_progress_bar_enabled(),
342341
):
343342
# call str to support PathLike objects
344343
self._recorded_sizes_checksums[str(url)] = get_size_checksum_dict(
@@ -373,9 +372,7 @@ def download_custom(self, url_or_urls, custom_download):
373372
def url_to_downloaded_path(url):
374373
return os.path.join(cache_dir, hash_url_to_filename(url))
375374

376-
downloaded_path_or_paths = map_nested(
377-
url_to_downloaded_path, url_or_urls, disable_tqdm=not is_progress_bar_enabled()
378-
)
375+
downloaded_path_or_paths = map_nested(url_to_downloaded_path, url_or_urls)
379376
url_or_urls = NestedDataStructure(url_or_urls)
380377
downloaded_path_or_paths = NestedDataStructure(downloaded_path_or_paths)
381378
for url, path in zip(url_or_urls.flatten(), downloaded_path_or_paths.flatten()):
@@ -426,7 +423,6 @@ def download(self, url_or_urls):
426423
url_or_urls,
427424
map_tuple=True,
428425
num_proc=download_config.num_proc,
429-
disable_tqdm=not is_progress_bar_enabled(),
430426
desc="Downloading data files",
431427
)
432428
duration = datetime.now() - start_time
@@ -534,7 +530,6 @@ def extract(self, path_or_paths, num_proc="deprecated"):
534530
partial(cached_path, download_config=download_config),
535531
path_or_paths,
536532
num_proc=download_config.num_proc,
537-
disable_tqdm=not is_progress_bar_enabled(),
538533
desc="Extracting data files",
539534
)
540535
path_or_paths = NestedDataStructure(path_or_paths)

src/datasets/io/csv.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .. import Dataset, Features, NamedSplit, config
66
from ..formatting import query_table
77
from ..packaged_modules.csv.csv import Csv
8-
from ..utils import logging
8+
from ..utils import tqdm as hf_tqdm
99
from ..utils.typing import NestedDataStructureLike, PathLike
1010
from .abc import AbstractDatasetReader
1111

@@ -117,10 +117,9 @@ def _write(self, file_obj: BinaryIO, header, index, **to_csv_kwargs) -> int:
117117
written = 0
118118

119119
if self.num_proc is None or self.num_proc == 1:
120-
for offset in logging.tqdm(
120+
for offset in hf_tqdm(
121121
range(0, len(self.dataset), self.batch_size),
122122
unit="ba",
123-
disable=not logging.is_progress_bar_enabled(),
124123
desc="Creating CSV from Arrow format",
125124
):
126125
csv_str = self._batch_csv((offset, header, index, to_csv_kwargs))
@@ -129,14 +128,13 @@ def _write(self, file_obj: BinaryIO, header, index, **to_csv_kwargs) -> int:
129128
else:
130129
num_rows, batch_size = len(self.dataset), self.batch_size
131130
with multiprocessing.Pool(self.num_proc) as pool:
132-
for csv_str in logging.tqdm(
131+
for csv_str in hf_tqdm(
133132
pool.imap(
134133
self._batch_csv,
135134
[(offset, header, index, to_csv_kwargs) for offset in range(0, num_rows, batch_size)],
136135
),
137136
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
138137
unit="ba",
139-
disable=not logging.is_progress_bar_enabled(),
140138
desc="Creating CSV from Arrow format",
141139
):
142140
written += file_obj.write(csv_str)

0 commit comments

Comments
 (0)