Skip to content
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
ca4efbb
Fix the lm_head in gptbigcode in lora mode
maxdebayser Jul 12, 2024
5d2cd9e
Enforce no quantization for lm_head
maxdebayser Jul 12, 2024
c5769a5
intervene in fight between yapf and isort
maxdebayser Jul 12, 2024
2d172e0
address review comments
maxdebayser Jul 13, 2024
b04d65b
address review comments
maxdebayser Jul 13, 2024
bc0bfcb
address review comments
maxdebayser Jul 13, 2024
73434eb
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Jul 13, 2024
0153ee9
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Jul 15, 2024
ed6d951
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Jul 16, 2024
04abf1b
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Jul 17, 2024
07aa7e4
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Jul 17, 2024
daf566f
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Jul 18, 2024
462deea
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Jul 23, 2024
2f36f3b
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Jul 24, 2024
3fb641b
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Jul 24, 2024
c690724
fix superclass method signature change
maxdebayser Jul 24, 2024
4b5dbd5
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Jul 25, 2024
1a8d282
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Jul 29, 2024
685bd96
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Jul 30, 2024
2d1cc82
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Jul 30, 2024
5c1fe68
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Jul 31, 2024
31b6e93
Merge remote-tracking branch 'remotes/upstream/main' into gpt_bigcode…
maxdebayser Aug 1, 2024
5ac5547
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Aug 2, 2024
b05e5c4
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Aug 5, 2024
d11655a
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Aug 6, 2024
7450a1d
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Aug 7, 2024
5f6ac9e
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Aug 10, 2024
b87657d
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Aug 12, 2024
b6867de
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Aug 15, 2024
b4abc03
Merge branch 'upstream_main' into gpt_bigcode_lora
maxdebayser Aug 20, 2024
ef9f79d
fix merge mistake
maxdebayser Aug 20, 2024
4700613
fix merge mistake
maxdebayser Aug 20, 2024
3f4037e
fix merge mistake
maxdebayser Aug 20, 2024
c809ed4
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Aug 26, 2024
33dd909
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Aug 27, 2024
c89dcf0
Merge branch 'upstream_main' into gpt_bigcode_lora
maxdebayser Sep 2, 2024
e31b2b6
add default prefix to get_quant_method
maxdebayser Sep 2, 2024
ea5ea7d
skip loading of lm_head only if tie_word_embeddings is True
maxdebayser Sep 2, 2024
10246be
Make weight tie work with quantization
maxdebayser Sep 2, 2024
9cef81e
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Sep 4, 2024
98a0269
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Sep 10, 2024
fb3eef0
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Sep 24, 2024
6028806
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Sep 26, 2024
db57383
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Oct 8, 2024
77e1966
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Oct 14, 2024
0b81edf
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Oct 15, 2024
b7558fa
Merge branch 'upstream_main' into gpt_bigcode_lora
maxdebayser Nov 12, 2024
7d6f1e9
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Nov 14, 2024
e43d46b
Merge branch 'main' into gpt_bigcode_lora
maxdebayser Feb 28, 2025
08f3360
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Mar 3, 2025
fad78ab
address review comments
maxdebayser Mar 3, 2025
1634ef0
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Mar 13, 2025
86c195e
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Mar 17, 2025
994ec18
Merge branch 'vllm-project:main' into gpt_bigcode_lora
maxdebayser Mar 21, 2025
a6f0308
Merge branch 'upstream_main' into gpt_bigcode_lora
maxdebayser Apr 23, 2025
8ee4962
Simplify the fix
maxdebayser Apr 23, 2025
1395572
Revert other changes
maxdebayser Apr 23, 2025
1d9b323
Merge branch 'upstream_main' into gpt_bigcode_lora
maxdebayser Apr 24, 2025
311a97c
Merge branch 'upstream_main' into gpt_bigcode_lora
maxdebayser May 13, 2025
ac01fee
Merge branch 'upstream_main' into gpt_bigcode_lora
maxdebayser May 15, 2025
e28e496
Merge branch 'upstream_main' into gpt_bigcode_lora
maxdebayser May 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import itertools
from abc import abstractmethod
from typing import Optional
from typing import List, Optional, Type

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -142,6 +142,36 @@ def apply(self,
return F.linear(x, layer.weight, bias)


class TiedWeightLinearMethod(UnquantizedLinearMethod):
"""Linear method base with noop create_weights

Can be used to prevent the initialization of weights
during the initialization of modules with weight tying.
"""

def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
...


class QuantizationConfigOverride(QuantizationConfig):
"""Config class to inject a specific LinearMethod.
"""

def __init__(self, cls: Type[LinearMethodBase]):
self.cls = cls

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional[LinearMethodBase]:
return self.cls()


QuantizationConfigOverride.__abstractmethods__ = frozenset()


class LinearBase(torch.nn.Module):
"""Base linear layer.

Expand Down
60 changes: 47 additions & 13 deletions vllm/model_executor/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,19 @@
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
QuantizationConfigOverride,
RowParallelLinear,
TiedWeightLinearMethod)
# yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -204,9 +209,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.wte = VocabParallelEmbedding(self.vocab_size,
self.embed_dim,
org_num_embeddings=config.vocab_size)
self.wte = VocabParallelEmbedding(
self.vocab_size,
self.embed_dim,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
Expand Down Expand Up @@ -248,6 +259,7 @@ def forward(
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {"c_attn": ["c_attn"]}

supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]
# LoRA specific attributes
embedding_modules = {
"wte": "input_embeddings",
Expand All @@ -266,16 +278,38 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
prefix=prefix)
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(
self.transformer.vocab_size,
self.transformer.embed_dim,
org_num_embeddings=self.config.vocab_size)

self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

if self.config.tie_word_embeddings:
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=QuantizationConfigOverride(
TiedWeightLinearMethod),
params_dtype=self.transformer.wte.weight.dtype,
)
self.lm_head.register_parameter("weight",
self.transformer.wte.weight)
else:
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)

self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
Expand Down Expand Up @@ -318,7 +352,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "lm_head.weight" in name:
if "lm_head.weight" in name and self.config.tie_word_embeddings:
continue
if ".attn.bias" in name:
# Skip attention mask.
Expand Down