Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions latencypredictor/training_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class Settings:
SAMPLE_WEIGHTING_FOR_PREFIX_CACHE: bool = os.getenv("LATENCY_SAMPLE_WEIGHTING_FOR_PREFIX_CACHE", "false").lower() == "true"
ENSEMBLE_MODE: bool = os.getenv("LATENCY_ENSEMBLE_MODE", "true").lower() == "true"
MIN_SAMPLES_FOR_ENSEMBLE_SPLIT: int = int(os.getenv("LATENCY_MIN_SAMPLES_FOR_ENSEMBLE_SPLIT", "200"))
TPOT_ZERO_TOKEN_COUNT: bool = os.getenv("LATENCY_TPOT_ZERO_TOKEN_COUNT", "true").lower() == "true"

# Gated ensemble model paths (each wraps noqueue + queued sub-models)
TTFT_GATED_MODEL_PATH: str = os.getenv("LATENCY_TTFT_GATED_MODEL_PATH", "/tmp/models/ttft_gated.joblib")
Expand Down Expand Up @@ -851,6 +852,8 @@ def train(self):
if tpot_snap:
df_tpot = pd.DataFrame(tpot_snap).dropna()
df_tpot = df_tpot[df_tpot['actual_tpot_ms'] > 0]
if settings.TPOT_ZERO_TOKEN_COUNT:
df_tpot['num_tokens_generated'] = 0
if len(df_tpot) >= settings.MIN_SAMPLES_FOR_RETRAIN:
# TPOT features - use feature preparation to add pod_type_cat
X_tpot = self._prepare_features_with_interaction(df_tpot.copy(), model_type="tpot")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,9 @@ func processTokenForLatencyPrediction(
latencyMs := float64(now.Sub(predictedLatencyCtx.lastTokenTimestamp).Milliseconds())
predictedLatencyCtx.generatedTokenCount++

// log the inter-token latency for predicted samples
if predictedLatencyCtx.generatedTokenCount == 2 || predictedLatencyCtx.tokenSampler.shouldPredict(predictedLatencyCtx.generatedTokenCount) { // tricky logic, since next sample token is always +1 from current token
// record sampled TPOT observations (avgTPOT is computed in ResponseComplete from e2e latency)
if predictedLatencyCtx.generatedTokenCount == 2 || predictedLatencyCtx.tokenSampler.shouldPredict(predictedLatencyCtx.generatedTokenCount) {
predictedLatencyCtx.tpotObservations = append(predictedLatencyCtx.tpotObservations, latencyMs)
predictedLatencyCtx.avgTPOT = calculateRunningAverage(predictedLatencyCtx.avgTPOT, latencyMs, len(predictedLatencyCtx.tpotObservations))
}
if predictedLatencyCtx.generatedTokenCount == 2 {
// debug log actual and predicted tpot
Expand All @@ -289,26 +288,14 @@ func processTokenForLatencyPrediction(
"predicted_tpot_ms", predictedLatencyCtx.avgPredictedTPOT)
}

// TPOT training is now done once per request in ResponseComplete using avgTPOT

m, err := getLatestMetricsForProfile(predictedLatencyCtx, "")
if err != nil {
logger.V(logutil.DEBUG).Info("Skipping TPOT training due to missing metrics or schedulingResult",
logger.V(logutil.DEBUG).Info("Skipping TPOT prediction due to missing metrics or schedulingResult",
"error", err)
return
}
entry := buildTrainingEntry(
endpointRoleLabel,
targetEndpointMetadata,
m,
predictedLatencyCtx.promptText,
0, // TTFT not recorded for TPOT
latencyMs,
now,
predictedLatencyCtx.generatedTokenCount-1,
0, // TPOT does not use prefix cache score
)
if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil {
logger.V(logutil.DEBUG).Error(err, "record TPOT training failed")
}

// Sampled predict
if predictedLatencyCtx.tokenSampler.shouldPredict(predictedLatencyCtx.generatedTokenCount) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol"
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
)

const (
Expand Down Expand Up @@ -247,13 +248,37 @@ func (t *PredictedLatency) ResponseComplete(ctx context.Context, request *schedu
}
}

// Compute avgTPOT as (e2e - ttft) / (tokens - 1) for a more accurate overall average
if predictedLatencyCtx.ttft > 0 && predictedLatencyCtx.generatedTokenCount > 1 {
e2eMs := float64(now.Sub(predictedLatencyCtx.requestReceivedTimestamp).Milliseconds())
predictedLatencyCtx.avgTPOT = (e2eMs - predictedLatencyCtx.ttft) / float64(predictedLatencyCtx.generatedTokenCount-1)
}

if predictedLatencyCtx.avgTPOT > 0 {
logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTPOT", predictedLatencyCtx.avgTPOT, "avgPredictedTPOT", predictedLatencyCtx.avgPredictedTPOT)
metrics.RecordRequestTPOT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.avgTPOT/1000)
metrics.RecordRequestPredictedTPOT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.avgPredictedTPOT/1000)
if predictedLatencyCtx.avgTPOTSLO > 0 {
metrics.RecordRequestTPOTWithSLO(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.avgTPOT, predictedLatencyCtx.avgTPOTSLO)
}

// Record one TPOT training entry per request using avgTPOT and dispatch-time metrics
if m, err := getLatestMetricsForProfile(predictedLatencyCtx, ""); err == nil {
entry := buildTrainingEntry(
t.config.EndpointRoleLabel,
targetMetadata,
m,
predictedLatencyCtx.promptText,
0, // TTFT not recorded for TPOT
predictedLatencyCtx.avgTPOT,
now,
0, // not used for TPOT prediction
0, // TPOT does not use prefix cache score
)
if err := t.latencypredictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil {
logger.V(logutil.DEBUG).Error(err, "record TPOT training failed")
}
}
}

id := request.Headers[reqcommon.RequestIdHeaderKey]
Expand Down