Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
LitAPIStatus,
LoopResponseType,
WorkerSetupStatus,
add_ssl_context_from_env,
call_after_stream,
configure_logging,
is_package_installed,
Expand Down Expand Up @@ -1294,8 +1295,16 @@ def run(
if host not in ["0.0.0.0", "127.0.0.1", "::"]:
raise ValueError(host_msg)

kwargs = add_ssl_context_from_env(kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥


configure_logging(log_level, use_rich=pretty_logs)
config = uvicorn.Config(app=self.app, host=host, port=port, log_level=log_level, **kwargs)
config = uvicorn.Config(
app=self.app,
host=host,
port=port,
log_level=log_level,
**kwargs,
)
sockets = [config.bind_socket()]

if num_api_servers is None:
Expand Down
60 changes: 59 additions & 1 deletion src/litserve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import base64
import dataclasses
import importlib.util
import logging
import os
import pdb
import pickle
import sys
import tempfile
import time
import uuid
import warnings
from abc import ABCMeta
from contextlib import contextmanager
from enum import Enum
from typing import TYPE_CHECKING, Any, AsyncIterator, TextIO, Union
from pathlib import Path
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, TextIO, Union

from fastapi import HTTPException

Expand Down Expand Up @@ -279,3 +282,58 @@ def __call__(cls, *args, **kwargs):
)

return instance


def add_ssl_context_from_env(kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Loads SSL context from base64-encoded environment variables.

This function checks for the presence of `LIGHTNING_CERT_PEM` and
`LIGHTNING_KEY_FILE` environment variables. It expects these variables
to contain the SSL certificate and private key, respectively, as
base64-encoded PEM strings.

If both variables are found, it decodes them and writes the content to
secure, temporary files. The paths to these files are returned in a
dictionary suitable for direct use as keyword arguments in libraries
that require SSL file paths (like `uvicorn` or `requests`).

Note:
The temporary files are not automatically deleted (`delete=False`).
The calling application is responsible for cleaning up these files
after the SSL context is no longer needed to prevent leaving
sensitive data on disk.

Returns:
Dict[str, Any]: A dictionary containing `ssl_certfile` and `ssl_keyfile`
keys with `pathlib.Path` objects pointing to the temporary files.
If either of the required environment variables is missing, it
returns an empty dictionary.

"""

if "ssl_keyfile" in kwargs and "ssl_certfile" in kwargs:
return kwargs

cert_pem_b64 = os.getenv("LIGHTNING_CERT_PEM", "")
cert_key_b64 = os.getenv("LIGHTNING_KEY_FILE", "")

if cert_pem_b64 == "" or cert_key_b64 == "":
return kwargs

# Decode the base64 strings to get the actual PEM content
cert_pem = base64.b64decode(cert_pem_b64).decode("utf-8")
cert_key = base64.b64decode(cert_key_b64).decode("utf-8")

# Write to temporary files
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as cert_file, tempfile.NamedTemporaryFile(
mode="w+", delete=False
) as key_file:
cert_file.write(cert_pem)
cert_file.flush()
key_file.write(cert_key)
key_file.flush()

logger.info("Loading TLS Certificates \n")

# Return a dictionary with Path objects to the created files
return {"ssl_keyfile": Path(key_file.name), "ssl_certfile": Path(cert_file.name), **kwargs}
47 changes: 47 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import base64
import logging
import os
import pickle
import sys
from pathlib import Path
from unittest import mock
from unittest.mock import MagicMock

import pytest
from fastapi import HTTPException

from litserve.utils import (
add_ssl_context_from_env,
call_after_stream,
configure_logging,
dump_exception,
Expand Down Expand Up @@ -91,3 +94,47 @@ def test_set_trace_if_debug_not_set(mock_forked_pdb):
def test_is_package_installed():
assert is_package_installed("pytest")
assert not is_package_installed("nonexistent_package")


def test_add_ssl_context_from_env_with_env_vars():
"""Tests that the SSL context is loaded correctly when environment variables are set."""
dummy_cert = "dummy certificate"
dummy_key = "dummy key"

b64_cert = base64.b64encode(dummy_cert.encode("utf-8")).decode("utf-8")
b64_key = base64.b64encode(dummy_key.encode("utf-8")).decode("utf-8")

with mock.patch.dict(os.environ, {"LIGHTNING_CERT_PEM": b64_cert, "LIGHTNING_KEY_FILE": b64_key}):
ssl_context = add_ssl_context_from_env({})

assert ssl_context

assert "ssl_certfile" in ssl_context
assert "ssl_keyfile" in ssl_context
assert isinstance(ssl_context["ssl_certfile"], Path)
assert isinstance(ssl_context["ssl_keyfile"], Path)

with open(ssl_context["ssl_certfile"]) as f:
assert f.read() == dummy_cert
with open(ssl_context["ssl_keyfile"]) as f:
assert f.read() == dummy_key

os.remove(ssl_context["ssl_certfile"])
os.remove(ssl_context["ssl_keyfile"])


def test_add_ssl_context_from_env_without_env_vars():
"""Tests that an empty dictionary is returned when environment variables are not set."""
with mock.patch.dict(os.environ, {}, clear=True):
ssl_context = add_ssl_context_from_env({})
assert ssl_context == {}


def test_add_ssl_context_from_env_with_one_env_var_missing():
"""Tests that an empty dictionary is returned when one of the environment variables is missing."""
dummy_cert = "dummy certificate"
b64_cert = base64.b64encode(dummy_cert.encode("utf-8")).decode("utf-8")

with mock.patch.dict(os.environ, {"LIGHTNING_CERT_PEM": b64_cert}, clear=True):
ssl_context = add_ssl_context_from_env({})
assert ssl_context == {}
Loading