Skip to content

Conversation

@Isotr0py
Copy link
Member

@Isotr0py Isotr0py commented Nov 11, 2025

Purpose

Discussion: https://vllm-dev.slack.com/archives/C08U97ZRC0J/p1762826798388249

  • _load_custom_logitsprocs can cause significant performance regression when validating sampling params, because it will try to import custom logitsprocs each time.
  • This PR use lru_cache to cache the loaded custom logits processors to avoid duplicated import.

Test Plan

vllm serve Qwen/Qwen3-0.6B --enforce-eager
vllm bench serve   --backend vllm   --model Qwen/Qwen3-0.6B   --endpoint /v1/completions   --dataset-name sharegpt   --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json   --num-prompts 1000

Test Result

Main

============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Benchmark duration (s):                  40.25     
Total input tokens:                      217393    
Total generated tokens:                  201847    
Request throughput (req/s):              24.85     
Output token throughput (tok/s):         5015.16   
Peak output token throughput (tok/s):    22751.00  
Peak concurrent requests:                1000.00   
Total Token throughput (tok/s):          10416.59  
---------------Time to First Token----------------
Mean TTFT (ms):                          13626.51  
Median TTFT (ms):                        14087.92  
P99 TTFT (ms):                           28310.02  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          40.77     
Median TPOT (ms):                        37.58     
P99 TPOT (ms):                           156.15    
---------------Inter-token Latency----------------
Mean ITL (ms):                           35.98     
Median ITL (ms):                         20.86     
P99 ITL (ms):                            363.41    
==================================================

PR

============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Benchmark duration (s):                  39.53     
Total input tokens:                      217393    
Total generated tokens:                  201847    
Request throughput (req/s):              25.29     
Output token throughput (tok/s):         5105.56   
Peak output token throughput (tok/s):    7680.00   
Peak concurrent requests:                1000.00   
Total Token throughput (tok/s):          10604.35  
---------------Time to First Token----------------
Mean TTFT (ms):                          11136.58  
Median TTFT (ms):                        9902.20   
P99 TTFT (ms):                           27619.44  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          42.25     
Median TPOT (ms):                        41.89     
P99 TPOT (ms):                           65.23     
---------------Inter-token Latency----------------
Mean ITL (ms):                           38.48     
Median ITL (ms):                         35.05     
P99 ITL (ms):                            77.71     
==================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the v1 label Nov 11, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization by caching the results of loading custom logits processors using lru_cache. This is a valuable improvement to reduce overhead. However, I've identified a critical issue where using lru_cache with a potentially unhashable list argument will lead to a TypeError at runtime. I have provided a review comment with a suggested fix to ensure the argument is always hashable.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 219 to 226
cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs)


def validate_logits_processors_parameters(
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
sampling_params: SamplingParams,
):
for logits_procs in _load_custom_logitsprocs(logits_processors):
for logits_procs in cached_load_custom_logitsprocs(logits_processors):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Cache assumes hashable logits processor sequences

The new lru_cache wrapper is invoked with logits_processors as the cache key, but the callers in the OpenAI serving paths pass the model’s logits_processors list directly. Lists are unhashable, so cached_load_custom_logitsprocs(self.logits_processors) will now raise TypeError: unhashable type: 'list' as soon as a request is validated when custom processors are configured (e.g. via --logits-processors). This previously worked because _load_custom_logitsprocs accepted any sequence. Consider normalizing the argument to a tuple (or similar) before caching to avoid runtime failures for users supplying custom logits processors.

Useful? React with 👍 / 👎.

Signed-off-by: Isotr0py <[email protected]>
@Isotr0py
Copy link
Member Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces caching for loading custom logits processors to improve performance, which is a valuable optimization. The implementation correctly uses lru_cache and handles the non-hashable Sequence type by converting it to a tuple. My review includes one high-severity suggestion to further improve the caching efficiency by applying it more granularly to avoid redundant work when loading plugins.

)


cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While using lru_cache on _load_custom_logitsprocs is a good performance improvement, it can be made more efficient. _load_custom_logitsprocs internally calls _load_logitsprocs_plugins(), which does not depend on any arguments. With the current implementation, _load_logitsprocs_plugins() will be re-executed for every cache miss of cached_load_custom_logitsprocs (i.e., for each new logits_processors value).

To avoid this repeated work, _load_logitsprocs_plugins() should be cached independently. The ideal solution would be to apply @lru_cache directly to _load_logitsprocs_plugins and _load_logitsprocs_by_fqcns. This would require modifying those functions, which are outside the current diff.

For example:

@lru_cache(maxsize=None)
def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
    # ... function body

@lru_cache(maxsize=None)
def _load_logitsprocs_by_fqcns(
    logits_processors: tuple[str | type[LogitsProcessor], ...] | None,
) -> list[type[LogitsProcessor]]:
    # ... function body

Then _load_custom_logitsprocs can call these cached functions, and validate_logits_processors_parameters can call _load_custom_logitsprocs directly without an extra caching layer. This would be the most efficient implementation.

@Isotr0py Isotr0py changed the title [Performance] Cache custom logitsprocs loading results to avoid overheads [Performance] Cache loaded custom logitsprocs to avoid overheads Nov 11, 2025
@mgoin mgoin added this to the v0.11.1 milestone Nov 11, 2025
@mgoin mgoin added performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed labels Nov 11, 2025
@njhill njhill merged commit 3f770f4 into vllm-project:main Nov 12, 2025
48 checks passed
@Isotr0py Isotr0py deleted the cache-valid-logitsproc branch November 12, 2025 02:42
fangyuchu pushed a commit to fangyuchu/vllm that referenced this pull request Nov 12, 2025
geodavic pushed a commit to geodavic/vllm that referenced this pull request Nov 16, 2025
khluu pushed a commit that referenced this pull request Nov 16, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
charlotte12l pushed a commit to charlotte12l/vllm that referenced this pull request Dec 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants