Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions docs/features/custom_logitsprocs.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ RequestLogitsProcessor = Union[

While request-level logits processors are explicitly *not* supported in the vLLM engine, vLLM *does* provide a convenient process to wrap an existing `Callable` request-level logits processor and create a batch-level logits processor that is compatible with vLLM. The `Callable` must conform to the type annotation above; if your request-level logits processor has a different interface, then in order to wrap it, you may need to modify it or implement an additional wrapper layer to comply with the interface specification above.

You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.) Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance:
You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is great that you updated the documentation to reflect these changes.

Correct me if wrong, but it appears like this PR right now is only updating the documentation for the special case of adapting a request-level logits processor.

I think we will probably also want to update the documentation for

  1. The "Programming Model" section of the logits processors design docs https://docs.vllm.ai/en/latest/design/logits_processors.html#logits-processor-programming-model

  2. Other sections of the custom logits processor design docs, specifically "Creating a custom logits processor" (https://docs.vllm.ai/en/latest/features/custom_logitsprocs.html#creating-a-custom-logits-processor), "Passing Custom Argument to a Custom Logits Procesor" (https://docs.vllm.ai/en/latest/features/custom_logitsprocs.html#passing-custom-argument-to-a-custom-logits-processor), "Example custom logits processor implementation" (https://docs.vllm.ai/en/latest/features/custom_logitsprocs.html#example-custom-logits-processor-implementation)

- Override `AdapterLogitsProcessor.validate_params(cls,params)` to validate request's sampling parameters.
- Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit.
- Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance:

??? code "Example of Wrapping a Request-Level Logits Processor"

Expand Down Expand Up @@ -223,6 +226,16 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
def is_argmax_invariant(self) -> bool:
return False

@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} is not int"
)

def new_req_logits_processor(
self,
params: SamplingParams,
Expand All @@ -240,18 +253,11 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
Returns:
`Callable` request logits processor, or None
"""
target_token: Optional[Any] = params.extra_args and params.extra_args.get(
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is None:
return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token)
```

Expand Down
15 changes: 8 additions & 7 deletions examples/offline_inference/logits_processor/custom_req.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
def is_argmax_invariant(self) -> bool:
return False

@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(f"target_token value {target_token} is not int")

def new_req_logits_processor(
self,
params: SamplingParams,
Expand All @@ -101,13 +109,6 @@ def new_req_logits_processor(
)
if target_token is None:
return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token)


Expand Down
15 changes: 8 additions & 7 deletions examples/offline_inference/logits_processor/custom_req_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ def __init__(
def is_argmax_invariant(self) -> bool:
return False

@classmethod
def validate_params(cls, params: SamplingParams):
target_token = params.extra_args and params.extra_args.get("target_token")
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"`target_token` has to be an integer, got {target_token}."
)

def new_req_logits_processor(
self,
params: SamplingParams,
Expand Down Expand Up @@ -113,13 +121,6 @@ def new_req_logits_processor(
is None
):
return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token)


Expand Down
20 changes: 18 additions & 2 deletions vllm/entrypoints/openai/logits_processors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from functools import lru_cache, partial

import torch

from vllm.sampling_params import LogitsProcessor
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.sample.logits_processor import (
AdapterLogitsProcessor,
_load_custom_logitsprocs,
)


class AllowedTokenIdsLogitsProcessor:
Expand Down Expand Up @@ -90,3 +94,15 @@ def get_logits_processors(
)

return logits_processors


def validate_logits_processors_parameters(
logits_processors: Sequence[str | LogitsProcessor] | None,
sampling_params: SamplingParams,
):
if logits_processors is None:
return None

for logits_procs in _load_custom_logitsprocs(logits_processors): # type: ignore[arg-type]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't attempt to load the logits processors on every request.

We should also be careful about loading them at all in the front-end process since I think we currently only do it in the worker processes.

We will probably have to do this to support the request validation, it's hopefully ok if we only load the class without instantiating it (so hopefully won't be allocating tensors etc. for example), but this could for example cause cuda to be initialized which we don't want to do in this process.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but this could for example cause cuda to be initialized which we don't want to do in this process.

IIUC, this case is usually occurred when the file contents code that will caused cuda initialization outside logits processor implementation like this?

class CustomLogitsProcessor:
    ...

# some code initialize CUDA outside
torch._C._cuda_init()

Given that we're doing request validation in front-end process, I think we can hide all GPUs temporarily when loading the custom logits class, so it will trigger RuntimeError: No CUDA GPUs are available when trying to intialize CUDA and refuse request with this invalid processor. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Isotr0py that sounds like a good idea

if isinstance(logits_procs, AdapterLogitsProcessor):
logits_procs.validate_params(sampling_params)
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,10 +769,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
description="KVTransfer parameters used for disaggregated serving.",
)

vllm_xargs: dict[str, str | int | float] | None = Field(
vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
default=None,
description=(
"Additional request parameters with string or "
"Additional request parameters with (list of) string or "
"numeric values, used by custom extensions."
),
)
Expand Down
10 changes: 10 additions & 0 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
render_for_completion,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.logits_processors import (
validate_logits_processors_parameters,
)
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb,
ChatCompletionLogProbs,
Expand Down Expand Up @@ -110,6 +113,9 @@ def __init__(
self.trust_request_chat_template = trust_request_chat_template
self.enable_log_outputs = enable_log_outputs

# set up logits processors
self.logits_processors = self.model_config.logits_processors

# set up reasoning parser
self.reasoning_parser = self._get_reasoning_parser(
reasoning_parser_name=reasoning_parser
Expand Down Expand Up @@ -291,6 +297,10 @@ async def create_chat_completion(
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
validate_logits_processors_parameters(
self.logits_processors,
sampling_params,
)

self._log_inputs(
request_id,
Expand Down
29 changes: 20 additions & 9 deletions vllm/model_executor/models/deepseek_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,18 @@ def __init__(
def is_argmax_invariant(self) -> bool:
return True

def new_req_logits_processor(
self,
params: SamplingParams,
) -> RequestLogitsProcessor | None:
@classmethod
def validate_params(cls, params: SamplingParams):
ngram_size = params.extra_args and params.extra_args.get("ngram_size")
window_size = params.extra_args and params.extra_args.get("window_size", 100)
whitelist_token_ids = params.extra_args and params.extra_args.get(
"whitelist_token_ids", None
)
# if ngram_size is not provided, skip validation because the processor
# will not be used.
if ngram_size is None:
return None

if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(
f"`ngram_size` has to be a strictly positive integer, got {ngram_size}."
Expand All @@ -163,13 +164,23 @@ def new_req_logits_processor(
whitelist_token_ids, Iterable
):
raise ValueError(
"`whitelist_token_ids` has to be a set of integers, "
"`whitelist_token_ids` has to be a sequence of integers, "
f"got {whitelist_token_ids}."
)
else:
whitelist_token_ids = (
set(whitelist_token_ids) if whitelist_token_ids else None
)

def new_req_logits_processor(
self,
params: SamplingParams,
) -> RequestLogitsProcessor | None:
ngram_size = params.extra_args and params.extra_args.get("ngram_size")
window_size = params.extra_args and params.extra_args.get("window_size", 100)
whitelist_token_ids = params.extra_args and params.extra_args.get(
"whitelist_token_ids", None
)
if ngram_size is None:
return None

whitelist_token_ids = set(whitelist_token_ids) if whitelist_token_ids else None
return NoRepeatNGramLogitsProcessor(
ngram_size=ngram_size,
window_size=window_size,
Expand Down
6 changes: 6 additions & 0 deletions vllm/transformers_utils/configs/deepseek_vl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,9 @@ def __init__(
self.global_view_pos = global_view_pos
self.candidate_resolutions = candidate_resolutions
self.vocab_size = self.text_config.vocab_size

# update model_type for OCR model
if "DeepseekOCRForCausalLM" in (
self.architectures or kwargs.get("architectures", [])
):
self.model_type = "deepseek_ocr"
Comment on lines +222 to +226
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition if "DeepseekOCRForCausalLM" in (self.architectures or kwargs.get("architectures", [])) could potentially be simplified by directly checking if "DeepseekOCRForCausalLM" in self.architectures + kwargs.get("architectures", []). This avoids the need for the or operator and might improve readability.

However, it's crucial to ensure that this change doesn't alter the behavior of the code, especially if self.architectures or kwargs.get("architectures", []) could be None or not a list. Adding a check to ensure that these are lists before concatenation could mitigate this risk.

Suggested change
# update model_type for OCR model
if "DeepseekOCRForCausalLM" in (
self.architectures or kwargs.get("architectures", [])
):
self.model_type = "deepseek_ocr"
# update model_type for OCR model
architectures = self.architectures if self.architectures else []
architectures += kwargs.get("architectures", []) if kwargs.get("architectures", []) else []
if "DeepseekOCRForCausalLM" in architectures:
self.model_type = "deepseek_ocr"

14 changes: 14 additions & 0 deletions vllm/v1/sample/logits_processor/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ def is_argmax_invariant(self) -> bool:
"""Min-p never impacts greedy sampling"""
return True

@classmethod
def validate_params(cls, sampling_params: SamplingParams):
min_p = sampling_params.min_p
if min_p is not None and (min_p < 0.0 or min_p > 1.0):
raise ValueError("min_p should be in the range [0.0, 1.0]")

def get_min_p_by_index(self, index: int) -> float:
return float(self.min_p_cpu[index])

Expand Down Expand Up @@ -131,6 +137,10 @@ def is_argmax_invariant(self) -> bool:
outcome of argmax in greedy sampling."""
return False

@classmethod
def validate_params(cls, sampling_params: SamplingParams):
pass

def update_state(self, batch_update: BatchUpdate | None):
needs_update = process_dict_updates(
self.biases, batch_update, lambda params, _, __: params.logit_bias or None
Expand Down Expand Up @@ -183,6 +193,10 @@ def is_argmax_invariant(self) -> bool:
of the argmax operation in greedy sampling."""
return False

@classmethod
def validate_params(cls, sampling_params: SamplingParams):
pass

@staticmethod
def add_request(
params: SamplingParams, _: list[int] | None, output_tok_ids: list[int]
Expand Down
9 changes: 9 additions & 0 deletions vllm/v1/sample/logits_processor/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ def __init__(
) -> None:
raise NotImplementedError

@classmethod
@abstractmethod
def validate_params(cls, sampling_params: SamplingParams):
"""Validate sampling params for this logits processor.

Raise ValueError for invalid ones.
"""
raise NotImplementedError

@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
"""Apply LogitsProcessor to batch logits tensor.
Expand Down