Skip to content

Commit d7618b1

Browse files
committed
[fix] support loading model from local path
Signed-off-by: 齐保元 <[email protected]>
1 parent 86841b6 commit d7618b1

1 file changed

Lines changed: 10 additions & 5 deletions

File tree

vllm_omni/model_executor/models/qwen2_5_omni.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -819,11 +819,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
819819
# Load token2wav weights (if any)
820820
if token2wav_weights and self.token2wav is not None:
821821
# download weights from huggingface for spk_dict.pt
822-
hf_model_folder = download_weights_from_hf_specific(
823-
self.vllm_config.model_config.model,
824-
self.vllm_config.load_config.download_dir,
825-
allow_patterns=["*.pt"],
826-
)
822+
model_path = self.vllm_config.model_config.model
823+
download_dir = self.vllm_config.load_config.download_dir
824+
if os.path.exists(model_path):
825+
hf_model_folder = model_path
826+
else:
827+
hf_model_folder = download_weights_from_hf_specific(
828+
model_path,
829+
download_dir,
830+
allow_patterns=["*.pt"],
831+
)
827832
self._init_token2wav_model(hf_model_folder)
828833
t2w_loaded = self.token2wav.load_weights(
829834
token2wav_weights, os.path.join(hf_model_folder, "spk_dict.pt")

0 commit comments

Comments
 (0)