Skip to content

Conversation

@piood
Copy link
Contributor

@piood piood commented Oct 27, 2025

Purpose

Extend SigLIP model support to SigLIP2, building upon the existing SigLIP embedding architecture.

  • Added handling for model configs missing the architectures field by automatically get the architecture through MODEL_MAPPING_NAMES
  • Dynamically select dummy_token_id by finding the first non-special token ID, replacing the previous hardcoded value. This ensures compatibility with both SigLIP and SigLIP2 tokenizers across different versions
  • Maintains the existing SigLIP embedding architecture: for text inputs, only token_embedding is applied in get_input_embeddings; for image inputs, complete vision encoding is applied in get_input_embeddings
    This PR extends multimodal embedding capabilities to support SigLIP2 models, which are widely used for vision-language tasks.
    To resolve [New Model]: Google SigLip 2 #13663 and resolve [Feature]: Support serving of CLIP/SigLIP embeddings #25581 - This PR builds upon SigLIP embedding support [Model] Siglip Embedding Support #27324 to add SigLIP2 support.

Test Plan

  • Added google/siglip2-base-patch16-224 tests in tests/models/multimodal/pooling/test_siglip.py
  • Updated model registry to include SigLIP2 embedding support
  • Updated supported_models.md documentation with SigLIP2 model examples
  • Verified with local test runs

Test Result

  • All SigLIP2-specific tests pass
  • Model registry correctly recognizes SigLIP2 embedding models
  • Both text and image embedding generation work as expected
  • Config processing logic correctly handles models without architectures field

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@piood piood requested a review from noooop as a code owner October 27, 2025 10:51
@mergify
Copy link

mergify bot commented Oct 27, 2025

Documentation preview: https://vllm--27566.org.readthedocs.build/en/27566/

@mergify mergify bot added documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) labels Oct 27, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends vLLM's capabilities to support SigLIP2 models, building upon the existing SigLIP architecture. The changes include updating the dummy token ID for tokenizer compatibility, adding SigLIP2 to the test suite and documentation, and enhancing the model configuration loading to automatically infer model architecture when it's not explicitly defined. My review focuses on the robustness of this new configuration logic. I've identified one high-severity issue where the check for a missing architecture field could be more robust to handle cases like an empty list, which would otherwise cause model loading to fail.

@piood piood mentioned this pull request Oct 27, 2025
1 task
piood added 2 commits October 27, 2025 11:03
Signed-off-by: piood <[email protected]>
Signed-off-by: piood <[email protected]>
@piood
Copy link
Contributor Author

piood commented Oct 27, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends the SigLIP model support to SigLIP2, including handling model configurations and updating the dummy token ID. The changes involve modifications to the documentation, test files, and model configuration files. The review focuses on ensuring the correctness of the dummy token ID and the robustness of the configuration handling.

Comment on lines 623 to 627
if not config.architectures:
if config.model_type not in MODEL_MAPPING_NAMES:
raise ValueError(f"Model type {config.model_type} not supported")
model_type = MODEL_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The added logic to automatically determine the architecture for models without an explicit architectures field is good for robustness. However, raising a ValueError if the model_type is not found in MODEL_MAPPING_NAMES might be too strict. A more graceful fallback could involve attempting to load the model with trust_remote_code=True or issuing a warning and proceeding with a default architecture. This could prevent the system from failing completely when encountering a new or less common model type. Also, consider adding a log message to indicate when this automatic architecture detection is being used, which can help with debugging and understanding the system's behavior.

Suggested change
if not config.architectures:
if config.model_type not in MODEL_MAPPING_NAMES:
raise ValueError(f"Model type {config.model_type} not supported")
model_type = MODEL_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})
if not config.architectures:
if config.model_type not in MODEL_MAPPING_NAMES:
logger.warning(f"Model type {config.model_type} not found in MODEL_MAPPING_NAMES. Attempting to proceed without explicit architecture.")
# Optionally, try loading with trust_remote_code or assign a default architecture here
config.update({"architectures": ["AutoModel"]}) # Example: setting a default architecture
else:
model_type = MODEL_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})

@piood piood changed the title [Model] Siglip Model Support [Model] Siglip2 Model Support Oct 27, 2025
Signed-off-by: piood <[email protected]>
Signed-off-by: piood <[email protected]>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 27, 2025 11:32
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 27, 2025
@vllm-bot vllm-bot merged commit 4f882be into vllm-project:main Oct 27, 2025
54 of 57 checks passed
@ywang96
Copy link
Member

ywang96 commented Oct 27, 2025

FYI this is breaking entrypoints/test_chat_utils.py::test_resolve_content_format_fallbacks[deepseek-ai/deepseek-vl2-tiny-string on main

@njhill
Copy link
Member

njhill commented Oct 27, 2025

Curious why this was merged with CI failures?

@DarkLight1337
Copy link
Member

Oh, I didn't think this PR is related to chat utils so I thought the failure was unrelated

@njhill
Copy link
Member

njhill commented Oct 28, 2025

Really we should never assume this no matter how unlikely it seems. Unless the same failure has been seen on main. Since this keeps happening (breakage due to incorrect assumption which then has much wider blast radius).

@DarkLight1337
Copy link
Member

I agree in principle, unfortunately entrypoints tests has been quite flaky these few weeks so it's easy to accidentally miss these true positives...

@DarkLight1337
Copy link
Member

Does the HF implementation of the model support dynamic image size? In any case I doubt that has anything to do with the issue you faced since the issue occurs even with text only prompt.

@sleepwalker2017
Copy link
Contributor

Does the HF implementation of the model support dynamic image size? In any case I doubt that has anything to do with the issue you faced since the issue occurs even with text only prompt.

I'm not talking about the bug. I'm using Siglip2, and the output of Processor contains

(Pdb) p output
{'pixel_values': tensor([[[ 1., -1., -1.,  ...,  1., -1., -1.],
(EngineCore_DP0 pid=41706)          [ 1., -1., -1.,  ...,  1., -1., -1.],
(EngineCore_DP0 pid=41706)          [ 1., -1., -1.,  ...,  1., -1., -1.],
(EngineCore_DP0 pid=41706)          ...,
(EngineCore_DP0 pid=41706)          [ 1., -1., -1.,  ...,  1., -1., -1.],
(EngineCore_DP0 pid=41706)          [ 1., -1., -1.,  ...,  1., -1., -1.],
(EngineCore_DP0 pid=41706)          [ 1., -1., -1.,  ...,  1., -1., -1.]]]), 'pixel_attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
(EngineCore_DP0 pid=41706)          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
(EngineCore_DP0 pid=41706)          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
(EngineCore_DP0 pid=41706)          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
(EngineCore_DP0 pid=41706)          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
(EngineCore_DP0 pid=41706)          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
(EngineCore_DP0 pid=41706)          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
(EngineCore_DP0 pid=41706)          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
(EngineCore_DP0 pid=41706)          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
(EngineCore_DP0 pid=41706)          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
(EngineCore_DP0 pid=41706)          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32), 'spatial_shapes': tensor([[16, 16]]), 'input_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
(EngineCore_DP0 pid=41706)          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
(EngineCore_DP0 pid=41706)          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}

But when it comes to inference, the model only get pixel_values, others are discarded.

@DarkLight1337
Copy link
Member

It looks like on transformers side, SigLIP2 supports resizing image embeddings whereas SigLIP does not

ilmarkov pushed a commit to neuralmagic/vllm that referenced this pull request Nov 7, 2025
ZhengHongming888 pushed a commit to ZhengHongming888/vllm that referenced this pull request Nov 8, 2025
@piood
Copy link
Contributor Author

piood commented Nov 9, 2025

I have push #28365 to fix siglip batch text output error. @DarkLight1337 @sleepwalker2017

rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
@apoorv-stackav
Copy link

Hi @piood @DarkLight1337 ,
I was recently working with the recently added siglip2 model google/siglip2-base-patch16-224 and found the following shortcomings:

  • The text embeddings produced by google/siglip2-base-patch16-224 do not produce embeddings corresponding to padding='maxlen' . Because of this, it can not be used for downstream tasks like zero-shot classification. Can we somehow force padding='maxlen' for all the entire sequence length which is 64 in size as an exception for this model?
  • I tried to load siglip2-large-patch16-256 but this results in output embeddings of shape 768 while the huggingface model is returns a 1024 dimensional embedding. Where can I look into the root cause for this ?
  • I also tried to load google/siglip2-so400m-patch14-224 model, but am seeing this error :
    (EngineCore_DP0 pid=3066502) ERROR 11-15 00:28:55 [core.py:855] raise NotImplementedError(
    (EngineCore_DP0 pid=3066502) ERROR 11-15 00:28:55 [core.py:855] NotImplementedError: Encoder self-attention and encoder/decoder cross-attention are not implemented for TritonAttentionImpl

Any tips on how we can add these variants ?

@piood
Copy link
Contributor Author

piood commented Nov 16, 2025

@apoorv-stackav Second and third problem, i don't found in locally, in my own env the dimensional is right for siglip2-large-patch16-256 , and run siglip2-so400m-patch14-224 not occur error, please check your env.

@piood
Copy link
Contributor Author

piood commented Nov 16, 2025

The max_length parameter is implemented by the transformers tokenizer, not the model itself. Padding is applied in the tokenizer space before the token IDs are passed to the model. Since vllm uses the transformers tokenizer to get the token IDs, the padding should be configured at the tokenizer level.

Is it easy to implement max_length in vllm? @DarkLight1337

@DarkLight1337
Copy link
Member

Indeed the padding should be done by tokenizer. You should be able to do this by passing mm_processor_kwargs?

@piood
Copy link
Contributor Author

piood commented Nov 17, 2025

@apoorv-stackav, I believe the SigLIP model only uses the hidden state of the last token for the sequence embedding. Therefore, padding shouldn't affect the final result. Please correct me if I'm wrong.

@zac-wang-nv
Copy link

@piood , thanks for the great work to add support for siglip2, I tested it with google/siglip2-so400m-patch16-naflex and got the following error,
AttributeError: 'Siglip2Processor' object has no attribute '_get_num_multimodal_tokens',
is that expected? how can I get it fixed?

@DarkLight1337
Copy link
Member

Which version of transformers are you using? You might have to update it.

@DarkLight1337
Copy link
Member

Also this PR doesn't cover the NaViT variant

@piood
Copy link
Contributor Author

piood commented Nov 27, 2025

@zac-wang-nv siglip2navit implement in vllm/model_executor/models/siglip2navit.py , but now only have vision embedding, text embedding need to support.I will try to support siglip2navit embedding in weekend.

@FlyingYanglu
Copy link

FlyingYanglu commented Nov 28, 2025

Hi, @piood @DarkLight1337
I'm new to vllm and I’m trying to serve google/siglip2-giant-opt-patch16-384 for image embedding.
I noticed that image-only embedding requests always come back with 1152-dimensional vectors, even though the vision tower and HF reference implementation produce 1536-dimensional embeddings.

Model config

{
  "initializer_factor": 1.0,
  "model_type": "siglip",
  "text_config": {
    "hidden_size": 1152,
    "projection_size": 1536,
    ...
  },
  "vision_config": {
    "hidden_size": 1536,
    ...
  }
}

So the text encoder is 1152 → 1536 projection, while the vision encoder already operates at 1536.

Serving command

CUDA_VISIBLE_DEVICES=1,2 vllm serve \
  /data1/common_models/hub/models--google--siglip2-giant-opt-patch16-384/snapshots/a713301b217d38485fb2204c808367d10bc3cc40/ \
  --runner pooling \
  --tensor-parallel-size 2 \
  --gpu-memory-utilization 0.7 \
  --enforce-eager \
  --chat-template tmp/siglip_chat_template.jinja \
  --mm-processor-cache-gb 0 \
  --disable-mm-preprocessor-cache \
  --port 9120

Request

POST http://localhost:9120/v1/embeddings
{
  "model": "<resolved model id>",
  "messages": [
    {
      "role": "user",
      "content": [
        { "type": "image_url", "image_url": { "url": "<image>" } }
      ]
    }
  ]
}

Behavior

  • The server returns a 1152-length embedding array.
  • Debug logs show the vision tower outputs torch.Size([1, 1, 1536]), but by the time SiglipEmbeddingModel.forward() runs, inputs_embeds has been coerced to torch.Size([1, 1152]).
  • That suggests _merge_multimodal_embeddings (or upstream prompt assembly) truncates the image embedding to the text hidden size before the model forward, so the pooler only ever sees 1152 dims.

Question

Is there a configuration/task flag I’m missing to bypass the text prompt path for image-only requests, or is this indeed a bug in the merge logic for multimodal embeddings? I’d love guidance on whether I should pad the text embeddings to 1536 before merging or if there’s an existing vision-only embedding pathway I can use.

Thanks!

devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
@jtruetsch
Copy link

Same problem here (v0.11.2): google/siglip2-giant-opt-patch16-384 returns text and image embeddings with mismatched dimensions. 1536-dimensional text embeddings (expected size) but 1152-dimensional image embeddings.

@piood
Copy link
Contributor Author

piood commented Nov 30, 2025

@FlyingYanglu @jtruetsch There are some tricky issues here. For siglip2-giant-opt-patch16-384, text_config.hidden_size is 1152, text_config.projection_size is 1536, and vision_config.hidden_size is 1536.
However, the implementation for all models in vLLM's core gpu_model_runner.py creates input_embeds based on hidden_size(multimodal use text_config.hidden_size). If there are multimodal embeds, it fills them into the text embeds at the placeholder positions. So, the current SigLIP implementation is correct, but during scheduling, it gets truncated to hidden_size=1152. This happens because the current gpu_model_runner.py assumes the final output hidden_size equals the model config hidden_size. But for siglip2-giant-opt-patch16-384, this assumption does not hold, which causes the issue.

Since fixing this might require many changes to gpu_model_runner.py, @DarkLight1337 do you have any suggestions?

Maybe we can keep input_embeds as is, but introduce an output_embeds to handle cases where the final embedding size differs from config.hidden_size due to projection_size.

@DarkLight1337
Copy link
Member

Let me open a PR to facilitate that

@DarkLight1337
Copy link
Member

Opened #29741

@FlyingYanglu
Copy link

Thank you so much for the help! @piood @DarkLight1337

@FlyingYanglu
Copy link

Indeed the padding should be done by tokenizer. You should be able to do this by passing mm_processor_kwargs?

Hi team, @piood @DarkLight1337 — sorry for the trouble again.

It looks like padding-to-max_length is the officially expected behavior for SigLIP2 text preprocessing (as noted in the HF docs), so I’m trying to keep my tokenizer/preprocessor aligned with that.

However, I found a couple of issues:

  • mm_processor_kwargs only applies when multi_modal_data is present, so text-only prompts never receive those overrides.
  • There is a tokenization_kwargs parameter on LLM.encode(...), but in the current implementation it’s never used — the code reconstructs an empty dict before calling the tokenizer.

Because of this, the only reliable way to match HF behavior right now is to pass in pre-tokenized IDs manually. It would be super helpful if tokenization_kwargs were actually plumbed through so we can set things like padding="max_length" and truncation=True, and keep SigLIP’s text path consistent with the expected HF preprocessing.

@DarkLight1337
Copy link
Member

@noooop can you help with this?

@piood
Copy link
Contributor Author

piood commented Dec 1, 2025

Open #29794 to support tokenization_kwargs override.

@piood piood mentioned this pull request Dec 2, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: Support serving of CLIP/SigLIP embeddings [New Model]: Google SigLip 2