diff --git a/cmd/epp/main.go b/cmd/epp/main.go index ed799afb0..07dd94d79 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -32,6 +32,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "github.com/llm-d/llm-d-inference-scheduler/pkg/config" + "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins" prerequest "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/pre-request" "github.com/llm-d/llm-d-inference-scheduler/pkg/scheduling/pd" ) @@ -40,6 +41,12 @@ func main() { setupLog := ctrl.Log.WithName("setup") ctx := ctrl.SetupSignalHandler() + // Register GIE plugins + runner.RegisterAllPlugins() + + // Register llm-d-inference-scheduler plugins + plugins.RegisterAllPlugins() + pdConfig := config.LoadConfig(setupLog) requestControlConfig := requestcontrol.NewConfig() diff --git a/go.mod b/go.mod index 9f7818ec4..41242ded2 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,8 @@ go 1.24.1 toolchain go1.24.2 require ( - github.com/cespare/xxhash/v2 v2.3.0 github.com/go-logr/logr v1.4.3 github.com/google/go-cmp v0.7.0 - github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/llm-d/llm-d-kv-cache-manager v0.1.1 github.com/redis/go-redis/v9 v9.11.0 github.com/stretchr/testify v1.10.0 @@ -16,7 +14,7 @@ require ( k8s.io/client-go v0.33.2 sigs.k8s.io/controller-runtime v0.21.0 sigs.k8s.io/gateway-api v1.3.0 - sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250628171228-9c9abd51a6d0 + sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250629153429-5c851eb1ff8f ) require ( @@ -25,6 +23,7 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20250326154945-ae57f3c0d45f // indirect github.com/daulet/tokenizers v1.20.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect @@ -47,6 +46,7 @@ require ( github.com/google/gnostic-models v0.6.9 // indirect github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect diff --git a/go.sum b/go.sum index de7572a2d..923bd928e 100644 --- a/go.sum +++ b/go.sum @@ -273,8 +273,8 @@ sigs.k8s.io/controller-runtime v0.21.0 h1:CYfjpEuicjUecRk+KAeyYh+ouUBn4llGyDYytI sigs.k8s.io/controller-runtime v0.21.0/go.mod h1:OSg14+F65eWqIu4DceX7k/+QRAbTTvxeQSNSOQpukWM= sigs.k8s.io/gateway-api v1.3.0 h1:q6okN+/UKDATola4JY7zXzx40WO4VISk7i9DIfOvr9M= sigs.k8s.io/gateway-api v1.3.0/go.mod h1:d8NV8nJbaRbEKem+5IuxkL8gJGOZ+FJ+NvOIltV8gDk= -sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250628171228-9c9abd51a6d0 h1:rtnnZ3TNEV+SQO/FXxrd/lqbKw/D3RjeBqayCvyOlOA= -sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250628171228-9c9abd51a6d0/go.mod h1:xgeYdEPZf/+87+Dp5zcz2vhbezBHjTg8lAfpPU2Xgp8= +sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250629153429-5c851eb1ff8f h1:ByLjkC8b3tq1DFMN/pqoM2oVMcOHxavL+KQd80137CQ= +sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250629153429-5c851eb1ff8f/go.mod h1:xgeYdEPZf/+87+Dp5zcz2vhbezBHjTg8lAfpPU2Xgp8= sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 h1:/Rv+M11QRah1itp8VhT6HoVx1Ray9eB4DBr+K+/sCJ8= sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3/go.mod h1:18nIHnGi6636UCz6m8i4DhaJ65T6EruyzmoQqI2BVDo= sigs.k8s.io/randfill v0.0.0-20250304075658-069ef1bbf016/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY= diff --git a/pkg/plugins/doc.go b/pkg/plugins/doc.go new file mode 100644 index 000000000..0c7a473be --- /dev/null +++ b/pkg/plugins/doc.go @@ -0,0 +1,2 @@ +// Package plugins provides plugins for the scheduler. +package plugins diff --git a/pkg/plugins/filter/by_label.go b/pkg/plugins/filter/by_label.go index bbcd2bff0..09167a3de 100644 --- a/pkg/plugins/filter/by_label.go +++ b/pkg/plugins/filter/by_label.go @@ -2,7 +2,10 @@ package filter import ( "context" + "encoding/json" + "fmt" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -12,6 +15,12 @@ const ( ByLabelFilterType = "by-label" ) +type byLabelFilterParameters struct { + Label string `json:"label"` + ValidValues []string `json:"validValues"` + AllowsNoLabel bool `json:"allowsNoLabel"` +} + // ByLabel - filters out pods based on the values defined by the given label type ByLabel struct { // name defines the filter name @@ -26,19 +35,30 @@ type ByLabel struct { var _ framework.Filter = &ByLabel{} // validate interface conformance +// ByLabelFilterFactory defines the factory function for the ByLabelFilter +func ByLabelFilterFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + parameters := byLabelFilterParameters{} + if rawParameters != nil { + if err := json.Unmarshal(rawParameters, ¶meters); err != nil { + return nil, fmt.Errorf("failed to parse the parameters of the '%s' filter - %w", ByLabelFilterType, err) + } + } + return NewByLabel(name, parameters.Label, parameters.AllowsNoLabel, parameters.ValidValues...), nil +} + // NewByLabel creates and returns an instance of the RoleBasedFilter based on the input parameters // name - the filter name // labelName - the name of the label to use // allowsNoLabel - if true pods without given label will be considered as valid (not filtered out) // validValuesApp - list of valid values -func NewByLabel(name string, labelName string, allowsNoLabel bool, validValuesApp ...string) *ByLabel { - validValues := map[string]struct{}{} +func NewByLabel(name string, labelName string, allowsNoLabel bool, validValues ...string) *ByLabel { + validValuesMap := map[string]struct{}{} - for _, v := range validValuesApp { - validValues[v] = struct{}{} + for _, v := range validValues { + validValuesMap[v] = struct{}{} } - return &ByLabel{name: name, labelName: labelName, allowsNoLabel: allowsNoLabel, validValues: validValues} + return &ByLabel{name: name, labelName: labelName, allowsNoLabel: allowsNoLabel, validValues: validValuesMap} } // Type returns the type of the filter @@ -51,6 +71,12 @@ func (f *ByLabel) Name() string { return f.name } +// WithName sets the name of the filter. +func (f *ByLabel) WithName(name string) *ByLabel { + f.name = name + return f +} + // Filter filters out all pods that are not marked with one of roles from the validRoles collection // or has no role label in case allowsNoRolesLabel is true func (f *ByLabel) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { diff --git a/pkg/plugins/filter/by_labels.go b/pkg/plugins/filter/by_labels.go index cf4aeb2e8..c9db9093c 100644 --- a/pkg/plugins/filter/by_labels.go +++ b/pkg/plugins/filter/by_labels.go @@ -2,63 +2,77 @@ package filter import ( "context" + "encoding/json" "errors" + "fmt" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) const ( - // ByLabelsFilterType is the type of the ByLabelsFilter - ByLabelsFilterType = "by-labels" + // ByLabelSelectorFilterType is the type of the ByLabelsFilter + ByLabelSelectorFilterType = "by-label-selector" ) // compile-time type assertion -var _ framework.Filter = &ByLabels{} +var _ framework.Filter = &ByLabelSelector{} -// NewByLabels returns a new filter instance, configured with the provided +// ByLabelSelectorFactory defines the factory function for the ByLabelSelector filter +func ByLabelSelectorFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + parameters := metav1.LabelSelector{} + if rawParameters != nil { + if err := json.Unmarshal(rawParameters, ¶meters); err != nil { + return nil, fmt.Errorf("failed to parse the parameters of the '%s' filter - %w", ByLabelSelectorFilterType, err) + } + } + return NewByLabelSelector(name, ¶meters) +} + +// NewByLabelSelector returns a new filter instance, configured with the provided // name and label selector. -func NewByLabels(name string, selector *metav1.LabelSelector) (framework.Filter, error) { +func NewByLabelSelector(name string, selector *metav1.LabelSelector) (*ByLabelSelector, error) { if name == "" { - return nil, errors.New("ByLabels: missing filter name") + return nil, errors.New("ByLabelSelector: missing filter name") } labelSelector, err := metav1.LabelSelectorAsSelector(selector) if err != nil { return nil, err } - return &ByLabels{ + return &ByLabelSelector{ name: name, selector: labelSelector, }, nil } -// ByLabels filters out pods that do not match its label selector criteria -type ByLabels struct { +// ByLabelSelector filters out pods that do not match its label selector criteria +type ByLabelSelector struct { name string selector labels.Selector } // Type returns the type of the filter -func (blf *ByLabels) Type() string { - return ByLabelsFilterType +func (blf *ByLabelSelector) Type() string { + return ByLabelSelectorFilterType } // Name returns the name of the instance of the filter. -func (blf *ByLabels) Name() string { +func (blf *ByLabelSelector) Name() string { return blf.name } // WithName sets the name of the filter. -func (blf *ByLabels) WithName(name string) *ByLabels { +func (blf *ByLabelSelector) WithName(name string) *ByLabelSelector { blf.name = name return blf } // Filter filters out all pods that do not satisfy the label selector -func (blf *ByLabels) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { +func (blf *ByLabelSelector) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { filtered := []types.Pod{} for _, pod := range pods { diff --git a/pkg/plugins/filter/pd_role_filter.go b/pkg/plugins/filter/pd_role_filter.go index 2a67070d0..3ac50737e 100644 --- a/pkg/plugins/filter/pd_role_filter.go +++ b/pkg/plugins/filter/pd_role_filter.go @@ -1,7 +1,9 @@ package filter import ( - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + "encoding/json" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" ) const ( @@ -13,14 +15,29 @@ const ( RoleDecode = "decode" // RoleBoth set for workers that can act as both prefill and decode RoleBoth = "both" + + // DecodeFilterType is the type of the DecodeFilter + DecodeFilterType = "decode-filter" + // PrefillFilterType is the type of the PrefillFilter + PrefillFilterType = "prefill-filter" ) +// PrefillFilterFactory defines the factory function for the PrefillFilter +func PrefillFilterFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + return NewPrefillFilter().WithName(name), nil +} + // NewPrefillFilter creates and returns an instance of the Filter configured for prefill role -func NewPrefillFilter() framework.Filter { - return NewByLabel("prefill-filter", RoleLabel, false, RolePrefill) +func NewPrefillFilter() *ByLabel { + return NewByLabel(PrefillFilterType, RoleLabel, false, RolePrefill) +} + +// DecodeFilterFactory defines the factory function for the DecodeFilter +func DecodeFilterFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + return NewDecodeFilter().WithName(name), nil } // NewDecodeFilter creates and returns an instance of the Filter configured for decode role -func NewDecodeFilter() framework.Filter { - return NewByLabel("decode-filter", RoleLabel, true, RoleDecode, RoleBoth) +func NewDecodeFilter() *ByLabel { + return NewByLabel(DecodeFilterType, RoleLabel, true, RoleDecode, RoleBoth) } diff --git a/pkg/plugins/pre-request/pd_prerequest.go b/pkg/plugins/pre-request/pd_prerequest.go index 397412064..e1a01d826 100644 --- a/pkg/plugins/pre-request/pd_prerequest.go +++ b/pkg/plugins/pre-request/pd_prerequest.go @@ -3,9 +3,11 @@ package prerequest import ( "context" + "encoding/json" "net" "strconv" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -20,6 +22,11 @@ const ( // compile-time type assertion var _ requestcontrol.PreRequest = &PrefillHeaderHandler{} +// PrefillHeaderHandlerFactory defines the factory function for the PrefillHeaderHandler +func PrefillHeaderHandlerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + return NewPrefillHeaderHandler().WithName(name), nil +} + // NewPrefillHeaderHandler initializes a new PrefillHeaderHandler and returns its pointer. func NewPrefillHeaderHandler() *PrefillHeaderHandler { return &PrefillHeaderHandler{ diff --git a/pkg/plugins/profile/pd_profile_handler.go b/pkg/plugins/profile/pd_profile_handler.go index a85caf3e0..34fa67051 100644 --- a/pkg/plugins/profile/pd_profile_handler.go +++ b/pkg/plugins/profile/pd_profile_handler.go @@ -3,14 +3,18 @@ package profile import ( "context" + "encoding/json" "errors" + "fmt" - "github.com/llm-d/llm-d-inference-scheduler/pkg/config" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + + "github.com/llm-d/llm-d-inference-scheduler/pkg/config" ) const ( @@ -21,9 +25,37 @@ const ( prefill = "prefill" ) +type pdProfileHandlerParameters struct { + prefix.Config + Threshold int `json:"threshold"` +} + // compile-time type assertion var _ framework.ProfileHandler = &PdProfileHandler{} +// PdProfileHandlerFactory defines the factory function for the PdProfileHandler +func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + parameters := pdProfileHandlerParameters{ + Config: prefix.Config{ + HashBlockSize: prefix.DefaultHashBlockSize, + MaxPrefixBlocksToMatch: prefix.DefaultMaxPrefixBlocks, + LRUCapacityPerServer: prefix.DefaultLRUCapacityPerServer, + }, + Threshold: 100, + } + if rawParameters != nil { + if err := json.Unmarshal(rawParameters, ¶meters); err != nil { + return nil, fmt.Errorf("failed to parse the parameters of the '%s' profile handler - %w", PdProfileHandlerType, err) + } + } + + cfg := &config.Config{ + PDThreshold: parameters.Threshold, + GIEPrefixConfig: ¶meters.Config, + } + return NewPdProfileHandler(cfg).WithName(name), nil +} + // NewPdProfileHandler initializes a new PdProfileHandler and returns its pointer. func NewPdProfileHandler(cfg *config.Config) *PdProfileHandler { return &PdProfileHandler{ diff --git a/pkg/plugins/register.go b/pkg/plugins/register.go new file mode 100644 index 000000000..7dd28e1de --- /dev/null +++ b/pkg/plugins/register.go @@ -0,0 +1,23 @@ +package plugins + +import ( + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + + "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter" + prerequest "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/pre-request" + "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile" + "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer" +) + +// RegisterAllPlugins registers the factory functions of all plugins in this repository. +func RegisterAllPlugins() { + plugins.Register(filter.ByLabelFilterType, filter.ByLabelFilterFactory) + plugins.Register(filter.ByLabelSelectorFilterType, filter.ByLabelSelectorFactory) + plugins.Register(filter.DecodeFilterType, filter.DecodeFilterFactory) + plugins.Register(filter.PrefillFilterType, filter.PrefillFilterFactory) + plugins.Register(prerequest.PrefillHeaderHandlerType, prerequest.PrefillHeaderHandlerFactory) + plugins.Register(profile.PdProfileHandlerType, profile.PdProfileHandlerFactory) + plugins.Register(scorer.KvCacheAwareScorerType, scorer.KvCacheAwareScorerFactory) + plugins.Register(scorer.LoadAwareScorerType, scorer.LoadAwareScorerFactory) + plugins.Register(scorer.SessionAffinityScorerType, scorer.SessionAffinityScorerFactory) +} diff --git a/pkg/plugins/scorer/kvcache_aware.go b/pkg/plugins/scorer/kvcache_aware.go index 6df9398a6..cb70f530b 100644 --- a/pkg/plugins/scorer/kvcache_aware.go +++ b/pkg/plugins/scorer/kvcache_aware.go @@ -2,6 +2,7 @@ package scorer import ( "context" + "encoding/json" "fmt" "os" "strings" @@ -10,6 +11,7 @@ import ( "github.com/redis/go-redis/v9" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -26,12 +28,21 @@ const ( // compile-time type assertion var _ framework.Scorer = &KVCacheAwareScorer{} +// KvCacheAwareScorerFactory defines the factory function for the KVCacheAwareScorer +func KvCacheAwareScorerFactory(name string, _ json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { + plugin, err := NewKVCacheAwareScorer(handle.Context()) + if err != nil { + return nil, err + } + return plugin.WithName(name), nil +} + // NewKVCacheAwareScorer creates a new KVCacheAwareScorer instance. // It initializes the KVCacheIndexer from environment variables. // // If the environment variables are not set, or if the indexer // fails to initialize, an error is returned. -func NewKVCacheAwareScorer(ctx context.Context) (framework.Scorer, error) { +func NewKVCacheAwareScorer(ctx context.Context) (*KVCacheAwareScorer, error) { config := kvcache.NewDefaultConfig() redisAddr := os.Getenv(kvCacheRedisEnvVar) @@ -40,23 +51,23 @@ func NewKVCacheAwareScorer(ctx context.Context) (framework.Scorer, error) { if !strings.HasPrefix(redisAddr, "redis://") && !strings.HasPrefix(redisAddr, "rediss://") && !strings.HasPrefix(redisAddr, "unix://") { redisAddr = "redis://" + redisAddr } - redisOpt, err := redis.ParseURL(redisAddr) - if err != nil { - return nil, fmt.Errorf("failed to parse redisURL: %w", err) - } - - config.KVBlockIndexerConfig.RedisOpt = redisOpt } else { return nil, fmt.Errorf("environment variable %s is not set", kvCacheRedisEnvVar) } + redisOpt, err := redis.ParseURL(redisAddr) + if err != nil { + return nil, fmt.Errorf("failed to parse redisURL: %w", err) + } + config.KVBlockIndexerConfig.RedisOpt = redisOpt + hfToken := os.Getenv(huggingFaceTokenEnvVar) - if hfToken != "" { - config.TokenizersPoolConfig.HuggingFaceToken = hfToken - } else { + if hfToken == "" { return nil, fmt.Errorf("environment variable %s is not set", huggingFaceTokenEnvVar) } + config.TokenizersPoolConfig.HuggingFaceToken = hfToken + kvCacheIndexer, err := kvcache.NewKVCacheIndexer(config) if err != nil { return nil, fmt.Errorf("failed to create KVCacheIndexer: %w", err) diff --git a/pkg/plugins/scorer/load_aware_scorer.go b/pkg/plugins/scorer/load_aware_scorer.go index ae9bc7709..329276f19 100644 --- a/pkg/plugins/scorer/load_aware_scorer.go +++ b/pkg/plugins/scorer/load_aware_scorer.go @@ -2,29 +2,46 @@ package scorer import ( "context" + "encoding/json" + "fmt" - "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" ) const ( // LoadAwareScorerType is the type of the LoadAwareScorer LoadAwareScorerType = "load-aware-scorer" - queueThresholdEnvName = "LOAD_AWARE_SCORER_QUEUE_THRESHOLD" - queueThresholdDefault = 128 + // QueueThresholdDefault defines the default queue threshold value + QueueThresholdDefault = 128 ) +type loadAwareScorerParameters struct { + Threshold int `json:"threshold"` +} + // compile-time type assertion var _ framework.Scorer = &LoadAwareScorer{} +// LoadAwareScorerFactory defines the factory function for the LoadAwareScorer +func LoadAwareScorerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + parameters := loadAwareScorerParameters{Threshold: QueueThresholdDefault} + if rawParameters != nil { + if err := json.Unmarshal(rawParameters, ¶meters); err != nil { + return nil, fmt.Errorf("failed to parse the parameters of the '%s' scorer - %w", LoadAwareScorerType, err) + } + } + + return NewLoadAwareScorer(parameters.Threshold).WithName(name), nil +} + // NewLoadAwareScorer creates a new load based scorer -func NewLoadAwareScorer(ctx context.Context) framework.Scorer { +func NewLoadAwareScorer(queueThreshold int) *LoadAwareScorer { return &LoadAwareScorer{ name: LoadAwareScorerType, - queueThreshold: float64(env.GetEnvInt(queueThresholdEnvName, queueThresholdDefault, log.FromContext(ctx))), + queueThreshold: float64(queueThreshold), } } diff --git a/pkg/plugins/scorer/load_aware_scorer_test.go b/pkg/plugins/scorer/load_aware_scorer_test.go index 23c1cce92..38d7df102 100644 --- a/pkg/plugins/scorer/load_aware_scorer_test.go +++ b/pkg/plugins/scorer/load_aware_scorer_test.go @@ -27,7 +27,7 @@ func TestLoadBasedScorer(t *testing.T) { }{ { name: "load based scorer", - scorer: scorer.NewLoadAwareScorer(context.Background()), + scorer: scorer.NewLoadAwareScorer(0), req: &types.LLMRequest{ TargetModel: "critical", }, diff --git a/pkg/plugins/scorer/session_affinity.go b/pkg/plugins/scorer/session_affinity.go index 55f566d9f..0670b14c0 100644 --- a/pkg/plugins/scorer/session_affinity.go +++ b/pkg/plugins/scorer/session_affinity.go @@ -3,10 +3,11 @@ package scorer import ( "context" "encoding/base64" - "time" + "encoding/json" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" @@ -17,15 +18,18 @@ const ( // SessionAffinityScorerType is the type of the SessionAffinityScorer SessionAffinityScorerType = "session-affinity-scorer" - sessionKeepAliveTime = 60 * time.Minute // How long should an idle session be kept alive - sessionKeepAliveCheckFrequency = 15 * time.Minute // How often to check for overly idle sessions - sessionTokenHeader = "x-session-token" // name of the session header in request + sessionTokenHeader = "x-session-token" // name of the session header in request ) // compile-time type assertion var _ framework.Scorer = &SessionAffinity{} var _ requestcontrol.PostResponse = &SessionAffinity{} +// SessionAffinityScorerFactory defines the factory function for SessionAffinityScorer. +func SessionAffinityScorerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + return NewSessionAffinity().WithName(name), nil +} + // NewSessionAffinity returns a scorer func NewSessionAffinity() *SessionAffinity { return &SessionAffinity{ diff --git a/pkg/scheduling/pd/scheduler.go b/pkg/scheduling/pd/scheduler.go index b90757d7c..7a49f03ec 100644 --- a/pkg/scheduling/pd/scheduler.go +++ b/pkg/scheduling/pd/scheduler.go @@ -14,6 +14,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker" gieprofile "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/profile" giescorer "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/scorer" + envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" "github.com/llm-d/llm-d-inference-scheduler/pkg/config" @@ -22,6 +23,10 @@ import ( "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer" ) +const ( + queueThresholdEnvName = "LOAD_AWARE_SCORER_QUEUE_THRESHOLD" +) + // CreatePDSchedulerConfig returns a new disaggregated Prefill/Decode SchedulerConfig, using the provided configuration. func CreatePDSchedulerConfig(ctx context.Context, pdConfig *config.Config) (*scheduling.SchedulerConfig, error) { if !pdConfig.PDEnabled { // if PD is disabled, create scheduler with SingleProfileHandler (handling only decode profile) @@ -92,7 +97,8 @@ func pluginsFromConfig(ctx context.Context, pluginsConfig map[string]int, pdConf logger.Error(err, "KVCache scorer creation failed") } case config.LoadAwareScorerName: - plugins = append(plugins, framework.NewWeightedScorer(scorer.NewLoadAwareScorer(ctx), pluginWeight)) + queueThreshold := envutil.GetEnvInt(queueThresholdEnvName, scorer.QueueThresholdDefault, log.FromContext(ctx)) + plugins = append(plugins, framework.NewWeightedScorer(scorer.NewLoadAwareScorer(queueThreshold), pluginWeight)) case config.SessionAwareScorerName: plugins = append(plugins, framework.NewWeightedScorer(scorer.NewSessionAffinity(), pluginWeight))