Skip to content

Commit 6e8607b

Browse files
authored
unify dist download (PaddlePaddle#3866)
1 parent 63ead1e commit 6e8607b

File tree

2 files changed

+65
-42
lines changed

2 files changed

+65
-42
lines changed

ppdet/utils/checkpoint.py

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -55,41 +55,6 @@ def _get_unique_endpoints(trainer_endpoints):
5555
return unique_endpoints
5656

5757

58-
def get_weights_path_dist(path):
59-
env = os.environ
60-
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
61-
trainer_id = int(env['PADDLE_TRAINER_ID'])
62-
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
63-
if num_trainers <= 1:
64-
path = get_weights_path(path)
65-
else:
66-
from ppdet.utils.download import map_path, WEIGHTS_HOME
67-
weight_path = map_path(path, WEIGHTS_HOME)
68-
lock_path = weight_path + '.lock'
69-
if not os.path.exists(weight_path):
70-
from paddle.distributed import ParallelEnv
71-
unique_endpoints = _get_unique_endpoints(ParallelEnv()
72-
.trainer_endpoints[:])
73-
try:
74-
os.makedirs(os.path.dirname(weight_path))
75-
except OSError as e:
76-
if e.errno != errno.EEXIST:
77-
raise
78-
with open(lock_path, 'w'): # touch
79-
os.utime(lock_path, None)
80-
if ParallelEnv().current_endpoint in unique_endpoints:
81-
get_weights_path(path)
82-
os.remove(lock_path)
83-
else:
84-
while os.path.exists(lock_path):
85-
time.sleep(1)
86-
path = weight_path
87-
else:
88-
path = get_weights_path(path)
89-
90-
return path
91-
92-
9358
def _strip_postfix(path):
9459
path, ext = os.path.splitext(path)
9560
assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \
@@ -99,7 +64,7 @@ def _strip_postfix(path):
9964

10065
def load_weight(model, weight, optimizer=None):
10166
if is_url(weight):
102-
weight = get_weights_path_dist(weight)
67+
weight = get_weights_path(weight)
10368

10469
path = _strip_postfix(weight)
10570
pdparam_path = path + '.pdparams'
@@ -205,7 +170,7 @@ def match(a, b):
205170

206171
def load_pretrain_weight(model, pretrain_weight):
207172
if is_url(pretrain_weight):
208-
pretrain_weight = get_weights_path_dist(pretrain_weight)
173+
pretrain_weight = get_weights_path(pretrain_weight)
209174

210175
path = _strip_postfix(pretrain_weight)
211176
if not (os.path.isdir(path) or os.path.isfile(path) or
@@ -251,4 +216,4 @@ def save_model(model, optimizer, save_dir, save_name, last_epoch):
251216
state_dict = optimizer.state_dict()
252217
state_dict['last_epoch'] = last_epoch
253218
paddle.save(state_dict, save_path + ".pdopt")
254-
logger.info("Save checkpoint: {}".format(save_dir))
219+
logger.info("Save checkpoint: {}".format(save_dir))

ppdet/utils/download.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os.path as osp
2121
import sys
2222
import yaml
23+
import time
2324
import shutil
2425
import requests
2526
import tqdm
@@ -29,6 +30,7 @@
2930
import tarfile
3031
import zipfile
3132

33+
from paddle.utils.download import _get_unique_endpoints
3234
from ppdet.core.workspace import BASE_KEY
3335
from .logger import setup_logger
3436
from .voc_utils import create_list
@@ -147,8 +149,8 @@ def get_config_path(url):
147149
cfg_url = parse_url(cfg_url)
148150

149151
# 3. download and decompress
150-
cfg_fullname = _download(cfg_url, osp.dirname(CONFIGS_HOME))
151-
_decompress(cfg_fullname)
152+
cfg_fullname = _download_dist(cfg_url, osp.dirname(CONFIGS_HOME))
153+
_decompress_dist(cfg_fullname)
152154

153155
# 4. check config file existing
154156
if os.path.isfile(path):
@@ -284,12 +286,12 @@ def get_path(url, root_dir, md5sum=None, check_exist=True):
284286
else:
285287
os.remove(fullpath)
286288

287-
fullname = _download(url, root_dir, md5sum)
289+
fullname = _download_dist(url, root_dir, md5sum)
288290

289291
# new weights format which postfix is 'pdparams' not
290292
# need to decompress
291293
if osp.splitext(fullname)[-1] not in ['.pdparams', '.yml']:
292-
_decompress(fullname)
294+
_decompress_dist(fullname)
293295

294296
return fullpath, False
295297

@@ -384,6 +386,38 @@ def _download(url, path, md5sum=None):
384386
return fullname
385387

386388

389+
def _download_dist(url, path, md5sum=None):
390+
env = os.environ
391+
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
392+
trainer_id = int(env['PADDLE_TRAINER_ID'])
393+
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
394+
if num_trainers <= 1:
395+
return _download(url, path, md5sum)
396+
else:
397+
fname = osp.split(url)[-1]
398+
fullname = osp.join(path, fname)
399+
lock_path = fullname + '.download.lock'
400+
401+
if not osp.isdir(path):
402+
os.makedirs(path)
403+
404+
if not osp.exists(fullname):
405+
from paddle.distributed import ParallelEnv
406+
unique_endpoints = _get_unique_endpoints(ParallelEnv()
407+
.trainer_endpoints[:])
408+
with open(lock_path, 'w'): # touch
409+
os.utime(lock_path, None)
410+
if ParallelEnv().current_endpoint in unique_endpoints:
411+
_download(url, path, md5sum)
412+
os.remove(lock_path)
413+
else:
414+
while os.path.exists(lock_path):
415+
time.sleep(1)
416+
return fullname
417+
else:
418+
return _download(url, path, md5sum)
419+
420+
387421
def _check_exist_file_md5(filename, md5sum, url):
388422
# if md5sum is None, and file to check is weights file,
389423
# read md5um from url and check, else check md5sum directly
@@ -461,6 +495,30 @@ def _decompress(fname):
461495
os.remove(fname)
462496

463497

498+
def _decompress_dist(fname):
499+
env = os.environ
500+
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
501+
trainer_id = int(env['PADDLE_TRAINER_ID'])
502+
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
503+
if num_trainers <= 1:
504+
_decompress(fname)
505+
else:
506+
lock_path = fname + '.decompress.lock'
507+
from paddle.distributed import ParallelEnv
508+
unique_endpoints = _get_unique_endpoints(ParallelEnv()
509+
.trainer_endpoints[:])
510+
with open(lock_path, 'w'): # touch
511+
os.utime(lock_path, None)
512+
if ParallelEnv().current_endpoint in unique_endpoints:
513+
_decompress(fname)
514+
os.remove(lock_path)
515+
else:
516+
while os.path.exists(lock_path):
517+
time.sleep(1)
518+
else:
519+
_decompress(fname)
520+
521+
464522
def _move_and_merge_tree(src, dst):
465523
"""
466524
Move src directory to dst, if dst is already exists,

0 commit comments

Comments
 (0)