Skip to content

Commit be6c09b

Browse files
Add Gemma3 GGUF multimodal support
- Enable GGUF multimodal detection in ModelConfig - Implement GGUF loader for vision tower and projector weights - Add processor support for GGUF Gemma3 models - Support multimodal embedding gathering in V1 engine Signed-off-by: Luciano Martins <[email protected]>
1 parent 1994de9 commit be6c09b

File tree

11 files changed

+703
-66
lines changed

11 files changed

+703
-66
lines changed

vllm/config/model.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,36 @@
6868

6969
logger = init_logger(__name__)
7070

71+
72+
def _detect_gguf_multimodal_gemma3(model: str) -> bool:
73+
"""Check if GGUF model has multimodal projector file for Gemma3.
74+
75+
Args:
76+
model: Model path string
77+
78+
Returns:
79+
True if this is a Gemma3 GGUF model with mmproj file, False otherwise
80+
"""
81+
if not model.endswith(".gguf"):
82+
return False
83+
84+
try:
85+
from pathlib import Path
86+
87+
model_path = Path(model)
88+
if not model_path.is_file():
89+
return False
90+
91+
model_dir = model_path.parent
92+
mmproj_patterns = ["mmproj.gguf", "mmproj-*.gguf", "*mmproj*.gguf"]
93+
for pattern in mmproj_patterns:
94+
if list(model_dir.glob(pattern)):
95+
return True
96+
return False
97+
except Exception:
98+
return False
99+
100+
71101
RunnerOption = Literal["auto", RunnerType]
72102
ConvertType = Literal["none", "embed", "classify", "reward"]
73103
ConvertOption = Literal["auto", ConvertType]
@@ -556,6 +586,46 @@ def __post_init__(
556586

557587
architectures = self.architectures
558588
registry = self.registry
589+
590+
# GGUF multimodal: Force Gemma3ForConditionalGeneration architecture
591+
# when mmproj file is present, before model resolution
592+
if _detect_gguf_multimodal_gemma3(self.model):
593+
is_gemma3 = any("gemma3" in str(arch).lower() for arch in architectures)
594+
if is_gemma3:
595+
architectures = ["Gemma3ForConditionalGeneration"]
596+
self.hf_config.architectures = architectures
597+
logger.info(
598+
"Detected Gemma3 GGUF with mmproj.gguf, "
599+
"forcing Gemma3ForConditionalGeneration"
600+
)
601+
602+
# Initialize vision_config if not present
603+
if (
604+
not hasattr(self.hf_config, "vision_config")
605+
or self.hf_config.vision_config is None
606+
):
607+
from transformers import SiglipVisionConfig
608+
609+
self.hf_config.vision_config = SiglipVisionConfig(
610+
hidden_size=1152,
611+
intermediate_size=4304,
612+
num_hidden_layers=27,
613+
num_attention_heads=16,
614+
num_channels=3,
615+
image_size=896,
616+
patch_size=14,
617+
layer_norm_eps=1e-6,
618+
attention_dropout=0.0,
619+
num_image_tokens=256,
620+
# Disable pooling head for Gemma3
621+
vision_use_head=False,
622+
)
623+
self.hf_config.mm_tokens_per_image = 256
624+
self.hf_config.image_token_index = 262144
625+
# DO NOT set boi_token_index - let
626+
# gemma3_mm.py fall back to 262143
627+
self.hf_config.eoi_token_index = 256000
628+
559629
is_generative_model = registry.is_text_generation_model(architectures, self)
560630
is_pooling_model = registry.is_pooling_model(architectures, self)
561631

@@ -701,8 +771,24 @@ def _task_to_convert(task: TaskOption) -> ConvertType:
701771

702772
self.original_max_model_len = self.max_model_len
703773
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
774+
775+
# GGUF multimodal: Set flag to initialize multimodal_config
776+
# when Gemma3 mmproj file is present
777+
is_gguf_multimodal = False
778+
if _detect_gguf_multimodal_gemma3(self.model):
779+
is_gemma3 = any(
780+
"gemma3" in str(arch).lower() for arch in self.architectures
781+
)
782+
if is_gemma3:
783+
is_gguf_multimodal = True
784+
logger.info(
785+
"Detected Gemma3 GGUF multimodal model "
786+
"with mmproj.gguf, initializing "
787+
"multimodal_config"
788+
)
789+
704790
# Init multimodal config if needed
705-
if self._model_info.supports_multimodal:
791+
if self._model_info.supports_multimodal or is_gguf_multimodal:
706792
if (
707793
mm_encoder_tp_mode == "data"
708794
and not self._model_info.supports_multimodal_encoder_tp_data
@@ -888,8 +974,6 @@ def _get_default_runner_type(
888974
_, (runner_type, _) = match
889975
return runner_type
890976

891-
return "generate"
892-
893977
def _get_runner_type(
894978
self,
895979
architectures: list[str],

vllm/model_executor/layers/quantization/gguf.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,12 @@
3636
class GGUFConfig(QuantizationConfig):
3737
"""Config class for GGUF."""
3838

39-
def __init__(self, unquantized_modules: list[str] | None = None) -> None:
39+
def __init__(
40+
self, unquantized_modules: list[str] | None = None, model_arch: str = ""
41+
) -> None:
4042
super().__init__()
4143
self.unquantized_modules = unquantized_modules or []
44+
self.model_arch = model_arch
4245

4346
def __repr__(self) -> str:
4447
return "GGUFConfig()"
@@ -59,17 +62,19 @@ def get_config_filenames(cls) -> list[str]:
5962

6063
@classmethod
6164
def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
62-
return cls()
65+
# Extract model_arch from config if available
66+
model_arch = config.get("model_arch", "")
67+
return cls(model_arch=model_arch)
6368

6469
def get_quant_method(
6570
self, layer: torch.nn.Module, prefix: str
6671
) -> Optional["QuantizeMethodBase"]:
6772
if isinstance(layer, LinearBase):
6873
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
6974
return UnquantizedLinearMethod()
70-
return GGUFLinearMethod(self)
75+
return GGUFLinearMethod(self, self.model_arch)
7176
elif isinstance(layer, VocabParallelEmbedding):
72-
return GGUFEmbeddingMethod(self)
77+
return GGUFEmbeddingMethod(self, self.model_arch)
7378
elif isinstance(layer, FusedMoE):
7479
return GGUFMoEMethod(self, layer.moe_config)
7580
return None
@@ -115,7 +120,10 @@ def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
115120

116121

117122
def _fused_mul_mat_gguf(
118-
x: torch.Tensor, qweight: torch.Tensor, qweight_type: int
123+
x: torch.Tensor,
124+
qweight: torch.Tensor,
125+
qweight_type: int,
126+
target_dtype: torch.dtype | None = None,
119127
) -> torch.Tensor:
120128
if qweight_type in IMATRIX_QUANT_TYPES:
121129
mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
@@ -125,9 +133,15 @@ def _fused_mul_mat_gguf(
125133
# so input to logits generator is empty which causes invalid parameter
126134
if x.shape[0] == 0:
127135
return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
128-
# there is no need to call any kernel for fp16/bf16
136+
# Unquantized weights (F16/BF16) use direct matrix multiplication
129137
if qweight_type in UNQUANTIZED_TYPES:
130-
return x @ qweight.T
138+
# GGUF stores some weights as F16, but the model may use bfloat16 for
139+
# computation. Convert weight dtype to match target_dtype (typically
140+
# params_dtype=bfloat16) to ensure consistency in mixed-precision.
141+
weight = qweight
142+
if target_dtype is not None and weight.dtype != target_dtype:
143+
weight = weight.to(target_dtype)
144+
return x @ weight.T
131145
# enable MMVQ in contiguous batching with batch_size=1
132146
if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
133147
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
@@ -138,7 +152,9 @@ def _fused_mul_mat_gguf(
138152
elif qweight_type in DEQUANT_TYPES:
139153
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
140154
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
141-
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
155+
# Use target_dtype if provided, otherwise fall back to x.dtype
156+
dequant_dtype = target_dtype if target_dtype is not None else x.dtype
157+
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, dequant_dtype)
142158
y = x @ weight.T
143159
else:
144160
# Raise an error if the quantization type is not supported.
@@ -153,6 +169,7 @@ def _fused_mul_mat_gguf_fake(
153169
x: torch.Tensor,
154170
qweight: torch.Tensor,
155171
qweight_type: int,
172+
target_dtype: torch.dtype | None = None,
156173
) -> torch.Tensor:
157174
return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
158175

@@ -311,7 +328,13 @@ def _apply_gguf_embedding(
311328
dtype: torch.dtype | None = None,
312329
) -> torch.Tensor:
313330
if qweight_type in UNQUANTIZED_TYPES:
314-
return torch.embedding(qweight, x)
331+
# torch.embedding preserves weight tensor dtype (F16 for GGUF).
332+
# Convert result to model's computation dtype (typically bfloat16)
333+
# to maintain consistency throughout forward pass.
334+
result = torch.embedding(qweight, x)
335+
if dtype is not None and result.dtype != dtype:
336+
result = result.to(dtype)
337+
return result
315338
elif qweight_type in DEQUANT_TYPES:
316339
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
317340
x_flat = x.flatten()
@@ -355,7 +378,9 @@ class GGUFLinearMethod(LinearMethodBase):
355378
quant_config: The GGUF quantization config.
356379
"""
357380

358-
def __init__(self, quant_config: GGUFConfig):
381+
def __init__(self, quant_config: GGUFConfig, model_arch: str = ""):
382+
self.model_arch = model_arch.lower()
383+
self.is_gemma3 = "gemma3" in self.model_arch
359384
self.quant_config = quant_config
360385

361386
def create_weights(
@@ -369,6 +394,9 @@ def create_weights(
369394
**extra_weight_attrs,
370395
):
371396
self.params_dtype = params_dtype
397+
# Gemma3 isolation: only apply dtype conversion for Gemma3 models
398+
self.target_dtype = params_dtype if self.is_gemma3 else None
399+
372400
output_size_per_partition = sum(output_partition_sizes)
373401

374402
tensor_shape = (output_size_per_partition, input_size_per_partition)
@@ -467,14 +495,18 @@ def apply(
467495
qweight_type = layer.qweight_type.shard_weight_type[idx]
468496
result.append(
469497
fused_mul_mat_gguf(
470-
x, qweight[start:end, :offset].contiguous(), qweight_type
498+
x,
499+
qweight[start:end, :offset].contiguous(),
500+
qweight_type,
501+
self.target_dtype,
471502
)
472503
)
473504
out = torch.cat(result, axis=1)
474505
else:
475506
qweight = layer.qweight
476507
qweight_type = layer.qweight_type.weight_type
477-
out = fused_mul_mat_gguf(x, qweight, qweight_type)
508+
# Gemma3 isolation: only apply dtype conversion for Gemma3
509+
out = fused_mul_mat_gguf(x, qweight, qweight_type, self.target_dtype)
478510
if bias is not None:
479511
out.add_(bias)
480512
return out

0 commit comments

Comments
 (0)