diff --git a/setup.py b/setup.py index 0f46586deec..d49fddde177 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ def find_version(*file_paths): 'pillow >= 4.1.1', 'six', 'torch', + 'tqdm' ] setup( diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 9fa3b0b8c9b..4675013ac33 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -2,6 +2,15 @@ import os.path import hashlib import errno +from tqdm import tqdm + + +def gen_bar_updator(pbar): + def bar_update(count, block_size, total_size): + pbar.total = total_size / block_size + pbar.update(count) + + return bar_update def check_integrity(fpath, md5): @@ -38,7 +47,7 @@ def download_url(url, root, filename, md5): else: try: print('Downloading ' + url + ' to ' + fpath) - urllib.request.urlretrieve(url, fpath) + urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updator(tqdm())) except: if url[:5] == 'https': url = url.replace('https:', 'http:')