Skip to content

Commit 8fd7c16

Browse files
sheikheddyclaude
andcommitted
fix: Correct embedding dimension logic in LoRA dummy creation
Fixed incorrect fallback logic for embedding layers where dimensions were reversed. ## Problem For embedding layers with shape [vocab_size, embedding_dim]: - input_dim should be vocab_size (shape[0]) - output_dim should be embedding_dim (shape[1]) - embeddings_tensor_dim should be embedding_dim (shape[1]) Previous code had: - input_dim fallback: shape[1] ❌ (was getting embedding_dim instead of vocab_size) - output_dim fallback: shape[0] ❌ (was getting vocab_size instead of embedding_dim) - embeddings_tensor_dim: Used input_size instead of output_size ❌ ## Fix Corrected all fallback paths to use proper dimensions for embedding layers: - input_dim: shape[0] (vocab_size) - output_dim: shape[1] (embedding_dim) - embeddings_tensor_dim: shape[1] (embedding_dim) Also fixed elif chain to check output_size instead of input_size for embeddings_tensor_dim. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> Signed-off-by: sheikheddy <[email protected]>
1 parent 2a0f94e commit 8fd7c16

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

vllm/lora/models.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -624,28 +624,34 @@ def create_dummy_lora(
624624
input_dim = module.base_layer.input_size
625625
elif hasattr(module.base_layer, "weight_shape"):
626626
# Compressed tensors: weight_shape stores [output, input]
627-
input_dim = module.base_layer.weight_shape[1].item()
627+
# For embeddings: [vocab_size, embedding_dim]
628+
input_dim = module.base_layer.weight_shape[0].item()
628629
else:
629-
input_dim = module.weight.shape[1]
630+
# For embeddings: weight.shape = [vocab_size, embedding_dim]
631+
input_dim = module.weight.shape[0]
630632

631633
if hasattr(module.base_layer, "embedding_dim"):
632634
output_dim = module.base_layer.embedding_dim
633635
elif hasattr(module.base_layer, "output_size"):
634636
output_dim = module.base_layer.output_size
635637
elif hasattr(module.base_layer, "weight_shape"):
636638
# Compressed tensors: weight_shape stores [output, input]
637-
output_dim = module.base_layer.weight_shape[0].item()
639+
# For embeddings: [vocab_size, embedding_dim]
640+
output_dim = module.base_layer.weight_shape[1].item()
638641
else:
639-
output_dim = module.weight.shape[0]
642+
# For embeddings: weight.shape = [vocab_size, embedding_dim]
643+
output_dim = module.weight.shape[1]
640644

641645
if hasattr(module.base_layer, "embedding_dim"):
642646
embeddings_tensor_dim = module.base_layer.embedding_dim
643-
elif hasattr(module.base_layer, "input_size"):
644-
embeddings_tensor_dim = module.base_layer.input_size
647+
elif hasattr(module.base_layer, "output_size"):
648+
embeddings_tensor_dim = module.base_layer.output_size
645649
elif hasattr(module.base_layer, "weight_shape"):
646650
# Compressed tensors: weight_shape stores [output, input]
651+
# For embeddings: [vocab_size, embedding_dim]
647652
embeddings_tensor_dim = module.base_layer.weight_shape[1].item()
648653
else:
654+
# For embeddings: weight.shape = [vocab_size, embedding_dim]
649655
embeddings_tensor_dim = module.weight.shape[1]
650656
lora = LoRALayerWeights.create_dummy_lora_weights(
651657
module_name,

0 commit comments

Comments
 (0)