Skip to content

Commit 8db1b9d

Browse files
authored
Support SSL Key Rotation in HTTP Server (#13495)
1 parent 2382ad2 commit 8db1b9d

File tree

7 files changed

+173
-2
lines changed

7 files changed

+173
-2
lines changed

requirements-common.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
2020
tiktoken >= 0.6.0 # Required for DBRX tokenizer
2121
lm-format-enforcer >= 0.10.9, < 0.11
2222
outlines == 0.1.11
23-
lark == 1.2.2
23+
lark == 1.2.2
2424
xgrammar == 0.1.11; platform_machine == "x86_64"
2525
typing_extensions >= 4.10
2626
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
@@ -37,3 +37,4 @@ einops # Required for Qwen2-VL.
3737
compressed-tensors == 0.9.2 # required for compressed-tensors
3838
depyf==0.18.0 # required for profiling and debugging with compilation config
3939
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
40+
watchfiles # required for http server to monitor the updates of TLS files
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import asyncio
3+
import tempfile
4+
from pathlib import Path
5+
from ssl import SSLContext
6+
7+
import pytest
8+
9+
from vllm.entrypoints.ssl import SSLCertRefresher
10+
11+
12+
class MockSSLContext(SSLContext):
13+
14+
def __init__(self):
15+
self.load_cert_chain_count = 0
16+
self.load_ca_count = 0
17+
18+
def load_cert_chain(
19+
self,
20+
certfile,
21+
keyfile=None,
22+
password=None,
23+
):
24+
self.load_cert_chain_count += 1
25+
26+
def load_verify_locations(
27+
self,
28+
cafile=None,
29+
capath=None,
30+
cadata=None,
31+
):
32+
self.load_ca_count += 1
33+
34+
35+
def create_file() -> str:
36+
with tempfile.NamedTemporaryFile(dir='/tmp', delete=False) as f:
37+
return f.name
38+
39+
40+
def touch_file(path: str) -> None:
41+
Path(path).touch()
42+
43+
44+
@pytest.mark.asyncio
45+
async def test_ssl_refresher():
46+
ssl_context = MockSSLContext()
47+
key_path = create_file()
48+
cert_path = create_file()
49+
ca_path = create_file()
50+
ssl_refresher = SSLCertRefresher(ssl_context, key_path, cert_path, ca_path)
51+
await asyncio.sleep(1)
52+
assert ssl_context.load_cert_chain_count == 0
53+
assert ssl_context.load_ca_count == 0
54+
55+
touch_file(key_path)
56+
await asyncio.sleep(1)
57+
assert ssl_context.load_cert_chain_count == 1
58+
assert ssl_context.load_ca_count == 0
59+
60+
touch_file(cert_path)
61+
touch_file(ca_path)
62+
await asyncio.sleep(1)
63+
assert ssl_context.load_cert_chain_count == 2
64+
assert ssl_context.load_ca_count == 1
65+
66+
ssl_refresher.stop()
67+
68+
touch_file(cert_path)
69+
touch_file(ca_path)
70+
await asyncio.sleep(1)
71+
assert ssl_context.load_cert_chain_count == 2
72+
assert ssl_context.load_ca_count == 1

vllm/entrypoints/api_server.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ async def run_server(args: Namespace,
128128
shutdown_task = await serve_http(
129129
app,
130130
sock=None,
131+
enable_ssl_refresh=args.enable_ssl_refresh,
131132
host=args.host,
132133
port=args.port,
133134
log_level=args.log_level,
@@ -152,6 +153,11 @@ async def run_server(args: Namespace,
152153
type=str,
153154
default=None,
154155
help="The CA certificates file")
156+
parser.add_argument(
157+
"--enable-ssl-refresh",
158+
action="store_true",
159+
default=False,
160+
help="Refresh SSL Context when SSL certificate files change")
155161
parser.add_argument(
156162
"--ssl-cert-reqs",
157163
type=int,

vllm/entrypoints/launcher.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@
1212
from vllm import envs
1313
from vllm.engine.async_llm_engine import AsyncEngineDeadError
1414
from vllm.engine.multiprocessing import MQEngineDeadError
15+
from vllm.entrypoints.ssl import SSLCertRefresher
1516
from vllm.logger import init_logger
1617
from vllm.utils import find_process_using_port
1718

1819
logger = init_logger(__name__)
1920

2021

21-
async def serve_http(app: FastAPI, sock: Optional[socket.socket],
22+
async def serve_http(app: FastAPI,
23+
sock: Optional[socket.socket],
24+
enable_ssl_refresh: bool = False,
2225
**uvicorn_kwargs: Any):
2326
logger.info("Available routes are:")
2427
for route in app.routes:
@@ -31,6 +34,7 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket],
3134
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
3235

3336
config = uvicorn.Config(app, **uvicorn_kwargs)
37+
config.load()
3438
server = uvicorn.Server(config)
3539
_add_shutdown_handlers(app, server)
3640

@@ -39,9 +43,17 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket],
3943
server_task = loop.create_task(
4044
server.serve(sockets=[sock] if sock else None))
4145

46+
ssl_cert_refresher = None if not enable_ssl_refresh else SSLCertRefresher(
47+
ssl_context=config.ssl,
48+
key_path=config.ssl_keyfile,
49+
cert_path=config.ssl_certfile,
50+
ca_path=config.ssl_ca_certs)
51+
4252
def signal_handler() -> None:
4353
# prevents the uvicorn signal handler to exit early
4454
server_task.cancel()
55+
if ssl_cert_refresher:
56+
ssl_cert_refresher.stop()
4557

4658
async def dummy_shutdown() -> None:
4759
pass

vllm/entrypoints/openai/api_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,7 @@ def _listen_addr(a: str) -> str:
960960
shutdown_task = await serve_http(
961961
app,
962962
sock=sock,
963+
enable_ssl_refresh=args.enable_ssl_refresh,
963964
host=args.host,
964965
port=args.port,
965966
log_level=args.uvicorn_log_level,

vllm/entrypoints/openai/cli_args.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
164164
type=nullable_str,
165165
default=None,
166166
help="The CA certificates file.")
167+
parser.add_argument(
168+
"--enable-ssl-refresh",
169+
action="store_true",
170+
default=False,
171+
help="Refresh SSL Context when SSL certificate files change")
167172
parser.add_argument(
168173
"--ssl-cert-reqs",
169174
type=int,

vllm/entrypoints/ssl.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import asyncio
4+
from ssl import SSLContext
5+
from typing import Callable, Optional
6+
7+
from watchfiles import Change, awatch
8+
9+
from vllm.logger import init_logger
10+
11+
logger = init_logger(__name__)
12+
13+
14+
class SSLCertRefresher:
15+
"""A class that monitors SSL certificate files and
16+
reloads them when they change.
17+
"""
18+
19+
def __init__(self,
20+
ssl_context: SSLContext,
21+
key_path: Optional[str] = None,
22+
cert_path: Optional[str] = None,
23+
ca_path: Optional[str] = None) -> None:
24+
self.ssl = ssl_context
25+
self.key_path = key_path
26+
self.cert_path = cert_path
27+
self.ca_path = ca_path
28+
29+
# Setup certification chain watcher
30+
def update_ssl_cert_chain(change: Change, file_path: str) -> None:
31+
logger.info("Reloading SSL certificate chain")
32+
assert self.key_path and self.cert_path
33+
self.ssl.load_cert_chain(self.cert_path, self.key_path)
34+
35+
self.watch_ssl_cert_task = None
36+
if self.key_path and self.cert_path:
37+
self.watch_ssl_cert_task = asyncio.create_task(
38+
self._watch_files([self.key_path, self.cert_path],
39+
update_ssl_cert_chain))
40+
41+
# Setup CA files watcher
42+
def update_ssl_ca(change: Change, file_path: str) -> None:
43+
logger.info("Reloading SSL CA certificates")
44+
assert self.ca_path
45+
self.ssl.load_verify_locations(self.ca_path)
46+
47+
self.watch_ssl_ca_task = None
48+
if self.ca_path:
49+
self.watch_ssl_ca_task = asyncio.create_task(
50+
self._watch_files([self.ca_path], update_ssl_ca))
51+
52+
async def _watch_files(self, paths, fun: Callable[[Change, str],
53+
None]) -> None:
54+
"""Watch multiple file paths asynchronously."""
55+
logger.info("SSLCertRefresher monitors files: %s", paths)
56+
async for changes in awatch(*paths):
57+
try:
58+
for change, file_path in changes:
59+
logger.info("File change detected: %s - %s", change.name,
60+
file_path)
61+
fun(change, file_path)
62+
except Exception as e:
63+
logger.error(
64+
"SSLCertRefresher failed taking action on file change. "
65+
"Error: %s", e)
66+
67+
def stop(self) -> None:
68+
"""Stop watching files."""
69+
if self.watch_ssl_cert_task:
70+
self.watch_ssl_cert_task.cancel()
71+
self.watch_ssl_cert_task = None
72+
if self.watch_ssl_ca_task:
73+
self.watch_ssl_ca_task.cancel()
74+
self.watch_ssl_ca_task = None

0 commit comments

Comments
 (0)