From fd20d54adfa1808b407e84aba9d06da04fd27a06 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 12 Mar 2025 22:17:00 +0000 Subject: [PATCH] Disable outlines cache by default https://github.com/vllm-project/vllm/security/advisories/GHSA-mgrm-fgjv-mhv8 Outlines provides a cache for its compiled grammars on the local filesystem. This cache has been on by default in vLLM. Outlines is also available by default through the OpenAI compatible API server. A malicious user can send a stream of very short decoding requests with unique schemas, resulting in an addition to the cache for each request. This can result in a Denial of Service if the filesystem runs out of space. Note that even if vLLM was configured to use a different backend by default, it is still possible to choose outlines on a per-request basis using the `guided_decoding_backend` key of the `extra_body` field of the request. This issue applies to the V0 engine only. The V1 engine is not affected. Signed-off-by: Russell Bryant --- vllm/envs.py | 7 +++++++ .../guided_decoding/outlines_logits_processors.py | 10 +++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index a36d20a4f8b5..0b1bcd9eb358 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -95,6 +95,7 @@ VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 VLLM_MARLIN_USE_ATOMIC_ADD: bool = False + VLLM_V0_USE_OUTLINES_CACHE: bool = False def get_default_cache_root(): @@ -623,6 +624,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # Whether to use atomicAdd reduce in gptq/awq marlin kernel. "VLLM_MARLIN_USE_ATOMIC_ADD": lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1", + + # Whether to turn on the outlines cache for V0 + # This cache is unbounded and on disk, so it's not safe to use in + # an environment with potentially malicious users. + "VLLM_V0_USE_OUTLINES_CACHE": + lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", } # end-env-vars-definition diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index de24eaa1fb6a..8b2a0f4cfe64 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -24,7 +24,7 @@ import numpy as np import torch from outlines import grammars -from outlines.caching import cache +from outlines.caching import cache, disable_cache from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide, RegexGuide, Write) from outlines.fsm.parsing import PartialLark @@ -32,12 +32,20 @@ from pydantic import BaseModel from transformers import PreTrainedTokenizerBase +import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.platforms import current_platform logger = init_logger(__name__) +if envs.VLLM_V0_USE_OUTLINES_CACHE: + logger.warning("Enabling outlines cache. This is an unbounded on-disk " + "cache. It may consume a lot of disk space and should " + "not be used with untrusted clients.") +else: + disable_cache() + class BaseLogitsProcessor: