Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cmd/epp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"

"github.com/llm-d/llm-d-inference-scheduler/pkg/config"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins"
prerequest "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/pre-request"
"github.com/llm-d/llm-d-inference-scheduler/pkg/scheduling/pd"
)
Expand All @@ -40,6 +41,12 @@ func main() {
setupLog := ctrl.Log.WithName("setup")
ctx := ctrl.SetupSignalHandler()

// Register GIE plugins
runner.RegisterAllPlugins()

// Register llm-d-inference-scheduler plugins
plugins.RegisterAllPlugins()

pdConfig := config.LoadConfig(setupLog)

requestControlConfig := requestcontrol.NewConfig()
Expand Down
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@ go 1.24.1
toolchain go1.24.2

require (
github.com/cespare/xxhash/v2 v2.3.0
github.com/go-logr/logr v1.4.3
github.com/google/go-cmp v0.7.0
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/llm-d/llm-d-kv-cache-manager v0.1.1
github.com/redis/go-redis/v9 v9.11.0
github.com/stretchr/testify v1.10.0
k8s.io/apimachinery v0.33.2
k8s.io/client-go v0.33.2
sigs.k8s.io/controller-runtime v0.21.0
sigs.k8s.io/gateway-api v1.3.0
sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250628171228-9c9abd51a6d0
sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250629153429-5c851eb1ff8f
)

require (
Expand All @@ -25,6 +23,7 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/blang/semver/v4 v4.0.0 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cncf/xds/go v0.0.0-20250326154945-ae57f3c0d45f // indirect
github.com/daulet/tokenizers v1.20.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
Expand All @@ -47,6 +46,7 @@ require (
github.com/google/gnostic-models v0.6.9 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ sigs.k8s.io/controller-runtime v0.21.0 h1:CYfjpEuicjUecRk+KAeyYh+ouUBn4llGyDYytI
sigs.k8s.io/controller-runtime v0.21.0/go.mod h1:OSg14+F65eWqIu4DceX7k/+QRAbTTvxeQSNSOQpukWM=
sigs.k8s.io/gateway-api v1.3.0 h1:q6okN+/UKDATola4JY7zXzx40WO4VISk7i9DIfOvr9M=
sigs.k8s.io/gateway-api v1.3.0/go.mod h1:d8NV8nJbaRbEKem+5IuxkL8gJGOZ+FJ+NvOIltV8gDk=
sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250628171228-9c9abd51a6d0 h1:rtnnZ3TNEV+SQO/FXxrd/lqbKw/D3RjeBqayCvyOlOA=
sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250628171228-9c9abd51a6d0/go.mod h1:xgeYdEPZf/+87+Dp5zcz2vhbezBHjTg8lAfpPU2Xgp8=
sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250629153429-5c851eb1ff8f h1:ByLjkC8b3tq1DFMN/pqoM2oVMcOHxavL+KQd80137CQ=
sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250629153429-5c851eb1ff8f/go.mod h1:xgeYdEPZf/+87+Dp5zcz2vhbezBHjTg8lAfpPU2Xgp8=
sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 h1:/Rv+M11QRah1itp8VhT6HoVx1Ray9eB4DBr+K+/sCJ8=
sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3/go.mod h1:18nIHnGi6636UCz6m8i4DhaJ65T6EruyzmoQqI2BVDo=
sigs.k8s.io/randfill v0.0.0-20250304075658-069ef1bbf016/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY=
Expand Down
2 changes: 2 additions & 0 deletions pkg/plugins/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// Package plugins provides plugins for the scheduler.
package plugins
36 changes: 31 additions & 5 deletions pkg/plugins/filter/by_label.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package filter

import (
"context"
"encoding/json"
"fmt"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)
Expand All @@ -12,6 +15,12 @@ const (
ByLabelFilterType = "by-label"
)

type byLabelFilterParameters struct {
Label string `json:"label"`
ValidValues []string `json:"validValues"`
AllowsNoLabel bool `json:"allowsNoLabel"`
}

// ByLabel - filters out pods based on the values defined by the given label
type ByLabel struct {
// name defines the filter name
Expand All @@ -26,19 +35,30 @@ type ByLabel struct {

var _ framework.Filter = &ByLabel{} // validate interface conformance

// ByLabelFilterFactory defines the factory function for the ByLabelFilter
func ByLabelFilterFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
parameters := byLabelFilterParameters{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
return nil, fmt.Errorf("failed to parse the parameters of the '%s' filter - %w", ByLabelFilterType, err)
}
}
return NewByLabel(name, parameters.Label, parameters.AllowsNoLabel, parameters.ValidValues...), nil
}

// NewByLabel creates and returns an instance of the RoleBasedFilter based on the input parameters
// name - the filter name
// labelName - the name of the label to use
// allowsNoLabel - if true pods without given label will be considered as valid (not filtered out)
// validValuesApp - list of valid values
func NewByLabel(name string, labelName string, allowsNoLabel bool, validValuesApp ...string) *ByLabel {
validValues := map[string]struct{}{}
func NewByLabel(name string, labelName string, allowsNoLabel bool, validValues ...string) *ByLabel {
validValuesMap := map[string]struct{}{}

for _, v := range validValuesApp {
validValues[v] = struct{}{}
for _, v := range validValues {
validValuesMap[v] = struct{}{}
}

return &ByLabel{name: name, labelName: labelName, allowsNoLabel: allowsNoLabel, validValues: validValues}
return &ByLabel{name: name, labelName: labelName, allowsNoLabel: allowsNoLabel, validValues: validValuesMap}
}

// Type returns the type of the filter
Expand All @@ -51,6 +71,12 @@ func (f *ByLabel) Name() string {
return f.name
}

// WithName sets the name of the filter.
func (f *ByLabel) WithName(name string) *ByLabel {
f.name = name
return f
}

// Filter filters out all pods that are not marked with one of roles from the validRoles collection
// or has no role label in case allowsNoRolesLabel is true
func (f *ByLabel) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
Expand Down
42 changes: 28 additions & 14 deletions pkg/plugins/filter/by_labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,77 @@ package filter

import (
"context"
"encoding/json"
"errors"
"fmt"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

const (
// ByLabelsFilterType is the type of the ByLabelsFilter
ByLabelsFilterType = "by-labels"
// ByLabelSelectorFilterType is the type of the ByLabelsFilter
ByLabelSelectorFilterType = "by-label-selector"
)

// compile-time type assertion
var _ framework.Filter = &ByLabels{}
var _ framework.Filter = &ByLabelSelector{}

// NewByLabels returns a new filter instance, configured with the provided
// ByLabelSelectorFactory defines the factory function for the ByLabelSelector filter
func ByLabelSelectorFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
parameters := metav1.LabelSelector{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
return nil, fmt.Errorf("failed to parse the parameters of the '%s' filter - %w", ByLabelSelectorFilterType, err)
}
}
return NewByLabelSelector(name, &parameters)
}

// NewByLabelSelector returns a new filter instance, configured with the provided
// name and label selector.
func NewByLabels(name string, selector *metav1.LabelSelector) (framework.Filter, error) {
func NewByLabelSelector(name string, selector *metav1.LabelSelector) (*ByLabelSelector, error) {
if name == "" {
return nil, errors.New("ByLabels: missing filter name")
return nil, errors.New("ByLabelSelector: missing filter name")
Comment on lines -25 to +39
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it mandatory to specify name?
it's optional in all other plugins..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
labelSelector, err := metav1.LabelSelectorAsSelector(selector)
if err != nil {
return nil, err
}

return &ByLabels{
return &ByLabelSelector{
name: name,
selector: labelSelector,
}, nil
}

// ByLabels filters out pods that do not match its label selector criteria
type ByLabels struct {
// ByLabelSelector filters out pods that do not match its label selector criteria
type ByLabelSelector struct {
name string
selector labels.Selector
}

// Type returns the type of the filter
func (blf *ByLabels) Type() string {
return ByLabelsFilterType
func (blf *ByLabelSelector) Type() string {
return ByLabelSelectorFilterType
}

// Name returns the name of the instance of the filter.
func (blf *ByLabels) Name() string {
func (blf *ByLabelSelector) Name() string {
return blf.name
}

// WithName sets the name of the filter.
func (blf *ByLabels) WithName(name string) *ByLabels {
func (blf *ByLabelSelector) WithName(name string) *ByLabelSelector {
blf.name = name
return blf
}

// Filter filters out all pods that do not satisfy the label selector
func (blf *ByLabels) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
func (blf *ByLabelSelector) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
filtered := []types.Pod{}

for _, pod := range pods {
Expand Down
27 changes: 22 additions & 5 deletions pkg/plugins/filter/pd_role_filter.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package filter

import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"encoding/json"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
)

const (
Expand All @@ -13,14 +15,29 @@ const (
RoleDecode = "decode"
// RoleBoth set for workers that can act as both prefill and decode
RoleBoth = "both"

// DecodeFilterType is the type of the DecodeFilter
DecodeFilterType = "decode-filter"
// PrefillFilterType is the type of the PrefillFilter
PrefillFilterType = "prefill-filter"
)

// PrefillFilterFactory defines the factory function for the PrefillFilter
func PrefillFilterFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
return NewPrefillFilter().WithName(name), nil
}

// NewPrefillFilter creates and returns an instance of the Filter configured for prefill role
func NewPrefillFilter() framework.Filter {
return NewByLabel("prefill-filter", RoleLabel, false, RolePrefill)
func NewPrefillFilter() *ByLabel {
return NewByLabel(PrefillFilterType, RoleLabel, false, RolePrefill)
}

// DecodeFilterFactory defines the factory function for the DecodeFilter
func DecodeFilterFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
return NewDecodeFilter().WithName(name), nil
}

// NewDecodeFilter creates and returns an instance of the Filter configured for decode role
func NewDecodeFilter() framework.Filter {
return NewByLabel("decode-filter", RoleLabel, true, RoleDecode, RoleBoth)
func NewDecodeFilter() *ByLabel {
return NewByLabel(DecodeFilterType, RoleLabel, true, RoleDecode, RoleBoth)
}
7 changes: 7 additions & 0 deletions pkg/plugins/pre-request/pd_prerequest.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package prerequest

import (
"context"
"encoding/json"
"net"
"strconv"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)
Expand All @@ -20,6 +22,11 @@ const (
// compile-time type assertion
var _ requestcontrol.PreRequest = &PrefillHeaderHandler{}

// PrefillHeaderHandlerFactory defines the factory function for the PrefillHeaderHandler
func PrefillHeaderHandlerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
return NewPrefillHeaderHandler().WithName(name), nil
}

// NewPrefillHeaderHandler initializes a new PrefillHeaderHandler and returns its pointer.
func NewPrefillHeaderHandler() *PrefillHeaderHandler {
return &PrefillHeaderHandler{
Expand Down
34 changes: 33 additions & 1 deletion pkg/plugins/profile/pd_profile_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@ package profile

import (
"context"
"encoding/json"
"errors"
"fmt"

"github.com/llm-d/llm-d-inference-scheduler/pkg/config"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"

"github.com/llm-d/llm-d-inference-scheduler/pkg/config"
)

const (
Expand All @@ -21,9 +25,37 @@ const (
prefill = "prefill"
)

type pdProfileHandlerParameters struct {
prefix.Config
Threshold int `json:"threshold"`
}

// compile-time type assertion
var _ framework.ProfileHandler = &PdProfileHandler{}

// PdProfileHandlerFactory defines the factory function for the PdProfileHandler
func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
parameters := pdProfileHandlerParameters{
Config: prefix.Config{
HashBlockSize: prefix.DefaultHashBlockSize,
MaxPrefixBlocksToMatch: prefix.DefaultMaxPrefixBlocks,
LRUCapacityPerServer: prefix.DefaultLRUCapacityPerServer,
Comment on lines +41 to +42
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is profile handler using these values?

},
Comment on lines +38 to +43
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it make sense that prefix config appear on profile handler plugin?
shouldn't it appear on prefix plugin only?
something with this configuration looks not natural..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This configuration is from the previous change to eliminate our own Prefix Scorer

Threshold: 100,
}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
return nil, fmt.Errorf("failed to parse the parameters of the '%s' profile handler - %w", PdProfileHandlerType, err)
}
}

cfg := &config.Config{
PDThreshold: parameters.Threshold,
GIEPrefixConfig: &parameters.Config,
}
return NewPdProfileHandler(cfg).WithName(name), nil
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment on line 114 (I can't put the comment there cause it wasn't change in this PR):

prefixState, err := types.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, prefix.PrefixCachePluginType)
	if err != nil {
		log.FromContext(ctx).Error(err, "unable to read prefix state")
		return map[string]*framework.SchedulerProfile{}
	}

if this the expected behavior? if the prefix scorer failed to write the prefix we don't do PD?
cc @kfirtoledo

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is from the previous change to eliminate our own Prefix Scorer

// NewPdProfileHandler initializes a new PdProfileHandler and returns its pointer.
func NewPdProfileHandler(cfg *config.Config) *PdProfileHandler {
return &PdProfileHandler{
Expand Down
23 changes: 23 additions & 0 deletions pkg/plugins/register.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package plugins

import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"

"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter"
prerequest "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/pre-request"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer"
)

// RegisterAllPlugins registers the factory functions of all plugins in this repository.
func RegisterAllPlugins() {
plugins.Register(filter.ByLabelFilterType, filter.ByLabelFilterFactory)
plugins.Register(filter.ByLabelSelectorFilterType, filter.ByLabelSelectorFactory)
plugins.Register(filter.DecodeFilterType, filter.DecodeFilterFactory)
plugins.Register(filter.PrefillFilterType, filter.PrefillFilterFactory)
plugins.Register(prerequest.PrefillHeaderHandlerType, prerequest.PrefillHeaderHandlerFactory)
plugins.Register(profile.PdProfileHandlerType, profile.PdProfileHandlerFactory)
plugins.Register(scorer.KvCacheAwareScorerType, scorer.KvCacheAwareScorerFactory)
plugins.Register(scorer.LoadAwareScorerType, scorer.LoadAwareScorerFactory)
plugins.Register(scorer.SessionAffinityScorerType, scorer.SessionAffinityScorerFactory)
}
Loading