Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.9, < 0.11
outlines == 0.1.11
lark == 1.2.2
lark == 1.2.2
xgrammar == 0.1.11; platform_machine == "x86_64"
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
Expand All @@ -36,3 +36,4 @@ einops # Required for Qwen2-VL.
compressed-tensors == 0.9.1 # required for compressed-tensors
depyf==0.18.0 # required for profiling and debugging with compilation config
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
watchfiles # required for http server to monitor the updates of TLS files
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to pin a version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API being used is pretty standard, it should not be very sensitive to specific versions.
I would say the latest version is preferred here.

72 changes: 72 additions & 0 deletions tests/entrypoints/test_ssl_cert_refresher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import tempfile
from pathlib import Path
from ssl import SSLContext

import pytest

from vllm.entrypoints.ssl import SSLCertRefresher


class MockSSLContext(SSLContext):

def __init__(self):
self.load_cert_chain_count = 0
self.load_ca_count = 0

def load_cert_chain(
self,
certfile,
keyfile=None,
password=None,
):
self.load_cert_chain_count += 1

def load_verify_locations(
self,
cafile=None,
capath=None,
cadata=None,
):
self.load_ca_count += 1


def create_file() -> str:
with tempfile.NamedTemporaryFile(dir='/tmp', delete=False) as f:
return f.name


def touch_file(path: str) -> None:
Path(path).touch()


@pytest.mark.asyncio
async def test_ssl_refresher():
ssl_context = MockSSLContext()
key_path = create_file()
cert_path = create_file()
ca_path = create_file()
ssl_refresher = SSLCertRefresher(ssl_context, key_path, cert_path, ca_path)
await asyncio.sleep(1)
assert ssl_context.load_cert_chain_count == 0
assert ssl_context.load_ca_count == 0

touch_file(key_path)
await asyncio.sleep(1)
assert ssl_context.load_cert_chain_count == 1
assert ssl_context.load_ca_count == 0

touch_file(cert_path)
touch_file(ca_path)
await asyncio.sleep(1)
assert ssl_context.load_cert_chain_count == 2
assert ssl_context.load_ca_count == 1

ssl_refresher.stop()

touch_file(cert_path)
touch_file(ca_path)
await asyncio.sleep(1)
assert ssl_context.load_cert_chain_count == 2
assert ssl_context.load_ca_count == 1
6 changes: 6 additions & 0 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ async def run_server(args: Namespace,
shutdown_task = await serve_http(
app,
sock=None,
enable_ssl_refresh=args.enable_ssl_refresh,
host=args.host,
port=args.port,
log_level=args.log_level,
Expand All @@ -152,6 +153,11 @@ async def run_server(args: Namespace,
type=str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--enable-ssl-refresh",
action="store_true",
default=False,
help="Refresh SSL Context when SSL certificate files change")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
Expand Down
14 changes: 13 additions & 1 deletion vllm/entrypoints/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.entrypoints.ssl import SSLCertRefresher
from vllm.logger import init_logger
from vllm.utils import find_process_using_port

logger = init_logger(__name__)


async def serve_http(app: FastAPI, sock: Optional[socket.socket],
async def serve_http(app: FastAPI,
sock: Optional[socket.socket],
enable_ssl_refresh: bool = False,
**uvicorn_kwargs: Any):
logger.info("Available routes are:")
for route in app.routes:
Expand All @@ -31,6 +34,7 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket],
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))

config = uvicorn.Config(app, **uvicorn_kwargs)
config.load()
server = uvicorn.Server(config)
_add_shutdown_handlers(app, server)

Expand All @@ -39,9 +43,17 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket],
server_task = loop.create_task(
server.serve(sockets=[sock] if sock else None))

ssl_cert_refresher = None if not enable_ssl_refresh else SSLCertRefresher(
ssl_context=config.ssl,
key_path=config.ssl_keyfile,
cert_path=config.ssl_certfile,
ca_path=config.ssl_ca_certs)

def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
if ssl_cert_refresher:
ssl_cert_refresher.stop()

async def dummy_shutdown() -> None:
pass
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,7 @@ def _listen_addr(a: str) -> str:
shutdown_task = await serve_http(
app,
sock=sock,
enable_ssl_refresh=args.enable_ssl_refresh,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
Expand Down
5 changes: 5 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=nullable_str,
default=None,
help="The CA certificates file.")
parser.add_argument(
"--enable-ssl-refresh",
action="store_true",
default=False,
help="Refresh SSL Context when SSL certificate files change")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
Expand Down
74 changes: 74 additions & 0 deletions vllm/entrypoints/ssl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
from ssl import SSLContext
from typing import Callable, Optional

from watchfiles import Change, awatch

from vllm.logger import init_logger

logger = init_logger(__name__)


class SSLCertRefresher:
"""A class that monitors SSL certificate files and
reloads them when they change.
"""

def __init__(self,
ssl_context: SSLContext,
key_path: Optional[str] = None,
cert_path: Optional[str] = None,
ca_path: Optional[str] = None) -> None:
self.ssl = ssl_context
self.key_path = key_path
self.cert_path = cert_path
self.ca_path = ca_path

# Setup certification chain watcher
def update_ssl_cert_chain(change: Change, file_path: str) -> None:
logger.info("Reloading SSL certificate chain")
assert self.key_path and self.cert_path
self.ssl.load_cert_chain(self.cert_path, self.key_path)

self.watch_ssl_cert_task = None
if self.key_path and self.cert_path:
self.watch_ssl_cert_task = asyncio.create_task(
self._watch_files([self.key_path, self.cert_path],
update_ssl_cert_chain))

# Setup CA files watcher
def update_ssl_ca(change: Change, file_path: str) -> None:
logger.info("Reloading SSL CA certificates")
assert self.ca_path
self.ssl.load_verify_locations(self.ca_path)

self.watch_ssl_ca_task = None
if self.ca_path:
self.watch_ssl_ca_task = asyncio.create_task(
self._watch_files([self.ca_path], update_ssl_ca))

async def _watch_files(self, paths, fun: Callable[[Change, str],
None]) -> None:
"""Watch multiple file paths asynchronously."""
logger.info("SSLCertRefresher monitors files: %s", paths)
async for changes in awatch(*paths):
try:
for change, file_path in changes:
logger.info("File change detected: %s - %s", change.name,
file_path)
fun(change, file_path)
except Exception as e:
logger.error(
"SSLCertRefresher failed taking action on file change. "
"Error: %s", e)

def stop(self) -> None:
"""Stop watching files."""
if self.watch_ssl_cert_task:
self.watch_ssl_cert_task.cancel()
self.watch_ssl_cert_task = None
if self.watch_ssl_ca_task:
self.watch_ssl_ca_task.cancel()
self.watch_ssl_ca_task = None