Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions docs/models/pooling_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ shown in the table below.
| Architecture | `--convert` | Supported pooling tasks |
|-------------------------------------------------|-------------|---------------------------------------|
| `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `token_embed`, `embed` |
| `*ForRewardModeling`, `*RewardModel` | `embed` | `token_embed`, `embed` |
| `*For*Classification`, `*ClassificationModel` | `classify` | `token_classify`, `classify`, `score` |
| `*ForRewardModeling`, `*RewardModel` | `reward` | `token_classify` |

!!! tip
You can explicitly set `--convert <type>` to specify how to convert the model.
Expand Down Expand Up @@ -70,7 +70,6 @@ the pooler assigned to each task has the following attributes by default:

| Task | Pooling Type | Normalization | Softmax |
|------------|--------------|---------------|---------|
| `reward` | `ALL` | ❌ | ❌ |
| `embed` | `LAST` | ✅︎ | ❌ |
| `classify` | `LAST` | ❌ | ✅︎ |

Expand Down Expand Up @@ -318,3 +317,10 @@ We have split the `encode` task into two more specific token-wise tasks: `token_
### Remove softmax from PoolingParams

We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.

### as_reward_model

Pooling models now default support all pooling, you can use it without any settings.

- Extracting hidden states prefers using `token_embed` task.
- Reward models prefers using `token_classify` task.
9 changes: 1 addition & 8 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -581,16 +581,9 @@ These models primarily support the [`LLM.reward`](./pooling_models.md#llmreward)
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|-------------------|----------------------|---------------------------|
| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ |
| `LlamaForCausalLM`<sup>C</sup> | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ |
| `LlamaForCausalLM` | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ |
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ |
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ |
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* |

<sup>C</sup> Automatically converted into a reward model via `--convert reward`. ([details](./pooling_models.md#model-conversion))
\* Feature support is the same as that of the original model.

If your model is not in the above list, we will try to automatically convert the model using
[as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly.

!!! important
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
Expand Down
71 changes: 71 additions & 0 deletions examples/pooling/token_embed/jina_embeddings_v4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

from vllm import LLM
from vllm.inputs.data import TextPrompt
from vllm.multimodal.utils import fetch_image

# Initialize model
model = LLM(
model="jinaai/jina-embeddings-v4-vllm-text-matching",
runner="pooling",
max_model_len=1024,
gpu_memory_utilization=0.8,
)

# Create text prompts
text1 = "Ein wunderschöner Sonnenuntergang am Strand"
text1_prompt = TextPrompt(prompt=f"Query: {text1}")

text2 = "浜辺に沈む美しい夕日"
text2_prompt = TextPrompt(prompt=f"Query: {text2}")

# Create image prompt
image = fetch_image(
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/eskimo.jpg" # noqa: E501
)
image_prompt = TextPrompt(
prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", # noqa: E501
multi_modal_data={"image": image},
)

# Encode all prompts
prompts = [text1_prompt, text2_prompt, image_prompt]
outputs = model.encode(prompts, pooling_task="token_embed")


def get_embeddings(outputs):
VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653

embeddings = []
for output in outputs:
if VISION_START_TOKEN_ID in output.prompt_token_ids:
# Gather only vision tokens
img_start_pos = torch.where(
torch.tensor(output.prompt_token_ids) == VISION_START_TOKEN_ID
)[0][0]
img_end_pos = torch.where(
torch.tensor(output.prompt_token_ids) == VISION_END_TOKEN_ID
)[0][0]
embeddings_tensor = output.outputs.data.detach().clone()[
img_start_pos : img_end_pos + 1
]
else:
# Use all tokens for text-only prompts
embeddings_tensor = output.outputs.data.detach().clone()

# Pool and normalize embeddings
pooled_output = (
embeddings_tensor.sum(dim=0, dtype=torch.float32)
/ embeddings_tensor.shape[0]
)
embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1))
return embeddings


embeddings = get_embeddings(outputs)

for embedding in embeddings:
print(embedding.shape)
2 changes: 0 additions & 2 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
)
from vllm.model_executor.models.adapters import (
as_embedding_model,
as_reward_model,
as_seq_cls_model,
)
from vllm.model_executor.models.registry import (
Expand Down Expand Up @@ -46,7 +45,6 @@ def test_registry_imports(model_arch):
# All vLLM models should be convertible to a pooling model
assert is_pooling_model(as_seq_cls_model(model_cls))
assert is_pooling_model(as_embedding_model(model_cls))
assert is_pooling_model(as_reward_model(model_cls))

if model_arch in _MULTIMODAL_MODELS:
assert supports_multimodal(model_cls)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_update_config():
("intfloat/multilingual-e5-small", "pooling", "none", "embed"),
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", "classify"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "reward"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "embed"),
("openai/whisper-small", "generate", "none", "transcription"),
],
)
Expand Down
10 changes: 7 additions & 3 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,11 @@ def _task_to_convert(task: TaskOption) -> ConvertType:
if task == "classify":
return "classify"
if task == "reward":
return "reward"
logger.warning(
"Pooling models now default support all pooling; "
"you can use it without any settings."
)
return "embed"
if task == "score":
new_task = self._get_default_pooling_task(architectures)
return "classify" if new_task == "classify" else "embed"
Expand Down Expand Up @@ -1899,8 +1903,8 @@ def get_served_model_name(model: str, served_model_name: str | list[str] | None)
("ForImageClassification", ("pooling", "classify")),
("ForVideoClassification", ("pooling", "classify")),
("ClassificationModel", ("pooling", "classify")),
("ForRewardModeling", ("pooling", "reward")),
("RewardModel", ("pooling", "reward")),
("ForRewardModeling", ("pooling", "embed")),
("RewardModel", ("pooling", "embed")),
# Let other `*Model`s take priority
("Model", ("pooling", "embed")),
]
Expand Down
4 changes: 0 additions & 4 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
from vllm.model_executor.models.adapters import (
as_embedding_model,
as_reward_model,
as_seq_cls_model,
try_create_mm_pooling_model_cls,
)
Expand Down Expand Up @@ -207,9 +206,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
elif convert_type == "classify":
logger.debug_once("Converting to sequence classification model.")
model_cls = as_seq_cls_model(model_cls)
elif convert_type == "reward":
logger.debug_once("Converting to reward model.")
model_cls = as_reward_model(model_cls)
else:
assert_never(convert_type)

Expand Down
38 changes: 0 additions & 38 deletions vllm/model_executor/models/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,44 +346,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
return ModelForSequenceClassification # type: ignore


def as_reward_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support reward modeling.

By default, we return the hidden states of each token directly.

Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing reward models
if is_pooling_model(cls):
return cls

# Lazy import
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler

from .interfaces_base import default_pooling_type

@default_pooling_type("ALL")
class ModelForReward(_create_pooling_model_cls(cls)):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
)
}
)

ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward")

return ModelForReward # type: ignore


class SequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
Expand Down