Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/model_executor/model_loader/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def _get_gguf_weights_map(self, model_config: ModelConfig):
# hack: ggufs have a different name than transformers
if model_type == "cohere":
model_type = "command-r"
if model_type == "gemma3_text":
# Gemma3 models use "gemma3_text" in HuggingFace but
# "gemma3" in GGUF architecture naming
model_type = "gemma3"
if model_type in ("deepseek_v3", "deepseek_v2"):
model_type = "deepseek2"
# GGUF layer map assumes that we will have a merged expert weights
Expand Down
37 changes: 36 additions & 1 deletion vllm/model_executor/models/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from ...attention.layers.encoder_only_attention import EncoderOnlyAttention
Expand Down Expand Up @@ -442,6 +443,20 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# Apply GGUF-specific RMSNorm weight correction for Gemma3
# This must happen BEFORE any transformations (transpose, etc.)
# GemmaRMSNorm computes: output = x * (1 + weight)
# GGUF stores full weight values (for standard x * weight)
# but vLLM's GemmaRMSNorm expects (weight - 1) since it adds 1
# during the forward pass.
if (
self.quant_config is not None
and self.quant_config.get_name() == "gguf"
and "norm" in name
and len(loaded_weight.shape) == 1
):
loaded_weight = loaded_weight - 1.0

if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
Expand Down Expand Up @@ -485,6 +500,21 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip GGUF qweight_type metadata for layers that don't have it
# (e.g., embedding layers). These are handled by GGUF
# quantization layers.
if name.endswith(".qweight_type") and name not in params_dict:
continue

# Handle GGUF qweight for embedding and other non-merged layers
# GGUF uses .qweight for quantized weights, but some layers
# (like VocabParallelEmbedding) expect .weight
if name.endswith(".qweight") and name not in params_dict:
# Try to load as regular weight instead
name = name.replace(".qweight", ".weight")
if name not in params_dict:
continue

# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
Expand Down Expand Up @@ -519,6 +549,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
del lora_config # Unused.
super().__init__()
self.config = config
# Store model config for quantization access
self.model_config = vllm_config.model_config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.quant_config = quant_config
Expand Down Expand Up @@ -551,8 +583,11 @@ def forward(
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.embed_tokens, hidden_states)
logits = self.logits_processor(
self.model.embed_tokens, hidden_states, sampling_metadata
)
return logits

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Expand Down