Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conformance/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.24.9
replace sigs.k8s.io/gateway-api-inference-extension => ../

require (
golang.org/x/net v0.48.0
sigs.k8s.io/gateway-api v1.4.0
sigs.k8s.io/gateway-api-inference-extension v0.0.0-00010101000000-000000000000
)
Expand Down Expand Up @@ -58,7 +59,6 @@ require (
go.opentelemetry.io/otel/sdk/metric v1.39.0 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/oauth2 v0.34.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/term v0.38.0 // indirect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package framework
import (
"context"

scheduling "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/scheduling"
)

type Endpoint struct {
Expand Down
2 changes: 1 addition & 1 deletion pkg/epp/config/loader/configloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol"
fccontroller "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/controller"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/registry"
framework "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/scheduling"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector/framework/plugins/utilizationdetector"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/profile"
)

Expand Down
15 changes: 7 additions & 8 deletions pkg/epp/config/loader/configloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol"
framework "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/scheduling"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector/framework/plugins/utilizationdetector"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/profile"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
"sigs.k8s.io/gateway-api-inference-extension/test/utils"
)

Expand Down Expand Up @@ -421,7 +420,7 @@ func (m *mockScorer) Category() framework.ScorerCategory {
return framework.Distribution
}

func (m *mockScorer) Score(context.Context, *types.CycleState, *types.LLMRequest, []types.Endpoint) map[types.Endpoint]float64 {
func (m *mockScorer) Score(context.Context, *framework.CycleState, *framework.LLMRequest, []framework.Endpoint) map[framework.Endpoint]float64 {
return nil
}

Expand All @@ -431,7 +430,7 @@ type mockPicker struct{ mockPlugin }
// compile-time type assertion
var _ framework.Picker = &mockPicker{}

func (m *mockPicker) Pick(context.Context, *types.CycleState, []*types.ScoredEndpoint) *types.ProfileRunResult {
func (m *mockPicker) Pick(context.Context, *framework.CycleState, []*framework.ScoredEndpoint) *framework.ProfileRunResult {
return nil
}

Expand All @@ -441,12 +440,12 @@ type mockHandler struct{ mockPlugin }
// compile-time type assertion
var _ framework.ProfileHandler = &mockHandler{}

func (m *mockHandler) Pick(context.Context, *types.CycleState, *types.LLMRequest, map[string]*framework.SchedulerProfile,
map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile {
func (m *mockHandler) Pick(context.Context, *framework.CycleState, *framework.LLMRequest, map[string]*framework.SchedulerProfile,
map[string]*framework.ProfileRunResult) map[string]*framework.SchedulerProfile {
return nil
}
func (m *mockHandler) ProcessResults(context.Context, *types.CycleState, *types.LLMRequest,
map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) {
func (m *mockHandler) ProcessResults(context.Context, *framework.CycleState, *framework.LLMRequest,
map[string]*framework.ProfileRunResult) (*framework.SchedulingResult, error) {
return nil, nil
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/epp/config/loader/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import (

configapi "sigs.k8s.io/gateway-api-inference-extension/apix/config/v1alpha1"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/interflow"
framework "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/scheduling"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/profile"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package types
package scheduling

import (
"fmt"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package framework
package scheduling

import (
"context"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

const (
Expand Down Expand Up @@ -53,21 +52,21 @@ type ProfileHandler interface {
plugins.Plugin
// Pick selects the SchedulingProfiles to run from a list of candidate profiles, while taking into consideration the request properties
// and the previously executed SchedluderProfile cycles along with their results.
Pick(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, profiles map[string]*SchedulerProfile,
profileResults map[string]*types.ProfileRunResult) map[string]*SchedulerProfile
Pick(ctx context.Context, cycleState *CycleState, request *LLMRequest, profiles map[string]*SchedulerProfile,
profileResults map[string]*ProfileRunResult) map[string]*SchedulerProfile

// ProcessResults handles the outcome of the profile runs after all profiles ran.
// It may aggregate results, log test profile outputs, or apply custom logic. It specifies in the SchedulingResult the
// key of the primary profile that should be used to get the request selected destination.
// When a profile run fails, its result in the profileResults map is nil.
ProcessResults(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest,
profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error)
ProcessResults(ctx context.Context, cycleState *CycleState, request *LLMRequest,
profileResults map[string]*ProfileRunResult) (*SchedulingResult, error)
}

// Filter defines the interface for filtering a list of pods based on context.
type Filter interface {
plugins.Plugin
Filter(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Endpoint) []types.Endpoint
Filter(ctx context.Context, cycleState *CycleState, request *LLMRequest, pods []Endpoint) []Endpoint
}

// Scorer defines the interface for scoring a list of pods based on context.
Expand All @@ -77,11 +76,11 @@ type Filter interface {
type Scorer interface {
plugins.Plugin
Category() ScorerCategory
Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Endpoint) map[types.Endpoint]float64
Score(ctx context.Context, cycleState *CycleState, request *LLMRequest, pods []Endpoint) map[Endpoint]float64
}

// Picker picks the final pod(s) to send the request to.
type Picker interface {
plugins.Plugin
Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredEndpoint) *types.ProfileRunResult
Pick(ctx context.Context, cycleState *CycleState, scoredPods []*ScoredEndpoint) *ProfileRunResult
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package framework
package scheduling

import (
"context"
Expand All @@ -27,7 +27,6 @@ import (
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
)

Expand Down Expand Up @@ -114,7 +113,7 @@ func (p *SchedulerProfile) String() string {

// Run runs a SchedulerProfile. It invokes all the SchedulerProfile plugins for the given request in this
// order - Filters, Scorers, Picker. After completing all, it returns the result.
func (p *SchedulerProfile) Run(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, candidateEndpoints []types.Endpoint) (*types.ProfileRunResult, error) {
func (p *SchedulerProfile) Run(ctx context.Context, request *LLMRequest, cycleState *CycleState, candidateEndpoints []Endpoint) (*ProfileRunResult, error) {
endpoints := p.runFilterPlugins(ctx, request, cycleState, candidateEndpoints)
if len(endpoints) == 0 {
return nil, errutil.Error{Code: errutil.Internal, Msg: "no endpoints available for the given request"}
Expand All @@ -127,7 +126,7 @@ func (p *SchedulerProfile) Run(ctx context.Context, request *types.LLMRequest, c
return result, nil
}

func (p *SchedulerProfile) runFilterPlugins(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, endpoints []types.Endpoint) []types.Endpoint {
func (p *SchedulerProfile) runFilterPlugins(ctx context.Context, request *LLMRequest, cycleState *CycleState, endpoints []Endpoint) []Endpoint {
logger := log.FromContext(ctx)
filteredEndpoints := endpoints
logger.V(logutil.DEBUG).Info("Before running filter plugins", "endpoints", filteredEndpoints)
Expand All @@ -147,11 +146,11 @@ func (p *SchedulerProfile) runFilterPlugins(ctx context.Context, request *types.
return filteredEndpoints
}

func (p *SchedulerProfile) runScorerPlugins(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, endpoints []types.Endpoint) map[types.Endpoint]float64 {
func (p *SchedulerProfile) runScorerPlugins(ctx context.Context, request *LLMRequest, cycleState *CycleState, endpoints []Endpoint) map[Endpoint]float64 {
logger := log.FromContext(ctx)
logger.V(logutil.DEBUG).Info("Before running scorer plugins", "endpoints", endpoints)

weightedScorePerEndpoint := make(map[types.Endpoint]float64, len(endpoints))
weightedScorePerEndpoint := make(map[Endpoint]float64, len(endpoints))
for _, endpoint := range endpoints {
weightedScorePerEndpoint[endpoint] = float64(0) // initialize weighted score per endpoint with 0 value
}
Expand All @@ -172,12 +171,12 @@ func (p *SchedulerProfile) runScorerPlugins(ctx context.Context, request *types.
return weightedScorePerEndpoint
}

func (p *SchedulerProfile) runPickerPlugin(ctx context.Context, cycleState *types.CycleState, weightedScorePerEndpoint map[types.Endpoint]float64) *types.ProfileRunResult {
func (p *SchedulerProfile) runPickerPlugin(ctx context.Context, cycleState *CycleState, weightedScorePerEndpoint map[Endpoint]float64) *ProfileRunResult {
logger := log.FromContext(ctx)
scoredEndpoints := make([]*types.ScoredEndpoint, len(weightedScorePerEndpoint))
scoredEndpoints := make([]*ScoredEndpoint, len(weightedScorePerEndpoint))
i := 0
for endpoint, score := range weightedScorePerEndpoint {
scoredEndpoints[i] = &types.ScoredEndpoint{Endpoint: endpoint, Score: score}
scoredEndpoints[i] = &ScoredEndpoint{Endpoint: endpoint, Score: score}
i++
}
logger.V(logutil.VERBOSE).Info("Running picker plugin", "plugin", p.picker.TypedName())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package framework
package scheduling

import (
"context"
Expand All @@ -26,7 +26,6 @@ import (

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

func TestSchedulePlugins(t *testing.T) {
Expand All @@ -52,7 +51,7 @@ func TestSchedulePlugins(t *testing.T) {
tests := []struct {
name string
profile *SchedulerProfile
input []types.Endpoint
input []Endpoint
wantTargetEndpoint k8stypes.NamespacedName
targetEndpointScore float64
// Number of expected endpoints to score (after filter)
Expand All @@ -65,10 +64,10 @@ func TestSchedulePlugins(t *testing.T) {
WithFilters(tp1, tp2).
WithScorers(NewWeightedScorer(tp1, 1), NewWeightedScorer(tp2, 1)).
WithPicker(pickerPlugin),
input: []types.Endpoint{
&types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}},
&types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
&types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
input: []Endpoint{
&PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}},
&PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
&PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
},
wantTargetEndpoint: k8stypes.NamespacedName{Name: "pod1"},
targetEndpointScore: 1.1,
Expand All @@ -81,10 +80,10 @@ func TestSchedulePlugins(t *testing.T) {
WithFilters(tp1, tp2).
WithScorers(NewWeightedScorer(tp1, 60), NewWeightedScorer(tp2, 40)).
WithPicker(pickerPlugin),
input: []types.Endpoint{
&types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}},
&types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
&types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
input: []Endpoint{
&PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}},
&PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
&PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
},
wantTargetEndpoint: k8stypes.NamespacedName{Name: "pod1"},
targetEndpointScore: 50,
Expand All @@ -97,10 +96,10 @@ func TestSchedulePlugins(t *testing.T) {
WithFilters(tp1, tp_filterAll).
WithScorers(NewWeightedScorer(tp1, 1), NewWeightedScorer(tp2, 1)).
WithPicker(pickerPlugin),
input: []types.Endpoint{
&types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}},
&types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
&types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
input: []Endpoint{
&PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}},
&PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
&PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
},
numEndpointsToScore: 0,
err: true, // no available endpoints to server after filter all
Expand All @@ -119,12 +118,12 @@ func TestSchedulePlugins(t *testing.T) {
test.profile.picker.(*testPlugin).reset()

// Initialize the scheduling context
request := &types.LLMRequest{
request := &LLMRequest{
TargetModel: "test-model",
RequestId: uuid.NewString(),
}
// Run profile cycle
got, err := test.profile.Run(context.Background(), request, types.NewCycleState(), test.input)
got, err := test.profile.Run(context.Background(), request, NewCycleState(), test.input)

// Validate error state
if test.err != (err != nil) {
Expand All @@ -136,9 +135,9 @@ func TestSchedulePlugins(t *testing.T) {
}

// Validate output
wantRes := &types.ProfileRunResult{
TargetEndpoints: []types.Endpoint{
&types.PodMetrics{
wantRes := &ProfileRunResult{
TargetEndpoints: []Endpoint{
&PodMetrics{
EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: test.wantTargetEndpoint},
},
},
Expand Down Expand Up @@ -205,35 +204,35 @@ func (tp *testPlugin) Category() ScorerCategory {
return Distribution
}

func (tp *testPlugin) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, endpoints []types.Endpoint) []types.Endpoint {
func (tp *testPlugin) Filter(_ context.Context, _ *CycleState, _ *LLMRequest, endpoints []Endpoint) []Endpoint {
tp.FilterCallCount++
return findEndpoints(endpoints, tp.FilterRes...)

}

func (tp *testPlugin) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, endpoints []types.Endpoint) map[types.Endpoint]float64 {
func (tp *testPlugin) Score(_ context.Context, _ *CycleState, _ *LLMRequest, endpoints []Endpoint) map[Endpoint]float64 {
tp.ScoreCallCount++
scoredEndpoints := make(map[types.Endpoint]float64, len(endpoints))
scoredEndpoints := make(map[Endpoint]float64, len(endpoints))
for _, endpoint := range endpoints {
scoredEndpoints[endpoint] += tp.ScoreRes
}
tp.NumOfScoredEndpoints = len(scoredEndpoints)
return scoredEndpoints
}

func (tp *testPlugin) Pick(_ context.Context, _ *types.CycleState, scoredEndpoints []*types.ScoredEndpoint) *types.ProfileRunResult {
func (tp *testPlugin) Pick(_ context.Context, _ *CycleState, scoredEndpoints []*ScoredEndpoint) *ProfileRunResult {
tp.PickCallCount++
tp.NumOfPickerCandidates = len(scoredEndpoints)

winnerEndpoints := []types.Endpoint{}
winnerEndpoints := []Endpoint{}
for _, scoredEndpoint := range scoredEndpoints {
if scoredEndpoint.GetMetadata().NamespacedName.String() == tp.PickRes.String() {
winnerEndpoints = append(winnerEndpoints, scoredEndpoint.Endpoint)
tp.WinnerEndpointScore = scoredEndpoint.Score
}
}

return &types.ProfileRunResult{TargetEndpoints: winnerEndpoints}
return &ProfileRunResult{TargetEndpoints: winnerEndpoints}
}

func (tp *testPlugin) reset() {
Expand All @@ -244,8 +243,8 @@ func (tp *testPlugin) reset() {
tp.NumOfPickerCandidates = 0
}

func findEndpoints(endpoints []types.Endpoint, names ...k8stypes.NamespacedName) []types.Endpoint {
res := []types.Endpoint{}
func findEndpoints(endpoints []Endpoint, names ...k8stypes.NamespacedName) []Endpoint {
res := []Endpoint{}
for _, endpoint := range endpoints {
for _, name := range names {
if endpoint.GetMetadata().NamespacedName.String() == name.String() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package types
package scheduling

import (
"encoding/json"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package framework
package scheduling

// NewWeightedScorer initializes a new WeightedScorer and returns its pointer.
func NewWeightedScorer(scorer Scorer, weight int) *WeightedScorer {
Expand Down
Loading