Skip to content

Conversation

@lucianommartins
Copy link
Contributor

@lucianommartins lucianommartins commented Nov 21, 2025

Summary

This PR restores custom attention mask generation for Gemma3 GGUF multimodal models that was partially reverted in #28995. The implementation uses robust GGUF-only file format guards to ensure the feature exclusively applies to GGUF models and does not affect HuggingFace models.

Resolves: #28995 (HF model regression)
Restores functionality from: #27772

Background

PR #27772 initially added Gemma3 GGUF multimodal support, enabling users to run quantized Gemma3 multimodal models with both text-only and image+text prompts. However, it was partially reverted in #28995 because the custom attention mask logic incorrectly triggered for HuggingFace models, causing test failures.

Root cause of #28995: The original implementation lacked file format guards, causing the custom attention mask generation to activate for both GGUF and HF models.

Solution

This PR addresses the regression by implementing a 3-layer defense-in-depth guard mechanism:

Layer 1: Model Format Check (Primary Guard)

def uses_custom_attention_masks(config: PretrainedConfig, model_path: str) -> bool:
    """Only return True for GGUF Gemma3 multimodal models."""
    architectures = getattr(config, "architectures", [])
    is_gemma3 = "Gemma3ForConditionalGeneration" in architectures
    is_gguf = check_gguf_file(model_path)  # ← Critical GGUF guard
    return is_gemma3 and is_gguf

Layer 2: Multimodal Feature Check

has_mm_features = any(
    req_state.mm_features for req_state in self.requests.values()
)

Layer 3: Method Existence Check

hasattr(self.model, "generate_attention_masks")

Result: HF models never have uses_custom_attention_masks = True, preventing the issue that caused #28995.

Changes

Files Modified (4)

  1. vllm/transformers_utils/config.py

    • Add uses_custom_attention_masks() utility function
    • Implements GGUF file format check using check_gguf_file()
  2. vllm/config/model.py

    • Add uses_custom_attention_masks property to ModelConfig
    • Delegates to utility function with model path for GGUF detection
  3. vllm/v1/worker/gpu_model_runner.py

    • Initialize uses_custom_attention_masks attribute in GPUModelRunner
    • Apply 3-layer guard before calling custom attention mask generation
  4. vllm/model_executor/models/gemma3_mm.py

    • Restore generate_attention_masks() method
    • Generates custom masks enabling bidirectional attention between image tokens
    • Handles sliding window attention for GGUF compatibility

Test Plan

GGUF Model Validation

Tested with multiple quantized Gemma3 GGUF models to ensure functionality across different model sizes:

Text-Only Inference:

  • Gemma3 1B GGUF (Q4_0 quantization)
  • Gemma3 4B GGUF (Q4_0 quantization)
  • Gemma3 270M GGUF (Q4_0 quantization)

Multimodal Inference:

  • Gemma3 4B GGUF with mmproj.gguf vision tower
  • Input: Multi-image prompts (2 images: kitten photo + autumn forest scene)
  • Expected: Accurate image descriptions with proper bidirectional attention between image tokens

HuggingFace Model Regression Testing

Executed the full vLLM multimodal test suite to verify zero impact on HF models:

pytest -s -v tests/models/multimodal/generation/test_common.py -k "gemma3-test"

This ensures the GGUF guards prevent any unintended activation of custom attention mask logic for HuggingFace models.

Test Results

GGUF Model Results (All Pass)

Model Quantization Test Type Status Details
Gemma3-1B Q4_0 Text-only PASSED Generates coherent haiku about coding
Gemma3-4B Q4_0 Text-only PASSED Generates coherent haiku about coding
Gemma3-270M Q4_0 Text-only PASSED Generates coherent haiku about coding
Gemma3-4B Q4_0 Multimodal PASSED Correctly describes both images with accurate details

Multimodal Output Example:

Image 1: A close-up shot of a tabby kitten with striking blue eyes. The kitten 
is lying on a green surface, likely a rug or carpet. Its fur is a mix of brown 
and black stripes, and it has a slightly melancholic expression...

Image 2: A breathtaking landscape photograph of a dense pine forest bathed in 
the warm, golden light of a late autumn or early winter day...

HuggingFace Model Regression Test (All Pass)

pytest -s -v tests/models/multimodal/generation/test_common.py -k gemma3-test
# Result: 8 passed, 335 deselected, 23 warnings in 915.69s (15m 15s)

Test Coverage:

  • Single-image inference (3 test cases)
  • Multi-image inference (3 test cases)
  • Various prompt templates (2 test cases)
  • All 8 test cases pass - confirms zero regression for HF models

Verification of Fix for #28995

The failing test from #28995 (pytest gemma3-test) now passes completely:

Why it works now:

  • uses_custom_attention_masks returns False for HF models (no .gguf file detected)
  • Custom attention mask generation never executes for HF models
  • HF models use their native attention mechanism without interference

Isolation & Safety Guarantees

How HF Models Are Protected:

  1. File Format Check:

    is_gguf = check_gguf_file(model_path)  # Returns False for HF models
  2. Short-Circuit Logic:

    return is_gemma3 and is_gguf  # Requires BOTH conditions
  3. Runtime Guard:

    if self.uses_custom_attention_masks:  # Always False for HF
        # Custom mask generation (never executed for HF)

What Changed from #27772:

Aspect Original PR #27772 This PR (Fixed)
Guard Mechanism Architecture check only Architecture + GGUF file format check
HF Model Impact Incorrectly triggered Never triggers
GGUF Multimodal Works Works

Code Quality

  • All linting checks pass (ruff check, ruff format, mypy)
  • All pre-commit hooks pass
  • Follows Google Python style guide
  • Comprehensive docstrings with clear GGUF-only notes
  • Defense-in-depth guard pattern

Backward Compatibility

  • Zero impact on existing HuggingFace models (verified via pytest)
  • Zero impact on other model architectures (GGUF check prevents any non-Gemma3 activation)
  • Restores functionality from [Model] Add Gemma3 GGUF multimodal support #27772 for GGUF users
  • No API changes, no breaking changes

Documentation

No user-facing documentation changes required. The feature is transparent to users - GGUF Gemma3 multimodal models work automatically without configuration.

Release Notes

This fix should be included in release notes as:

[Model] Restored Gemma3 GGUF multimodal support - Re-enables custom attention mask generation for Gemma3 GGUF multimodal models with robust GGUF-only guards, fixing the regression introduced in #28995 while maintaining zero impact on HuggingFace models.


Checklist


Related PRs:

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively restores multimodal support for Gemma3 GGUF models by introducing robust file-format-based guards. The approach is sound and the defense-in-depth mechanism is a good practice to prevent regressions on HuggingFace models. My review focuses on performance optimizations within the newly restored generate_attention_masks method, suggesting more idiomatic and efficient PyTorch constructs to improve performance on this critical path.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@lucianommartins
Copy link
Contributor Author

Hi @Isotr0py / @DarkLight1337,

It is a quick one - reintroducing #27772, but now with guardrails to avoid the problems that caused the PR to be reverted via #28995.

It is pretty much all reviewed (as not much changed since #27772) and ready to go :)

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 22, 2025
@lucianommartins
Copy link
Contributor Author

hey @DarkLight1337,

The failing test is not related to my PR:

[2025-11-22T02:36:20Z] FAILED v1/entrypoints/openai/test_multi_api_servers.py::test_single_completion[hmellor/tiny-random-LlamaForCausalLM] - AssertionError: assert 0 >= 1
[2025-11-22T02:36:20Z]  +  where 0 = len('')
[2025-11-22T02:36:20Z]  +    where '' = CompletionChoice(finish_reason='stop', index=0, logprobs=None, text='', stop_reason=None, token_ids=None, prompt_logprobs=None, prompt_token_ids=None).text
[2025-11-22T02:36:20Z] !!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!!

can we move forward?

@mergify
Copy link

mergify bot commented Nov 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @lucianommartins.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@lucianommartins lucianommartins force-pushed the main branch 2 times, most recently from a54cb9c to f467714 Compare November 27, 2025 12:37
@Isotr0py
Copy link
Member

Isotr0py commented Nov 27, 2025

Can you update the GGUF multimodal test to align HF format gemma3 test's input for validation? I observed it failed with custom attention mask before:

    "gemma3": VLMTestInfo(
        models=["google/gemma-3-4b-it"],
        test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
        prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n",  # noqa: E501
        single_image_prompts=IMAGE_ASSETS.prompts(
            {
                "stop_sign": "<start_of_image>What's the content in the center of the image?",  # noqa: E501
                "cherry_blossom": "<start_of_image>What is the season?",
            }
        ),
        multi_image_prompt="<start_of_image><start_of_image>Describe the two images in detail.",  # noqa: E501
        max_model_len=4096,
        max_num_seqs=2,
        auto_cls=AutoModelForImageTextToText,
        vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
        patch_hf_runner=model_utils.gemma3_patch_hf_runner,
        num_logprobs=10,
        image_size_factors=[(0.25, 0.5, 1.0)],
    ),

@lucianommartins lucianommartins force-pushed the main branch 2 times, most recently from f664627 to cb81c6a Compare November 27, 2025 19:01
@mergify mergify bot added the multi-modality Related to multi-modality (#4194) label Nov 27, 2025
@lucianommartins
Copy link
Contributor Author

Can you update the GGUF multimodal test to align HF format gemma3 test's input for validation? I observed it failed with custom attention mask before:

    "gemma3": VLMTestInfo(
        models=["google/gemma-3-4b-it"],
        test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
        prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n",  # noqa: E501
        single_image_prompts=IMAGE_ASSETS.prompts(
            {
                "stop_sign": "<start_of_image>What's the content in the center of the image?",  # noqa: E501
                "cherry_blossom": "<start_of_image>What is the season?",
            }
        ),
        multi_image_prompt="<start_of_image><start_of_image>Describe the two images in detail.",  # noqa: E501
        max_model_len=4096,
        max_num_seqs=2,
        auto_cls=AutoModelForImageTextToText,
        vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
        patch_hf_runner=model_utils.gemma3_patch_hf_runner,
        num_logprobs=10,
        image_size_factors=[(0.25, 0.5, 1.0)],
    ),

hey @Isotr0py -

I've added standardized GGUF multimodal tests following the test_common.py pattern.

Changes in test_multimodal_gguf.py:

  1. Updated GGUFMMTestConfig with mm_processor_kwargs support for per-model configuration
  2. Added two test configurations: GEMMA3_CONFIG: Regular multimodal with Google's QAT Q4_0 GGUF and GEMMA3_CONFIG_PAN_AND_SCAN: Pan-and-scan with Unsloth's unquantized BF16 GGUF
  3. Aligned prompts with the existing gemma3 entry in test_common.py:
  • stop_sign: "What's the content in the center of the image?"
  • cherry_blossom: "What is the season?"
  1. Added @create_new_process_for_each_test() decorator to ensure proper V1 engine cleanup between model loads

The test compares GGUF vLLM output against HF safetensors vLLM output using check_logprobs_close(), following the same pattern as tests/models/quantization/test_gguf.py.

@lucianommartins
Copy link
Contributor Author

hey @Isotr0py,

can you re-run the [buildkite/ci/pr/multi-modal-models-test-standard](https://buildkite.com/vllm/ci/builds/40961#019ac6b2-891a-416a-bbc9-28fe513118e3) test? it failed with an odd behavior (like if it was cancelled):

[2025-11-27T20:45:16Z] INFO 11-27 12:45:16 [llm.py:346] Supported tasks: ['generate']
Adding requests:   0% 0/2 [00:00<?, ?it/s]# Received cancellation signal, interrupting
[2025-11-27T22:31:03Z] 🚨 Error: The command exited with status -1

@Isotr0py
Copy link
Member

The test compares GGUF vLLM output against HF safetensors vLLM output using check_logprobs_close(), following the same pattern as tests/models/quantization/test_gguf.py.

Hmmm, we need to compare against HF safetensors HF output instead of vLLM output.

@lucianommartins
Copy link
Contributor Author

hey @Isotr0py -do you mean running against HF models with HFRunner instead of the vllm_runner?

just double checking because I modeled the test after the existing tests/models/quantization/test_gguf.py, which also compares GGUF (vLLM) against HF safetensors (vLLM) rather than native HuggingFace (HfRunner).

Should I update the test_multimodal_gguf.py to use a reference HF model with HfRunner for consistency with test_common.py, or is the test_gguf.py pattern acceptable for GGUF-specific tests?

@Isotr0py
Copy link
Member

do you mean running against HF models with HFRunner instead of the vllm_runner?

Yes, I meant updating test_multimodal_gguf.py to test against hf runner, especially image_size_factors=[(0.25, 0.5, 1.0)], condition with max_num_seqs=2 case

@lucianommartins
Copy link
Contributor Author

done, @Isotr0py - now the whole test_multimodal_gguf.py tests the GGUF models against HF ones using HFRunner.

@lucianommartins
Copy link
Contributor Author

@Isotr0py - do you know the memory limits of the containers running the tests? my multimodal test failed with segment fault when trying to load the gemma-3-4b-it for HFRunner:

Processed prompts:   0% 0/2 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts:  50% 1/2 [00:02<00:02,  2.46s/it, est. speed input: 114.04 toks/s, output: 12.99 toks/s]
Processed prompts: 100% 2/2 [00:02<00:00,  2.46s/it, est. speed input: 220.28 toks/s, output: 25.40 toks/s]
Processed prompts: 100% 2/2 [00:02<00:00,  1.26s/it, est. speed input: 220.28 toks/s, output: 25.40 toks/s]
[2025-11-28T20:58:29Z] [rank0]:[W1128 12:58:29.763342309 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[2025-11-28T20:58:31Z] Fatal Python error: Segmentation fault

with the stack pointing to:

File "/vllm-workspace/tests/models/multimodal/generation/test_multimodal_gguf.py", line 121 in run_multimodal_gguf_test
File "/vllm-workspace/tests/conftest.py", line 355 in _init
File "/vllm-workspace/tests/conftest.py", line 293 in __init__
File "/usr/local/lib/python3.12/dist-packages/transformers/models/gemma3/modeling_gemma3.py", line 998 in __init__

(line 121 from test_multimodal_gguf.py is with hf_runner(...)...

Also it is suspicious that I have a larger GPU on my test environment and this test runned with no errors:

$ pytest tests/models/multimodal/generation/test_multimodal_gguf.py -v -s
========================================================== test session starts ===========================================================
platform linux -- Python 3.13.7, pytest-9.0.1, pluggy-1.6.0 -- /projects/gemma-vllm/fixes/P0.1b_Multimodal_GGUF/bugfix/mm-bugfix-env/bin/python3
cachedir: .pytest_cache
rootdir: /projects/gemma-vllm/fixes/P0.1b_Multimodal_GGUF/bugfix/my-vllm
configfile: pyproject.toml
plugins: anyio-4.11.0
collected 2 items                                                                                                

<...>

============================================================ warnings summary ============================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=============================================== 2 passed, 2 warnings in 525.68s (0:08:45) ================================================

@lucianommartins
Copy link
Contributor Author

just for reference: the entrypoints-integration-test-api-server is not related to my Gemma3 work:

[2025-11-28T20:55:33Z] =================================== FAILURES ===================================
[2025-11-28T20:55:33Z] _________________________ test_async_serving_chat_init _________________________
[2025-11-28T20:55:33Z]
[2025-11-28T20:55:33Z]     def test_async_serving_chat_init():
[2025-11-28T20:55:33Z] >       serving_completion = asyncio.run(_async_serving_chat_init())
[2025-11-28T20:55:33Z]
[2025-11-28T20:55:33Z] entrypoints/openai/test_serving_chat.py:422:
[2025-11-28T20:55:33Z] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
[2025-11-28T20:55:33Z] /usr/lib/python3.12/asyncio/runners.py:195: in run
[2025-11-28T20:55:33Z]     return runner.run(main)
[2025-11-28T20:55:33Z] /usr/lib/python3.12/asyncio/runners.py:118: in run
[2025-11-28T20:55:33Z]     return self._loop.run_until_complete(task)
[2025-11-28T20:55:33Z] /usr/lib/python3.12/asyncio/base_events.py:691: in run_until_complete
[2025-11-28T20:55:33Z]     return future.result()
[2025-11-28T20:55:33Z] entrypoints/openai/test_serving_chat.py:409: in _async_serving_chat_init
[2025-11-28T20:55:33Z]     models = OpenAIServingModels(engine, BASE_MODEL_PATHS)
[2025-11-28T20:55:33Z] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
[2025-11-28T20:55:33Z]
[2025-11-28T20:55:33Z] self = <vllm.entrypoints.openai.serving_models.OpenAIServingModels object at 0x7fabbd3a7ce0>
[2025-11-28T20:55:33Z] engine_client = MockEngine(model_config=MockModelConfig(logits_processors=None, diff_sampling_param=None, allowed_local_media_path='',...'auto', media_io_kwargs={}), processor=<MagicMock id='140375597446160'>, io_processor=<MagicMock id='140376483496304'>)
[2025-11-28T20:55:33Z] base_model_paths = [BaseModelPath(name='openai-community/gpt2', model_path='openai-community/gpt2'), BaseModelPath(name='gpt2', model_path='gpt2')]
[2025-11-28T20:55:33Z]
[2025-11-28T20:55:33Z]     def __init__(

and the amd-docker-build-image error seems to be related to an AMD ROCm/HIP C++ compilation error:

[35/38] Building HIP object CMakeFiles/_rocm_C.dir/csrc/rocm/attention.hip.o
ninja: build stopped: subcommand failed.

@Isotr0py
Copy link
Member

Isotr0py commented Dec 1, 2025

@lucianommartins I just updated the GGUF multimodal test with image_size_factors=[(0.25, 0.5, 1.0)] to test batching various size images, but seems it still encountered same failure like hf unquantized models before. 😅

BTW, if I disabled custom attention mask with uses_custom_attention_masks = False, the test can pass with converged logprobs.

Can you please check the custom attention mask implementation?

pytest -s -v tests/models/multimodal/generation/test_multimodal_gguf.py -k core_model

use_custom_attention_mask=True:

                    assert logprobs_elem_0 is not None, fail_msg
                    assert logprobs_elem_1 is not None, fail_msg
>                   assert output_id_0 in logprobs_elem_1, fail_msg
E                   AssertionError: Test0:
E                   Matched tokens:     [8291]
E                   hf: "Here's what's in the center of the image:\n\nIt's a traditional Chinese gate or archway. It's red and gold,"   {236789: -0.023282833397388458, 236858: -3.773282766342163, 563: -10.273283004760742, 659: -12.773283004760742, 1144: -16.523283004760742, 19005: -19.148283004760742, 236764: -19.148283004760742, 625: -19.773283004760742, 2721: -20.773283004760742, 528: -20.898283004760742}
E                   gguf:       'HereHere\'s what\'s in the center of the image:\n\nIt\'s a traditional Chinese gate or archway. You can see the characters "'  {8291: Logprob(logprob=-0.7552528381347656, rank=1, decoded_token='Here'), 12794: Logprob(logprob=-1.2552528381347656, rank=2, decoded_token='HERE'), 17756: Logprob(logprob=-2.6927528381347656, rank=3, decoded_token='ERE'), 108: Logprob(logprob=-3.6927528381347656, rank=4, decoded_token='\n\n'), 20463: Logprob(logprob=-3.8177528381347656, rank=5, decoded_token='ARE'), 128522: Logprob(logprob=-3.8802528381347656, rank=6, decoded_token='Bere'), 14219: Logprob(logprob=-4.380252838134766, rank=7, decoded_token='Are'), 236788: Logprob(logprob=-4.755252838134766, rank=8, decoded_token='E'), 2209: Logprob(logprob=-5.130252838134766, rank=9, decoded_token='He'), 239344: Logprob(logprob=-5.255252838134766, rank=10, decoded_token='È')}

tests/models/utils.py:242: AssertionError
==============================================================================
FAILED tests/models/multimodal/generation/test_multimodal_gguf.py::test_models[10-32-bfloat16-model0] - AssertionError: Test0:
==================================================================== 1 failed, 11 warnings in 159.53s (0:02:39) =====================================================================

use_custom_attention_mask=False:

tests/models/multimodal/generation/test_multimodal_gguf.py::test_models[10-32-bfloat16-model0]
  /home/mozf/develop-projects/vllm/tests/models/multimodal/generation/test_multimodal_gguf.py:150: UserWarning: Test0:
  Matched tokens:       [8291, 236789, 236751, 1144, 236789, 236751, 528, 506, 3988, 529, 506, 2471, 236787, 108]
  hf:   "Here's what's in the center of the image:\n\nIt's a traditional Chinese gate or archway. It's red and gold,"   {1509: -0.061625704169273376, 818: -3.0616257190704346, 236776: -4.5616254806518555, 3810: -6.0616254806518555, 236829: -8.811625480651855, 1018: -10.811625480651855, 3048: -12.311625480651855, 86909: -13.686625480651855, 3834: -13.686625480651855, 902: -13.686625480651855}
  gguf: 'Here\'s what\'s in the center of the image:\n\nThe image shows a traditional Chinese gate, with the characters "中华" (Zhōng Hu'       {818: Logprob(logprob=-1.1401469707489014, rank=1, decoded_token='The'), 1509: Logprob(logprob=-1.1401469707489014, rank=2, decoded_token='It'), 236776: Logprob(logprob=-1.3901469707489014, rank=3, decoded_token='A'), 236829: Logprob(logprob=-2.8901469707489014, rank=4, decoded_token='*'), 3810: Logprob(logprob=-2.8901469707489014, rank=5, decoded_token='There'), 1018: Logprob(logprob=-8.51514720916748, rank=6, decoded_token='**'), 3048: Logprob(logprob=-10.76514720916748, rank=7, decoded_token='You'), 3834: Logprob(logprob=-10.89014720916748, rank=8, decoded_token='At'), 2267: Logprob(logprob=-10.89014720916748, rank=9, decoded_token='An'), 902: Logprob(logprob=-10.89014720916748, rank=10, decoded_token='In')}
    check_logprobs_close(

tests/models/multimodal/generation/test_multimodal_gguf.py::test_models[10-32-bfloat16-model0]
  /home/mozf/develop-projects/vllm/tests/models/multimodal/generation/test_multimodal_gguf.py:150: UserWarning: Test1:
  Matched tokens:       [8291, 236789, 236751]
  hf:   "Here's a breakdown of the content in the center of the image:\n\n*   **Chinese Archway:** The most prominent feature is a large, ornate"       {496: -0.28813475370407104, 1144: -1.5381348133087158, 506: -3.538134813308716, 614: -5.038134574890137, 1041: -12.288134574890137, 3671: -16.163135528564453, 1217: -17.038135528564453, 25890: -18.663135528564453, 22454: -19.038135528564453, 29141: -19.288135528564453}
  gguf: "Here's what's in the center of the image:\n\n*   **A dark blue SUV** is parked in the street.\n*   **" {1144: Logprob(logprob=-0.217246413230896, rank=1, decoded_token=' what'), 496: Logprob(logprob=-1.717246413230896, rank=2, decoded_token=' a'), 506: Logprob(logprob=-4.4672465324401855, rank=3, decoded_token=' the'), 614: Logprob(logprob=-5.4672465324401855, rank=4, decoded_token=' an'), 1041: Logprob(logprob=-11.967246055603027, rank=5, decoded_token=' my'), 1217: Logprob(logprob=-14.967246055603027, rank=6, decoded_token=' how'), 3671: Logprob(logprob=-18.217247009277344, rank=7, decoded_token=' analysis'), 25890: Logprob(logprob=-19.967247009277344, rank=8, decoded_token=' breakdown'), 20129: Logprob(logprob=-20.217247009277344, rank=9, decoded_token=' roughly'), 22454: Logprob(logprob=-20.467247009277344, rank=10, decoded_token=' describing')}
    check_logprobs_close(

tests/models/multimodal/generation/test_multimodal_gguf.py::test_models[10-32-bfloat16-model0]
  /home/mozf/develop-projects/vllm/tests/models/multimodal/generation/test_multimodal_gguf.py:150: UserWarning: Test2:
  Matched tokens:       []
  hf:   "Here's a breakdown of the content in the center of the image:\n\n*   **Chinese Gate:** The most prominent feature is a large, ornate Chinese"  {8291: -0.017760176211595535, 117494: -4.517760276794434, 6481: -5.517760276794434, 19058: -6.017760276794434, 818: -8.517760276794434, 22515: -10.517760276794434, 902: -11.267760276794434, 1018: -15.267760276794434, 113106: -15.767760276794434, 100409: -17.642759323120117}
  gguf: "The content in the center of the image is a **stop sign** mounted on a metal pole. It's positioned in front of the Chinese-themed arch"        {818: Logprob(logprob=-0.3230724334716797, rank=1, decoded_token='The'), 8291: Logprob(logprob=-1.3230724334716797, rank=2, decoded_token='Here'), 117494: Logprob(logprob=-5.32307243347168, rank=3, decoded_token='Certainly'), 6481: Logprob(logprob=-6.07307243347168, rank=4, decoded_token='Let'), 19058: Logprob(logprob=-6.82307243347168, rank=5, decoded_token='Okay'), 22515: Logprob(logprob=-6.82307243347168, rank=6, decoded_token='Based'), 902: Logprob(logprob=-7.82307243347168, rank=7, decoded_token='In'), 1018: Logprob(logprob=-13.32307243347168, rank=8, decoded_token='**'), 30988: Logprob(logprob=-13.82307243347168, rank=9, decoded_token='Looking'), 100409: Logprob(logprob=-14.69807243347168, rank=10, decoded_token='Absolutely')}
    check_logprobs_close(

tests/models/multimodal/generation/test_multimodal_gguf.py::test_models[10-32-bfloat16-model0]
  /home/mozf/develop-projects/vllm/tests/models/multimodal/generation/test_multimodal_gguf.py:150: UserWarning: Test1:
  Matched tokens:       [22515, 580, 506, 2471, 236764, 506, 3409, 563, 5213, 5412, 84750]
  hf:   'Based on the image, the season is **spring**. \n\nThe prominent feature is the blooming cherry blossoms (sakura), which are a classic symbol of spring'        {236743: -0.06157604977488518, 108: -3.0615761280059814, 669: -4.811575889587402, 138: -6.061575889587402, 39799: -6.311575889587402, 5715: -7.561575889587402, 1174: -9.311575889587402, 1599: -11.936575889587402, 139: -12.561575889587402, 1030: -14.061575889587402}
  gguf: 'Based on the image, the season is **spring**. The abundance of pink cherry blossoms (sakura) is a strong indicator of this time of year.'      {669: Logprob(logprob=-0.7382382750511169, rank=1, decoded_token=' The'), 236743: Logprob(logprob=-0.7382382750511169, rank=2, decoded_token=' '), 39799: Logprob(logprob=-3.2382383346557617, rank=3, decoded_token=' Specifically'), 108: Logprob(logprob=-5.488238334655762, rank=4, decoded_token='\n\n'), 138: Logprob(logprob=-7.488238334655762, rank=5, decoded_token='  '), 1174: Logprob(logprob=-9.488238334655762, rank=6, decoded_token=' This'), 1030: Logprob(logprob=-9.988238334655762, rank=7, decoded_token=' It'), 4643: Logprob(logprob=-10.738238334655762, rank=8, decoded_token=' More'), 1599: Logprob(logprob=-10.988238334655762, rank=9, decoded_token=' You'), 1191: Logprob(logprob=-14.238238334655762, rank=10, decoded_token=' We')}
    check_logprobs_close(

tests/models/multimodal/generation/test_multimodal_gguf.py::test_models[10-32-bfloat16-model0]
  /home/mozf/develop-projects/vllm/tests/models/multimodal/generation/test_multimodal_gguf.py:150: UserWarning: Test2:
  Matched tokens:       [22515, 580, 506, 2471, 236764, 506, 3409, 563, 5213, 5412, 84750, 236743, 108, 818]
  hf:   'Based on the image, the season is **spring**. \n\nThe prominent feature is the blooming cherry blossoms (sakura), which are a quintessential symbol of spring' {19942: -0.43055424094200134, 6219: -1.4305542707443237, 28239: -2.680554151535034, 26444: -4.180554389953613, 1881: -5.180554389953613, 1346: -5.180554389953613, 31295: -5.430554389953613, 19659: -5.680554389953613, 2471: -5.930554389953613, 2307: -7.180554389953613}
  gguf: 'Based on the image, the season is **spring**. \n\nThe presence of blooming cherry blossoms (sakura) is a strong indicator of spring.'  {6219: Logprob(logprob=-1.2486515045166016, rank=2, decoded_token=' presence'), 19942: Logprob(logprob=-1.2486515045166016, rank=1, decoded_token=' prominent'), 26444: Logprob(logprob=-1.4986515045166016, rank=3, decoded_token=' abundance'), 28239: Logprob(logprob=-2.4986515045166016, rank=4, decoded_token=' vibrant'), 2471: Logprob(logprob=-2.4986515045166016, rank=5, decoded_token=' image'), 31295: Logprob(logprob=-4.248651504516602, rank=6, decoded_token=' abundant'), 19659: Logprob(logprob=-4.498651504516602, rank=7, decoded_token=' dominant'), 1881: Logprob(logprob=-5.248651504516602, rank=8, decoded_token=' prof'), 68852: Logprob(logprob=-6.248651504516602, rank=9, decoded_token=' blossoms'), 1346: Logprob(logprob=-6.748651504516602, rank=10, decoded_token=' most')}
    check_logprobs_close(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================================== 1 passed, 16 warnings in 161.59s (0:02:41) =====================================================================

@lucianommartins
Copy link
Contributor Author

Hi @Isotr0py - there was a bug on the crops manipulation when dealing with batched requests (it is how your scenario caught the issue).

It is fixed now and the test_multimodal_gguf.py is working fine for both tests.

@lucianommartins
Copy link
Contributor Author

fyi @Isotr0py - ci/pr/docker-build-image is failing with Failed to fetch precompiled wheel metadata for variant cu129: HTTP Error 404: Not Found. is it a known issue?

Restores custom attention mask generation for Gemma3 GGUF multimodal models
that was partially reverted in vllm-project#28995. Implements robust GGUF-only guards to
ensure the feature only applies to GGUF models and does not affect HF models.

Changes:
- Add uses_custom_attention_masks() utility with GGUF file format check
- Add uses_custom_attention_masks property to ModelConfig
- Initialize uses_custom_attention_masks in GPUModelRunner
- Restore generate_attention_masks() method to Gemma3ForConditionalGeneration
- Implement 3-layer defense-in-depth guard mechanism

The implementation uses check_gguf_file() to guarantee that custom attention
mask logic only triggers for GGUF files, preventing the issue that caused
the original revert where HF models incorrectly triggered the custom logic.

Tested with GGUF models (1B, 4B, 270M) for both text-only and multimodal
inference. HF model compatibility verified via pytest multimodal test suite.

Signed-off-by: Luciano Martins <[email protected]>
Fixes seven critical issues in GGUF multimodal inference:

1. Attention scaling parameter bug (gemma3.py):
   - Fix F.scaled_dot_product_attention to use named parameters
   - Changed positional args to attn_mask=attn_mask, scale=self.scaling
   - Prevents incorrect dropout application (was 6.25% instead of 0%)

2. Custom attention mask persistence (gpu_model_runner.py):
   - Store custom_model_kwargs after mask generation
   - Merge custom_model_kwargs in _dummy_run
   - Prevents loss of attention masks during CUDA graph re-initialization

3. Pan-and-scan attention pattern (gemma3_mm.py):
   - Detect pan-and-scan mode via multimodal_config.do_pan_and_scan
   - Prevents crop isolation artifacts in sequential processing

4. GGUF unquantized weight loading (weight_utils.py):
   - Add proper dtype conversion for BF16/F16/F32 stored as uint8
   - Handle byte-to-dtype conversion (BF16: 2 bytes, F16: 2 bytes, F32: 4 bytes)
   - Add fallback handling for unexpected dtype/type combinations
   - Fixes weight loading for unquantized GGUF multimodal projector weights

5. Fix GGUF pan-and-scan attention and token handling:
   - Unified bidirectional attention for all image tokens (fixes random output)
   - Fixed pan-and-scan token splits via position discontinuity detection
   - Corrected num_image_patches to reflect actual crop count from pixel_values

6. Add GGUF multimodal integration tests (test_multimodal_gguf.py):
   - Compare GGUF (vLLM) output against native HuggingFace (HfRunner)
   - Test regular multimodal inference with QAT Q4_0 GGUF
   - Test pan-and-scan multimodal with unquantized BF16 GGUF
   - Use @create_new_process_for_each_test for subprocess isolation

7. Robust sequence detection for batched multimodal (gemma3_mm.py):
   - Replace  heuristic with
   - Pass query_start_loc from gpu_model_runner.py to generate_attention_masks()
   - Fixes multi-scale batching failures where num_computed_tokens > 0
   - Ensures correct attention mask generation for all batch configurations

Signed-off-by: Luciano Martins <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants