Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ import (
testresponsereceived "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol/plugins/test/responsereceived"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector/framework/plugins/utilizationdetector"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/predicted_latency"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/profile"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/scorer"
Expand Down Expand Up @@ -414,7 +414,7 @@ func (r *Runner) registerInTreePlugins() {
plugins.Register(scorer.RunningRequestsSizeScorerType, scorer.RunningRequestsSizeScorerFactory)
plugins.Register(scorer.LoraAffinityScorerType, scorer.LoraAffinityScorerFactory)
// Latency predictor plugins
plugins.Register(slo_aware_router.SLOAwareRouterPluginType, slo_aware_router.SLOAwareRouterFactory)
plugins.Register(predicted_latency.PredictedLatencyPluginType, predicted_latency.PredictedLatencyFactory)
// register filter for test purpose only (used in conformance tests)
plugins.Register(testfilter.HeaderBasedTestingFilterType, testfilter.HeaderBasedTestingFilterFactory)
// register response received plugin for test purpose only (used in conformance tests)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ limitations under the License.
*/

// Package requestcontrol contains helpers to decouple latency-predictor logic.
package slo_aware_router
package predicted_latency

import (
"strconv"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package slo_aware_router
package predicted_latency

import (
"context"
Expand All @@ -26,7 +26,7 @@ import (
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

func (s *SLOAwareRouter) selectFromCompositeScores(ctx context.Context, allPreds []endpointPredictionResult, r *rand.Rand, strategy headroomStrategy) schedulingtypes.Endpoint {
func (s *PredictedLatency) selectFromCompositeScores(ctx context.Context, allPreds []endpointPredictionResult, r *rand.Rand, strategy headroomStrategy) schedulingtypes.Endpoint {
total := 0
choices := s.buildCompositeChoices(
ctx, allPreds, s.config.CompositeKVWeight, s.config.CompositeQueueWeight, s.config.CompositePrefixWeight, &total,
Expand All @@ -40,7 +40,7 @@ func (s *SLOAwareRouter) selectFromCompositeScores(ctx context.Context, allPreds
selectedEndpoint := s.performWeightedRandomSelection(choices, total, allPreds, r)
return selectedEndpoint
}
func (s *SLOAwareRouter) performWeightedRandomSelection(weightedChoices []choice, total int, candidates []endpointPredictionResult, r *rand.Rand) schedulingtypes.Endpoint {
func (s *PredictedLatency) performWeightedRandomSelection(weightedChoices []choice, total int, candidates []endpointPredictionResult, r *rand.Rand) schedulingtypes.Endpoint {
if total == 0 {
return nil
}
Expand Down Expand Up @@ -84,7 +84,7 @@ func (s *SLOAwareRouter) performWeightedRandomSelection(weightedChoices []choice

return selectedEndpoint
}
func (s *SLOAwareRouter) buildCompositeChoices(
func (s *PredictedLatency) buildCompositeChoices(
ctx context.Context,
candidates []endpointPredictionResult,
wkv, wq, wpref float64,
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package slo_aware_router
package predicted_latency

import (
"context"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ limitations under the License.
*/

// Package requestcontrol contains helpers to decouple latency-predictor logic.
package slo_aware_router
package predicted_latency

import (
"context"
Expand All @@ -41,7 +41,7 @@ type endpointPredictionResult struct {
}

// generatePredictions creates prediction results for all candidate pods
func (s *SLOAwareRouter) generatePredictions(ctx context.Context, request *schedulingtypes.LLMRequest, sloCtx *sloRequestContext, candidateEndpoints []schedulingtypes.Endpoint) ([]endpointPredictionResult, error) {
func (s *PredictedLatency) generatePredictions(ctx context.Context, request *schedulingtypes.LLMRequest, predictedLatencyCtx *predictedLatencyCtx, candidateEndpoints []schedulingtypes.Endpoint) ([]endpointPredictionResult, error) {
logger := log.FromContext(ctx)
predictions := make([]endpointPredictionResult, 0, len(candidateEndpoints))

Expand All @@ -55,7 +55,7 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, request *sched
logger.V(logutil.TRACE).Info("Candidate pod for scheduling", "endpoint", endpoint.GetMetadata().String(), "metrics", endpoint.GetMetrics().String())

// Get prefix cache score for the pod
prefixCacheScore := sloCtx.prefixCacheScoresForEndpoints[endpoint.GetMetadata().NamespacedName.Name]
prefixCacheScore := predictedLatencyCtx.prefixCacheScoresForEndpoints[endpoint.GetMetadata().NamespacedName.Name]

logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", endpoint.GetMetadata().String(), "prefixCacheScore", prefixCacheScore)

Expand All @@ -82,7 +82,7 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, request *sched
predResult.TPOT = prediction.TPOT

podMinTPOTSLO := s.getEndpointMinTPOTSLO(endpoint)
predResult.TTFTValid, predResult.TPOTValid, predResult.IsValid, predResult.Headroom, predResult.TTFTHeadroom = s.validatePrediction(prediction, sloCtx, podMinTPOTSLO)
predResult.TTFTValid, predResult.TPOTValid, predResult.IsValid, predResult.Headroom, predResult.TTFTHeadroom = s.validatePrediction(prediction, predictedLatencyCtx, podMinTPOTSLO)

logger.V(logutil.DEBUG).Info("Prediction for scheduling",
"endpoint", endpoint.GetMetadata().String(),
Expand All @@ -91,8 +91,8 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, request *sched
"TPOT", prediction.TPOT,
"buffer", s.config.SLOBufferFactor,
"podMinTPOTSLO", podMinTPOTSLO,
"ttftSLO", sloCtx.ttftSLO,
"requestTPOTSLO", sloCtx.avgTPOTSLO,
"ttftSLO", predictedLatencyCtx.ttftSLO,
"requestTPOTSLO", predictedLatencyCtx.avgTPOTSLO,
"tpotHeadroom", predResult.Headroom,
"ttftHeadroom", predResult.TTFTHeadroom,
"tpotValid", predResult.TPOTValid,
Expand All @@ -106,28 +106,28 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, request *sched
}

// updateRequestContextWithPredictions updates the request context with prediction data
func (s *SLOAwareRouter) updateRequestContextWithPredictions(sloCtx *sloRequestContext, predictions []endpointPredictionResult) {
sloCtx.predictionsForScheduling = predictions
func (s *PredictedLatency) updateRequestContextWithPredictions(predictedLatencyCtx *predictedLatencyCtx, predictions []endpointPredictionResult) {
predictedLatencyCtx.predictionsForScheduling = predictions
}

func (s *SLOAwareRouter) validatePrediction(
func (s *PredictedLatency) validatePrediction(
pred *latencypredictor.PredictionResponse,
sloCtx *sloRequestContext,
predictedLatencyCtx *predictedLatencyCtx,
podMinTPOTSLO float64,
) (ttftOk, tpotOk, isValid bool, headroom float64, ttftHeadroom float64) {

ttftOk = pred.TTFT < sloCtx.ttftSLO
ttftHeadroom = sloCtx.ttftSLO - pred.TTFT
ttftOk = pred.TTFT < predictedLatencyCtx.ttftSLO
ttftHeadroom = predictedLatencyCtx.ttftSLO - pred.TTFT

tpotOk = true
headroom = 0.0

if s.config.StreamingMode {
bufferedTPOT := sloCtx.avgTPOTSLO * s.config.SLOBufferFactor
bufferedTPOT := predictedLatencyCtx.avgTPOTSLO * s.config.SLOBufferFactor
// a podMinTPOTSLO of 0 means no either no requests, or no TPOT SLOs specified on running requests
if podMinTPOTSLO > 0 {
if podMinTPOTSLO < sloCtx.avgTPOTSLO {
log.FromContext(context.Background()).V(logutil.DEBUG).Info("Pod min TPOT SLO is less than the req SLO, adjusting", "podMinTPOTSLO", podMinTPOTSLO, "bufferedTPOT", sloCtx.avgTPOTSLO)
if podMinTPOTSLO < predictedLatencyCtx.avgTPOTSLO {
log.FromContext(context.Background()).V(logutil.DEBUG).Info("Pod min TPOT SLO is less than the req SLO, adjusting", "podMinTPOTSLO", podMinTPOTSLO, "bufferedTPOT", predictedLatencyCtx.avgTPOTSLO)
}
bufferedTPOT = min(bufferedTPOT, podMinTPOTSLO*s.config.SLOBufferFactor)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package slo_aware_router
package predicted_latency

import (
"context"
Expand All @@ -28,9 +28,9 @@ import (
)

// PrepareRequestData prepares the SLO context for the request, including parsing SLO headers and gathering prefix cache scores abds generating predictions.
func (s *SLOAwareRouter) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) error {
func (s *PredictedLatency) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) error {
logger := log.FromContext(ctx)
sloCtx := s.getOrMakeSLORequestContext(request)
sloCtx := s.getOrMakePredictedLatencyContextForRequest(request)

s.parseSLOHeaders(ctx, request, sloCtx)
var prefixCacheScore float64
Expand All @@ -51,14 +51,14 @@ func (s *SLOAwareRouter) PrepareRequestData(ctx context.Context, request *schedu
}
sloCtx.prefixCacheScoresForEndpoints[endpoint.GetMetadata().NamespacedName.Name] = prefixCacheScore
}
s.setSLOContextForRequest(request, sloCtx)
s.setPredictedLatencyContextForRequest(request, sloCtx)
return nil
}

func (p *SLOAwareRouter) Produces() map[string]any {
func (p *PredictedLatency) Produces() map[string]any {
return map[string]any{}
}

func (p *SLOAwareRouter) Consumes() map[string]any {
func (p *PredictedLatency) Consumes() map[string]any {
return map[string]any{approximateprefix.PrefixCacheMatchInfoKey: approximateprefix.PrefixCacheMatchInfo{}}
}
Loading