@@ -72,25 +72,23 @@ func NewScheduler(datastore Datastore) *Scheduler {
7272}
7373
7474func NewSchedulerWithConfig (datastore Datastore , config * SchedulerConfig ) * Scheduler {
75- scheduler := & Scheduler {
75+ return & Scheduler {
7676 datastore : datastore ,
7777 preSchedulePlugins : config .preSchedulePlugins ,
78- scorers : config .scorers ,
7978 filters : config .filters ,
80- postSchedulePlugins : config .postSchedulePlugins ,
79+ scorers : config .scorers ,
8180 picker : config .picker ,
81+ postSchedulePlugins : config .postSchedulePlugins ,
8282 }
83-
84- return scheduler
8583}
8684
8785type Scheduler struct {
8886 datastore Datastore
8987 preSchedulePlugins []plugins.PreSchedule
9088 filters []plugins.Filter
91- scorers []plugins.Scorer
92- postSchedulePlugins []plugins.PostSchedule
89+ scorers map [plugins.Scorer ]int // map from scorer to its weight
9390 picker plugins.Picker
91+ postSchedulePlugins []plugins.PostSchedule
9492}
9593
9694type Datastore interface {
@@ -106,25 +104,22 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types
106104 // 1. Reduce concurrent access to the datastore.
107105 // 2. Ensure consistent data during the scheduling operation of a request.
108106 sCtx := types .NewSchedulingContext (ctx , req , types .ToSchedulerPodMetrics (s .datastore .PodGetAll ()))
109- loggerDebug .Info (fmt .Sprintf ("Scheduling a request. Metrics: %+v" , sCtx .PodsSnapshot ))
107+ loggerDebug .Info (fmt .Sprintf ("Scheduling a request, Metrics: %+v" , sCtx .PodsSnapshot ))
110108
111109 s .runPreSchedulePlugins (sCtx )
112110
113111 pods := s .runFilterPlugins (sCtx )
114112 if len (pods ) == 0 {
115113 return nil , errutil.Error {Code : errutil .InferencePoolResourceExhausted , Msg : "failed to find a target pod" }
116114 }
115+ // if we got here, there is at least one pod to score
116+ weightedScorePerPod := s .runScorerPlugins (sCtx , pods )
117117
118- s .runScorerPlugins (sCtx , pods )
119-
120- before := time .Now ()
121- res := s .picker .Pick (sCtx , pods )
122- metrics .RecordSchedulerPluginProcessingLatency (plugins .PickerPluginType , s .picker .Name (), time .Since (before ))
123- loggerDebug .Info ("After running picker plugins" , "result" , res )
118+ result := s .runPickerPlugin (sCtx , weightedScorePerPod )
124119
125- s .runPostSchedulePlugins (sCtx , res )
120+ s .runPostSchedulePlugins (sCtx , result )
126121
127- return res , nil
122+ return result , nil
128123}
129124
130125func (s * Scheduler ) runPreSchedulePlugins (ctx * types.SchedulingContext ) {
@@ -136,15 +131,6 @@ func (s *Scheduler) runPreSchedulePlugins(ctx *types.SchedulingContext) {
136131 }
137132}
138133
139- func (s * Scheduler ) runPostSchedulePlugins (ctx * types.SchedulingContext , res * types.Result ) {
140- for _ , plugin := range s .postSchedulePlugins {
141- ctx .Logger .V (logutil .DEBUG ).Info ("Running post-schedule plugin" , "plugin" , plugin .Name ())
142- before := time .Now ()
143- plugin .PostSchedule (ctx , res )
144- metrics .RecordSchedulerPluginProcessingLatency (plugins .PostSchedulePluginType , plugin .Name (), time .Since (before ))
145- }
146- }
147-
148134func (s * Scheduler ) runFilterPlugins (ctx * types.SchedulingContext ) []types.Pod {
149135 loggerDebug := ctx .Logger .V (logutil .DEBUG )
150136 filteredPods := ctx .PodsSnapshot
@@ -160,32 +146,60 @@ func (s *Scheduler) runFilterPlugins(ctx *types.SchedulingContext) []types.Pod {
160146 break
161147 }
162148 }
149+ loggerDebug .Info ("After running filter plugins" )
150+
163151 return filteredPods
164152}
165153
166- func (s * Scheduler ) runScorerPlugins (ctx * types.SchedulingContext , pods []types.Pod ) {
154+ func (s * Scheduler ) runScorerPlugins (ctx * types.SchedulingContext , pods []types.Pod ) map [types. Pod ] float64 {
167155 loggerDebug := ctx .Logger .V (logutil .DEBUG )
168- loggerDebug .Info ("Before running score plugins" , "pods" , pods )
156+ loggerDebug .Info ("Before running scorer plugins" , "pods" , pods )
157+
158+ weightedScorePerPod := make (map [types.Pod ]float64 , len (pods ))
169159 for _ , pod := range pods {
170- score := s .runScorersForPod (ctx , pod )
171- pod .SetScore (score )
160+ weightedScorePerPod [pod ] = float64 (0 ) // initialize weighted score per pod with 0 value
161+ }
162+ // Iterate through each scorer in the chain and accumulate the weighted scores.
163+ for scorer , weight := range s .scorers {
164+ loggerDebug .Info ("Running scorer" , "scorer" , scorer .Name ())
165+ before := time .Now ()
166+ scores := scorer .Score (ctx , pods )
167+ metrics .RecordSchedulerPluginProcessingLatency (plugins .ScorerPluginType , scorer .Name (), time .Since (before ))
168+ for pod , score := range scores { // weight is relative to the sum of weights
169+ weightedScorePerPod [pod ] += score * float64 (weight ) // TODO normalize score before multiply with weight
170+ }
171+ loggerDebug .Info ("After running scorer" , "scorer" , scorer .Name ())
172+ }
173+ loggerDebug .Info ("After running scorer plugins" )
174+
175+ return weightedScorePerPod
176+ }
177+
178+ func (s * Scheduler ) runPickerPlugin (ctx * types.SchedulingContext , weightedScorePerPod map [types.Pod ]float64 ) * types.Result {
179+ loggerDebug := ctx .Logger .V (logutil .DEBUG )
180+ scoredPods := make ([]* types.ScoredPod , len (weightedScorePerPod ))
181+ i := 0
182+ for pod , score := range weightedScorePerPod {
183+ scoredPods [i ] = & types.ScoredPod {Pod : pod , Score : score }
184+ i ++
172185 }
173- loggerDebug .Info ("After running score plugins" , "pods" , pods )
186+
187+ loggerDebug .Info ("Before running picker plugin" , "pods" , weightedScorePerPod )
188+ before := time .Now ()
189+ result := s .picker .Pick (ctx , scoredPods )
190+ metrics .RecordSchedulerPluginProcessingLatency (plugins .PickerPluginType , s .picker .Name (), time .Since (before ))
191+ loggerDebug .Info ("After running picker plugin" , "result" , result )
192+
193+ return result
174194}
175195
176- // Iterate through each scorer in the chain and accumulate the scores.
177- func (s * Scheduler ) runScorersForPod (ctx * types.SchedulingContext , pod types.Pod ) float64 {
178- logger := ctx .Logger .WithValues ("pod" , pod .GetPod ().NamespacedName ).V (logutil .DEBUG )
179- score := float64 (0 )
180- for _ , scorer := range s .scorers {
181- logger .Info ("Running scorer" , "scorer" , scorer .Name ())
196+ func (s * Scheduler ) runPostSchedulePlugins (ctx * types.SchedulingContext , res * types.Result ) {
197+ for _ , plugin := range s .postSchedulePlugins {
198+ ctx .Logger .V (logutil .DEBUG ).Info ("Running post-schedule plugin" , "plugin" , plugin .Name ())
182199 before := time .Now ()
183- oneScore := scorer .Score (ctx , pod )
184- metrics .RecordSchedulerPluginProcessingLatency (plugins .ScorerPluginType , scorer .Name (), time .Since (before ))
185- score += oneScore
186- logger .Info ("After scorer" , "scorer" , scorer .Name (), "score" , oneScore , "total score" , score )
200+ plugin .PostSchedule (ctx , res )
201+ metrics .RecordSchedulerPluginProcessingLatency (plugins .PostSchedulePluginType , plugin .Name (), time .Since (before ))
187202 }
188- return score
189203}
190204
191205type defaultPlugin struct {
0 commit comments