Skip to content

Conversation

@lucianommartins
Copy link
Contributor

@lucianommartins lucianommartins commented Oct 29, 2025

[Model] Add Gemma3 GGUF multimodal support

Summary

This PR enables full multimodal inference support for Gemma3 models in GGUF format, allowing users to run quantized Gemma3 multimodal models with both text-only and image+text prompts. The implementation adds automatic detection and loading of mmproj.gguf files for vision tower weights while maintaining complete isolation to ensure zero impact on existing Gemma3 HuggingFace models or other model architectures.

Resolves: #22497


New Functionality

1. Automatic GGUF Multimodal Detection

The model configuration now automatically detects when a Gemma3 GGUF model has an accompanying mmproj.gguf file and switches to the multimodal architecture:

  • Location: vllm/config/model.py
  • Mechanism: Helper function _detect_gguf_multimodal_gemma3() checks for mmproj-*.gguf in model directory
  • Behavior: Sets model_cls = "Gemma3ForConditionalGeneration" when multimodal projector is found
  • Fallback: Uses Gemma3ForCausalLM for text-only GGUF models

Example usage:

from vllm import LLM

# Automatically detects mmproj.gguf and loads multimodal architecture
llm = LLM(
    model="/path/to/gemma-3-4b-it-q4_0.gguf",
    tokenizer="google/gemma-3-4b-it",
)

2. GGUF Vision Tower Weight Loading

Enhanced GGUF loader to handle multimodal weights from separate mmproj.gguf files:

  • Location: vllm/model_executor/model_loader/gguf_loader.py, weight_utils.py
  • Tensor Mapping: Maps GGUF vision encoder tensors to vLLM's SigLIP vision tower
    • v.*.attn.{q,k,v,out} → Vision attention weights
    • v.*.mlp.{fc1,fc2} → Vision FFN weights
    • v.post_ln.{weight,bias} → Post-layernorm weights
    • mm.0.{weight,bias} → Multimodal projector weights
  • Quantization: Vision weights loaded as unquantized float16/bfloat16 (GGUF spec)
  • Isolation: Only activated when is_gguf_weight = True

3. GGUF Processor Loading

Processors now correctly load from tokenizer path for GGUF models:

  • Location: vllm/transformers_utils/processor.py
  • Fix: Converts model repo paths (including GGUF files) to directory paths before loading processor
  • Benefit: Avoids "not a directory" errors for .gguf file paths
  • Impact: Applies to all processor types

4. V1 Engine Multimodal Support

Added multimodal embedding gathering in V1 GPU model runner:

  • Location: vllm/v1/worker/gpu_model_runner.py
  • Function: _gather_mm_embeddings() extracts vision embeddings after encoder execution
  • Integration: Passes embeddings to model's get_multimodal_embeddings() for proper merging
  • Compatibility: Works with both GGUF and HuggingFace multimodal models

Isolation & Safety Guarantees

All changes are strictly scoped to Gemma3 GGUF multimodal models only. No other models or formats are affected.

Isolation Verification

Component Guard Condition
GGUF Loader if quantization == "gguf"
Weight Utils if quantization in ("gguf", "inc")
Model Config if self.model.endswith('.gguf')
Gemma3 MM if self._is_gguf
Processor Only for GGUF model paths

Non-Regression Tests

Comprehensive testing confirms zero impact on other model types:

  • Gemma3 HF text-only (1B, 4B)
  • Gemma3 HF multimodal (4B)
  • Gemma3 GGUF text-only (4B)
  • Gemma3 GGUF multimodal (4B)

No changes to:

  • Other GGUF models (Llama, Mistral, etc.)
  • Other multimodal models (Phi3V, Qwen2-VL, etc.)
  • V0 engine
  • Kernel implementations

Testing

Test Environment

  • Hardware: NVIDIA GPU A100 80GB
  • Models Tested:
    • google/gemma-3-1b-it (HF text-only)
    • google/gemma-3-4b-it (HF text-only & multimodal)
    • google/gemma-3-4b-it-q4_0-gguf (GGUF text-only & multimodal)

Backward Compatibility

100% backward compatible - No breaking changes to existing functionality:

  • Existing Gemma3 HF models work identically
  • Existing GGUF text-only models unaffected
  • No API changes
  • No behavior changes for non-GGUF models
  • No performance regressions

Checklist

  • PR title follows format: [Model] Add Gemma3 GGUF multimodal support
  • Code follows Google Python style guide
  • No regressions verified through testing
  • Isolated to Gemma3 GGUF only - zero impact on other models

Additional Context

This PR builds upon the recently merged text-only GGUF support for Gemma3 and extends it to support multimodal inference. The implementation carefully preserves all upstream bugfixes, including the recent newline token ordering fix (#27538).

Tested on: vLLM main branch (commit b5d90f740, Oct 29 2025)

Signed-off-by: Luciano Martins [email protected]

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Gemma3 GGUF multimodal models, which is a significant feature enhancement. The implementation correctly handles the detection of multimodal GGUF files and extends the model loader to accommodate vision tower weights from separate mmproj.gguf files. The changes are well-isolated to minimize impact on existing models. However, there are several areas for improvement. A critical issue was found where a function is missing a return statement, which could lead to runtime errors. Additionally, there are multiple instances of hardcoded configurations for the vision tower and image processor, as well as code duplication for model detection logic. These reduce maintainability and make the implementation brittle to future changes. I've also identified a bug in the runai_safetensors_weights_iterator that will cause an incorrect progress bar. Addressing these points will improve the robustness and maintainability of this new feature.

@mergify
Copy link

mergify bot commented Oct 30, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @lucianommartins.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 30, 2025
Copy link
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR introduces too many gemma3-mm GGUF specific hardcoded code.

Besides this, I'm a bit confused why we need to use GGUFQuantMethod for a totally unquantized mm_proj checkpoint as well.

Copy link
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I think this PR is still under poor design, because there are too many gemma3 specific codes outside model implementation...

Isotr0py and others added 8 commits November 17, 2025 18:32
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
This commit enhances GGUF model loading with two key improvements:

1. Dynamic vocab_size extraction and override:
   - Added extract_vocab_size_from_gguf() to read vocab_size from
     token_embd.weight tensor shape
   - Updated maybe_patch_hf_config_from_gguf() to automatically override
     vocab_size from GGUF file for all GGUF models
   - Enables loading models with extended vocabularies (e.g., Unsloth)
     without manual config editing

2. Improved automatic weight mapping for multimodal models:
   - Enhanced find_hf_name_in_tensor_map() to handle naming convention
     differences between HF and gguf-py:
     * Strips 'language_model.' prefix for multimodal models
     * Converts '_weight' suffix to '.weight' (Gemma3 compatibility)
   - Removed hardcoded Gemma3-specific weight mappings in favor of
     automatic mapping via gguf-py's vision_name_map
   - Removed redundant prefix fallback logic

I will:
- add utomatic compatibility with models using extended vocabularies
- make a more maintainable code with no model-specific hardcoded mappings
- introduce a robust handling of HF/gguf-py naming convention mismatches
- work for all multimodal GGUF models, not just Gemma3

Technical details:
- GGUF embedding tensor format: [hidden_size, vocab_size]
- vocab_size extracted from shape[1] as source of truth
- Uses get_text_config() to handle both regular and multimodal configs
- Graceful fallback if vocab_size extraction fails

Signed-off-by: Luciano Martins <[email protected]>
- Change gemma3-transformers test to use model_impl='auto' instead of 'transformers'
  to avoid missing generate_attention_masks() method in TransformersMultiModalForCausalLM
- Change quantized models test to use dtype='bfloat16' instead of 'half'
  since gemma3_text models don't support float16 due to numerical instability

These are test configuration fixes for pre-existing bugs unrelated to GGUF changes.

Signed-off-by: Luciano Martins <[email protected]>
Remove extract_vocab_size_from_gguf() and vocab_size override logic
as transformers now handles this internally via modeling_gguf_pytorch_utils.

Also fixed the tests/models/multimodal/generation/test_common.py to use
HuggingFace implementation for Gemma3 testing.

Signed-off-by: Luciano Martins <[email protected]>
Prevent AttributeError when calling generate_attention_masks() on models
that don't implement this method (e.g., TransformersMultiModalForCausalLM).

This fixes the errors on multi-modal-models-test CI test.

The condition now checks:
1. uses_custom_attention_masks flag is set
2. multimodal features are present
3. model has generate_attention_masks method

This ensures the method is only called on models that support custom
attention masks (e.g., Gemma3ForConditionalGeneration).

Signed-off-by: Luciano Martins <[email protected]>
@lucianommartins
Copy link
Contributor Author

Hey @Isotr0py - all tests are passing, with 2 exceptions not related to Gemma3 or the PR work:

  1. buildkite/ci/pr/entrypoints-integration-test-api-server

It is failing with this access error for the test image to be used:

[2025-11-17T20:38:37Z] FAILED entrypoints/openai/test_metrics.py::test_metrics_exist[-multimodal] - openai.InternalServerError: Error code: 500 - {'error': {'message': "403, message='Forbidden', url='https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg'", 'type': 'Internal Server Error', 'param': None, 'code': 500}}
  1. buildkite/ci/pr/examples-test

It is failing for EAGLE-LLAMA3.1-Instruct-8B (before even getting to any Gemma3 test):

[2025-11-17T20:31:00Z] + python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048
[2025-11-17T20:31:07Z] [2025-11-17 12:31:07] INFO config.py:54: PyTorch version 2.9.0+cu129 available.
[2025-11-17T20:31:07Z] [2025-11-17 12:31:07] INFO config.py:66: Polars version 1.29.0 available.
[2025-11-17T20:31:10Z] INFO 11-17 12:31:10 [utils.py:253] non-default args: {'trust_remote_code': True, 'max_model_len': 2048, 'gpu_memory_utilization': 0.8, 'limit_mm_per_prompt': {'image': 5}, 'enable_chunked_prefill': True, 'disable_chunked_mm_input': True, 'speculative_config': {'method': 'eagle', 'model': 'yuhuili/EAGLE-LLaMA3.1-Instruct-8B', 'num_speculative_tokens': 3}, 'model': 'meta-llama/Llama-3.1-8B-Instruct'}

And it is failing due to GPU OOM errors:

[2025-11-17T20:33:01Z] �[1;36m(EngineCore_DP0 pid=4274)�[0;0m INFO 11-17 12:33:01 [backends.py:631] Using cache directory: /root/.cache/vllm/torch_compile_cache/bb408db9f5/rank_0_0/eagle_head for vLLM's torch.compile
[2025-11-17T20:33:01Z] �[1;36m(EngineCore_DP0 pid=4274)�[0;0m INFO 11-17 12:33:01 [backends.py:647] Dynamo bytecode transform time: 0.39 s
[2025-11-17T20:33:06Z] �[1;36m(EngineCore_DP0 pid=4274)�[0;0m INFO 11-17 12:33:06 [backends.py:282] Compiling a graph for dynamic shape takes 4.61 s
[2025-11-17T20:33:06Z] �[1;36m(EngineCore_DP0 pid=4274)�[0;0m INFO 11-17 12:33:06 [monitor.py:34] torch.compile takes 17.58 s in total
[2025-11-17T20:33:13Z] �[1;36m(EngineCore_DP0 pid=4274)�[0;0m INFO 11-17 12:33:13 [gpu_worker.py:361] Available KV cache memory: -0.26 GiB
[2025-11-17T20:33:14Z] �[1;36m(EngineCore_DP0 pid=4274)�[0;0m ERROR 11-17 12:33:14 [core.py:855] EngineCore failed to start.

<...>

[2025-11-17T20:33:14Z] �[1;36m(EngineCore_DP0 pid=4274)�[0;0m   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/core/kv_cache_utils.py", line 686, in check_enough_kv_cache_memory
[2025-11-17T20:33:14Z] �[1;36m(EngineCore_DP0 pid=4274)�[0;0m     raise ValueError(
[2025-11-17T20:33:14Z] �[1;36m(EngineCore_DP0 pid=4274)�[0;0m ValueError: No available memory for the cache blocks. Try increasing `gpu_memory_utilization` when initializing the engine.

can we proceed with the merge?

@DarkLight1337
Copy link
Member

Yeah they are not related, let's merge

@Isotr0py Isotr0py enabled auto-merge (squash) November 18, 2025 13:49
@vllm-bot vllm-bot merged commit c261237 into vllm-project:main Nov 18, 2025
86 of 88 checks passed
Victor49152 pushed a commit to Victor49152/vllm that referenced this pull request Nov 20, 2025
Signed-off-by: Luciano Martins <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Luciano Martins <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
bigPYJ1151 pushed a commit that referenced this pull request Nov 25, 2025
Signed-off-by: Luciano Martins <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Luciano Martins <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
bringlein pushed a commit to bringlein/vllm that referenced this pull request Nov 26, 2025
Signed-off-by: Luciano Martins <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Luciano Martins <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: Luciano Martins <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Luciano Martins <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
Signed-off-by: Luciano Martins <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Luciano Martins <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
iboiko-habana pushed a commit to vllm-project/vllm-gaudi that referenced this pull request Dec 3, 2025
Due to the latest changes from upstream, gemma3 is failing to compile on
HPU
vllm-project/vllm#27772
vllm-project/vllm#28842

-replace unfold to view/reshape
-replace text embedding to avoid dynamic shape
-remove merge_multimodal replacement since masked_scatter issue is fixed
-enable back gemma3 model test

---------

Signed-off-by: Jimin Ha <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants