Skip to content

Commit 66b3023

Browse files
Add Gemma3 GGUF multimodal support
- Enable GGUF multimodal detection in ModelConfig - Implement GGUF loader for vision tower and projector weights - Add processor support for GGUF Gemma3 models - Support multimodal embedding gathering in V1 engine Signed-off-by: Luciano Martins <[email protected]>
1 parent b5d90f7 commit 66b3023

File tree

11 files changed

+725
-69
lines changed

11 files changed

+725
-69
lines changed

vllm/config/model.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,35 @@
6969

7070
logger = init_logger(__name__)
7171

72+
73+
def _detect_gguf_multimodal_gemma3(model: str) -> bool:
74+
"""Check if GGUF model has multimodal projector file for Gemma3.
75+
76+
Args:
77+
model: Model path string
78+
79+
Returns:
80+
True if this is a Gemma3 GGUF model with mmproj file, False otherwise
81+
"""
82+
if not model.endswith('.gguf'):
83+
return False
84+
85+
try:
86+
from pathlib import Path
87+
model_path = Path(model)
88+
if not model_path.is_file():
89+
return False
90+
91+
model_dir = model_path.parent
92+
mmproj_patterns = ['mmproj.gguf', 'mmproj-*.gguf', '*mmproj*.gguf']
93+
for pattern in mmproj_patterns:
94+
if list(model_dir.glob(pattern)):
95+
return True
96+
return False
97+
except Exception:
98+
return False
99+
100+
72101
RunnerOption = Literal["auto", RunnerType]
73102
ConvertType = Literal["none", "embed", "classify", "reward"]
74103
ConvertOption = Literal["auto", ConvertType]
@@ -560,6 +589,45 @@ def __post_init__(
560589

561590
architectures = self.architectures
562591
registry = self.registry
592+
593+
# GGUF multimodal: Force Gemma3ForConditionalGeneration architecture
594+
# when mmproj file is present, before model resolution
595+
if _detect_gguf_multimodal_gemma3(self.model):
596+
is_gemma3 = any(
597+
'gemma3' in str(arch).lower()
598+
for arch in architectures
599+
)
600+
if is_gemma3:
601+
architectures = ["Gemma3ForConditionalGeneration"]
602+
self.hf_config.architectures = architectures
603+
logger.info(
604+
"Detected Gemma3 GGUF with mmproj.gguf, "
605+
"forcing Gemma3ForConditionalGeneration")
606+
607+
# Initialize vision_config if not present
608+
if not hasattr(self.hf_config, 'vision_config') or \
609+
self.hf_config.vision_config is None:
610+
from transformers import SiglipVisionConfig
611+
self.hf_config.vision_config = SiglipVisionConfig(
612+
hidden_size=1152,
613+
intermediate_size=4304,
614+
num_hidden_layers=27,
615+
num_attention_heads=16,
616+
num_channels=3,
617+
image_size=896,
618+
patch_size=14,
619+
layer_norm_eps=1e-6,
620+
attention_dropout=0.0,
621+
num_image_tokens=256,
622+
# Disable pooling head for Gemma3
623+
vision_use_head=False,
624+
)
625+
self.hf_config.mm_tokens_per_image = 256
626+
self.hf_config.image_token_index = 262144
627+
# DO NOT set boi_token_index - let
628+
# gemma3_mm.py fall back to 262143
629+
self.hf_config.eoi_token_index = 256000
630+
563631
is_generative_model = registry.is_text_generation_model(architectures, self)
564632
is_pooling_model = registry.is_pooling_model(architectures, self)
565633

@@ -722,8 +790,25 @@ def _task_to_convert(task: TaskOption) -> ConvertType:
722790

723791
self.original_max_model_len = self.max_model_len
724792
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
793+
794+
# GGUF multimodal: Set flag to initialize multimodal_config
795+
# when Gemma3 mmproj file is present
796+
is_gguf_multimodal = False
797+
if _detect_gguf_multimodal_gemma3(self.model):
798+
is_gemma3 = any(
799+
'gemma3' in str(arch).lower()
800+
for arch in self.architectures
801+
)
802+
if is_gemma3:
803+
is_gguf_multimodal = True
804+
logger.info(
805+
"Detected Gemma3 GGUF multimodal model "
806+
"with mmproj.gguf, initializing "
807+
"multimodal_config"
808+
)
809+
725810
# Init multimodal config if needed
726-
if self._model_info.supports_multimodal:
811+
if self._model_info.supports_multimodal or is_gguf_multimodal:
727812
if (
728813
mm_encoder_tp_mode == "data"
729814
and not self._model_info.supports_multimodal_encoder_tp_data
@@ -909,8 +994,6 @@ def _get_default_runner_type(
909994
_, (runner_type, _) = match
910995
return runner_type
911996

912-
return "generate"
913-
914997
def _get_runner_type(
915998
self,
916999
architectures: list[str],

vllm/model_executor/layers/quantization/gguf.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@
3636
class GGUFConfig(QuantizationConfig):
3737
"""Config class for GGUF."""
3838

39-
def __init__(self, unquantized_modules: list[str] | None = None) -> None:
39+
def __init__(self, unquantized_modules: list[str] | None = None,
40+
model_arch: str = "") -> None:
4041
super().__init__()
4142
self.unquantized_modules = unquantized_modules or []
43+
self.model_arch = model_arch
4244

4345
def __repr__(self) -> str:
4446
return "GGUFConfig()"
@@ -59,17 +61,19 @@ def get_config_filenames(cls) -> list[str]:
5961

6062
@classmethod
6163
def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
62-
return cls()
64+
# Extract model_arch from config if available
65+
model_arch = config.get("model_arch", "")
66+
return cls(model_arch=model_arch)
6367

6468
def get_quant_method(
6569
self, layer: torch.nn.Module, prefix: str
6670
) -> Optional["QuantizeMethodBase"]:
6771
if isinstance(layer, LinearBase):
6872
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
6973
return UnquantizedLinearMethod()
70-
return GGUFLinearMethod(self)
74+
return GGUFLinearMethod(self, self.model_arch)
7175
elif isinstance(layer, VocabParallelEmbedding):
72-
return GGUFEmbeddingMethod(self)
76+
return GGUFEmbeddingMethod(self, self.model_arch)
7377
elif isinstance(layer, FusedMoE):
7478
return GGUFMoEMethod(self, layer.moe_config)
7579
return None
@@ -115,7 +119,8 @@ def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
115119

116120

117121
def _fused_mul_mat_gguf(
118-
x: torch.Tensor, qweight: torch.Tensor, qweight_type: int
122+
x: torch.Tensor, qweight: torch.Tensor, qweight_type: int,
123+
target_dtype: Optional[torch.dtype] = None
119124
) -> torch.Tensor:
120125
if qweight_type in IMATRIX_QUANT_TYPES:
121126
mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
@@ -125,9 +130,15 @@ def _fused_mul_mat_gguf(
125130
# so input to logits generator is empty which causes invalid parameter
126131
if x.shape[0] == 0:
127132
return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
128-
# there is no need to call any kernel for fp16/bf16
133+
# Unquantized weights (F16/BF16) use direct matrix multiplication
129134
if qweight_type in UNQUANTIZED_TYPES:
130-
return x @ qweight.T
135+
# GGUF stores some weights as F16, but the model may use bfloat16 for
136+
# computation. Convert weight dtype to match target_dtype (typically
137+
# params_dtype=bfloat16) to ensure consistency in mixed-precision.
138+
weight = qweight
139+
if target_dtype is not None and weight.dtype != target_dtype:
140+
weight = weight.to(target_dtype)
141+
return x @ weight.T
131142
# enable MMVQ in contiguous batching with batch_size=1
132143
if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
133144
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
@@ -138,7 +149,9 @@ def _fused_mul_mat_gguf(
138149
elif qweight_type in DEQUANT_TYPES:
139150
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
140151
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
141-
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
152+
# Use target_dtype if provided, otherwise fall back to x.dtype
153+
dequant_dtype = target_dtype if target_dtype is not None else x.dtype
154+
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, dequant_dtype)
142155
y = x @ weight.T
143156
else:
144157
# Raise an error if the quantization type is not supported.
@@ -153,6 +166,7 @@ def _fused_mul_mat_gguf_fake(
153166
x: torch.Tensor,
154167
qweight: torch.Tensor,
155168
qweight_type: int,
169+
target_dtype: Optional[torch.dtype] = None,
156170
) -> torch.Tensor:
157171
return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
158172

@@ -311,7 +325,13 @@ def _apply_gguf_embedding(
311325
dtype: torch.dtype | None = None,
312326
) -> torch.Tensor:
313327
if qweight_type in UNQUANTIZED_TYPES:
314-
return torch.embedding(qweight, x)
328+
# torch.embedding preserves weight tensor dtype (F16 for GGUF).
329+
# Convert result to model's computation dtype (typically bfloat16)
330+
# to maintain consistency throughout forward pass.
331+
result = torch.embedding(qweight, x)
332+
if dtype is not None and result.dtype != dtype:
333+
result = result.to(dtype)
334+
return result
315335
elif qweight_type in DEQUANT_TYPES:
316336
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
317337
x_flat = x.flatten()
@@ -355,7 +375,9 @@ class GGUFLinearMethod(LinearMethodBase):
355375
quant_config: The GGUF quantization config.
356376
"""
357377

358-
def __init__(self, quant_config: GGUFConfig):
378+
def __init__(self, quant_config: GGUFConfig, model_arch: str = ""):
379+
self.model_arch = model_arch.lower()
380+
self.is_gemma3 = "gemma3" in self.model_arch
359381
self.quant_config = quant_config
360382

361383
def create_weights(
@@ -369,6 +391,9 @@ def create_weights(
369391
**extra_weight_attrs,
370392
):
371393
self.params_dtype = params_dtype
394+
# Gemma3 isolation: only apply dtype conversion for Gemma3 models
395+
self.target_dtype = params_dtype if self.is_gemma3 else None
396+
372397
output_size_per_partition = sum(output_partition_sizes)
373398

374399
tensor_shape = (output_size_per_partition, input_size_per_partition)
@@ -467,14 +492,16 @@ def apply(
467492
qweight_type = layer.qweight_type.shard_weight_type[idx]
468493
result.append(
469494
fused_mul_mat_gguf(
470-
x, qweight[start:end, :offset].contiguous(), qweight_type
495+
x, qweight[start:end, :offset].contiguous(), qweight_type,
496+
self.target_dtype
471497
)
472498
)
473499
out = torch.cat(result, axis=1)
474500
else:
475501
qweight = layer.qweight
476502
qweight_type = layer.qweight_type.weight_type
477-
out = fused_mul_mat_gguf(x, qweight, qweight_type)
503+
# Gemma3 isolation: only apply dtype conversion for Gemma3
504+
out = fused_mul_mat_gguf(x, qweight, qweight_type, self.target_dtype)
478505
if bias is not None:
479506
out.add_(bias)
480507
return out

0 commit comments

Comments
 (0)