Skip to content

Commit 1a94642

Browse files
jeejeeleeAkshat-Tripathi
authored andcommitted
[Bugfix] Add file lock for ModelScope download (vllm-project#14060)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent c875893 commit 1a94642

File tree

4 files changed

+40
-22
lines changed

4 files changed

+40
-22
lines changed

benchmarks/backend_request_func.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from transformers import (AutoTokenizer, PreTrainedTokenizer,
1515
PreTrainedTokenizerFast)
1616

17+
from vllm.model_executor.model_loader.weight_utils import get_lock
18+
1719
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
1820

1921

@@ -430,12 +432,15 @@ def get_model(pretrained_model_name_or_path: str) -> str:
430432
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
431433
from modelscope import snapshot_download
432434

433-
model_path = snapshot_download(
434-
model_id=pretrained_model_name_or_path,
435-
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
436-
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
435+
# Use file lock to prevent multiple processes from
436+
# downloading the same model weights at the same time.
437+
with get_lock(pretrained_model_name_or_path):
438+
model_path = snapshot_download(
439+
model_id=pretrained_model_name_or_path,
440+
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
441+
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
437442

438-
return model_path
443+
return model_path
439444
return pretrained_model_name_or_path
440445

441446

vllm/model_executor/model_loader/loader.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from vllm.model_executor.model_loader.weight_utils import (
5050
download_safetensors_index_file_from_hf, download_weights_from_hf,
5151
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
52-
get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
52+
get_gguf_extra_tensor_names, get_lock, gguf_quant_weights_iterator,
5353
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
5454
runai_safetensors_weights_iterator, safetensors_weights_iterator)
5555
from vllm.model_executor.utils import set_weight_attrs
@@ -235,13 +235,17 @@ def _maybe_download_from_modelscope(
235235
from modelscope.hub.snapshot_download import snapshot_download
236236

237237
if not os.path.exists(model):
238-
model_path = snapshot_download(
239-
model_id=model,
240-
cache_dir=self.load_config.download_dir,
241-
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
242-
revision=revision,
243-
ignore_file_pattern=self.load_config.ignore_patterns,
244-
)
238+
# Use file lock to prevent multiple processes from
239+
# downloading the same model weights at the same time.
240+
with get_lock(model, self.load_config.download_dir):
241+
model_path = snapshot_download(
242+
model_id=model,
243+
cache_dir=self.load_config.download_dir,
244+
local_files_only=huggingface_hub.constants.
245+
HF_HUB_OFFLINE,
246+
revision=revision,
247+
ignore_file_pattern=self.load_config.ignore_patterns,
248+
)
245249
else:
246250
model_path = model
247251
return model_path

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import tempfile
99
import time
1010
from collections import defaultdict
11+
from pathlib import Path
1112
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
1213

1314
import filelock
@@ -67,8 +68,10 @@ def __init__(self, *args, **kwargs):
6768
super().__init__(*args, **kwargs, disable=True)
6869

6970

70-
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
71+
def get_lock(model_name_or_path: Union[str, Path],
72+
cache_dir: Optional[str] = None):
7173
lock_dir = cache_dir or temp_dir
74+
model_name_or_path = str(model_name_or_path)
7275
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
7376
model_name = model_name_or_path.replace("/", "-")
7477
hash_name = hashlib.sha256(model_name.encode()).hexdigest()

vllm/transformers_utils/tokenizer.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,16 +150,22 @@ def get_tokenizer(
150150
# pylint: disable=C.
151151
from modelscope.hub.snapshot_download import snapshot_download
152152

153+
# avoid circuit import
154+
from vllm.model_executor.model_loader.weight_utils import get_lock
155+
153156
# Only set the tokenizer here, model will be downloaded on the workers.
154157
if not os.path.exists(tokenizer_name):
155-
tokenizer_path = snapshot_download(
156-
model_id=tokenizer_name,
157-
cache_dir=download_dir,
158-
revision=revision,
159-
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
160-
# Ignore weights - we only need the tokenizer.
161-
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
162-
tokenizer_name = tokenizer_path
158+
# Use file lock to prevent multiple processes from
159+
# downloading the same file at the same time.
160+
with get_lock(tokenizer_name, download_dir):
161+
tokenizer_path = snapshot_download(
162+
model_id=tokenizer_name,
163+
cache_dir=download_dir,
164+
revision=revision,
165+
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
166+
# Ignore weights - we only need the tokenizer.
167+
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
168+
tokenizer_name = tokenizer_path
163169

164170
if tokenizer_mode == "slow":
165171
if kwargs.get("use_fast", False):

0 commit comments

Comments
 (0)