diff --git a/pkg/epp/scheduling/config.go b/pkg/epp/scheduling/config.go index 6c0f4be7b..4ed109af1 100644 --- a/pkg/epp/scheduling/config.go +++ b/pkg/epp/scheduling/config.go @@ -20,8 +20,22 @@ import "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" type SchedulerConfig struct { preSchedulePlugins []plugins.PreSchedule - scorers []plugins.Scorer filters []plugins.Filter - postSchedulePlugins []plugins.PostSchedule + scorers map[plugins.Scorer]int // map from scorer to weight picker plugins.Picker + postSchedulePlugins []plugins.PostSchedule +} + +var defPlugin = &defaultPlugin{} + +// When the scheduler is initialized with NewScheduler function, this config will be used as default. +// it's possible to call NewSchedulerWithConfig to pass a different argument. + +// For build time plugins changes, it's recommended to change the defaultConfig variable in this file. +var defaultConfig = &SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{}, + filters: []plugins.Filter{defPlugin}, + scorers: map[plugins.Scorer]int{}, + picker: defPlugin, + postSchedulePlugins: []plugins.PostSchedule{}, } diff --git a/pkg/epp/scheduling/default_config.go b/pkg/epp/scheduling/default_config.go deleted file mode 100644 index e42f13179..000000000 --- a/pkg/epp/scheduling/default_config.go +++ /dev/null @@ -1,31 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package scheduling - -import ( - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" -) - -var defPlugin = &defaultPlugin{} - -var defaultConfig = &SchedulerConfig{ - preSchedulePlugins: []plugins.PreSchedule{}, - scorers: []plugins.Scorer{}, - filters: []plugins.Filter{defPlugin}, - postSchedulePlugins: []plugins.PostSchedule{}, - picker: defPlugin, -} diff --git a/pkg/epp/scheduling/plugins/filter/filter_test.go b/pkg/epp/scheduling/plugins/filter/filter_test.go index 56cccb3b8..a06ec3caa 100644 --- a/pkg/epp/scheduling/plugins/filter/filter_test.go +++ b/pkg/epp/scheduling/plugins/filter/filter_test.go @@ -54,8 +54,7 @@ func TestFilter(t *testing.T) { ctx := types.NewSchedulingContext(context.Background(), test.req, test.input) got := test.filter.Filter(ctx, test.input) - opt := cmp.AllowUnexported(types.PodMetrics{}) - if diff := cmp.Diff(test.output, got, opt); diff != "" { + if diff := cmp.Diff(test.output, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) @@ -190,8 +189,7 @@ func TestFilterFunc(t *testing.T) { ctx := types.NewSchedulingContext(context.Background(), test.req, test.input) got := test.f(ctx, test.input) - opt := cmp.AllowUnexported(types.PodMetrics{}) - if diff := cmp.Diff(test.output, got, opt); diff != "" { + if diff := cmp.Diff(test.output, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) diff --git a/pkg/epp/scheduling/plugins/noop.go b/pkg/epp/scheduling/plugins/noop.go deleted file mode 100644 index 8f50ff36e..000000000 --- a/pkg/epp/scheduling/plugins/noop.go +++ /dev/null @@ -1,42 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package plugins - -import ( - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" -) - -// NoopPlugin provides a default, no-operation implementation of the Plugin interface. -// It can be embedded in other plugin implementations to avoid boilerplate code for -// unused methods. -type NoopPlugin struct{} - -func (p *NoopPlugin) Name() string { return "NoopPlugin" } - -func (p *NoopPlugin) PreSchedule(ctx *types.SchedulingContext) {} - -func (p *NoopPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) ([]types.Pod, error) { - return pods, nil -} - -func (p *NoopPlugin) Score(ctx *types.SchedulingContext, pod types.Pod) (float64, error) { - return 0.0, nil -} - -func (p *NoopPlugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) {} - -func (p *NoopPlugin) PostResponse(ctx *types.SchedulingContext, pod types.Pod) {} diff --git a/pkg/epp/scheduling/plugins/picker/random_picker.go b/pkg/epp/scheduling/plugins/picker/random_picker.go index 850108e7e..6eecbb0da 100644 --- a/pkg/epp/scheduling/plugins/picker/random_picker.go +++ b/pkg/epp/scheduling/plugins/picker/random_picker.go @@ -20,18 +20,22 @@ import ( "fmt" "math/rand" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) +var _ plugins.Picker = &RandomPicker{} + +// RandomPicker picks a random pod from the list of candidates. type RandomPicker struct{} func (rp *RandomPicker) Name() string { return "random" } -func (rp *RandomPicker) Pick(ctx *types.SchedulingContext, pods []types.Pod) *types.Result { - ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(pods), pods)) - i := rand.Intn(len(pods)) - return &types.Result{TargetPod: pods[i]} +func (rp *RandomPicker) Pick(ctx *types.SchedulingContext, scoredPods []*types.ScoredPod) *types.Result { + ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(scoredPods), scoredPods)) + i := rand.Intn(len(scoredPods)) + return &types.Result{TargetPod: scoredPods[i].Pod} } diff --git a/pkg/epp/scheduling/plugins/plugins.go b/pkg/epp/scheduling/plugins/plugins.go index 4b334803b..f3412ab72 100644 --- a/pkg/epp/scheduling/plugins/plugins.go +++ b/pkg/epp/scheduling/plugins/plugins.go @@ -49,22 +49,23 @@ type Filter interface { Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod } -// Scorer defines the interface for scoring pods based on context. +// Scorer defines the interface for scoring a list of pods based on context. +// Scorers must score pods with a value within the range of [0,1] where 1 is the highest score. type Scorer interface { Plugin - Score(ctx *types.SchedulingContext, pod types.Pod) float64 + Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 } -// PostSchedule is called by the scheduler after it selects a targetPod for the request. -type PostSchedule interface { +// Picker picks the final pod(s) to send the request to. +type Picker interface { Plugin - PostSchedule(ctx *types.SchedulingContext, res *types.Result) + Pick(ctx *types.SchedulingContext, scoredPods []*types.ScoredPod) *types.Result } -// Picker picks the final pod(s) to send the request to. -type Picker interface { +// PostSchedule is called by the scheduler after it selects a targetPod for the request. +type PostSchedule interface { Plugin - Pick(ctx *types.SchedulingContext, pods []types.Pod) *types.Result + PostSchedule(ctx *types.SchedulingContext, res *types.Result) } // PostResponse is called by the scheduler after a successful response was sent. diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 322f714f4..04d24ea24 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -72,25 +72,23 @@ func NewScheduler(datastore Datastore) *Scheduler { } func NewSchedulerWithConfig(datastore Datastore, config *SchedulerConfig) *Scheduler { - scheduler := &Scheduler{ + return &Scheduler{ datastore: datastore, preSchedulePlugins: config.preSchedulePlugins, - scorers: config.scorers, filters: config.filters, - postSchedulePlugins: config.postSchedulePlugins, + scorers: config.scorers, picker: config.picker, + postSchedulePlugins: config.postSchedulePlugins, } - - return scheduler } type Scheduler struct { datastore Datastore preSchedulePlugins []plugins.PreSchedule filters []plugins.Filter - scorers []plugins.Scorer - postSchedulePlugins []plugins.PostSchedule + scorers map[plugins.Scorer]int // map from scorer to its weight picker plugins.Picker + postSchedulePlugins []plugins.PostSchedule } type Datastore interface { @@ -106,7 +104,7 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types // 1. Reduce concurrent access to the datastore. // 2. Ensure consistent data during the scheduling operation of a request. sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll())) - loggerDebug.Info(fmt.Sprintf("Scheduling a request. Metrics: %+v", sCtx.PodsSnapshot)) + loggerDebug.Info(fmt.Sprintf("Scheduling a request, Metrics: %+v", sCtx.PodsSnapshot)) s.runPreSchedulePlugins(sCtx) @@ -114,17 +112,14 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types if len(pods) == 0 { return nil, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: "failed to find a target pod"} } + // if we got here, there is at least one pod to score + weightedScorePerPod := s.runScorerPlugins(sCtx, pods) - s.runScorerPlugins(sCtx, pods) - - before := time.Now() - res := s.picker.Pick(sCtx, pods) - metrics.RecordSchedulerPluginProcessingLatency(plugins.PickerPluginType, s.picker.Name(), time.Since(before)) - loggerDebug.Info("After running picker plugins", "result", res) + result := s.runPickerPlugin(sCtx, weightedScorePerPod) - s.runPostSchedulePlugins(sCtx, res) + s.runPostSchedulePlugins(sCtx, result) - return res, nil + return result, nil } func (s *Scheduler) runPreSchedulePlugins(ctx *types.SchedulingContext) { @@ -136,15 +131,6 @@ func (s *Scheduler) runPreSchedulePlugins(ctx *types.SchedulingContext) { } } -func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *types.Result) { - for _, plugin := range s.postSchedulePlugins { - ctx.Logger.V(logutil.DEBUG).Info("Running post-schedule plugin", "plugin", plugin.Name()) - before := time.Now() - plugin.PostSchedule(ctx, res) - metrics.RecordSchedulerPluginProcessingLatency(plugins.PostSchedulePluginType, plugin.Name(), time.Since(before)) - } -} - func (s *Scheduler) runFilterPlugins(ctx *types.SchedulingContext) []types.Pod { loggerDebug := ctx.Logger.V(logutil.DEBUG) filteredPods := ctx.PodsSnapshot @@ -160,32 +146,60 @@ func (s *Scheduler) runFilterPlugins(ctx *types.SchedulingContext) []types.Pod { break } } + loggerDebug.Info("After running filter plugins") + return filteredPods } -func (s *Scheduler) runScorerPlugins(ctx *types.SchedulingContext, pods []types.Pod) { +func (s *Scheduler) runScorerPlugins(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { loggerDebug := ctx.Logger.V(logutil.DEBUG) - loggerDebug.Info("Before running score plugins", "pods", pods) + loggerDebug.Info("Before running scorer plugins", "pods", pods) + + weightedScorePerPod := make(map[types.Pod]float64, len(pods)) for _, pod := range pods { - score := s.runScorersForPod(ctx, pod) - pod.SetScore(score) + weightedScorePerPod[pod] = float64(0) // initialize weighted score per pod with 0 value + } + // Iterate through each scorer in the chain and accumulate the weighted scores. + for scorer, weight := range s.scorers { + loggerDebug.Info("Running scorer", "scorer", scorer.Name()) + before := time.Now() + scores := scorer.Score(ctx, pods) + metrics.RecordSchedulerPluginProcessingLatency(plugins.ScorerPluginType, scorer.Name(), time.Since(before)) + for pod, score := range scores { // weight is relative to the sum of weights + weightedScorePerPod[pod] += score * float64(weight) // TODO normalize score before multiply with weight + } + loggerDebug.Info("After running scorer", "scorer", scorer.Name()) + } + loggerDebug.Info("After running scorer plugins") + + return weightedScorePerPod +} + +func (s *Scheduler) runPickerPlugin(ctx *types.SchedulingContext, weightedScorePerPod map[types.Pod]float64) *types.Result { + loggerDebug := ctx.Logger.V(logutil.DEBUG) + scoredPods := make([]*types.ScoredPod, len(weightedScorePerPod)) + i := 0 + for pod, score := range weightedScorePerPod { + scoredPods[i] = &types.ScoredPod{Pod: pod, Score: score} + i++ } - loggerDebug.Info("After running score plugins", "pods", pods) + + loggerDebug.Info("Before running picker plugin", "pods", weightedScorePerPod) + before := time.Now() + result := s.picker.Pick(ctx, scoredPods) + metrics.RecordSchedulerPluginProcessingLatency(plugins.PickerPluginType, s.picker.Name(), time.Since(before)) + loggerDebug.Info("After running picker plugin", "result", result) + + return result } -// Iterate through each scorer in the chain and accumulate the scores. -func (s *Scheduler) runScorersForPod(ctx *types.SchedulingContext, pod types.Pod) float64 { - logger := ctx.Logger.WithValues("pod", pod.GetPod().NamespacedName).V(logutil.DEBUG) - score := float64(0) - for _, scorer := range s.scorers { - logger.Info("Running scorer", "scorer", scorer.Name()) +func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *types.Result) { + for _, plugin := range s.postSchedulePlugins { + ctx.Logger.V(logutil.DEBUG).Info("Running post-schedule plugin", "plugin", plugin.Name()) before := time.Now() - oneScore := scorer.Score(ctx, pod) - metrics.RecordSchedulerPluginProcessingLatency(plugins.ScorerPluginType, scorer.Name(), time.Since(before)) - score += oneScore - logger.Info("After scorer", "scorer", scorer.Name(), "score", oneScore, "total score", score) + plugin.PostSchedule(ctx, res) + metrics.RecordSchedulerPluginProcessingLatency(plugins.PostSchedulePluginType, plugin.Name(), time.Since(before)) } - return score } type defaultPlugin struct { diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index 2fb26a865..559f53f8b 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -220,24 +220,15 @@ func TestSchedule(t *testing.T) { }, } - schedConfig := &SchedulerConfig{ - preSchedulePlugins: []plugins.PreSchedule{}, - scorers: []plugins.Scorer{}, - filters: []plugins.Filter{defPlugin}, - postSchedulePlugins: []plugins.PostSchedule{}, - picker: defPlugin, - } - for _, test := range tests { t.Run(test.name, func(t *testing.T) { - scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, schedConfig) + scheduler := NewScheduler(&fakeDataStore{pods: test.input}) got, err := scheduler.Schedule(context.Background(), test.req) if test.err != (err != nil) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - opt := cmp.AllowUnexported(types.PodMetrics{}) - if diff := cmp.Diff(test.wantRes, got, opt); diff != "" { + if diff := cmp.Diff(test.wantRes, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) @@ -275,13 +266,16 @@ func TestSchedulePlugins(t *testing.T) { err bool }{ { - name: "all plugins executed successfully", + name: "all plugins executed successfully, all scorers with same weight", config: SchedulerConfig{ - preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, - filters: []plugins.Filter{tp1, tp2}, - scorers: []plugins.Scorer{tp1, tp2}, - postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, + preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, + filters: []plugins.Filter{tp1, tp2}, + scorers: map[plugins.Scorer]int{ + tp1: 1, + tp2: 1, + }, picker: pickerPlugin, + postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, }, input: []*backendmetrics.FakePodMetrics{ {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, @@ -294,13 +288,38 @@ func TestSchedulePlugins(t *testing.T) { err: false, }, { - name: "filter all", + name: "all plugins executed successfully, different scorers weights", config: SchedulerConfig{ - preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, - filters: []plugins.Filter{tp1, tp_filterAll}, - scorers: []plugins.Scorer{tp1, tp2}, + preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, + filters: []plugins.Filter{tp1, tp2}, + scorers: map[plugins.Scorer]int{ + tp1: 60, + tp2: 40, + }, + picker: pickerPlugin, postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, + }, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + }, + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + targetPodScore: 50, + numPodsToScore: 2, + err: false, + }, + { + name: "filter all", + config: SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, + filters: []plugins.Filter{tp1, tp_filterAll}, + scorers: map[plugins.Scorer]int{ + tp1: 1, + tp2: 1, + }, picker: pickerPlugin, + postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, }, input: []*backendmetrics.FakePodMetrics{ {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, @@ -318,16 +337,16 @@ func TestSchedulePlugins(t *testing.T) { for _, plugin := range test.config.preSchedulePlugins { plugin.(*TestPlugin).reset() } - for _, plugin := range test.config.postSchedulePlugins { - plugin.(*TestPlugin).reset() - } for _, plugin := range test.config.filters { plugin.(*TestPlugin).reset() } - for _, plugin := range test.config.scorers { + for plugin := range test.config.scorers { plugin.(*TestPlugin).reset() } test.config.picker.(*TestPlugin).reset() + for _, plugin := range test.config.postSchedulePlugins { + plugin.(*TestPlugin).reset() + } // Initialize the scheduler scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config) @@ -345,13 +364,11 @@ func TestSchedulePlugins(t *testing.T) { } // Validate output - opt := cmp.AllowUnexported(types.PodMetrics{}) wantPod := &types.PodMetrics{ Pod: &backendmetrics.Pod{NamespacedName: test.wantTargetPod}, } - wantPod.SetScore(test.targetPodScore) wantRes := &types.Result{TargetPod: wantPod} - if diff := cmp.Diff(wantRes, got, opt); diff != "" { + if diff := cmp.Diff(wantRes, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } @@ -359,36 +376,44 @@ func TestSchedulePlugins(t *testing.T) { for _, plugin := range test.config.preSchedulePlugins { tp, _ := plugin.(*TestPlugin) if tp.PreScheduleCallCount != 1 { - t.Errorf("Plugin %s PreSchedule() called %d times, expected 1", tp.NameRes, tp.PreScheduleCallCount) + t.Errorf("Plugin %s PreSchedule() called %d times, expected 1", plugin.Name(), tp.PreScheduleCallCount) } } for _, plugin := range test.config.filters { tp, _ := plugin.(*TestPlugin) if tp.FilterCallCount != 1 { - t.Errorf("Plugin %s Filter() called %d times, expected 1", tp.NameRes, tp.FilterCallCount) + t.Errorf("Plugin %s Filter() called %d times, expected 1", plugin.Name(), tp.FilterCallCount) } } - for _, plugin := range test.config.scorers { + for plugin := range test.config.scorers { tp, _ := plugin.(*TestPlugin) - if tp.ScoreCallCount != test.numPodsToScore { - t.Errorf("Plugin %s Score() called %d times, expected 1", tp.NameRes, tp.ScoreCallCount) + if tp.ScoreCallCount != 1 { + t.Errorf("Plugin %s Score() called %d times, expected 1", plugin.Name(), tp.ScoreCallCount) } - } - - for _, plugin := range test.config.postSchedulePlugins { - tp, _ := plugin.(*TestPlugin) - if tp.PostScheduleCallCount != 1 { - t.Errorf("Plugin %s PostSchedule() called %d times, expected 1", tp.NameRes, tp.PostScheduleCallCount) + if test.numPodsToScore != tp.NumOfScoredPods { + t.Errorf("Plugin %s Score() called with %d pods, expected %d", plugin.Name(), tp.NumOfScoredPods, test.numPodsToScore) } } tp, _ := test.config.picker.(*TestPlugin) + if tp.NumOfPickerCandidates != test.numPodsToScore { + t.Errorf("Picker plugin %s Pick() called with %d candidates, expected %d", tp.Name(), tp.NumOfPickerCandidates, tp.NumOfScoredPods) + } if tp.PickCallCount != 1 { - t.Errorf("Picker plugin %s Pick() called %d times, expected 1", tp.NameRes, tp.PickCallCount) + t.Errorf("Picker plugin %s Pick() called %d times, expected 1", tp.Name(), tp.PickCallCount) + } + if tp.WinnderPodScore != test.targetPodScore { + t.Errorf("winnder pod score %v, expected %v", tp.WinnderPodScore, test.targetPodScore) } + for _, plugin := range test.config.postSchedulePlugins { + tp, _ := plugin.(*TestPlugin) + if tp.PostScheduleCallCount != 1 { + t.Errorf("Plugin %s PostSchedule() called %d times, expected 1", plugin.Name(), tp.PostScheduleCallCount) + } + } }) } } @@ -409,13 +434,16 @@ func (fds *fakeDataStore) PodGetAll() []backendmetrics.PodMetrics { type TestPlugin struct { NameRes string ScoreCallCount int + NumOfScoredPods int ScoreRes float64 FilterCallCount int FilterRes []k8stypes.NamespacedName PreScheduleCallCount int PostScheduleCallCount int PickCallCount int + NumOfPickerCandidates int PickRes k8stypes.NamespacedName + WinnderPodScore float64 } func (tp *TestPlugin) Name() string { return tp.NameRes } @@ -427,29 +455,39 @@ func (tp *TestPlugin) PreSchedule(ctx *types.SchedulingContext) { func (tp *TestPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { tp.FilterCallCount++ return findPods(ctx, tp.FilterRes...) -} -func (tp *TestPlugin) Score(ctx *types.SchedulingContext, pod types.Pod) float64 { - tp.ScoreCallCount++ - return tp.ScoreRes } -func (tp *TestPlugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) { - tp.PostScheduleCallCount++ +func (tp *TestPlugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { + tp.ScoreCallCount++ + scoredPods := make(map[types.Pod]float64, len(pods)) + for _, pod := range pods { + scoredPods[pod] += tp.ScoreRes + } + tp.NumOfScoredPods = len(scoredPods) + return scoredPods } -func (tp *TestPlugin) Pick(ctx *types.SchedulingContext, pods []types.Pod) *types.Result { +func (tp *TestPlugin) Pick(ctx *types.SchedulingContext, scoredPods []*types.ScoredPod) *types.Result { tp.PickCallCount++ + tp.NumOfPickerCandidates = len(scoredPods) pod := findPods(ctx, tp.PickRes)[0] + tp.WinnderPodScore = getPodScore(scoredPods, pod) return &types.Result{TargetPod: pod} } +func (tp *TestPlugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) { + tp.PostScheduleCallCount++ +} + func (tp *TestPlugin) reset() { tp.PreScheduleCallCount = 0 tp.FilterCallCount = 0 tp.ScoreCallCount = 0 + tp.NumOfScoredPods = 0 tp.PostScheduleCallCount = 0 tp.PickCallCount = 0 + tp.NumOfPickerCandidates = 0 } func findPods(ctx *types.SchedulingContext, names ...k8stypes.NamespacedName) []types.Pod { @@ -463,3 +501,14 @@ func findPods(ctx *types.SchedulingContext, names ...k8stypes.NamespacedName) [] } return res } + +func getPodScore(scoredPods []*types.ScoredPod, selectedPod types.Pod) float64 { + finalScore := 0.0 + for _, scoredPod := range scoredPods { + if scoredPod.Pod.GetPod().NamespacedName.String() == selectedPod.GetPod().NamespacedName.String() { + finalScore = scoredPod.Score + break + } + } + return finalScore +} diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index e66b5fb5d..5ccfbdcef 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -43,11 +43,14 @@ func (r *LLMRequest) String() string { type Pod interface { GetPod() *backendmetrics.Pod GetMetrics() *backendmetrics.Metrics - SetScore(float64) - Score() float64 String() string } +type ScoredPod struct { + Pod Pod + Score float64 +} + // SchedulingContext holds contextual information during a scheduling operation. type SchedulingContext struct { context.Context @@ -71,16 +74,7 @@ func (pm *PodMetrics) GetMetrics() *backendmetrics.Metrics { return pm.Metrics } -func (pm *PodMetrics) SetScore(score float64) { - pm.score = score -} - -func (pm *PodMetrics) Score() float64 { - return pm.score -} - type PodMetrics struct { - score float64 *backendmetrics.Pod *backendmetrics.Metrics }