Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 162 additions & 1 deletion latencypredictor/training_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,34 @@ class Settings:
RETRAINING_INTERVAL_SEC: int = int(os.getenv("LATENCY_RETRAINING_INTERVAL_SEC", 1800))
MIN_SAMPLES_FOR_RETRAIN_FRESH: int = int(os.getenv("LATENCY_MIN_SAMPLES_FOR_RETRAIN_FRESH", 10))
MIN_SAMPLES_FOR_RETRAIN: int = int(os.getenv("LATENCY_MIN_SAMPLES_FOR_RETRAIN", 1000))
MAX_TRAINING_DATA_SIZE_PER_BUCKET: int = int(os.getenv("LATENCY_MAX_TRAINING_DATA_SIZE_PER_BUCKET", 10000))
MAX_TRAINING_DATA_SIZE_PER_BUCKET: int = int(os.getenv("LATENCY_MAX_TRAINING_DATA_SIZE_PER_BUCKET", 500))
TEST_TRAIN_RATIO: float = float(os.getenv("LATENCY_TEST_TRAIN_RATIO", "0.1")) # Default 1:10 (10% test, 90% train)
MAX_TEST_DATA_SIZE: int = int(os.getenv("LATENCY_MAX_TEST_DATA_SIZE", "1000")) # Max test samples to keep
MODEL_TYPE: str = os.getenv("LATENCY_MODEL_TYPE", "xgboost") # Default to XGBoost
QUANTILE_ALPHA: float = float(os.getenv("LATENCY_QUANTILE_ALPHA", "0.9")) # p90 quantile
OBJECTIVE_TYPE: ObjectiveType = ObjectiveType(os.getenv("LATENCY_OBJECTIVE_TYPE", "quantile"))
SAMPLE_WEIGHTING_FOR_PREFIX_CACHE: bool = os.getenv("LATENCY_SAMPLE_WEIGHTING_FOR_PREFIX_CACHE", "false").lower() == "true"
ENSEMBLE_MODE: bool = os.getenv("LATENCY_ENSEMBLE_MODE", "true").lower() == "true"
MIN_SAMPLES_FOR_ENSEMBLE_SPLIT: int = int(os.getenv("LATENCY_MIN_SAMPLES_FOR_ENSEMBLE_SPLIT", "200"))

# Gated ensemble model paths (each wraps noqueue + queued sub-models)
TTFT_GATED_MODEL_PATH: str = os.getenv("LATENCY_TTFT_GATED_MODEL_PATH", "/tmp/models/ttft_gated.joblib")
TPOT_GATED_MODEL_PATH: str = os.getenv("LATENCY_TPOT_GATED_MODEL_PATH", "/tmp/models/tpot_gated.joblib")


class QueueGatedModel:
"""Wraps noqueue + queued sub-models into one joblib-serializable object.

At prediction time the caller checks num_request_waiting and picks the
appropriate sub-model + scaler from inside this wrapper.
"""

def __init__(self, noqueue_model, queued_model,
noqueue_scaler=None, queued_scaler=None):
self.noqueue_model = noqueue_model
self.queued_model = queued_model
self.noqueue_scaler = noqueue_scaler
self.queued_scaler = queued_scaler


settings = Settings()
Expand Down Expand Up @@ -280,6 +301,11 @@ def __init__(self, model_type: str = None):
self.ttft_coefficients = None # Will store descaled coefficients as dict
self.tpot_coefficients = None # Will store descaled coefficients as dict

# Gated ensemble model wrappers (QueueGatedModel instances)
self.ttft_gated: Optional[QueueGatedModel] = None
self.tpot_gated: Optional[QueueGatedModel] = None
self.ensemble_active: bool = False

self.lock = threading.Lock()
self.last_retrain_time = None
self._shutdown_event = threading.Event()
Expand Down Expand Up @@ -316,6 +342,36 @@ def _get_bucket_key(self, sample: dict) -> tuple:

return (queue_bucket, cache_bucket, prefix_bucket)

def _split_samples_by_queue(self, buckets: dict) -> Tuple[list, list]:
"""Split bucket samples into noqueue (queue_bucket==0) and queued (queue_bucket>=1).

Returns:
(noqueue_samples, queued_samples)
"""
noqueue_samples = []
queued_samples = []
for (queue_bucket, cache_bucket, prefix_bucket), bucket_deque in buckets.items():
if queue_bucket == 0:
noqueue_samples.extend(bucket_deque)
else:
queued_samples.extend(bucket_deque)
return noqueue_samples, queued_samples

def _prepare_features_for_ensemble(self, df: pd.DataFrame, model_type: str, queue_regime: str) -> pd.DataFrame:
"""Prepare features for ensemble sub-models.

Args:
df: DataFrame with raw features
model_type: 'ttft' or 'tpot'
queue_regime: 'noqueue' or 'queued'
Returns:
DataFrame with engineered features, with queue columns dropped for noqueue regime
"""
features = self._prepare_features_with_interaction(df, model_type)
if queue_regime == "noqueue":
features = features.drop(columns=['is_queued', 'num_request_waiting'], errors='ignore')
return features

def _store_descaled_coefficients(self, model, scaler, feature_names, model_name):
"""
Store descaled coefficients for Bayesian Ridge models.
Expand Down Expand Up @@ -835,6 +891,71 @@ def train(self):
else:
logging.warning("Not enough TPOT samples, skipping TPOT training.")

# --- Ensemble (gated) training ---
new_ttft_gated = None
new_tpot_gated = None

if settings.ENSEMBLE_MODE:
try:
with self.lock:
ttft_noqueue, ttft_queued = self._split_samples_by_queue(self.ttft_data_buckets)
tpot_noqueue, tpot_queued = self._split_samples_by_queue(self.tpot_data_buckets)

min_split = settings.MIN_SAMPLES_FOR_ENSEMBLE_SPLIT
split_counts = {
"ttft_noqueue": len(ttft_noqueue),
"ttft_queued": len(ttft_queued),
"tpot_noqueue": len(tpot_noqueue),
"tpot_queued": len(tpot_queued),
}
logging.info(f"Ensemble split counts: {split_counts}, min required: {min_split}")

if all(cnt >= min_split for cnt in split_counts.values()):
sub_models = {} # key -> (model, scaler_or_None)

for key, samples, model_name, target_col, regime in [
("ttft_noqueue", ttft_noqueue, "ttft", "actual_ttft_ms", "noqueue"),
("ttft_queued", ttft_queued, "ttft", "actual_ttft_ms", "queued"),
("tpot_noqueue", tpot_noqueue, "tpot", "actual_tpot_ms", "noqueue"),
("tpot_queued", tpot_queued, "tpot", "actual_tpot_ms", "queued"),
]:
try:
raw = pd.DataFrame(samples).dropna()
raw = raw[raw[target_col] > 0]
X = self._prepare_features_for_ensemble(raw.copy(), model_name, regime)
y = raw[target_col]
drop_q = (regime == "noqueue")
result = self._train_model_with_scaling(X, y, model_name=model_name, drop_queue_features=drop_q)
if self.model_type == ModelType.BAYESIAN_RIDGE:
sub_models[key] = result # (model, scaler)
else:
sub_models[key] = (result, None)
logging.info(f"{key} model trained on {len(raw)} samples")
except Exception:
logging.error(f"Error training {key} model", exc_info=True)

# Build gated wrappers only if all 4 sub-models trained
if len(sub_models) == 4:
new_ttft_gated = QueueGatedModel(
noqueue_model=sub_models["ttft_noqueue"][0],
queued_model=sub_models["ttft_queued"][0],
noqueue_scaler=sub_models["ttft_noqueue"][1],
queued_scaler=sub_models["ttft_queued"][1],
)
new_tpot_gated = QueueGatedModel(
noqueue_model=sub_models["tpot_noqueue"][0],
queued_model=sub_models["tpot_queued"][0],
noqueue_scaler=sub_models["tpot_noqueue"][1],
queued_scaler=sub_models["tpot_queued"][1],
)
logging.info("Ensemble training succeeded — built QueueGatedModel wrappers")
else:
logging.warning("Ensemble training failed for some sub-models, falling back to single model")
else:
logging.info("Insufficient samples for ensemble split, using single model only")
except Exception:
logging.error("Error in ensemble training block", exc_info=True)

with self.lock:
if new_ttft_model:
self.ttft_model = new_ttft_model
Expand All @@ -861,6 +982,14 @@ def train(self):
new_tpot_model, new_tpot_scaler, tpot_features, "TPOT"
)

# Update gated ensemble models
if new_ttft_gated and new_tpot_gated:
self.ttft_gated = new_ttft_gated
self.tpot_gated = new_tpot_gated
self.ensemble_active = True
else:
self.ensemble_active = False

if self.is_ready:
self.last_retrain_time = datetime.now(timezone.utc)
try:
Expand Down Expand Up @@ -1094,6 +1223,17 @@ def _save_models_unlocked(self):
joblib.dump(self.tpot_scaler, settings.TPOT_SCALER_PATH)
logging.info("TPOT scaler saved.")

# Save gated ensemble models (2 files instead of 8)
if self.ensemble_active:
for gated, path, name in [
(self.ttft_gated, settings.TTFT_GATED_MODEL_PATH, "TTFT gated"),
(self.tpot_gated, settings.TPOT_GATED_MODEL_PATH, "TPOT gated"),
]:
if gated:
os.makedirs(os.path.dirname(path), exist_ok=True)
joblib.dump(gated, path)
logging.info(f"{name} ensemble model saved.")

except Exception as e:
logging.error(f"Error saving models: {e}", exc_info=True)

Expand Down Expand Up @@ -1220,6 +1360,8 @@ def get_metrics(self) -> str:
lines.append(f'model_type{{type="{self.model_type.value}"}} 1')
lines.append(f'objective_type{{type="{self.objective_type.value}"}} 1')
lines.append(f'model_quantile{{}} {self.quantile}')
lines.append(f'ensemble_active{{}} {1 if self.ensemble_active else 0}')
lines.append(f'ensemble_mode{{}} {1 if settings.ENSEMBLE_MODE else 0}')

# Helper: emit linear‐model coefs or tree importances
def emit_metrics(model, coefficients, feats, prefix):
Expand Down Expand Up @@ -1619,6 +1761,10 @@ async def get_data_status():
key = f"queue_{q}_cache_{c}_prefix_{p}"
bucket_distribution[key] = len(bucket)

# Compute per-regime sample counts
ttft_noqueue, ttft_queued = predictor._split_samples_by_queue(predictor.ttft_data_buckets)
tpot_noqueue, tpot_queued = predictor._split_samples_by_queue(predictor.tpot_data_buckets)

return {
"training_data": {
"ttft_samples": ttft_training_count,
Expand All @@ -1630,6 +1776,15 @@ async def get_data_status():
"tpot_samples": len(predictor.tpot_test_data),
"total_samples": len(predictor.ttft_test_data) + len(predictor.tpot_test_data)
},
"ensemble": {
"ensemble_mode": settings.ENSEMBLE_MODE,
"ensemble_active": predictor.ensemble_active,
"min_samples_for_split": settings.MIN_SAMPLES_FOR_ENSEMBLE_SPLIT,
"ttft_noqueue_samples": len(ttft_noqueue),
"ttft_queued_samples": len(ttft_queued),
"tpot_noqueue_samples": len(tpot_noqueue),
"tpot_queued_samples": len(tpot_queued),
},
"metrics": {
"ttft_scores_count": len(predictor.ttft_quantile_loss_scores),
"tpot_scores_count": len(predictor.tpot_quantile_loss_scores)
Expand Down Expand Up @@ -1741,6 +1896,8 @@ async def model_info(model_name: str):
"tpot": settings.TPOT_MODEL_PATH,
"ttft_scaler": settings.TTFT_SCALER_PATH,
"tpot_scaler": settings.TPOT_SCALER_PATH,
"ttft_gated": settings.TTFT_GATED_MODEL_PATH,
"tpot_gated": settings.TPOT_GATED_MODEL_PATH,
}

if model_name not in model_paths:
Expand Down Expand Up @@ -1774,6 +1931,8 @@ async def download_model(model_name: str):
"tpot": settings.TPOT_MODEL_PATH,
"ttft_scaler": settings.TTFT_SCALER_PATH,
"tpot_scaler": settings.TPOT_SCALER_PATH,
"ttft_gated": settings.TTFT_GATED_MODEL_PATH,
"tpot_gated": settings.TPOT_GATED_MODEL_PATH,
}

if model_name not in model_paths:
Expand Down Expand Up @@ -1802,6 +1961,8 @@ async def list_models():
"tpot": settings.TPOT_MODEL_PATH,
"ttft_scaler": settings.TTFT_SCALER_PATH,
"tpot_scaler": settings.TPOT_SCALER_PATH,
"ttft_gated": settings.TTFT_GATED_MODEL_PATH,
"tpot_gated": settings.TPOT_GATED_MODEL_PATH,
}

for model_name, model_path in model_paths.items():
Expand Down