From e78542f97feeb95f3b0643338c05fe8455c13def Mon Sep 17 00:00:00 2001 From: Jacob Schein Date: Sun, 4 Aug 2024 12:55:06 -0700 Subject: [PATCH 1/2] fix: Specify device when loading LoRA and embedding tensors - Add map_location="device" when loading LoRA tensors from .bin files - Add map_location="device" when loading new embeddings from .bin files --- vllm/lora/models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 017a1002bb9a..2ce46204ab9f 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -248,7 +248,7 @@ def from_local_checkpoint( f" target modules in {expected_lora_modules}" f" but received {unexpected_modules}." f" Please verify that the loaded LoRA module is correct") - tensors = torch.load(lora_bin_file_path) + tensors = torch.load(lora_bin_file_path, map_location="device") else: raise ValueError(f"{lora_dir} doesn't contain tensors") @@ -257,7 +257,8 @@ def from_local_checkpoint( embeddings = safetensors.torch.load_file( new_embeddings_tensor_path) elif os.path.isfile(new_embeddings_bin_file_path): - embeddings = torch.load(new_embeddings_bin_file_path) + embeddings = torch.load(new_embeddings_bin_file_path, + map_location="device") rank = config["r"] lora_alpha = config["lora_alpha"] From 274c3a23c3d46b1d246c5dc4ab29d71f8f9ebbcd Mon Sep 17 00:00:00 2001 From: jischein Date: Mon, 5 Aug 2024 09:50:36 -0700 Subject: [PATCH 2/2] Use device variable --- vllm/lora/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 2ce46204ab9f..279477562a94 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -248,7 +248,7 @@ def from_local_checkpoint( f" target modules in {expected_lora_modules}" f" but received {unexpected_modules}." f" Please verify that the loaded LoRA module is correct") - tensors = torch.load(lora_bin_file_path, map_location="device") + tensors = torch.load(lora_bin_file_path, map_location=device) else: raise ValueError(f"{lora_dir} doesn't contain tensors") @@ -258,7 +258,7 @@ def from_local_checkpoint( new_embeddings_tensor_path) elif os.path.isfile(new_embeddings_bin_file_path): embeddings = torch.load(new_embeddings_bin_file_path, - map_location="device") + map_location=device) rank = config["r"] lora_alpha = config["lora_alpha"]