Skip to content
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
6cb4b22
Add Gemma3 GGUF multimodal support
lucianommartins Oct 29, 2025
f9f1db1
Remove deprecated V0 compatibility code
lucianommartins Oct 30, 2025
387d5ed
Address architectural feedback:
lucianommartins Oct 30, 2025
dd162d6
Address feedbacks
lucianommartins Oct 30, 2025
be3fa26
Fix ReadTheDocs type annotations in utils.py
lucianommartins Oct 30, 2025
67a6006
Fix type annotation imports - move to module level
lucianommartins Oct 30, 2025
1b36cb5
fix: restore model_arch parameter for GGUF dtype handling
lucianommartins Oct 31, 2025
2e90620
fix: restore model_arch parameter for GGUF dtype handling
lucianommartins Oct 31, 2025
2a55c43
Address reviewer feedback: remove hardcoded values, apply fail-fast
lucianommartins Nov 1, 2025
ee439e4
refactor: eliminate code duplication and generalize multimodal GGUF s…
lucianommartins Nov 1, 2025
2b91abf
refactor: address code review feedback for Gemma3 GGUF
lucianommartins Nov 3, 2025
18276d4
Addressing reviews/feedbacks.
lucianommartins Nov 3, 2025
0ce2423
reverting cosmetic changes.
lucianommartins Nov 4, 2025
50a66c8
refactor: reorganize GGUF utilities and extract config patching
lucianommartins Nov 4, 2025
a480c72
Add Gemma3 GGUF multimodal generation tests
lucianommartins Nov 5, 2025
e86ae43
test: split Gemma3 tests into separate GGUF and HF files
lucianommartins Nov 6, 2025
7cc3406
clean
Isotr0py Nov 9, 2025
7c42bf3
avoid reading GGUF multiple times
Isotr0py Nov 9, 2025
4401fa5
better hf_config patch
Isotr0py Nov 9, 2025
ba1536b
clean test
Isotr0py Nov 9, 2025
556a044
remove redundant vibe coding test
Isotr0py Nov 9, 2025
d3a69b9
update autom tensor mapping
Isotr0py Nov 10, 2025
f6f48ec
gguf: implement automatic mmproj weight mapping with filtering
lucianommartins Nov 10, 2025
dcd2d5e
fix(gguf): resolve Gemma3 multimodal parameter prefix mismatch
lucianommartins Nov 12, 2025
ea8a03c
fix(gguf): resolve Gemma3 multimodal parameter prefix mismatch
lucianommartins Nov 12, 2025
2f5eac0
bump gguf version
Isotr0py Nov 13, 2025
4579031
refactor(gguf): use official GGUF constants for vision config extraction
lucianommartins Nov 13, 2025
05ce7f3
clean
Isotr0py Nov 14, 2025
19539e5
fix Gemma3 GGUF multimodal test with correct processor loading
lucianommartins Nov 14, 2025
8c7c735
unify unqunatized weights handling
Isotr0py Nov 15, 2025
466f1ef
revert unnecessary changes
Isotr0py Nov 15, 2025
a7934c9
update gemma3mm embed_input_ids
Isotr0py Nov 15, 2025
8365694
fix
Isotr0py Nov 15, 2025
b699a6d
clean model config
Isotr0py Nov 15, 2025
f915c21
clean processor to load processor from tokenizer
Isotr0py Nov 15, 2025
6d0ae74
feat(gguf): Finish gguf loader code cleanup
lucianommartins Nov 15, 2025
9e0e6f5
move tokenizer validation to model_config
Isotr0py Nov 15, 2025
9846458
fix test
Isotr0py Nov 15, 2025
2f2de1c
revert gemma3 text backbone
Isotr0py Nov 15, 2025
211131f
compatability with unsloth gguf
Isotr0py Nov 16, 2025
f0b0b03
fix broken inc quant
Isotr0py Nov 16, 2025
803a166
revert unnecessary changes
Isotr0py Nov 16, 2025
be831a7
revert unnecessary changes
Isotr0py Nov 16, 2025
12a26cf
clean and correct some comments
Isotr0py Nov 16, 2025
5caf036
update test
Isotr0py Nov 16, 2025
9077fc4
fix deadlock
Isotr0py Nov 16, 2025
8d6cfd4
gemini
Isotr0py Nov 16, 2025
86a99d2
Add vocab_size override and improve automatic weight mapping
lucianommartins Nov 16, 2025
9c071b9
fix: Update CI test configurations for Gemma3 compatibility
lucianommartins Nov 16, 2025
5a245bd
refactor(gguf): remove redundant vocab_size extraction
lucianommartins Nov 17, 2025
ef3a0ab
fix(v1): add hasattr check before calling generate_attention_masks
lucianommartins Nov 17, 2025
ea1525a
revert
Isotr0py Nov 18, 2025
46d8746
Merge branch 'main' into main
Isotr0py Nov 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/31
partial-json-parser # used for parsing partial JSON outputs
pyzmq >= 25.0.0
msgspec
gguf >= 0.13.0
mistral_common[image] >= 1.8.5
gguf >= 0.17.0
mistral_common[image,audio] >= 1.8.5
opencv-python-headless >= 4.11.0 # required for video IO
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
Expand Down
115 changes: 115 additions & 0 deletions tests/models/multimodal/generation/test_multimodal_gguf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Literal, NamedTuple

import pytest
from huggingface_hub import hf_hub_download
from pytest import MarkDecorator

from tests.quantization.utils import is_quant_method_supported
from vllm.assets.image import ImageAsset
from vllm.utils.torch_utils import set_default_torch_num_threads

from ....conftest import PromptImageInput, VllmRunner
from ...utils import check_logprobs_close


class GGUFMMTestConfig(NamedTuple):
original_model: str
gguf_repo: str
gguf_backbone: str
gguf_mmproj: str
prompt: list[str]
mm_data: dict[Literal["images"], PromptImageInput]
max_model_len: int = 4096
marks: list[MarkDecorator] = []

@property
def gguf_model(self):
hf_hub_download(self.gguf_repo, filename=self.gguf_mmproj)
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)


GEMMA3_CONFIG = GGUFMMTestConfig(
original_model="google/gemma-3-4b-it",
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
gguf_mmproj="mmproj-model-f16-4B.gguf",
prompt=["<start_of_image>Describe this image in detail:"],
mm_data={"images": [ImageAsset("stop_sign").pil_image]},
marks=[pytest.mark.core_model],
)

MODELS_TO_TEST = [GEMMA3_CONFIG]


def run_multimodal_gguf_test(
vllm_runner: type[VllmRunner],
model: GGUFMMTestConfig,
dtype: str,
max_tokens: int,
num_logprobs: int,
):
# Run gguf model.
with (
set_default_torch_num_threads(1),
vllm_runner(
model_name=model.gguf_model,
enforce_eager=True,
tokenizer_name=model.original_model,
dtype=dtype,
max_model_len=model.max_model_len,
) as gguf_model,
):
gguf_outputs = gguf_model.generate_greedy_logprobs(
prompts=model.prompt,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
**model.mm_data,
)

# Run unquantized model.
with vllm_runner(
model_name=model.original_model,
enforce_eager=True, # faster tests
dtype=dtype,
max_model_len=model.max_model_len,
) as original_model:
original_outputs = original_model.generate_greedy_logprobs(
prompts=model.prompt,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
**model.mm_data,
)

check_logprobs_close(
outputs_0_lst=original_outputs,
outputs_1_lst=gguf_outputs,
name_0="original",
name_1="gguf",
)


@pytest.mark.skipif(
not is_quant_method_supported("gguf"),
reason="gguf is not supported on this GPU type.",
)
@pytest.mark.parametrize(
"model",
[
pytest.param(test_config, marks=test_config.marks)
for test_config in MODELS_TO_TEST
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(
vllm_runner: type[VllmRunner],
model: GGUFMMTestConfig,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs)
9 changes: 8 additions & 1 deletion tests/models/quantization/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,20 @@ def gguf_model(self):
gguf_filename="tinydolphin-2.8-1.1b.Q6_K.gguf",
)

GEMMA3_CONFIG = GGUFTestConfig(
original_model="google/gemma-3-270m-it",
gguf_repo="ggml-org/gemma-3-270m-it-qat-GGUF",
gguf_filename="gemma-3-270m-it-qat-Q4_0.gguf",
)

MODELS = [
# LLAMA_CONFIG, # broken: https://github.com/vllm-project/vllm/issues/19458
QWEN2_CONFIG,
PHI3_CONFIG,
GPT2_CONFIG,
STABLELM_CONFIG,
DOLPHIN_CONFIG,
GEMMA3_CONFIG,
# STARCODER_CONFIG, # broken
]

Expand Down Expand Up @@ -148,7 +155,7 @@ def check_model_outputs(
"model",
[pytest.param(test_config, marks=test_config.marks) for test_config in MODELS],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [1])
Expand Down
20 changes: 19 additions & 1 deletion vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@
try_get_generation_config,
try_get_safetensors_metadata,
try_get_tokenizer_config,
uses_custom_attention_masks,
uses_mrope,
)
from vllm.transformers_utils.gguf_utils import (
maybe_patch_hf_config_from_gguf,
)
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.transformers_utils.utils import check_gguf_file, maybe_model_redirect
from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype

Expand Down Expand Up @@ -450,6 +454,12 @@ def __post_init__(
self.model = maybe_model_redirect(self.model)
# The tokenizer is consistent with the model by default.
if self.tokenizer is None:
if check_gguf_file(self.model):
raise ValueError(
"Using a tokenizer is mandatory when loading a GGUF model. "
"Please specify the tokenizer path or name using the "
"--tokenizer argument."
)
self.tokenizer = self.model
if self.tokenizer_revision is None:
self.tokenizer_revision = self.revision
Expand Down Expand Up @@ -508,6 +518,10 @@ def __post_init__(
hf_overrides_kw=hf_overrides_kw,
hf_overrides_fn=hf_overrides_fn,
)
hf_config = maybe_patch_hf_config_from_gguf(
self.model,
hf_config,
)

self.hf_config = hf_config
if dict_overrides:
Expand Down Expand Up @@ -1605,6 +1619,10 @@ def uses_alibi(self) -> bool:
def uses_mrope(self) -> bool:
return uses_mrope(self.hf_config)

@property
def uses_custom_attention_masks(self) -> bool:
return uses_custom_attention_masks(self.hf_config)

@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None
Expand Down
67 changes: 62 additions & 5 deletions vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Callable
from collections.abc import Callable, Mapping
from types import MappingProxyType
from typing import Any, Optional

import gguf
Expand All @@ -26,7 +27,11 @@
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod,
VocabParallelEmbedding,
)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils.torch_utils import direct_register_custom_op

Expand Down Expand Up @@ -65,18 +70,70 @@ def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
if is_layer_skipped_gguf(
prefix, self.unquantized_modules, self.packed_modules_mapping
):
return UnquantizedLinearMethod()
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
if is_layer_skipped_gguf(
prefix, self.unquantized_modules, self.packed_modules_mapping
):
return UnquantizedEmbeddingMethod()
return GGUFEmbeddingMethod(self)
elif isinstance(layer, FusedMoE):
return GGUFMoEMethod(self, layer.moe_config)
return None

def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
"""
Interface for models to update module names referenced in
quantization configs in order to reflect the vllm model structure

:param hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
"""
if self.unquantized_modules is not None:
self.unquantized_modules = hf_to_vllm_mapper.apply_list(
self.unquantized_modules
)


def is_layer_skipped_gguf(
prefix: str,
unquantized_modules: list[str],
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
):
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
proj_name = prefix.split(".")[-1]
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in fused_mapping[proj_name]
]

is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = any(
shard_prefix in module_name for module_name in unquantized_modules
)

if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision."
)
else:
is_skipped = any(module_name in prefix for module_name in unquantized_modules)

def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
return any(module_name in prefix for module_name in unquantized_modules)
assert is_skipped is not None
return is_skipped


UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
Expand Down
Loading