Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
MultiprocessingSerializer,
ServerStatus,
assert_pkg_version,
configure_logger,
get_zmq_socket,
Expand All @@ -73,6 +74,7 @@
launch_dummy_health_check_server,
maybe_set_triton_cache_manager,
prepare_model_and_tokenizer,
report_health,
set_prometheus_multiproc_dir,
set_ulimit,
)
Expand Down Expand Up @@ -661,6 +663,7 @@ def _set_envs_and_config(server_args: ServerArgs):
def sigchld_handler(signum, frame):
pid, exitcode = os.waitpid(0, os.WNOHANG)
if exitcode != 0:
report_health(ServerStatus.Crashed, server_args.host, ServerArgs.port)
logger.warning(
f"Child process unexpectedly failed with {exitcode=}. {pid=}"
)
Expand All @@ -674,6 +677,7 @@ def sigquit_handler(signum, frame):
logger.error(
"Received sigquit from a child process. It usually means the child failed."
)
report_health(ServerStatus.Crashed, server_args.host, ServerArgs.port)
kill_process_tree(os.getpid())

signal.signal(signal.SIGQUIT, sigquit_handler)
Expand Down
56 changes: 47 additions & 9 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
ServerStatus,
add_api_key_middleware,
add_prometheus_middleware,
delete_directory,
Expand Down Expand Up @@ -220,8 +221,31 @@ async def validate_json_request(raw_request: Request):

@app.get("/health")
async def health() -> Response:
"""Check the health of the http server."""
return Response(status_code=200)
"""Check the status of the http server."""
code = HTTPStatus.SERVICE_UNAVAILABLE.value
if _global_state.tokenizer_manager.server_status == ServerStatus.Up:
code = HTTPStatus.OK.value
return Response(
status_code=code,
content=json.dumps(
{"status": _global_state.tokenizer_manager.server_status.value}
),
)


@app.post("/health")
async def health_update(obj: ReportHealthInput, request: Request) -> Response:
"""Update the Status of the http server."""
try:
server_status = ServerStatus(obj.status)
_global_state.tokenizer_manager.server_status = server_status
if server_status != ServerStatus.Up:
return Response(
status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, content=obj.msg
)
except Exception as e:
logger.error(e)
return Response(status_code=HTTPStatus.OK.value)


@app.get("/health_generate")
Expand Down Expand Up @@ -256,7 +280,7 @@ async def gen():
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
task.cancel()
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
_global_state.tokenizer_manager.health_check_failed = False
_global_state.tokenizer_manager.server_status = ServerStatus.Up
return Response(status_code=200)

task.cancel()
Expand All @@ -270,7 +294,7 @@ async def gen():
f"last_heartbeat time: {last_receive_time}"
)
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
_global_state.tokenizer_manager.health_check_failed = True
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
return Response(status_code=503)


Expand Down Expand Up @@ -1022,9 +1046,13 @@ def _execute_server_warmup(
headers=headers,
timeout=600,
)
assert res.status_code == 200, f"{res}"
if res.status_code == 200:
_global_state.tokenizer_manager.server_status = ServerStatus.Up
else:
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
logger.info(f"{res}")
else:
logger.info(f"Start of prefill warmup ...")
logger.info(f"Start of prefill/decode warmup ...")
json_data = {
"sampling_params": {
"temperature": 0.0,
Expand All @@ -1046,15 +1074,25 @@ def _execute_server_warmup(
headers=headers,
timeout=1800, # because of deep gemm precache is very long if not precache.
)
logger.info(
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
)
if res.status_code == 200:
logger.info(
f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}"
)
_global_state.tokenizer_manager.server_status = ServerStatus.Up
else:
logger.info(
"Prefill disaggregation mode warm Up Failed, status code: {}".format(
res.status_code
)
)
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy

except Exception:
last_traceback = get_exception_traceback()
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
_global_state.tokenizer_manager.server_status = ServerStatus.Crashed
kill_process_tree(os.getpid())
return False

Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,3 +1083,9 @@ class LoRAUpdateResult:


LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult


@dataclass
class ReportHealthInput:
status: str
msg: Optional[str] = ""
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
from sglang.srt.utils import (
DeepEPMode,
DynamicGradMode,
ServerStatus,
broadcast_pyobj,
configure_gc_logger,
configure_logger,
Expand All @@ -154,6 +155,7 @@
kill_itself_when_parent_died,
point_to_point_pyobj,
pyspy_dump_schedulers,
report_health,
require_mlp_sync,
require_mlp_tp_gather,
set_gpu_proc_affinity,
Expand Down Expand Up @@ -2953,4 +2955,5 @@ def run_scheduler_process(
except Exception:
traceback = get_exception_traceback()
logger.error(f"Scheduler hit an exception: {traceback}")
report_health(ServerStatus.Crashed, server_args.host, ServerArgs.port)
parent_process.send_signal(signal.SIGQUIT)
7 changes: 5 additions & 2 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
ServerStatus,
dataclass_to_string_truncated,
get_bool_env_var,
get_zmq_socket,
Expand Down Expand Up @@ -173,6 +174,9 @@ def __init__(
server_args: ServerArgs,
port_args: PortArgs,
):
# Server Status
self.server_status = ServerStatus.Starting

# Parse args
self.server_args = server_args
self.enable_metrics = server_args.enable_metrics
Expand Down Expand Up @@ -251,7 +255,6 @@ def __init__(
# Store states
self.no_create_loop = False
self.rid_to_state: Dict[str, ReqState] = {}
self.health_check_failed = False
self.gracefully_exit = False
self.last_receive_tstamp = 0
self.dump_requests_folder = "" # By default do not dump
Expand Down Expand Up @@ -1324,7 +1327,7 @@ async def sigterm_watchdog(self):
while True:
remain_num_req = len(self.rid_to_state)

if self.health_check_failed:
if not self.server_status.is_healthy():
# if health check failed, we should exit immediately
logger.error(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
Expand Down
16 changes: 16 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,22 @@
HIP_FP8_E4M3_FNUZ_MAX = 224.0


class ServerStatus(Enum):
Up = "Up"
Starting = "Starting"
UnHealthy = "UnHealthy"
Crashed = "Crashed"

def is_healthy(self) -> bool:
return self == ServerStatus.Up


def report_health(status: ServerStatus, host: str, http_port: int, msg: str = ""):
requests.post(
f"http://{host}:{http_port}/health", json={"status": status.value, "msg": msg}
)


# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
def is_hip() -> bool:
return torch.version.hip is not None
Expand Down
Loading