diff --git a/vllm/envs.py b/vllm/envs.py index a36d20a4f8b5..0b1bcd9eb358 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -95,6 +95,7 @@ VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 VLLM_MARLIN_USE_ATOMIC_ADD: bool = False + VLLM_V0_USE_OUTLINES_CACHE: bool = False def get_default_cache_root(): @@ -623,6 +624,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # Whether to use atomicAdd reduce in gptq/awq marlin kernel. "VLLM_MARLIN_USE_ATOMIC_ADD": lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1", + + # Whether to turn on the outlines cache for V0 + # This cache is unbounded and on disk, so it's not safe to use in + # an environment with potentially malicious users. + "VLLM_V0_USE_OUTLINES_CACHE": + lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", } # end-env-vars-definition diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index de24eaa1fb6a..8b2a0f4cfe64 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -24,7 +24,7 @@ import numpy as np import torch from outlines import grammars -from outlines.caching import cache +from outlines.caching import cache, disable_cache from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide, RegexGuide, Write) from outlines.fsm.parsing import PartialLark @@ -32,12 +32,20 @@ from pydantic import BaseModel from transformers import PreTrainedTokenizerBase +import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.platforms import current_platform logger = init_logger(__name__) +if envs.VLLM_V0_USE_OUTLINES_CACHE: + logger.warning("Enabling outlines cache. This is an unbounded on-disk " + "cache. It may consume a lot of disk space and should " + "not be used with untrusted clients.") +else: + disable_cache() + class BaseLogitsProcessor: