diff --git a/requirements-common.txt b/requirements-common.txt index b7c94cbdba8b..07c925cdb92c 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines == 0.1.11 -lark == 1.2.2 +lark == 1.2.2 xgrammar == 0.1.11; platform_machine == "x86_64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 @@ -36,3 +36,4 @@ einops # Required for Qwen2-VL. compressed-tensors == 0.9.1 # required for compressed-tensors depyf==0.18.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py +watchfiles # required for http server to monitor the updates of TLS files diff --git a/tests/entrypoints/test_ssl_cert_refresher.py b/tests/entrypoints/test_ssl_cert_refresher.py new file mode 100644 index 000000000000..23ce7a679f3e --- /dev/null +++ b/tests/entrypoints/test_ssl_cert_refresher.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import tempfile +from pathlib import Path +from ssl import SSLContext + +import pytest + +from vllm.entrypoints.ssl import SSLCertRefresher + + +class MockSSLContext(SSLContext): + + def __init__(self): + self.load_cert_chain_count = 0 + self.load_ca_count = 0 + + def load_cert_chain( + self, + certfile, + keyfile=None, + password=None, + ): + self.load_cert_chain_count += 1 + + def load_verify_locations( + self, + cafile=None, + capath=None, + cadata=None, + ): + self.load_ca_count += 1 + + +def create_file() -> str: + with tempfile.NamedTemporaryFile(dir='/tmp', delete=False) as f: + return f.name + + +def touch_file(path: str) -> None: + Path(path).touch() + + +@pytest.mark.asyncio +async def test_ssl_refresher(): + ssl_context = MockSSLContext() + key_path = create_file() + cert_path = create_file() + ca_path = create_file() + ssl_refresher = SSLCertRefresher(ssl_context, key_path, cert_path, ca_path) + await asyncio.sleep(1) + assert ssl_context.load_cert_chain_count == 0 + assert ssl_context.load_ca_count == 0 + + touch_file(key_path) + await asyncio.sleep(1) + assert ssl_context.load_cert_chain_count == 1 + assert ssl_context.load_ca_count == 0 + + touch_file(cert_path) + touch_file(ca_path) + await asyncio.sleep(1) + assert ssl_context.load_cert_chain_count == 2 + assert ssl_context.load_ca_count == 1 + + ssl_refresher.stop() + + touch_file(cert_path) + touch_file(ca_path) + await asyncio.sleep(1) + assert ssl_context.load_cert_chain_count == 2 + assert ssl_context.load_ca_count == 1 diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 00793d4b9677..9ad3c3786166 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -128,6 +128,7 @@ async def run_server(args: Namespace, shutdown_task = await serve_http( app, sock=None, + enable_ssl_refresh=args.enable_ssl_refresh, host=args.host, port=args.port, log_level=args.log_level, @@ -152,6 +153,11 @@ async def run_server(args: Namespace, type=str, default=None, help="The CA certificates file") + parser.add_argument( + "--enable-ssl-refresh", + action="store_true", + default=False, + help="Refresh SSL Context when SSL certificate files change") parser.add_argument( "--ssl-cert-reqs", type=int, diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 79946a498dad..b09ee526f14a 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -12,13 +12,16 @@ from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.entrypoints.ssl import SSLCertRefresher from vllm.logger import init_logger from vllm.utils import find_process_using_port logger = init_logger(__name__) -async def serve_http(app: FastAPI, sock: Optional[socket.socket], +async def serve_http(app: FastAPI, + sock: Optional[socket.socket], + enable_ssl_refresh: bool = False, **uvicorn_kwargs: Any): logger.info("Available routes are:") for route in app.routes: @@ -31,6 +34,7 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket], logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) config = uvicorn.Config(app, **uvicorn_kwargs) + config.load() server = uvicorn.Server(config) _add_shutdown_handlers(app, server) @@ -39,9 +43,17 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket], server_task = loop.create_task( server.serve(sockets=[sock] if sock else None)) + ssl_cert_refresher = None if not enable_ssl_refresh else SSLCertRefresher( + ssl_context=config.ssl, + key_path=config.ssl_keyfile, + cert_path=config.ssl_certfile, + ca_path=config.ssl_ca_certs) + def signal_handler() -> None: # prevents the uvicorn signal handler to exit early server_task.cancel() + if ssl_cert_refresher: + ssl_cert_refresher.stop() async def dummy_shutdown() -> None: pass diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ad391d6737bf..5376633515c9 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -940,6 +940,7 @@ def _listen_addr(a: str) -> str: shutdown_task = await serve_http( app, sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, host=args.host, port=args.port, log_level=args.uvicorn_log_level, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 3054958f3c8a..ba953c219708 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -164,6 +164,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=nullable_str, default=None, help="The CA certificates file.") + parser.add_argument( + "--enable-ssl-refresh", + action="store_true", + default=False, + help="Refresh SSL Context when SSL certificate files change") parser.add_argument( "--ssl-cert-reqs", type=int, diff --git a/vllm/entrypoints/ssl.py b/vllm/entrypoints/ssl.py new file mode 100644 index 000000000000..dba916b8bf13 --- /dev/null +++ b/vllm/entrypoints/ssl.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from ssl import SSLContext +from typing import Callable, Optional + +from watchfiles import Change, awatch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SSLCertRefresher: + """A class that monitors SSL certificate files and + reloads them when they change. + """ + + def __init__(self, + ssl_context: SSLContext, + key_path: Optional[str] = None, + cert_path: Optional[str] = None, + ca_path: Optional[str] = None) -> None: + self.ssl = ssl_context + self.key_path = key_path + self.cert_path = cert_path + self.ca_path = ca_path + + # Setup certification chain watcher + def update_ssl_cert_chain(change: Change, file_path: str) -> None: + logger.info("Reloading SSL certificate chain") + assert self.key_path and self.cert_path + self.ssl.load_cert_chain(self.cert_path, self.key_path) + + self.watch_ssl_cert_task = None + if self.key_path and self.cert_path: + self.watch_ssl_cert_task = asyncio.create_task( + self._watch_files([self.key_path, self.cert_path], + update_ssl_cert_chain)) + + # Setup CA files watcher + def update_ssl_ca(change: Change, file_path: str) -> None: + logger.info("Reloading SSL CA certificates") + assert self.ca_path + self.ssl.load_verify_locations(self.ca_path) + + self.watch_ssl_ca_task = None + if self.ca_path: + self.watch_ssl_ca_task = asyncio.create_task( + self._watch_files([self.ca_path], update_ssl_ca)) + + async def _watch_files(self, paths, fun: Callable[[Change, str], + None]) -> None: + """Watch multiple file paths asynchronously.""" + logger.info("SSLCertRefresher monitors files: %s", paths) + async for changes in awatch(*paths): + try: + for change, file_path in changes: + logger.info("File change detected: %s - %s", change.name, + file_path) + fun(change, file_path) + except Exception as e: + logger.error( + "SSLCertRefresher failed taking action on file change. " + "Error: %s", e) + + def stop(self) -> None: + """Stop watching files.""" + if self.watch_ssl_cert_task: + self.watch_ssl_cert_task.cancel() + self.watch_ssl_cert_task = None + if self.watch_ssl_ca_task: + self.watch_ssl_ca_task.cancel() + self.watch_ssl_ca_task = None