|
20 | 20 | import os.path as osp |
21 | 21 | import sys |
22 | 22 | import yaml |
| 23 | +import time |
23 | 24 | import shutil |
24 | 25 | import requests |
25 | 26 | import tqdm |
|
29 | 30 | import tarfile |
30 | 31 | import zipfile |
31 | 32 |
|
| 33 | +from paddle.utils.download import _get_unique_endpoints |
32 | 34 | from ppdet.core.workspace import BASE_KEY |
33 | 35 | from .logger import setup_logger |
34 | 36 | from .voc_utils import create_list |
@@ -147,8 +149,8 @@ def get_config_path(url): |
147 | 149 | cfg_url = parse_url(cfg_url) |
148 | 150 |
|
149 | 151 | # 3. download and decompress |
150 | | - cfg_fullname = _download(cfg_url, osp.dirname(CONFIGS_HOME)) |
151 | | - _decompress(cfg_fullname) |
| 152 | + cfg_fullname = _download_dist(cfg_url, osp.dirname(CONFIGS_HOME)) |
| 153 | + _decompress_dist(cfg_fullname) |
152 | 154 |
|
153 | 155 | # 4. check config file existing |
154 | 156 | if os.path.isfile(path): |
@@ -284,12 +286,12 @@ def get_path(url, root_dir, md5sum=None, check_exist=True): |
284 | 286 | else: |
285 | 287 | os.remove(fullpath) |
286 | 288 |
|
287 | | - fullname = _download(url, root_dir, md5sum) |
| 289 | + fullname = _download_dist(url, root_dir, md5sum) |
288 | 290 |
|
289 | 291 | # new weights format which postfix is 'pdparams' not |
290 | 292 | # need to decompress |
291 | 293 | if osp.splitext(fullname)[-1] not in ['.pdparams', '.yml']: |
292 | | - _decompress(fullname) |
| 294 | + _decompress_dist(fullname) |
293 | 295 |
|
294 | 296 | return fullpath, False |
295 | 297 |
|
@@ -384,6 +386,38 @@ def _download(url, path, md5sum=None): |
384 | 386 | return fullname |
385 | 387 |
|
386 | 388 |
|
| 389 | +def _download_dist(url, path, md5sum=None): |
| 390 | + env = os.environ |
| 391 | + if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env: |
| 392 | + trainer_id = int(env['PADDLE_TRAINER_ID']) |
| 393 | + num_trainers = int(env['PADDLE_TRAINERS_NUM']) |
| 394 | + if num_trainers <= 1: |
| 395 | + return _download(url, path, md5sum) |
| 396 | + else: |
| 397 | + fname = osp.split(url)[-1] |
| 398 | + fullname = osp.join(path, fname) |
| 399 | + lock_path = fullname + '.download.lock' |
| 400 | + |
| 401 | + if not osp.isdir(path): |
| 402 | + os.makedirs(path) |
| 403 | + |
| 404 | + if not osp.exists(fullname): |
| 405 | + from paddle.distributed import ParallelEnv |
| 406 | + unique_endpoints = _get_unique_endpoints(ParallelEnv() |
| 407 | + .trainer_endpoints[:]) |
| 408 | + with open(lock_path, 'w'): # touch |
| 409 | + os.utime(lock_path, None) |
| 410 | + if ParallelEnv().current_endpoint in unique_endpoints: |
| 411 | + _download(url, path, md5sum) |
| 412 | + os.remove(lock_path) |
| 413 | + else: |
| 414 | + while os.path.exists(lock_path): |
| 415 | + time.sleep(1) |
| 416 | + return fullname |
| 417 | + else: |
| 418 | + return _download(url, path, md5sum) |
| 419 | + |
| 420 | + |
387 | 421 | def _check_exist_file_md5(filename, md5sum, url): |
388 | 422 | # if md5sum is None, and file to check is weights file, |
389 | 423 | # read md5um from url and check, else check md5sum directly |
@@ -461,6 +495,30 @@ def _decompress(fname): |
461 | 495 | os.remove(fname) |
462 | 496 |
|
463 | 497 |
|
| 498 | +def _decompress_dist(fname): |
| 499 | + env = os.environ |
| 500 | + if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env: |
| 501 | + trainer_id = int(env['PADDLE_TRAINER_ID']) |
| 502 | + num_trainers = int(env['PADDLE_TRAINERS_NUM']) |
| 503 | + if num_trainers <= 1: |
| 504 | + _decompress(fname) |
| 505 | + else: |
| 506 | + lock_path = fname + '.decompress.lock' |
| 507 | + from paddle.distributed import ParallelEnv |
| 508 | + unique_endpoints = _get_unique_endpoints(ParallelEnv() |
| 509 | + .trainer_endpoints[:]) |
| 510 | + with open(lock_path, 'w'): # touch |
| 511 | + os.utime(lock_path, None) |
| 512 | + if ParallelEnv().current_endpoint in unique_endpoints: |
| 513 | + _decompress(fname) |
| 514 | + os.remove(lock_path) |
| 515 | + else: |
| 516 | + while os.path.exists(lock_path): |
| 517 | + time.sleep(1) |
| 518 | + else: |
| 519 | + _decompress(fname) |
| 520 | + |
| 521 | + |
464 | 522 | def _move_and_merge_tree(src, dst): |
465 | 523 | """ |
466 | 524 | Move src directory to dst, if dst is already exists, |
|
0 commit comments