Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions simpletuner/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ def cmd_server(args) -> int:
# Ensure a configuration directory exists and record it for downstream services
config_dir = get_config_directory()
os.environ.setdefault("SIMPLETUNER_CONFIG_DIR", str(config_dir))
os.environ.setdefault("SIMPLETUNER_SERVER_ROOT_PID", str(os.getpid()))

try:
import uvicorn
Expand Down
2 changes: 2 additions & 0 deletions simpletuner/service_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from simpletuner.simpletuner_sdk.server.utils.paths import get_config_directory, get_static_directory, get_template_directory
from simpletuner.simpletuner_sdk.training_host import TrainingHost

os.environ.setdefault("SIMPLETUNER_SERVER_ROOT_PID", str(os.getpid()))


# Pydantic models for request/response
class TrainerConfig(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions simpletuner/simpletuner_sdk/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
# Track if we're shutting down to avoid duplicate cleanup
_shutting_down = False

# Record the first process that imported this module so the shutdown endpoint
# can signal parent reload/watchdog processes when running in dev mode.
os.environ.setdefault("SIMPLETUNER_SERVER_ROOT_PID", str(os.getpid()))


# These placeholders allow tests to monkeypatch heavy imports before the app factory runs.
WebInterface = None
Expand Down
135 changes: 134 additions & 1 deletion simpletuner/simpletuner_sdk/server/routes/system.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,131 @@
"""System status API endpoints."""

import asyncio
import logging
import os
import signal
from typing import Any, Dict, Optional

from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, BackgroundTasks, HTTPException

from simpletuner.simpletuner_sdk.server.services.maintenance_service import MAINTENANCE_SERVICE, MaintenanceServiceError
from simpletuner.simpletuner_sdk.server.services.system_status_service import SystemStatusService

router = APIRouter(prefix="/api/system", tags=["system"])
_service = SystemStatusService()
logger = logging.getLogger("SimpleTunerSystemRoutes")

_SHUTDOWN_IN_PROGRESS = False
_SHUTDOWN_DELAY_SECONDS = 0.75


def _coerce_pid(value: Optional[str]) -> Optional[int]:
if not value:
return None
try:
pid = int(str(value))
except (TypeError, ValueError):
return None
if pid <= 1:
return None
return pid


async def _terminate_training_processes() -> int:
"""Attempt to terminate all tracked training processes."""
try:
from simpletuner.simpletuner_sdk import process_keeper # Imported lazily to avoid circular deps
except Exception as exc: # pragma: no cover - defensive in case module import fails
logger.debug("process_keeper import failed during shutdown: %s", exc)
return 0

try:
processes = process_keeper.list_processes() or {}
except Exception as exc: # pragma: no cover - best effort logging
logger.warning("Unable to list training processes while shutting down: %s", exc)
return 0

terminated = 0
for job_id in list(processes.keys()):
try:
await asyncio.to_thread(process_keeper.terminate_process, job_id)
terminated += 1
except Exception as exc: # pragma: no cover - terminating should not crash shutdown
logger.warning("Failed to terminate training job %s: %s", job_id, exc)
return terminated


def _signal_process_exit() -> None:
"""Signal running processes (including reload supervisors) to exit, then hard-exit."""
pid = os.getpid()
target_pids = set()

root_pid = _coerce_pid(os.environ.get("SIMPLETUNER_SERVER_ROOT_PID"))
if root_pid:
target_pids.add(root_pid)

managed_parent_pid = _coerce_pid(os.environ.get("SIMPLETUNER_SERVER_PARENT_PID"))
if managed_parent_pid:
target_pids.add(managed_parent_pid)

if root_pid and root_pid == os.getppid():
target_pids.add(root_pid)

try:
import multiprocessing

parent_proc = multiprocessing.parent_process()
except Exception: # pragma: no cover - best effort reflection
parent_proc = None

if parent_proc and parent_proc.pid and parent_proc.pid > 1:
target_pids.add(parent_proc.pid)
if parent_proc and parent_proc.pid and parent_proc.pid > 1:
target_pids.add(parent_proc.pid)

# Never send a signal to the current worker via os.kill; we'll exit via os._exit below.
if pid in target_pids:
target_pids.remove(pid)

sent_signal = False
for target_pid in list(target_pids):
for sig_name in ("SIGINT", "SIGTERM"):
sig = getattr(signal, sig_name, None)
if sig is None:
continue
try:
os.kill(target_pid, sig)
sent_signal = True
break
except Exception as exc: # pragma: no cover - best effort logging
logger.debug("Sending %s to pid %s failed: %s", sig_name, target_pid, exc)
continue

if target_pids and not sent_signal:
logger.warning("Unable to deliver shutdown signals to parent processes; forcing exit")

if parent_proc:
try:
parent_proc.terminate()
except Exception: # pragma: no cover - best effort logging
logger.debug("Unable to terminate parent process via multiprocessing API", exc_info=True)

# Finally exit the current worker immediately regardless of outstanding ASGI tasks.
os._exit(0)


async def _initiate_shutdown_sequence(delay_seconds: float = _SHUTDOWN_DELAY_SECONDS) -> None:
"""Background task that handles graceful shutdown."""
logger.info("Shutdown requested via API; beginning graceful termination.")
try:
terminated = await _terminate_training_processes()
if terminated:
logger.info("Requested termination for %s training process(es) before shutdown.", terminated)
except Exception as exc: # pragma: no cover - defensive
logger.warning("Error while terminating training jobs during shutdown: %s", exc)

await asyncio.sleep(max(delay_seconds, 0.1))
_signal_process_exit()


@router.get("/status")
Expand Down Expand Up @@ -36,3 +153,19 @@ async def clear_deepspeed_offload(payload: Optional[Dict[str, Any]] = None) -> D
return MAINTENANCE_SERVICE.clear_deepspeed_offload_cache(config_name=config_name)
except MaintenanceServiceError as exc:
raise HTTPException(status_code=400, detail=exc.message) from exc


@router.post("/shutdown")
async def shutdown_simpletuner(background_tasks: BackgroundTasks) -> Dict[str, Any]:
"""Schedule a graceful shutdown of the SimpleTuner process."""
global _SHUTDOWN_IN_PROGRESS
if _SHUTDOWN_IN_PROGRESS:
return {"status": "shutting_down", "message": "Shutdown already in progress."}

_SHUTDOWN_IN_PROGRESS = True
background_tasks.add_task(_initiate_shutdown_sequence)

return {
"status": "shutting_down",
"message": "SimpleTuner is shutting down. Active training processes will be stopped.",
}
8 changes: 8 additions & 0 deletions simpletuner/templates/base_htmx.html
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,14 @@
statusEl._x_dataStack[0].status = status;
statusEl._x_dataStack[0].message = message;
}
// Broadcast status updates so other components can react (e.g., re-enable buttons on reconnect)
try {
window.dispatchEvent(new CustomEvent('trainer-connection-status', {
detail: { status, message }
}));
} catch (error) {
console.warn('Failed to dispatch trainer-connection-status event', error);
}
}

// Server-Sent Events for real-time updates using SSE Manager
Expand Down
Loading