Skip to content

Commit e6c875c

Browse files
kaushikmitrBizerNotNull
authored andcommitted
latencypredictor: improve TPOT training accuracy (kubernetes-sigs#2509)
1 parent b34a8ba commit e6c875c

3 files changed

Lines changed: 33 additions & 18 deletions

File tree

latencypredictor/training_server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class Settings:
114114
SAMPLE_WEIGHTING_FOR_PREFIX_CACHE: bool = os.getenv("LATENCY_SAMPLE_WEIGHTING_FOR_PREFIX_CACHE", "false").lower() == "true"
115115
ENSEMBLE_MODE: bool = os.getenv("LATENCY_ENSEMBLE_MODE", "true").lower() == "true"
116116
MIN_SAMPLES_FOR_ENSEMBLE_SPLIT: int = int(os.getenv("LATENCY_MIN_SAMPLES_FOR_ENSEMBLE_SPLIT", "200"))
117+
TPOT_ZERO_TOKEN_COUNT: bool = os.getenv("LATENCY_TPOT_ZERO_TOKEN_COUNT", "true").lower() == "true"
117118

118119
# Gated ensemble model paths (each wraps noqueue + queued sub-models)
119120
TTFT_GATED_MODEL_PATH: str = os.getenv("LATENCY_TTFT_GATED_MODEL_PATH", "/tmp/models/ttft_gated.joblib")
@@ -851,6 +852,8 @@ def train(self):
851852
if tpot_snap:
852853
df_tpot = pd.DataFrame(tpot_snap).dropna()
853854
df_tpot = df_tpot[df_tpot['actual_tpot_ms'] > 0]
855+
if settings.TPOT_ZERO_TOKEN_COUNT:
856+
df_tpot['num_tokens_generated'] = 0
854857
if len(df_tpot) >= settings.MIN_SAMPLES_FOR_RETRAIN:
855858
# TPOT features - use feature preparation to add pod_type_cat
856859
X_tpot = self._prepare_features_with_interaction(df_tpot.copy(), model_type="tpot")

pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/latencypredictor_helper.go

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,9 @@ func processTokenForLatencyPrediction(
277277
latencyMs := float64(now.Sub(predictedLatencyCtx.lastTokenTimestamp).Milliseconds())
278278
predictedLatencyCtx.generatedTokenCount++
279279

280-
// log the inter-token latency for predicted samples
281-
if predictedLatencyCtx.generatedTokenCount == 2 || predictedLatencyCtx.tokenSampler.shouldPredict(predictedLatencyCtx.generatedTokenCount) { // tricky logic, since next sample token is always +1 from current token
280+
// record sampled TPOT observations (avgTPOT is computed in ResponseComplete from e2e latency)
281+
if predictedLatencyCtx.generatedTokenCount == 2 || predictedLatencyCtx.tokenSampler.shouldPredict(predictedLatencyCtx.generatedTokenCount) {
282282
predictedLatencyCtx.tpotObservations = append(predictedLatencyCtx.tpotObservations, latencyMs)
283-
predictedLatencyCtx.avgTPOT = calculateRunningAverage(predictedLatencyCtx.avgTPOT, latencyMs, len(predictedLatencyCtx.tpotObservations))
284283
}
285284
if predictedLatencyCtx.generatedTokenCount == 2 {
286285
// debug log actual and predicted tpot
@@ -289,26 +288,14 @@ func processTokenForLatencyPrediction(
289288
"predicted_tpot_ms", predictedLatencyCtx.avgPredictedTPOT)
290289
}
291290

291+
// TPOT training is now done once per request in ResponseComplete using avgTPOT
292+
292293
m, err := getLatestMetricsForProfile(predictedLatencyCtx, "")
293294
if err != nil {
294-
logger.V(logutil.DEBUG).Info("Skipping TPOT training due to missing metrics or schedulingResult",
295+
logger.V(logutil.DEBUG).Info("Skipping TPOT prediction due to missing metrics or schedulingResult",
295296
"error", err)
296297
return
297298
}
298-
entry := buildTrainingEntry(
299-
endpointRoleLabel,
300-
targetEndpointMetadata,
301-
m,
302-
predictedLatencyCtx.promptText,
303-
0, // TTFT not recorded for TPOT
304-
latencyMs,
305-
now,
306-
predictedLatencyCtx.generatedTokenCount-1,
307-
0, // TPOT does not use prefix cache score
308-
)
309-
if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil {
310-
logger.V(logutil.DEBUG).Error(err, "record TPOT training failed")
311-
}
312299

313300
// Sampled predict
314301
if predictedLatencyCtx.tokenSampler.shouldPredict(predictedLatencyCtx.generatedTokenCount) {

pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/requestcontrol_hooks.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import (
3333
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol"
3434
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
3535
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
36+
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
3637
)
3738

3839
const (
@@ -247,13 +248,37 @@ func (t *PredictedLatency) ResponseComplete(ctx context.Context, request *schedu
247248
}
248249
}
249250

251+
// Compute avgTPOT as (e2e - ttft) / (tokens - 1) for a more accurate overall average
252+
if predictedLatencyCtx.ttft > 0 && predictedLatencyCtx.generatedTokenCount > 1 {
253+
e2eMs := float64(now.Sub(predictedLatencyCtx.requestReceivedTimestamp).Milliseconds())
254+
predictedLatencyCtx.avgTPOT = (e2eMs - predictedLatencyCtx.ttft) / float64(predictedLatencyCtx.generatedTokenCount-1)
255+
}
256+
250257
if predictedLatencyCtx.avgTPOT > 0 {
251258
logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTPOT", predictedLatencyCtx.avgTPOT, "avgPredictedTPOT", predictedLatencyCtx.avgPredictedTPOT)
252259
metrics.RecordRequestTPOT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.avgTPOT/1000)
253260
metrics.RecordRequestPredictedTPOT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.avgPredictedTPOT/1000)
254261
if predictedLatencyCtx.avgTPOTSLO > 0 {
255262
metrics.RecordRequestTPOTWithSLO(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.avgTPOT, predictedLatencyCtx.avgTPOTSLO)
256263
}
264+
265+
// Record one TPOT training entry per request using avgTPOT and dispatch-time metrics
266+
if m, err := getLatestMetricsForProfile(predictedLatencyCtx, ""); err == nil {
267+
entry := buildTrainingEntry(
268+
t.config.EndpointRoleLabel,
269+
targetMetadata,
270+
m,
271+
predictedLatencyCtx.promptText,
272+
0, // TTFT not recorded for TPOT
273+
predictedLatencyCtx.avgTPOT,
274+
now,
275+
0, // not used for TPOT prediction
276+
0, // TPOT does not use prefix cache score
277+
)
278+
if err := t.latencypredictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil {
279+
logger.V(logutil.DEBUG).Error(err, "record TPOT training failed")
280+
}
281+
}
257282
}
258283

259284
id := request.Headers[reqcommon.RequestIdHeaderKey]

0 commit comments

Comments
 (0)