From f2b6457d2e46941b233e0eef80d967a51aa2ee35 Mon Sep 17 00:00:00 2001 From: Cong Liu Date: Tue, 29 Apr 2025 20:00:41 +0000 Subject: [PATCH 1/9] Add prefix cache aware scheduling --- cmd/epp/main.go | 23 ++ pkg/epp/metrics/metrics.go | 75 +++++ pkg/epp/metrics/metrics_test.go | 103 +++++++ .../testdata/prefix_indexer_hit_bytes_metric | 19 ++ .../testdata/prefix_indexer_hit_ratio_metric | 16 + .../testdata/prefix_indexer_size_metric | 3 + pkg/epp/scheduling/plugins/filter/filter.go | 286 ++++++++++++++++++ pkg/epp/scheduling/plugins/prefix/indexer.go | 163 ++++++++++ .../scheduling/plugins/prefix/indexer_test.go | 46 +++ .../scheduling/plugins/prefix/linked_list.go | 85 ++++++ pkg/epp/scheduling/plugins/prefix/plugin.go | 178 +++++++++++ .../scheduling/plugins/prefix/plugin_test.go | 132 ++++++++ pkg/epp/scheduling/scheduler_v2.go | 62 ++++ pkg/epp/scheduling/types/types.go | 14 + 14 files changed, 1205 insertions(+) create mode 100644 pkg/epp/metrics/testdata/prefix_indexer_hit_bytes_metric create mode 100644 pkg/epp/metrics/testdata/prefix_indexer_hit_ratio_metric create mode 100644 pkg/epp/metrics/testdata/prefix_indexer_size_metric create mode 100644 pkg/epp/scheduling/plugins/filter/filter.go create mode 100644 pkg/epp/scheduling/plugins/prefix/indexer.go create mode 100644 pkg/epp/scheduling/plugins/prefix/indexer_test.go create mode 100644 pkg/epp/scheduling/plugins/prefix/linked_list.go create mode 100644 pkg/epp/scheduling/plugins/prefix/plugin.go create mode 100644 pkg/epp/scheduling/plugins/prefix/plugin_test.go create mode 100644 pkg/epp/scheduling/scheduler_v2.go diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 9fd401d4e..728363047 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -34,6 +34,7 @@ import ( "k8s.io/client-go/rest" "k8s.io/component-base/metrics/legacyregistry" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/metrics/filters" @@ -43,7 +44,9 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics/collectors" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/prefix" runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" + envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -107,8 +110,24 @@ var ( "Prometheus metric for the LoRA info metrics (must be in vLLM label format).") setupLog = ctrl.Log.WithName("setup") + + // Environment variables + schedulerV2 = envutil.GetEnvString("EXPERIMENTAL_USE_SCHEDULE_V2", "false", setupLog) + prefixCacheConfig = loadPrefixCacheConfig() ) +func loadPrefixCacheConfig() prefix.Config { + // logger := zap.New(zap.RawZapOpts(uberzap.AddCaller())) + // log.SetLogger(logger) + baseLogger := log.Log.WithName("env-config") + + return prefix.Config{ + HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultCacheBlockSize, baseLogger), + MaxPrefixBlocksToMatch: envutil.GetEnvInt("PREFIX_CACHE_MAX_PREFIX_BLOCKS", prefix.DefaultMaxPrefixBlocks, baseLogger), + LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_MAX_CACHE_SIZE_MB", prefix.DefaultLRUIndexerCapacity, baseLogger), + } +} + func main() { if err := run(); err != nil { os.Exit(1) @@ -172,6 +191,10 @@ func run() error { datastore := datastore.NewDatastore(ctx, pmf) scheduler := scheduling.NewScheduler(datastore) + if schedulerV2 == "true" { + setupLog.Info("Creating scheduler with prefixCache plugin", "prefix cache config", prefixCacheConfig) + scheduler = scheduling.NewSchedulerV2(datastore, prefixCacheConfig) + } serverRunner := &runserver.ExtProcServerRunner{ GrpcPort: *grpcPort, DestinationEndpointHintMetadataNamespace: *destinationEndpointHintMetadataNamespace, diff --git a/pkg/epp/metrics/metrics.go b/pkg/epp/metrics/metrics.go index 6cc0cdb83..1baa3099f 100644 --- a/pkg/epp/metrics/metrics.go +++ b/pkg/epp/metrics/metrics.go @@ -18,6 +18,7 @@ package metrics import ( "context" + "runtime/debug" "sync" "time" @@ -219,6 +220,40 @@ var ( }, []string{"commit"}, ) + + // Prefix indexer Metrics + PrefixCacheSize = compbasemetrics.NewGaugeVec( + &compbasemetrics.GaugeOpts{ + Subsystem: InferenceExtension, + Name: "prefix_indexer_size", + Help: "Size of the prefix indexer.", + StabilityLevel: compbasemetrics.ALPHA, + }, + []string{}, + ) + + PrefixCacheHitRatio = compbasemetrics.NewHistogramVec( + &compbasemetrics.HistogramOpts{ + Subsystem: InferenceExtension, + Name: "prefix_indexer_hit_ratio", + Help: "Ratio of prefix length matched to total prefix length in the cache lookup.", + // Buckets from 0.0 to 1.0 in increments + Buckets: []float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, + StabilityLevel: compbasemetrics.ALPHA, + }, + []string{}, + ) + + PrefixCacheHitLength = compbasemetrics.NewHistogramVec( + &compbasemetrics.HistogramOpts{ + Subsystem: InferenceExtension, + Name: "prefix_indexer_hit_bytes", + Help: "Length of the prefix match in number of bytes in the cache lookup.", + Buckets: []float64{0, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}, + StabilityLevel: compbasemetrics.ALPHA, + }, + []string{}, + ) ) var registerMetrics sync.Once @@ -244,6 +279,10 @@ func Register() { legacyregistry.MustRegister(SchedulerE2ELatency) legacyregistry.MustRegister(InferenceExtensionInfo) + + legacyregistry.MustRegister(PrefixCacheSize) + legacyregistry.MustRegister(PrefixCacheHitRatio) + legacyregistry.MustRegister(PrefixCacheHitLength) }) } @@ -352,8 +391,44 @@ func RecordSchedulerE2ELatency(duration time.Duration) { SchedulerE2ELatency.WithLabelValues().Observe(duration.Seconds()) } +// RecordPrefixCacheSize records the size of the prefix indexer in megabytes. +func RecordPrefixCacheSize(size int64) { + PrefixCacheSize.WithLabelValues().Set(float64(size)) +} + +// RecordPrefixCacheMatch records both the hit ratio and hit length for a prefix indexer match. +// matchedLength is the number of characters that matched, and totalLength is the total prefix length. +func RecordPrefixCacheMatch(matchedLength, totalLength int) { + // Record the hit length metric + PrefixCacheHitLength.WithLabelValues().Observe(float64(matchedLength)) + + // Record the hit ratio metric if totalLength is positive + if totalLength > 0 { + ratio := float64(matchedLength) / float64(totalLength) + PrefixCacheHitRatio.WithLabelValues().Observe(ratio) + } +} + func RecordInferenceExtensionInfo() { if CommitSHA != "" { InferenceExtensionInfo.WithLabelValues(CommitSHA).Set(1) } } + +func init() { + info, ok := debug.ReadBuildInfo() + if !ok { + return + } + + var Commit = func(i *debug.BuildInfo) string { + for _, setting := range i.Settings { + if setting.Key == "vcs.revision" { + return setting.Value + } + } + return "" + }(info) + + CommitSHA = Commit +} diff --git a/pkg/epp/metrics/metrics_test.go b/pkg/epp/metrics/metrics_test.go index 3a8136a08..4ad6f96e1 100644 --- a/pkg/epp/metrics/metrics_test.go +++ b/pkg/epp/metrics/metrics_test.go @@ -664,3 +664,106 @@ func TestSchedulerE2ELatency(t *testing.T) { }) } } + +func TestPrefixCacheMetrics(t *testing.T) { + const ( + PrefixCacheSizeMetric = InferenceExtension + "_prefix_indexer_size" + PrefixCacheHitRatioMetric = InferenceExtension + "_prefix_indexer_hit_ratio" + PrefixCacheHitLengthMetric = InferenceExtension + "_prefix_indexer_hit_bytes" + ) + + type cacheMatchRecord struct { + matchedLength int + totalLength int + } + + scenario := struct { + name string + cacheSizes []int64 + cacheMatches []cacheMatchRecord + }{ + name: "multiple cache metrics", + cacheSizes: []int64{1024, 2048, 4096}, + cacheMatches: []cacheMatchRecord{ + { + matchedLength: 5, + totalLength: 10, + }, + { + matchedLength: 0, + totalLength: 10, + }, + { + matchedLength: 10, + totalLength: 10, + }, + { + matchedLength: 7, + totalLength: 10, + }, + { + matchedLength: 64, + totalLength: 128, + }, + { + matchedLength: 0, + totalLength: 128, + }, + }, + } + + Register() + t.Run(scenario.name, func(t *testing.T) { + // Record cache size metrics + for _, size := range scenario.cacheSizes { + RecordPrefixCacheSize(size) + } + + // Record cache match metrics (both hit ratio and hit length) + for _, match := range scenario.cacheMatches { + RecordPrefixCacheMatch(match.matchedLength, match.totalLength) + } + + // Verify cache size metrics + wantCacheSizeMetrics, err := os.Open("testdata/prefix_indexer_size_metric") + defer func() { + if err := wantCacheSizeMetrics.Close(); err != nil { + t.Error(err) + } + }() + if err != nil { + t.Fatal(err) + } + if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, wantCacheSizeMetrics, PrefixCacheSizeMetric); err != nil { + t.Error(err) + } + + // Verify hit ratio metrics + wantHitRatioMetrics, err := os.Open("testdata/prefix_indexer_hit_ratio_metric") + defer func() { + if err := wantHitRatioMetrics.Close(); err != nil { + t.Error(err) + } + }() + if err != nil { + t.Fatal(err) + } + if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, wantHitRatioMetrics, PrefixCacheHitRatioMetric); err != nil { + t.Error(err) + } + + // Verify hit length metrics + wantHitLengthMetrics, err := os.Open("testdata/prefix_indexer_hit_bytes_metric") + defer func() { + if err := wantHitLengthMetrics.Close(); err != nil { + t.Error(err) + } + }() + if err != nil { + t.Fatal(err) + } + if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, wantHitLengthMetrics, PrefixCacheHitLengthMetric); err != nil { + t.Error(err) + } + }) +} diff --git a/pkg/epp/metrics/testdata/prefix_indexer_hit_bytes_metric b/pkg/epp/metrics/testdata/prefix_indexer_hit_bytes_metric new file mode 100644 index 000000000..86b48724e --- /dev/null +++ b/pkg/epp/metrics/testdata/prefix_indexer_hit_bytes_metric @@ -0,0 +1,19 @@ +# HELP inference_extension_prefix_indexer_hit_bytes [ALPHA] Length of the prefix match in number of bytes in the cache lookup. +# TYPE inference_extension_prefix_indexer_hit_bytes histogram +inference_extension_prefix_indexer_hit_bytes_bucket{le="0"} 2 +inference_extension_prefix_indexer_hit_bytes_bucket{le="16"} 5 +inference_extension_prefix_indexer_hit_bytes_bucket{le="32"} 5 +inference_extension_prefix_indexer_hit_bytes_bucket{le="64"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="128"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="256"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="512"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="1024"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="2048"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="4096"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="8192"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="16384"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="32768"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="65536"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="+Inf"} 6 +inference_extension_prefix_indexer_hit_bytes_sum 86 +inference_extension_prefix_indexer_hit_bytes_count 6 diff --git a/pkg/epp/metrics/testdata/prefix_indexer_hit_ratio_metric b/pkg/epp/metrics/testdata/prefix_indexer_hit_ratio_metric new file mode 100644 index 000000000..e94827cb6 --- /dev/null +++ b/pkg/epp/metrics/testdata/prefix_indexer_hit_ratio_metric @@ -0,0 +1,16 @@ +# HELP inference_extension_prefix_indexer_hit_ratio [ALPHA] Ratio of prefix length matched to total prefix length in the cache lookup. +# TYPE inference_extension_prefix_indexer_hit_ratio histogram +inference_extension_prefix_indexer_hit_ratio_bucket{le="0"} 2 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.1"} 2 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.2"} 2 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.3"} 2 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.4"} 2 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.5"} 4 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.6"} 4 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.7"} 5 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.8"} 5 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.9"} 5 +inference_extension_prefix_indexer_hit_ratio_bucket{le="1"} 6 +inference_extension_prefix_indexer_hit_ratio_bucket{le="+Inf"} 6 +inference_extension_prefix_indexer_hit_ratio_sum 2.7 +inference_extension_prefix_indexer_hit_ratio_count 6 diff --git a/pkg/epp/metrics/testdata/prefix_indexer_size_metric b/pkg/epp/metrics/testdata/prefix_indexer_size_metric new file mode 100644 index 000000000..9799b1729 --- /dev/null +++ b/pkg/epp/metrics/testdata/prefix_indexer_size_metric @@ -0,0 +1,3 @@ +# HELP inference_extension_prefix_indexer_size [ALPHA] Size of the prefix indexer. +# TYPE inference_extension_prefix_indexer_size gauge +inference_extension_prefix_indexer_size{} 4096 diff --git a/pkg/epp/scheduling/plugins/filter/filter.go b/pkg/epp/scheduling/plugins/filter/filter.go new file mode 100644 index 000000000..67ce764dd --- /dev/null +++ b/pkg/epp/scheduling/plugins/filter/filter.go @@ -0,0 +1,286 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filter + +import ( + "math" + "math/rand" + "time" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +type baseFilter struct { + name string + filter filterFunc +} + +func (f *baseFilter) Name() string { + if f == nil { + return "nil" + } + return f.name +} + +func (f *baseFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + loggerTrace := ctx.Logger.V(logutil.TRACE) + loggerTrace.Info("Running a filter", "name", f.Name(), "podCount", len(pods)) + + return f.filter(ctx, pods) +} + +// DecisionTreeFilter applies current filterFunc, and then recursively applies next filters +// depending success or failure of the current filter. +// It can be used to construct a flow chart algorithm. +type DecisionTreeFilter struct { + Current plugins.Filter + // NextOnSuccess filter will be applied after successfully applying the current filter. + // The filtered results will be passed to the next filter. + NextOnSuccess plugins.Filter + // NextOnFailure filter will be applied if current filter fails. + // The original input will be passed to the next filter. + NextOnFailure plugins.Filter + // NextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the + // success or failure of the current filter. + // NOTE: When using NextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil. + // However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of + // NextOnSuccessOrFailure, in the success and failure scenarios, respectively. + NextOnSuccessOrFailure plugins.Filter +} + +func (f *DecisionTreeFilter) Name() string { + if f == nil { + return "nil" + } + return f.Current.Name() +} + +func (f *DecisionTreeFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + loggerTrace := ctx.Logger.V(logutil.TRACE) + filtered := f.Current.Filter(ctx, pods) + + next := f.NextOnSuccessOrFailure + if len(filtered) > 0 { + if f.NextOnSuccess == nil && f.NextOnSuccessOrFailure == nil { + // No succeeding filters to run, return. + return filtered + } + if f.NextOnSuccess != nil { + next = f.NextOnSuccess + } + loggerTrace.Info("Filter succeeded", "filter", f.Name(), "next", next.Name(), "filteredPodCount", len(filtered)) + // On success, pass the filtered result to the next filter. + return next.Filter(ctx, filtered) + } else { + if f.NextOnFailure == nil && f.NextOnSuccessOrFailure == nil { + // No succeeding filters to run, return. + return filtered + } + if f.NextOnFailure != nil { + next = f.NextOnFailure + } + loggerTrace.Info("Filter failed", "filter", f.Name(), "next", next.Name()) + // On failure, pass the initial set of pods to the next filter. + return next.Filter(ctx, pods) + } +} + +// filterFunc filters a set of input pods to a subset. +type filterFunc func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod + +// toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc. +func toFilterFunc(pp podPredicate) filterFunc { + return func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + filtered := []types.Pod{} + for _, pod := range pods { + pass := pp(ctx.Req, pod) + if pass { + filtered = append(filtered, pod) + } + } + + return filtered + } +} + +var LeastQueueFilter = &baseFilter{ + name: "least queuing", + filter: leastQueuingFilterFunc, +} + +// leastQueuingFilterFunc finds the max and min queue size of all pods, divides the whole range +// (max-min) by the number of pods, and finds the pods that fall into the first range. +// The intuition is that if there are multiple pods that share similar queue size in the low range, +// we should consider them all instead of the absolute minimum one. This worked better than picking +// the least one as it gives more choices for the next filter, which on aggregate gave better +// results. +// TODO: Compare this strategy with other strategies such as top K. +func leastQueuingFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + min := math.MaxInt + max := 0 + filtered := []types.Pod{} + + for _, pod := range pods { + if pod.GetMetrics().WaitingQueueSize <= min { + min = pod.GetMetrics().WaitingQueueSize + } + if pod.GetMetrics().WaitingQueueSize >= max { + max = pod.GetMetrics().WaitingQueueSize + } + } + + for _, pod := range pods { + if pod.GetMetrics().WaitingQueueSize >= min && pod.GetMetrics().WaitingQueueSize <= min+(max-min)/len(pods) { + filtered = append(filtered, pod) + } + } + return filtered +} + +var LowQueueFilter = &baseFilter{ + name: "low queueing filter", + filter: toFilterFunc((queueThresholdPredicate(config.Conf.QueueingThresholdLoRA))), +} + +var LeastKVCacheFilter = &baseFilter{ + name: "least KV cache percent", + filter: leastKVCacheFilterFunc, +} + +// leastKVCacheFilterFunc finds the max and min KV cache of all pods, divides the whole range +// (max-min) by the number of pods, and finds the pods that fall into the first range. +// The intuition is that if there are multiple pods that share similar KV cache in the low range, we +// should consider them all instead of the absolute minimum one. This worked better than picking the +// least one as it gives more choices for the next filter, which on aggregate gave better results. +// TODO: Compare this strategy with other strategies such as top K. +func leastKVCacheFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + min := math.MaxFloat64 + var max float64 = 0 + filtered := []types.Pod{} + + for _, pod := range pods { + if pod.GetMetrics().KVCacheUsagePercent <= min { + min = pod.GetMetrics().KVCacheUsagePercent + } + if pod.GetMetrics().KVCacheUsagePercent >= max { + max = pod.GetMetrics().KVCacheUsagePercent + } + } + + for _, pod := range pods { + if pod.GetMetrics().KVCacheUsagePercent >= min && pod.GetMetrics().KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) { + filtered = append(filtered, pod) + } + } + return filtered +} + +var LoRAAffinityFilter = &baseFilter{ + name: "affinity LoRA", + filter: loRASoftAffinityFilterFunc, +} + +// loRASoftAffinityPredicate implements a pod selection strategy that prioritizes pods +// with existing LoRA model affinity while allowing for load balancing through randomization. +// +// The function works by: +// 1. Separating pods into two groups: those with target model affinity and those with available capacity +// 2. Using a probability threshold to sometimes select from non-affinity pods to enable load balancing +// 3. Falling back to whatever group has pods if one group is empty +// +// Parameters: +// - logger: Logger interface for diagnostic output +// - req: LLM request containing the resolved target model +// - pods: Slice of pod metrics to filter +// +// Returns: +// - Filtered slice of pod metrics based on affinity and availability +// - Error if any issues occur during filtering +func loRASoftAffinityFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + + // Pre-allocate slices with estimated capacity + filtered_affinity := make([]types.Pod, 0, len(pods)) + filtered_available := make([]types.Pod, 0, len(pods)) + + // Categorize pods based on affinity and availability + for _, pod := range pods { + _, active := pod.GetMetrics().ActiveModels[ctx.Req.ResolvedTargetModel] + _, waiting := pod.GetMetrics().WaitingModels[ctx.Req.ResolvedTargetModel] + + if active || waiting { + filtered_affinity = append(filtered_affinity, pod) + } else if len(pod.GetMetrics().ActiveModels)+len(pod.GetMetrics().WaitingModels) < pod.GetMetrics().MaxActiveModels { + filtered_available = append(filtered_available, pod) + } + } + + // Use crypto/rand for better randomization in production environments + randSource := rand.NewSource(time.Now().UnixNano()) + randGen := rand.New(randSource) + + // If both groups have pods, use probability to select which group to return + if len(filtered_affinity) > 0 && len(filtered_available) > 0 { + if randGen.Float64() < config.Conf.LoraAffinityThreshold { + return filtered_affinity + } + return filtered_available + } + + // Return whichever group has pods + if len(filtered_affinity) > 0 { + return filtered_affinity + } + + return filtered_available +} + +var HasCapacityFilter = &baseFilter{ + name: "has capacity for sheddable requests", + filter: toFilterFunc(queueThresholdPredicate(config.Conf.QueueThresholdCritical).and(kvCacheThresholdPredicate(config.Conf.KVCacheThreshold))), +} + +// NoopFilter is a filter that does not filter out any pods. +var NoopFilter = &baseFilter{ + name: "noop", + filter: toFilterFunc(func(req *types.LLMRequest, pod types.Pod) bool { + return true + }), +} + +// podPredicate is a filter function to check whether a pod is desired. +type podPredicate func(req *types.LLMRequest, pod types.Pod) bool + +func queueThresholdPredicate(queueThreshold int) podPredicate { + return func(req *types.LLMRequest, pod types.Pod) bool { + return pod.GetMetrics().WaitingQueueSize <= queueThreshold + } +} + +func kvCacheThresholdPredicate(kvCacheThreshold float64) podPredicate { + return func(req *types.LLMRequest, pod types.Pod) bool { + return pod.GetMetrics().KVCacheUsagePercent <= kvCacheThreshold + } +} + +func (pp podPredicate) and(another podPredicate) podPredicate { + return func(req *types.LLMRequest, pod types.Pod) bool { + return pp(req, pod) && another(req, pod) + } +} diff --git a/pkg/epp/scheduling/plugins/prefix/indexer.go b/pkg/epp/scheduling/plugins/prefix/indexer.go new file mode 100644 index 000000000..cae7739bd --- /dev/null +++ b/pkg/epp/scheduling/plugins/prefix/indexer.go @@ -0,0 +1,163 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package prefix + +import ( + "context" + "sync" + "time" + "unsafe" + + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func newIndexer(maxCacheSize int) *indexer { + t := &indexer{ + maxCacheSize: maxCacheSize, + table: make(map[types.BlockHash]map[types.ServerID]*node), + list: newLinkedList(), + } + go t.ReportCacheSize(time.Second) + return t +} + +// An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that +// prefix cached . +type indexer struct { + mu sync.RWMutex + maxCacheSize int + table map[types.BlockHash]map[types.ServerID]*node // from any prefix cache to the cache entry to find the server + list *linkedList // LRU list to keep track of the order of entries +} + +// Get returns the set of servers that have the given prefix hash cached. +func (i *indexer) Get(hash types.BlockHash) map[types.ServerID]bool { + i.mu.RLock() + defer i.mu.RUnlock() + res := map[types.ServerID]bool{} + for server := range i.table[hash] { + res[server] = true + } + return res +} + +// Add adds a list of prefix hashes of a single request to the server the request was sent to. +// The intuition is that this server is likely to have the prefix cached, so next time a request +// sharing the longest prefix should be sent to the same server to take advantage of the cache hit. +func (i *indexer) Add(hashes []types.BlockHash, server types.ServerID) { + i.mu.Lock() + defer i.mu.Unlock() + for _, hash := range hashes { + i.add(hash, server) + } +} + +func (i *indexer) check(hash types.BlockHash, server types.ServerID) (*node, bool) { + servers, ok := i.table[hash] + if !ok { + return nil, false + } + n, ok := servers[server] + return n, ok +} + +func (i *indexer) add(hash types.BlockHash, server types.ServerID) { + node, exists := i.check(hash, server) + if exists { + i.list.moveToTail(node) + } else { + i.create(hash, server) + } +} + +func (i *indexer) create(hash types.BlockHash, server types.ServerID) { + n := &node{ + hash: hash, + server: server, + } + + for i.list.size >= i.maxCacheSize { + // Evict the least recently used entry if we've exceeded the max cache size + i.evict() + } + + if _, ok := i.table[hash]; !ok { + i.table[hash] = make(map[types.ServerID]*node) + } + i.table[hash][server] = n + i.list.add(n) +} + +// evict removes the least recently used entry from the cache +func (i *indexer) evict() { + oldestNode := i.list.dummyHead.next + i.list.delete(oldestNode) + + hash := oldestNode.hash + server := oldestNode.server + // Remove from the hash map + serverMap := i.table[hash] + delete(serverMap, server) + + // If this was the last server for this hash, remove the hash entry entirely + if len(serverMap) == 0 { + delete(i.table, hash) + } + + log.FromContext(context.TODO()).V(logutil.TRACE).Info("Evicted LRU entry", "hash", hash, "server", server) +} + +// ReportCacheSize starts a goroutine that periodically reports the cache size metric +func (i *indexer) ReportCacheSize(interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for range ticker.C { + i.mu.RLock() + metrics.RecordPrefixCacheSize(int64(i.list.size)) + log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU", "# entries", i.list.size, "estimated size MB", i.list.size*i.estimateEntrySize()/1000000) + i.mu.RUnlock() + } +} + +// estimateEntrySize estimates the memory size of a cache entry in bytes. +func (i *indexer) estimateEntrySize() int { + size := 0 + + // Estimate the size of a node in the linked list. + // First get the size of the node struct via unsafe.Sizeof. + // The prev and next pointers are 8 bytes each on a 64-bit system. + // The BlockHash is a uint64, which is 8 bytes. + // The ServerID is a NamespacedName, which contains two strings (Name and Namespace). + // The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length). + // So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes. + size += int(unsafe.Sizeof(node{})) + // Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName). + size += 2 * 63 + + // Estimate the size of an entry in the hash map. Note the overhead of the map headers and buckets are ignored. + size += 8 // Size of the BlockHash (uint64). + size += 2 * 16 // Size of the ServerID string headers (NamespacedName). + size += 2 * 63 // Size of the Name and Namespace strings in ServerID. + size += 8 // Size of the pointer to the node in the hash map. + + // Based on the above estimates, the estimated size of an entry is: + // (48 + 2*63) + (8 + 2*16 + 2*63 + 8) = 348 bytes. + return size +} diff --git a/pkg/epp/scheduling/plugins/prefix/indexer_test.go b/pkg/epp/scheduling/plugins/prefix/indexer_test.go new file mode 100644 index 000000000..592b7c3e3 --- /dev/null +++ b/pkg/epp/scheduling/plugins/prefix/indexer_test.go @@ -0,0 +1,46 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package prefix + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +func TestIndexer_AddAndGet(t *testing.T) { + cache := newIndexer(2) + + hash1 := types.BlockHash(1) + server := types.ServerID{Namespace: "default", Name: "server1"} + + // Add an entry to the cache + cache.Add([]types.BlockHash{hash1}, server) + + // Retrieve the entry + assert.Equal(t, 1, cache.list.size, "Cache size should be 1 after adding an entry") + servers := cache.Get(hash1) + assert.Contains(t, servers, server, "Cache should contain the added server") + + // Add another entry to the cache, the cache size should be incremented to 2. + cache.Add([]types.BlockHash{types.BlockHash(2)}, server) + assert.Equal(t, 2, cache.list.size, "Cache size should be 2 after adding an entry") + + // Add another entry to the cache, which should evict the first one due to max size. + cache.Add([]types.BlockHash{types.BlockHash(3)}, server) + assert.Equal(t, 2, cache.list.size, "Cache size should still be 2 after adding an entry") +} diff --git a/pkg/epp/scheduling/plugins/prefix/linked_list.go b/pkg/epp/scheduling/plugins/prefix/linked_list.go new file mode 100644 index 000000000..9c9b82103 --- /dev/null +++ b/pkg/epp/scheduling/plugins/prefix/linked_list.go @@ -0,0 +1,85 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package prefix + +import ( + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +type linkedList struct { + dummyHead *node // The head of the linked list (dummy node). + tail *node // The tail of the linked list. + size int // The size of the linked list (excluding dummy head). +} + +// newLinkedList initializes a new linked list with a dummy head node. +// Using a dummy head simplifies the implementation by eliminating nil checks. +func newLinkedList() *linkedList { + dummy := &node{} // Create dummy head node + return &linkedList{ + dummyHead: dummy, + tail: dummy, + size: 0, + } +} + +type node struct { + prev *node + next *node + server types.ServerID + hash types.BlockHash +} + +// add adds a node to the end of the linked list. +func (ll *linkedList) add(n *node) { + ll.size++ + + n.prev = ll.tail + ll.tail.next = n + ll.tail = n +} + +// delete removes a node from the linked list. +// Note the method assumes the input node exists in the list. +func (ll *linkedList) delete(n *node) { + ll.size-- + n.prev.next = n.next + + // If it's the tail node + if n.next == nil { + ll.tail = n.prev + } else { + n.next.prev = n.prev + } +} + +// moveToTail moves an existing node to the end of the linked list (most recent). +func (ll *linkedList) moveToTail(n *node) { + if n.next == nil { + // Already the tail, no need to move. + return + } + + n.prev.next = n.next + n.next.prev = n.prev + + // Move it to the tail position + n.prev = ll.tail + n.next = nil + ll.tail.next = n + ll.tail = n +} diff --git a/pkg/epp/scheduling/plugins/prefix/plugin.go b/pkg/epp/scheduling/plugins/prefix/plugin.go new file mode 100644 index 000000000..2e748af82 --- /dev/null +++ b/pkg/epp/scheduling/plugins/prefix/plugin.go @@ -0,0 +1,178 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package prefix + +import ( + "encoding/binary" + "fmt" + + "github.com/cespare/xxhash/v2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const ( + // Attempt to return DefaultNumServersToMatch servers with their longest prefix match length. + // Why not just return the server with longest prefix match? + // It may not be the optimal choice, e.g., it may have a high queue depth. + // We optimistically search more than one to give more candidates for the scheduler to choose. + DefaultNumServersToMatch = 2 + // vLLM default token block size is 16, and a good guess of average characters per token is 4. + DefaultCacheBlockSize = 64 + DefaultMaxPrefixBlocks = 128 + // Assume each request reaches DefaultMaxPrefixBlocks = 128, and each BlockHash is cached onto 2 + // servers due to load balancing, then it requires 256 entries per request. + // According to the estimates in indexer.estimateEntrySize(), the size of each entry is 348 bytes. + // So each request will cost 89,088 bytes ~ 90KB. + // Therefore, to cache 50k requests, we need 50K * 90KB = 4.5GB. Assuming 500 requests per + // second, a 4.5 GB cache can hold at least last 100 seconds of requests. + // Note in practice, the size of each entry will be much smaller (shorter NamespacedNames, + // shorter prompt). And due to the prefix cache hit, the number of unique cache entries will be + // much smaller per request. Therefore the actual cache size will be much smaller. + // TODO: Add some guidance for choosing the right size. + DefaultLRUIndexerCapacity = 50000 +) + +type Config struct { + // The input prompt is broken into sizes of HashBlockSize to calculate block hashes . Requests + // with length shorter than the block size will be ignored. + HashBlockSize int + // MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will + // be ignored. + MaxPrefixBlocksToMatch int + // Max (approximate) size of the LRU indexer in number of entries. + LRUIndexerCapacity int +} + +var DefaultConfig = Config{ + HashBlockSize: DefaultCacheBlockSize, + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + LRUIndexerCapacity: DefaultLRUIndexerCapacity, +} + +type plugin struct { + Config + indexer Indexer +} + +type Indexer interface { + Get(hash types.BlockHash) map[types.ServerID]bool + Add(hashes []types.BlockHash, server types.ServerID) +} + +func New(config Config) *plugin { + m := &plugin{ + Config: config, + indexer: newIndexer(config.LRUIndexerCapacity), + } + return m +} + +func (m *plugin) Name() string { + return "prefixCache" +} + +func (m *plugin) PreSchedule(ctx *types.SchedulingContext) { + ctx.PrefixHashes = hashPrompt(ctx, m.HashBlockSize, m.MaxPrefixBlocksToMatch) + ctx.PrefixCacheServers = m.matchLongestPrefix(ctx, DefaultNumServersToMatch) + ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("PreSchedule, cached servers: %+v", ctx.PrefixCacheServers), "hashes", ctx.PrefixHashes) +} + +// If a request was routed to a server, record it in the cache: +func (m *plugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) { + targetPod := res.TargetPod.GetPod() + m.indexer.Add(ctx.PrefixHashes, types.ServerID(targetPod.NamespacedName)) + total := len(ctx.PrefixHashes) + matchLen := ctx.PrefixCacheServers[types.ServerID(targetPod.NamespacedName)] + metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize) +} + +func (m *plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { + total := len(ctx.PrefixHashes) + podScoreFunc := func(ctx *types.SchedulingContext, pod types.Pod) float64 { + if total == 0 { + return 0 + } + matchLen := ctx.PrefixCacheServers[types.ServerID(pod.GetPod().NamespacedName)] + return float64(matchLen) / float64(total) + } + + scores := make(map[types.Pod]float64, len(pods)) + for _, pod := range pods { + scores[pod] = podScoreFunc(ctx, pod) + } + return scores +} + +// matchLongestPrefix returns a map of servers and length of prefix that each server caches. +func (m *plugin) matchLongestPrefix(ctx *types.SchedulingContext, numServers int) map[types.ServerID]int { + if numServers > len(ctx.PodsSnapshot) { + numServers = len(ctx.PodsSnapshot) + } + res := make(map[types.ServerID]int) + // Use a greedy strategy to search from the longest prefix. + // NOTE: It's possible to further optimize this with a binary search. + for i := len(ctx.PrefixHashes) - 1; i >= 0 && len(res) < numServers; i-- { + hash := ctx.PrefixHashes[i] + cachedServers := m.indexer.Get(hash) + if len(cachedServers) > 0 { + ctx.Logger.V(logutil.VERBOSE).Info("Found cached servers", "cachedServers", cachedServers, "total # blocks", len(ctx.PrefixHashes), "longest prefix", i) + for server := range cachedServers { + // Update servers with their longest prefix match. + // If we already found this server with longer prefix match, don't update it. + if _, ok := res[server]; !ok { + res[server] = i + 1 + } + } + } + } + return res +} + +// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block. +// hash(0) is the hash of the model name, since different models generally don't share prefix cache. +// For block i, hash(i) = hash(block i content, hash(i-1)). +func hashPrompt(ctx *types.SchedulingContext, cacheBlockSize int, maxPrefixBlocks int) []types.BlockHash { + prompt := []byte(ctx.Req.Prompt) + if len(prompt) < cacheBlockSize { + ctx.Logger.V(logutil.DEBUG).Info("Request body too small for prefix cache", "size", len(prompt), "block size", cacheBlockSize) + return nil + } + if len(prompt) > cacheBlockSize*maxPrefixBlocks { + ctx.Logger.V(logutil.DEBUG).Info("Truncating input", "size", len(prompt), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize) + prompt = prompt[:maxPrefixBlocks*cacheBlockSize] + } + // Split the body into blocks of size cacheBlockSize. The +1 is to account for the model. + // If the last block is smaller than cacheBlockSize, it will be ignored. + res := make([]types.BlockHash, 0, 1+len(prompt)/cacheBlockSize) + // Add the model to the first block hash so that different models have different hashes even with the same body. + res = append(res, types.BlockHash(xxhash.Sum64String(ctx.Req.ResolvedTargetModel))) + for i := 0; i+cacheBlockSize <= len(prompt); i += cacheBlockSize { + block := prompt[i : i+cacheBlockSize] + prevBlockHash := res[len(res)-1] + toHash := append(block, toBytes(prevBlockHash)...) + res = append(res, types.BlockHash(xxhash.Sum64(toHash))) + } + return res +} + +func toBytes(i types.BlockHash) []byte { + bytes := make([]byte, 8) + binary.LittleEndian.PutUint64(bytes, uint64(i)) + return bytes +} diff --git a/pkg/epp/scheduling/plugins/prefix/plugin_test.go b/pkg/epp/scheduling/plugins/prefix/plugin_test.go new file mode 100644 index 000000000..47c6c7f18 --- /dev/null +++ b/pkg/epp/scheduling/plugins/prefix/plugin_test.go @@ -0,0 +1,132 @@ +package prefix + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + k8stypes "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +func TestPrefixPlugin(t *testing.T) { + config := Config{ + HashBlockSize: 4, + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + LRUIndexerCapacity: DefaultLRUIndexerCapacity, + } + plugin := New(config) + + pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}} + pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}} + pods := []types.Pod{pod1, pod2} + + // First request. + req1 := &types.LLMRequest{ + Model: "test-model1", + ResolvedTargetModel: "test-model1", + Prompt: "aaaaaa", + } + ctx := types.NewSchedulingContext(context.Background(), req1, pods) + plugin.PreSchedule(ctx) + t.Logf("Hashes %+v, cached servers: %+v", ctx.PrefixHashes, ctx.PrefixCacheServers) + // Input size is 6, hash block size is 4, the last 2 characters are ignored. + // Total hashes = 2 (the first one is for the model) + assert.Equal(t, 2, len(ctx.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(ctx.PrefixCacheServers), "there shouldn't be any cached servers") + + // Updated to use the new Score method signature + scores := plugin.Score(ctx, pods) + assert.Equal(t, float64(0), scores[pod1], "score for pod1") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + // Simulate pod1 was picked. + plugin.PostSchedule(ctx, &types.Result{TargetPod: pod1}) + + // Second request doesn't share any prefix with first one. It should be added to the cache but + // the pod score should be 0. + req2 := &types.LLMRequest{ + Model: "test-model2", + ResolvedTargetModel: "test-model2", + Prompt: "bbbbbb", + } + ctx = types.NewSchedulingContext(context.Background(), req2, pods) + plugin.PreSchedule(ctx) + t.Logf("Hashes %+v, cached servers: %+v", ctx.PrefixHashes, ctx.PrefixCacheServers) + // Input size is 6, hash block size is 4, the last 2 characters are ignored. + // Total hashes = 2 (the first one is for the model) + assert.Equal(t, 2, len(ctx.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(ctx.PrefixCacheServers), "there shouldn't be any cached servers") + + // Updated to use the new Score method signature + scores = plugin.Score(ctx, pods) + assert.Equal(t, float64(0), scores[pod1], "score for pod1") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + // Simulate pod2 was picked. + plugin.PostSchedule(ctx, &types.Result{TargetPod: pod2}) + + // Third request shares partial prefix with first one. + req3 := &types.LLMRequest{ + Model: "test-model1", + ResolvedTargetModel: "test-model1", + Prompt: "aaaabbbb", + } + ctx = types.NewSchedulingContext(context.Background(), req3, pods) + plugin.PreSchedule(ctx) + t.Logf("Hashes %+v, cached servers: %+v", ctx.PrefixHashes, ctx.PrefixCacheServers) + // Input size is 8, hash block size is 4, so 2 hashes will be calculated. + // Total hashes = 3 (the first one is for the model) + assert.Equal(t, 3, len(ctx.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 1, len(ctx.PrefixCacheServers), "pod1 should have cached the aaaa prefix") + + // Updated to use the new Score method signature + scores = plugin.Score(ctx, pods) + assert.Equal(t, float64(2)/float64(3), scores[pod1], "score should be 2/3 - the model and the first prefix block match") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + plugin.PostSchedule(ctx, &types.Result{TargetPod: pod1}) + + // 4th request is same as req3 except the model is different, still no match. + req4 := &types.LLMRequest{ + Model: "test-model-new", + ResolvedTargetModel: "test-model-new", + Prompt: "aaaabbbb", + } + ctx = types.NewSchedulingContext(context.Background(), req4, pods) + plugin.PreSchedule(ctx) + t.Logf("Hashes %+v, cached servers: %+v", ctx.PrefixHashes, ctx.PrefixCacheServers) + // Input size is 8, hash block size is 4, so 2 hashes will be calculated. + // Total hashes = 3 (the first one is for the model) + assert.Equal(t, 3, len(ctx.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(ctx.PrefixCacheServers), "pod1 should have cached the aaaa prefix") + + // Updated to use the new Score method signature + scores = plugin.Score(ctx, pods) + assert.Equal(t, float64(0), scores[pod1], "score for pod1") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + plugin.PostSchedule(ctx, &types.Result{TargetPod: pod1}) + + // 5th request shares partial prefix with 3rd one. + req5 := &types.LLMRequest{ + Model: "test-model1", + ResolvedTargetModel: "test-model1", + Prompt: "aaaabbbbcccc", + } + ctx = types.NewSchedulingContext(context.Background(), req5, pods) + plugin.PreSchedule(ctx) + t.Logf("Hashes %+v, cached servers: %+v", ctx.PrefixHashes, ctx.PrefixCacheServers) + // Input size is 12, hash block size is 4, so 3 hashes will be calculated. + // Total hashes = 4 (the first one is for the model) + assert.Equal(t, 4, len(ctx.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 1, len(ctx.PrefixCacheServers), "pod1 should have cached the aaaa prefix") + + // Updated to use the new Score method signature + scores = plugin.Score(ctx, pods) + assert.Equal(t, 0.75, scores[pod1], "score should be 0.75 - the model and the first 2 prefix blocks match") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + plugin.PostSchedule(ctx, &types.Result{TargetPod: pod1}) +} diff --git a/pkg/epp/scheduling/scheduler_v2.go b/pkg/epp/scheduling/scheduler_v2.go new file mode 100644 index 000000000..7a3da3b3a --- /dev/null +++ b/pkg/epp/scheduling/scheduler_v2.go @@ -0,0 +1,62 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package scheduling implements request scheduling algorithms. +package scheduling + +import ( + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/prefix" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +func NewSchedulerV2(datastore Datastore, prefixConfig prefix.Config) *Scheduler { + prefixPlugin := prefix.New(prefixConfig) + queuePlugin := &scorer.QueueScorer{} + kvCachePlugin := &scorer.KVCacheScorer{} + configV2 := &SchedulerConfig{ + PreSchedulePlugins: []plugins.PreSchedule{prefixPlugin}, + PostSchedulePlugins: []plugins.PostSchedule{prefixPlugin}, + Scorers: map[plugins.Scorer]int{ + prefixPlugin: 3, + queuePlugin: 1, + kvCachePlugin: 1, + }, + Filters: []plugins.Filter{&sheddableRequestFilterV2{}}, + Picker: &picker.MaxScorePicker{}, + } + return NewSchedulerWithConfig(datastore, configV2) +} + +type sheddableRequestFilterV2 struct { +} + +func (p *sheddableRequestFilterV2) Name() string { + return "sheddableRequestFilterV2" +} + +func (p *sheddableRequestFilterV2) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + if ctx.Req.Critical { + // Allow all pods to pass through if the request is critical, even if all pods reach their capacity. + return pods + } + + // Only allow pods that have enough capacity to handle the request. + return filter.HasCapacityFilter.Filter(ctx, pods) +} diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 795ef65d2..6cf399938 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" @@ -62,6 +63,10 @@ type SchedulingContext struct { Logger logr.Logger Req *LLMRequest PodsSnapshot []Pod + // PrefixHashes is a list of prefix hashes of the request prompt broken into blocks. + PrefixHashes []BlockHash + // A map of server to its longest prefix cache match length. + PrefixCacheServers map[ServerID]int } func (pm *PodMetrics) String() string { @@ -106,3 +111,12 @@ func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []Pod { type Result struct { TargetPod Pod } + +// BlockHash is a hash of the block of request body. +type BlockHash uint64 + +type ServerID types.NamespacedName + +func (s ServerID) String() string { + return types.NamespacedName(s).String() +} From 2bfc15ae12348ae861ad03c80759be3623fa1f27 Mon Sep 17 00:00:00 2001 From: Cong Liu Date: Fri, 2 May 2025 21:57:50 +0000 Subject: [PATCH 2/9] Replace scheduler v2 with config v2 --- cmd/epp/main.go | 24 ++++++++++++------- .../{scheduler_v2.go => config_v2.go} | 19 +++++++++------ pkg/epp/scheduling/plugins/prefix/plugin.go | 4 ++-- 3 files changed, 30 insertions(+), 17 deletions(-) rename pkg/epp/scheduling/{scheduler_v2.go => config_v2.go} (85%) diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 728363047..c7a9ca6e3 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -112,19 +112,26 @@ var ( setupLog = ctrl.Log.WithName("setup") // Environment variables - schedulerV2 = envutil.GetEnvString("EXPERIMENTAL_USE_SCHEDULE_V2", "false", setupLog) - prefixCacheConfig = loadPrefixCacheConfig() + schedulerV2 = envutil.GetEnvString("EXPERIMENTAL_USE_SCHEDULER_V2", "false", setupLog) ) func loadPrefixCacheConfig() prefix.Config { - // logger := zap.New(zap.RawZapOpts(uberzap.AddCaller())) - // log.SetLogger(logger) baseLogger := log.Log.WithName("env-config") return prefix.Config{ - HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultCacheBlockSize, baseLogger), + HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultHashBlockSize, baseLogger), MaxPrefixBlocksToMatch: envutil.GetEnvInt("PREFIX_CACHE_MAX_PREFIX_BLOCKS", prefix.DefaultMaxPrefixBlocks, baseLogger), - LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_MAX_CACHE_SIZE_MB", prefix.DefaultLRUIndexerCapacity, baseLogger), + LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY", prefix.DefaultLRUIndexerCapacity, baseLogger), + } +} + +func loadSchedulingScorerWeights() scheduling.ScorerWeights { + baseLogger := log.Log.WithName("env-config") + + return scheduling.ScorerWeights{ + Prefix: envutil.GetEnvInt("PREFIX_CACHE_SCORE_WEIGHT", 3, baseLogger), + Queue: envutil.GetEnvInt("QUEUE_SCORE_WEIGHT", 2, baseLogger), + KVCache: envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", 1, baseLogger), } } @@ -192,8 +199,9 @@ func run() error { scheduler := scheduling.NewScheduler(datastore) if schedulerV2 == "true" { - setupLog.Info("Creating scheduler with prefixCache plugin", "prefix cache config", prefixCacheConfig) - scheduler = scheduling.NewSchedulerV2(datastore, prefixCacheConfig) + schedConfig := scheduling.CreateConfig(loadSchedulingScorerWeights(), loadPrefixCacheConfig()) + setupLog.Info("Creating scheduler", "config", *schedConfig) + scheduler = scheduling.NewSchedulerWithConfig(datastore, schedConfig) } serverRunner := &runserver.ExtProcServerRunner{ GrpcPort: *grpcPort, diff --git a/pkg/epp/scheduling/scheduler_v2.go b/pkg/epp/scheduling/config_v2.go similarity index 85% rename from pkg/epp/scheduling/scheduler_v2.go rename to pkg/epp/scheduling/config_v2.go index 7a3da3b3a..4992de637 100644 --- a/pkg/epp/scheduling/scheduler_v2.go +++ b/pkg/epp/scheduling/config_v2.go @@ -26,27 +26,32 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) -func NewSchedulerV2(datastore Datastore, prefixConfig prefix.Config) *Scheduler { +func CreateConfig(weights ScorerWeights, prefixConfig prefix.Config) *SchedulerConfig { prefixPlugin := prefix.New(prefixConfig) queuePlugin := &scorer.QueueScorer{} kvCachePlugin := &scorer.KVCacheScorer{} - configV2 := &SchedulerConfig{ + config := &SchedulerConfig{ PreSchedulePlugins: []plugins.PreSchedule{prefixPlugin}, PostSchedulePlugins: []plugins.PostSchedule{prefixPlugin}, Scorers: map[plugins.Scorer]int{ - prefixPlugin: 3, - queuePlugin: 1, - kvCachePlugin: 1, + prefixPlugin: weights.Prefix, + queuePlugin: weights.Queue, + kvCachePlugin: weights.KVCache, }, Filters: []plugins.Filter{&sheddableRequestFilterV2{}}, Picker: &picker.MaxScorePicker{}, } - return NewSchedulerWithConfig(datastore, configV2) + return config } -type sheddableRequestFilterV2 struct { +type ScorerWeights struct { + Prefix int + Queue int + KVCache int } +type sheddableRequestFilterV2 struct{} + func (p *sheddableRequestFilterV2) Name() string { return "sheddableRequestFilterV2" } diff --git a/pkg/epp/scheduling/plugins/prefix/plugin.go b/pkg/epp/scheduling/plugins/prefix/plugin.go index 2e748af82..39ccf886e 100644 --- a/pkg/epp/scheduling/plugins/prefix/plugin.go +++ b/pkg/epp/scheduling/plugins/prefix/plugin.go @@ -33,7 +33,7 @@ const ( // We optimistically search more than one to give more candidates for the scheduler to choose. DefaultNumServersToMatch = 2 // vLLM default token block size is 16, and a good guess of average characters per token is 4. - DefaultCacheBlockSize = 64 + DefaultHashBlockSize = 64 DefaultMaxPrefixBlocks = 128 // Assume each request reaches DefaultMaxPrefixBlocks = 128, and each BlockHash is cached onto 2 // servers due to load balancing, then it requires 256 entries per request. @@ -60,7 +60,7 @@ type Config struct { } var DefaultConfig = Config{ - HashBlockSize: DefaultCacheBlockSize, + HashBlockSize: DefaultHashBlockSize, MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, LRUIndexerCapacity: DefaultLRUIndexerCapacity, } From 8a1f89f5a53e859598472de449ec0d3a46b37ac6 Mon Sep 17 00:00:00 2001 From: Cong Liu Date: Thu, 8 May 2025 04:07:33 +0000 Subject: [PATCH 3/9] Add score weight to XXScorerConfig --- cmd/epp/main.go | 32 +++++++------ pkg/epp/scheduling/config_v2.go | 49 +++++++++++++------- pkg/epp/scheduling/plugins/prefix/plugin.go | 18 +++---- pkg/epp/scheduling/plugins/scorer/kvcache.go | 8 ++++ pkg/epp/scheduling/plugins/scorer/queue.go | 8 ++++ 5 files changed, 76 insertions(+), 39 deletions(-) diff --git a/cmd/epp/main.go b/cmd/epp/main.go index c7a9ca6e3..c67909ffd 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -45,6 +45,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics/collectors" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/prefix" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -112,7 +113,8 @@ var ( setupLog = ctrl.Log.WithName("setup") // Environment variables - schedulerV2 = envutil.GetEnvString("EXPERIMENTAL_USE_SCHEDULER_V2", "false", setupLog) + schedulerV2 = envutil.GetEnvString("EXPERIMENTAL_USE_SCHEDULER_V2", "false", setupLog) + prefixCacheScheduling = envutil.GetEnvString("ENABLE_PREFIX_CACHE_SCHEDULING", "false", setupLog) ) func loadPrefixCacheConfig() prefix.Config { @@ -125,16 +127,6 @@ func loadPrefixCacheConfig() prefix.Config { } } -func loadSchedulingScorerWeights() scheduling.ScorerWeights { - baseLogger := log.Log.WithName("env-config") - - return scheduling.ScorerWeights{ - Prefix: envutil.GetEnvInt("PREFIX_CACHE_SCORE_WEIGHT", 3, baseLogger), - Queue: envutil.GetEnvInt("QUEUE_SCORE_WEIGHT", 2, baseLogger), - KVCache: envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", 1, baseLogger), - } -} - func main() { if err := run(); err != nil { os.Exit(1) @@ -199,9 +191,21 @@ func run() error { scheduler := scheduling.NewScheduler(datastore) if schedulerV2 == "true" { - schedConfig := scheduling.CreateConfig(loadSchedulingScorerWeights(), loadPrefixCacheConfig()) - setupLog.Info("Creating scheduler", "config", *schedConfig) - scheduler = scheduling.NewSchedulerWithConfig(datastore, schedConfig) + queueConfig := scorer.QueueScorerConfig{ + Weight: envutil.GetEnvInt("QUEUE_SCORE_WEIGHT", scorer.DefaultQueueScorerWeight, setupLog), + } + kvCacheConfig := scorer.KVCacheScorerConfig{ + Weight: envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", scorer.DefaultKVCacheScorerWeight, setupLog), + } + schedConfigOpts := []scheduling.ConfigOption{ + scheduling.WithQueuePlugin(queueConfig), + scheduling.WithKVCachePlugin(kvCacheConfig), + } + if prefixCacheScheduling == "true" { + schedConfigOpts = append(schedConfigOpts, scheduling.WithPrefixPlugin(loadPrefixCacheConfig())) + } + schedulerConfig := scheduling.CreateConfig(schedConfigOpts...) + scheduler = scheduling.NewSchedulerWithConfig(datastore, schedulerConfig) } serverRunner := &runserver.ExtProcServerRunner{ GrpcPort: *grpcPort, diff --git a/pkg/epp/scheduling/config_v2.go b/pkg/epp/scheduling/config_v2.go index 4992de637..0ad96ee9e 100644 --- a/pkg/epp/scheduling/config_v2.go +++ b/pkg/epp/scheduling/config_v2.go @@ -26,28 +26,43 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) -func CreateConfig(weights ScorerWeights, prefixConfig prefix.Config) *SchedulerConfig { - prefixPlugin := prefix.New(prefixConfig) - queuePlugin := &scorer.QueueScorer{} - kvCachePlugin := &scorer.KVCacheScorer{} +func CreateConfig(opts ...ConfigOption) *SchedulerConfig { config := &SchedulerConfig{ - PreSchedulePlugins: []plugins.PreSchedule{prefixPlugin}, - PostSchedulePlugins: []plugins.PostSchedule{prefixPlugin}, - Scorers: map[plugins.Scorer]int{ - prefixPlugin: weights.Prefix, - queuePlugin: weights.Queue, - kvCachePlugin: weights.KVCache, - }, - Filters: []plugins.Filter{&sheddableRequestFilterV2{}}, - Picker: &picker.MaxScorePicker{}, + PreSchedulePlugins: []plugins.PreSchedule{}, + PostSchedulePlugins: []plugins.PostSchedule{}, + Scorers: map[plugins.Scorer]int{}, + Filters: []plugins.Filter{&sheddableRequestFilterV2{}}, + Picker: &picker.MaxScorePicker{}, + } + for _, opt := range opts { + opt(config) } return config } -type ScorerWeights struct { - Prefix int - Queue int - KVCache int +type ConfigOption func(*SchedulerConfig) + +func WithPrefixPlugin(prefixConfig prefix.Config) ConfigOption { + return func(cfg *SchedulerConfig) { + prefixPlugin := prefix.New(prefixConfig) + cfg.PreSchedulePlugins = append(cfg.PreSchedulePlugins, prefixPlugin) + cfg.PostSchedulePlugins = append(cfg.PostSchedulePlugins, prefixPlugin) + cfg.Scorers[prefixPlugin] = prefixConfig.Weight + } +} + +func WithQueuePlugin(queueConfig scorer.QueueScorerConfig) ConfigOption { + return func(cfg *SchedulerConfig) { + queuePlugin := &scorer.QueueScorer{} + cfg.Scorers[queuePlugin] = queueConfig.Weight + } +} + +func WithKVCachePlugin(kvCacheConfig scorer.KVCacheScorerConfig) ConfigOption { + return func(cfg *SchedulerConfig) { + kvCachePlugin := &scorer.KVCacheScorer{} + cfg.Scorers[kvCachePlugin] = kvCacheConfig.Weight + } } type sheddableRequestFilterV2 struct{} diff --git a/pkg/epp/scheduling/plugins/prefix/plugin.go b/pkg/epp/scheduling/plugins/prefix/plugin.go index 39ccf886e..a80cf2254 100644 --- a/pkg/epp/scheduling/plugins/prefix/plugin.go +++ b/pkg/epp/scheduling/plugins/prefix/plugin.go @@ -27,6 +27,7 @@ import ( ) const ( + DefaultScorerWeight = 1 // Attempt to return DefaultNumServersToMatch servers with their longest prefix match length. // Why not just return the server with longest prefix match? // It may not be the optimal choice, e.g., it may have a high queue depth. @@ -49,6 +50,7 @@ const ( ) type Config struct { + Weight int // The input prompt is broken into sizes of HashBlockSize to calculate block hashes . Requests // with length shorter than the block size will be ignored. HashBlockSize int @@ -65,7 +67,7 @@ var DefaultConfig = Config{ LRUIndexerCapacity: DefaultLRUIndexerCapacity, } -type plugin struct { +type Plugin struct { Config indexer Indexer } @@ -75,26 +77,26 @@ type Indexer interface { Add(hashes []types.BlockHash, server types.ServerID) } -func New(config Config) *plugin { - m := &plugin{ +func New(config Config) *Plugin { + m := &Plugin{ Config: config, indexer: newIndexer(config.LRUIndexerCapacity), } return m } -func (m *plugin) Name() string { +func (m *Plugin) Name() string { return "prefixCache" } -func (m *plugin) PreSchedule(ctx *types.SchedulingContext) { +func (m *Plugin) PreSchedule(ctx *types.SchedulingContext) { ctx.PrefixHashes = hashPrompt(ctx, m.HashBlockSize, m.MaxPrefixBlocksToMatch) ctx.PrefixCacheServers = m.matchLongestPrefix(ctx, DefaultNumServersToMatch) ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("PreSchedule, cached servers: %+v", ctx.PrefixCacheServers), "hashes", ctx.PrefixHashes) } // If a request was routed to a server, record it in the cache: -func (m *plugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) { +func (m *Plugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) { targetPod := res.TargetPod.GetPod() m.indexer.Add(ctx.PrefixHashes, types.ServerID(targetPod.NamespacedName)) total := len(ctx.PrefixHashes) @@ -102,7 +104,7 @@ func (m *plugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) { metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize) } -func (m *plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { +func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { total := len(ctx.PrefixHashes) podScoreFunc := func(ctx *types.SchedulingContext, pod types.Pod) float64 { if total == 0 { @@ -120,7 +122,7 @@ func (m *plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types } // matchLongestPrefix returns a map of servers and length of prefix that each server caches. -func (m *plugin) matchLongestPrefix(ctx *types.SchedulingContext, numServers int) map[types.ServerID]int { +func (m *Plugin) matchLongestPrefix(ctx *types.SchedulingContext, numServers int) map[types.ServerID]int { if numServers > len(ctx.PodsSnapshot) { numServers = len(ctx.PodsSnapshot) } diff --git a/pkg/epp/scheduling/plugins/scorer/kvcache.go b/pkg/epp/scheduling/plugins/scorer/kvcache.go index 0877691d1..762f421bf 100644 --- a/pkg/epp/scheduling/plugins/scorer/kvcache.go +++ b/pkg/epp/scheduling/plugins/scorer/kvcache.go @@ -20,6 +20,14 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) +const ( + DefaultKVCacheScorerWeight = 1 +) + +type KVCacheScorerConfig struct { + Weight int +} + type KVCacheScorer struct{} func (ss *KVCacheScorer) Name() string { diff --git a/pkg/epp/scheduling/plugins/scorer/queue.go b/pkg/epp/scheduling/plugins/scorer/queue.go index 3df9d4140..0aa8ffd09 100644 --- a/pkg/epp/scheduling/plugins/scorer/queue.go +++ b/pkg/epp/scheduling/plugins/scorer/queue.go @@ -22,6 +22,14 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) +const ( + DefaultQueueScorerWeight = 1 +) + +type QueueScorerConfig struct { + Weight int +} + type QueueScorer struct{} func (q *QueueScorer) Name() string { From 9447ad2650db492d097600c9274696e65b88a2dc Mon Sep 17 00:00:00 2001 From: Cong Liu Date: Thu, 8 May 2025 05:33:40 +0000 Subject: [PATCH 4/9] Address comments --- cmd/epp/main.go | 1 + pkg/epp/metrics/metrics.go | 41 +++----- pkg/epp/scheduling/config.go | 56 +++++++++++ pkg/epp/scheduling/config_v2.go | 82 ---------------- pkg/epp/scheduling/plugins/filter/filter.go | 8 -- pkg/epp/scheduling/plugins/prefix/indexer.go | 21 ++--- .../scheduling/plugins/prefix/indexer_test.go | 11 +-- .../scheduling/plugins/prefix/linked_list.go | 8 +- pkg/epp/scheduling/plugins/prefix/plugin.go | 93 ++++++++++++------- .../scheduling/plugins/prefix/plugin_test.go | 35 ++++--- pkg/epp/scheduling/plugins/scorer/kvcache.go | 2 +- pkg/epp/scheduling/types/types.go | 43 +++++---- 12 files changed, 193 insertions(+), 208 deletions(-) delete mode 100644 pkg/epp/scheduling/config_v2.go diff --git a/cmd/epp/main.go b/cmd/epp/main.go index c67909ffd..11fa12e97 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -121,6 +121,7 @@ func loadPrefixCacheConfig() prefix.Config { baseLogger := log.Log.WithName("env-config") return prefix.Config{ + Weight: envutil.GetEnvInt("PREFIX_CACHE_WEIGHT", prefix.DefaultScorerWeight, baseLogger), HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultHashBlockSize, baseLogger), MaxPrefixBlocksToMatch: envutil.GetEnvInt("PREFIX_CACHE_MAX_PREFIX_BLOCKS", prefix.DefaultMaxPrefixBlocks, baseLogger), LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY", prefix.DefaultLRUIndexerCapacity, baseLogger), diff --git a/pkg/epp/metrics/metrics.go b/pkg/epp/metrics/metrics.go index 1baa3099f..84f0f1f9a 100644 --- a/pkg/epp/metrics/metrics.go +++ b/pkg/epp/metrics/metrics.go @@ -18,7 +18,6 @@ package metrics import ( "context" - "runtime/debug" "sync" "time" @@ -210,17 +209,6 @@ var ( []string{"plugin_type", "plugin_name"}, ) - // Info Metrics - InferenceExtensionInfo = compbasemetrics.NewGaugeVec( - &compbasemetrics.GaugeOpts{ - Subsystem: InferenceExtension, - Name: "info", - Help: "General information of the current build of Inference Extension.", - StabilityLevel: compbasemetrics.ALPHA, - }, - []string{"commit"}, - ) - // Prefix indexer Metrics PrefixCacheSize = compbasemetrics.NewGaugeVec( &compbasemetrics.GaugeOpts{ @@ -254,6 +242,17 @@ var ( }, []string{}, ) + + // Info Metrics + InferenceExtensionInfo = compbasemetrics.NewGaugeVec( + &compbasemetrics.GaugeOpts{ + Subsystem: InferenceExtension, + Name: "info", + Help: "General information of the current build of Inference Extension.", + StabilityLevel: compbasemetrics.ALPHA, + }, + []string{"commit"}, + ) ) var registerMetrics sync.Once @@ -414,21 +413,3 @@ func RecordInferenceExtensionInfo() { InferenceExtensionInfo.WithLabelValues(CommitSHA).Set(1) } } - -func init() { - info, ok := debug.ReadBuildInfo() - if !ok { - return - } - - var Commit = func(i *debug.BuildInfo) string { - for _, setting := range i.Settings { - if setting.Key == "vcs.revision" { - return setting.Value - } - } - return "" - }(info) - - CommitSHA = Commit -} diff --git a/pkg/epp/scheduling/config.go b/pkg/epp/scheduling/config.go index a4f4c2950..5a1fb1658 100644 --- a/pkg/epp/scheduling/config.go +++ b/pkg/epp/scheduling/config.go @@ -18,6 +18,9 @@ package scheduling import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/prefix" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" ) // NewSchedulerConfig creates a new SchedulerConfig object with the given plugins. @@ -40,3 +43,56 @@ type SchedulerConfig struct { picker plugins.Picker postSchedulePlugins []plugins.PostSchedule } + +var defPlugin = &defaultPlugin{} + +// When the scheduler is initialized with NewScheduler function, this config will be used as default. +// it's possible to call NewSchedulerWithConfig to pass a different argument. + +// For build time plugins changes, it's recommended to change the defaultConfig variable in this file. +var defaultConfig = &SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{}, + filters: []plugins.Filter{defPlugin}, + scorers: map[plugins.Scorer]int{}, + picker: defPlugin, + postSchedulePlugins: []plugins.PostSchedule{}, +} + +func CreateConfig(opts ...ConfigOption) *SchedulerConfig { + config := &SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{}, + postSchedulePlugins: []plugins.PostSchedule{}, + scorers: map[plugins.Scorer]int{}, + filters: []plugins.Filter{&sheddableRequestFilterV2{}}, + picker: &picker.MaxScorePicker{}, + } + for _, opt := range opts { + opt(config) + } + return config +} + +type ConfigOption func(*SchedulerConfig) + +func WithPrefixPlugin(prefixConfig prefix.Config) ConfigOption { + return func(cfg *SchedulerConfig) { + prefixPlugin := prefix.New(prefixConfig) + cfg.preSchedulePlugins = append(cfg.preSchedulePlugins, prefixPlugin) + cfg.postSchedulePlugins = append(cfg.postSchedulePlugins, prefixPlugin) + cfg.scorers[prefixPlugin] = prefixConfig.Weight + } +} + +func WithQueuePlugin(queueConfig scorer.QueueScorerConfig) ConfigOption { + return func(cfg *SchedulerConfig) { + queuePlugin := &scorer.QueueScorer{} + cfg.scorers[queuePlugin] = queueConfig.Weight + } +} + +func WithKVCachePlugin(kvCacheConfig scorer.KVCacheScorerConfig) ConfigOption { + return func(cfg *SchedulerConfig) { + kvCachePlugin := &scorer.KVCacheScorer{} + cfg.scorers[kvCachePlugin] = kvCacheConfig.Weight + } +} diff --git a/pkg/epp/scheduling/config_v2.go b/pkg/epp/scheduling/config_v2.go deleted file mode 100644 index 0ad96ee9e..000000000 --- a/pkg/epp/scheduling/config_v2.go +++ /dev/null @@ -1,82 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Package scheduling implements request scheduling algorithms. -package scheduling - -import ( - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/prefix" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" -) - -func CreateConfig(opts ...ConfigOption) *SchedulerConfig { - config := &SchedulerConfig{ - PreSchedulePlugins: []plugins.PreSchedule{}, - PostSchedulePlugins: []plugins.PostSchedule{}, - Scorers: map[plugins.Scorer]int{}, - Filters: []plugins.Filter{&sheddableRequestFilterV2{}}, - Picker: &picker.MaxScorePicker{}, - } - for _, opt := range opts { - opt(config) - } - return config -} - -type ConfigOption func(*SchedulerConfig) - -func WithPrefixPlugin(prefixConfig prefix.Config) ConfigOption { - return func(cfg *SchedulerConfig) { - prefixPlugin := prefix.New(prefixConfig) - cfg.PreSchedulePlugins = append(cfg.PreSchedulePlugins, prefixPlugin) - cfg.PostSchedulePlugins = append(cfg.PostSchedulePlugins, prefixPlugin) - cfg.Scorers[prefixPlugin] = prefixConfig.Weight - } -} - -func WithQueuePlugin(queueConfig scorer.QueueScorerConfig) ConfigOption { - return func(cfg *SchedulerConfig) { - queuePlugin := &scorer.QueueScorer{} - cfg.Scorers[queuePlugin] = queueConfig.Weight - } -} - -func WithKVCachePlugin(kvCacheConfig scorer.KVCacheScorerConfig) ConfigOption { - return func(cfg *SchedulerConfig) { - kvCachePlugin := &scorer.KVCacheScorer{} - cfg.Scorers[kvCachePlugin] = kvCacheConfig.Weight - } -} - -type sheddableRequestFilterV2 struct{} - -func (p *sheddableRequestFilterV2) Name() string { - return "sheddableRequestFilterV2" -} - -func (p *sheddableRequestFilterV2) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { - if ctx.Req.Critical { - // Allow all pods to pass through if the request is critical, even if all pods reach their capacity. - return pods - } - - // Only allow pods that have enough capacity to handle the request. - return filter.HasCapacityFilter.Filter(ctx, pods) -} diff --git a/pkg/epp/scheduling/plugins/filter/filter.go b/pkg/epp/scheduling/plugins/filter/filter.go index 67ce764dd..86620aa9f 100644 --- a/pkg/epp/scheduling/plugins/filter/filter.go +++ b/pkg/epp/scheduling/plugins/filter/filter.go @@ -256,14 +256,6 @@ var HasCapacityFilter = &baseFilter{ filter: toFilterFunc(queueThresholdPredicate(config.Conf.QueueThresholdCritical).and(kvCacheThresholdPredicate(config.Conf.KVCacheThreshold))), } -// NoopFilter is a filter that does not filter out any pods. -var NoopFilter = &baseFilter{ - name: "noop", - filter: toFilterFunc(func(req *types.LLMRequest, pod types.Pod) bool { - return true - }), -} - // podPredicate is a filter function to check whether a pod is desired. type podPredicate func(req *types.LLMRequest, pod types.Pod) bool diff --git a/pkg/epp/scheduling/plugins/prefix/indexer.go b/pkg/epp/scheduling/plugins/prefix/indexer.go index cae7739bd..5f42fa37a 100644 --- a/pkg/epp/scheduling/plugins/prefix/indexer.go +++ b/pkg/epp/scheduling/plugins/prefix/indexer.go @@ -24,14 +24,13 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) func newIndexer(maxCacheSize int) *indexer { t := &indexer{ maxCacheSize: maxCacheSize, - table: make(map[types.BlockHash]map[types.ServerID]*node), + table: make(map[BlockHash]map[ServerID]*node), list: newLinkedList(), } go t.ReportCacheSize(time.Second) @@ -43,15 +42,15 @@ func newIndexer(maxCacheSize int) *indexer { type indexer struct { mu sync.RWMutex maxCacheSize int - table map[types.BlockHash]map[types.ServerID]*node // from any prefix cache to the cache entry to find the server - list *linkedList // LRU list to keep track of the order of entries + table map[BlockHash]map[ServerID]*node // from any prefix cache to the cache entry to find the server + list *linkedList // LRU list to keep track of the order of entries } // Get returns the set of servers that have the given prefix hash cached. -func (i *indexer) Get(hash types.BlockHash) map[types.ServerID]bool { +func (i *indexer) Get(hash BlockHash) map[ServerID]bool { i.mu.RLock() defer i.mu.RUnlock() - res := map[types.ServerID]bool{} + res := map[ServerID]bool{} for server := range i.table[hash] { res[server] = true } @@ -61,7 +60,7 @@ func (i *indexer) Get(hash types.BlockHash) map[types.ServerID]bool { // Add adds a list of prefix hashes of a single request to the server the request was sent to. // The intuition is that this server is likely to have the prefix cached, so next time a request // sharing the longest prefix should be sent to the same server to take advantage of the cache hit. -func (i *indexer) Add(hashes []types.BlockHash, server types.ServerID) { +func (i *indexer) Add(hashes []BlockHash, server ServerID) { i.mu.Lock() defer i.mu.Unlock() for _, hash := range hashes { @@ -69,7 +68,7 @@ func (i *indexer) Add(hashes []types.BlockHash, server types.ServerID) { } } -func (i *indexer) check(hash types.BlockHash, server types.ServerID) (*node, bool) { +func (i *indexer) check(hash BlockHash, server ServerID) (*node, bool) { servers, ok := i.table[hash] if !ok { return nil, false @@ -78,7 +77,7 @@ func (i *indexer) check(hash types.BlockHash, server types.ServerID) (*node, boo return n, ok } -func (i *indexer) add(hash types.BlockHash, server types.ServerID) { +func (i *indexer) add(hash BlockHash, server ServerID) { node, exists := i.check(hash, server) if exists { i.list.moveToTail(node) @@ -87,7 +86,7 @@ func (i *indexer) add(hash types.BlockHash, server types.ServerID) { } } -func (i *indexer) create(hash types.BlockHash, server types.ServerID) { +func (i *indexer) create(hash BlockHash, server ServerID) { n := &node{ hash: hash, server: server, @@ -99,7 +98,7 @@ func (i *indexer) create(hash types.BlockHash, server types.ServerID) { } if _, ok := i.table[hash]; !ok { - i.table[hash] = make(map[types.ServerID]*node) + i.table[hash] = make(map[ServerID]*node) } i.table[hash][server] = n i.list.add(n) diff --git a/pkg/epp/scheduling/plugins/prefix/indexer_test.go b/pkg/epp/scheduling/plugins/prefix/indexer_test.go index 592b7c3e3..0531351e1 100644 --- a/pkg/epp/scheduling/plugins/prefix/indexer_test.go +++ b/pkg/epp/scheduling/plugins/prefix/indexer_test.go @@ -19,17 +19,16 @@ import ( "testing" "github.com/stretchr/testify/assert" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) func TestIndexer_AddAndGet(t *testing.T) { cache := newIndexer(2) - hash1 := types.BlockHash(1) - server := types.ServerID{Namespace: "default", Name: "server1"} + hash1 := BlockHash(1) + server := ServerID{Namespace: "default", Name: "server1"} // Add an entry to the cache - cache.Add([]types.BlockHash{hash1}, server) + cache.Add([]BlockHash{hash1}, server) // Retrieve the entry assert.Equal(t, 1, cache.list.size, "Cache size should be 1 after adding an entry") @@ -37,10 +36,10 @@ func TestIndexer_AddAndGet(t *testing.T) { assert.Contains(t, servers, server, "Cache should contain the added server") // Add another entry to the cache, the cache size should be incremented to 2. - cache.Add([]types.BlockHash{types.BlockHash(2)}, server) + cache.Add([]BlockHash{BlockHash(2)}, server) assert.Equal(t, 2, cache.list.size, "Cache size should be 2 after adding an entry") // Add another entry to the cache, which should evict the first one due to max size. - cache.Add([]types.BlockHash{types.BlockHash(3)}, server) + cache.Add([]BlockHash{BlockHash(3)}, server) assert.Equal(t, 2, cache.list.size, "Cache size should still be 2 after adding an entry") } diff --git a/pkg/epp/scheduling/plugins/prefix/linked_list.go b/pkg/epp/scheduling/plugins/prefix/linked_list.go index 9c9b82103..de6b6a163 100644 --- a/pkg/epp/scheduling/plugins/prefix/linked_list.go +++ b/pkg/epp/scheduling/plugins/prefix/linked_list.go @@ -16,10 +16,6 @@ limitations under the License. package prefix -import ( - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" -) - type linkedList struct { dummyHead *node // The head of the linked list (dummy node). tail *node // The tail of the linked list. @@ -40,8 +36,8 @@ func newLinkedList() *linkedList { type node struct { prev *node next *node - server types.ServerID - hash types.BlockHash + server ServerID + hash BlockHash } // add adds a node to the end of the linked list. diff --git a/pkg/epp/scheduling/plugins/prefix/plugin.go b/pkg/epp/scheduling/plugins/prefix/plugin.go index a80cf2254..2cd5f1d4f 100644 --- a/pkg/epp/scheduling/plugins/prefix/plugin.go +++ b/pkg/epp/scheduling/plugins/prefix/plugin.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/cespare/xxhash/v2" + k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -34,18 +35,25 @@ const ( // We optimistically search more than one to give more candidates for the scheduler to choose. DefaultNumServersToMatch = 2 // vLLM default token block size is 16, and a good guess of average characters per token is 4. - DefaultHashBlockSize = 64 + DefaultHashBlockSize = 64 + // The maximum number of blocks to match. Two long requests with the same prefix up to this + // limit will be indistinguishable. + // This parameter provides a trade-off between cache size, prefix matching speed and matching + // accuracy. Use a small value if most requests are short to reduce cache size and speed up the + // matching process. Use a large value if most requests are long to increase the matching accuracy. DefaultMaxPrefixBlocks = 128 - // Assume each request reaches DefaultMaxPrefixBlocks = 128, and each BlockHash is cached onto 2 - // servers due to load balancing, then it requires 256 entries per request. - // According to the estimates in indexer.estimateEntrySize(), the size of each entry is 348 bytes. - // So each request will cost 89,088 bytes ~ 90KB. - // Therefore, to cache 50k requests, we need 50K * 90KB = 4.5GB. Assuming 500 requests per - // second, a 4.5 GB cache can hold at least last 100 seconds of requests. - // Note in practice, the size of each entry will be much smaller (shorter NamespacedNames, - // shorter prompt). And due to the prefix cache hit, the number of unique cache entries will be - // much smaller per request. Therefore the actual cache size will be much smaller. - // TODO: Add some guidance for choosing the right size. + // The indexer is an approximation to the actual prefix cache state on the model servers. + // A small capacity ensures a high accuracy of cache hit on the model server, but it will + // increase the chance of false negatives. A high capacity does the opposite. + // To properly size this, consider the sum of the total number of cache entries on all model + // servers. Consider the llama3 8B model on 3 H100 80GB GPUs. The size of the model weight is + // about 16GB. Assume 50% of the remaining HBM is used for caching prefixes, we have 32GB. Each + // token is about 128KB in size, so we can cache 250K tokens. Using the default block size of 16 + // in vLLM, we will have 250K / 16 = 15.6K blocks. In total we have 15.6K * 3 = 46.8K blocks, or + // roughly 50K. + // How much memory space does it require to hold the 50K block hashes? + // According to the estimates in indexer.estimateEntrySize(), the size of each entry is + // approximately 348 bytes. So in total we have 50K * 348 = 17.4MB. DefaultLRUIndexerCapacity = 50000 ) @@ -73,8 +81,25 @@ type Plugin struct { } type Indexer interface { - Get(hash types.BlockHash) map[types.ServerID]bool - Add(hashes []types.BlockHash, server types.ServerID) + Get(hash BlockHash) map[ServerID]bool + Add(hashes []BlockHash, server ServerID) +} + +// This is the state of this plugin to be used during a scheduling cycle. +type SchedulingContextState struct { + // PrefixHashes is a list of prefix hashes of the request prompt broken into blocks. + PrefixHashes []BlockHash + // A map of server to its longest prefix cache match length. + PrefixCacheServers map[ServerID]int +} + +// BlockHash is a hash of the block of request body. +type BlockHash uint64 + +type ServerID k8stypes.NamespacedName + +func (s ServerID) String() string { + return k8stypes.NamespacedName(s).String() } func New(config Config) *Plugin { @@ -90,27 +115,33 @@ func (m *Plugin) Name() string { } func (m *Plugin) PreSchedule(ctx *types.SchedulingContext) { - ctx.PrefixHashes = hashPrompt(ctx, m.HashBlockSize, m.MaxPrefixBlocksToMatch) - ctx.PrefixCacheServers = m.matchLongestPrefix(ctx, DefaultNumServersToMatch) - ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("PreSchedule, cached servers: %+v", ctx.PrefixCacheServers), "hashes", ctx.PrefixHashes) + hashes := hashPrompt(ctx, m.HashBlockSize, m.MaxPrefixBlocksToMatch) + state := SchedulingContextState{ + PrefixHashes: hashes, + PrefixCacheServers: m.matchLongestPrefix(ctx, hashes, DefaultNumServersToMatch), + } + ctx.SetPluginState(types.PluginName(m.Name()), state) + ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("PreSchedule, cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes) } // If a request was routed to a server, record it in the cache: func (m *Plugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) { targetPod := res.TargetPod.GetPod() - m.indexer.Add(ctx.PrefixHashes, types.ServerID(targetPod.NamespacedName)) - total := len(ctx.PrefixHashes) - matchLen := ctx.PrefixCacheServers[types.ServerID(targetPod.NamespacedName)] + state := ctx.GetPluginState(types.PluginName(m.Name())).(SchedulingContextState) + m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName)) + total := len(state.PrefixHashes) + matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)] metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize) } func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { - total := len(ctx.PrefixHashes) + state := ctx.GetPluginState(types.PluginName(m.Name())).(SchedulingContextState) + total := len(state.PrefixHashes) podScoreFunc := func(ctx *types.SchedulingContext, pod types.Pod) float64 { if total == 0 { return 0 } - matchLen := ctx.PrefixCacheServers[types.ServerID(pod.GetPod().NamespacedName)] + matchLen := state.PrefixCacheServers[ServerID(pod.GetPod().NamespacedName)] return float64(matchLen) / float64(total) } @@ -122,18 +153,18 @@ func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types } // matchLongestPrefix returns a map of servers and length of prefix that each server caches. -func (m *Plugin) matchLongestPrefix(ctx *types.SchedulingContext, numServers int) map[types.ServerID]int { +func (m *Plugin) matchLongestPrefix(ctx *types.SchedulingContext, hashes []BlockHash, numServers int) map[ServerID]int { if numServers > len(ctx.PodsSnapshot) { numServers = len(ctx.PodsSnapshot) } - res := make(map[types.ServerID]int) + res := make(map[ServerID]int) // Use a greedy strategy to search from the longest prefix. // NOTE: It's possible to further optimize this with a binary search. - for i := len(ctx.PrefixHashes) - 1; i >= 0 && len(res) < numServers; i-- { - hash := ctx.PrefixHashes[i] + for i := len(hashes) - 1; i >= 0 && len(res) < numServers; i-- { + hash := hashes[i] cachedServers := m.indexer.Get(hash) if len(cachedServers) > 0 { - ctx.Logger.V(logutil.VERBOSE).Info("Found cached servers", "cachedServers", cachedServers, "total # blocks", len(ctx.PrefixHashes), "longest prefix", i) + ctx.Logger.V(logutil.DEBUG).Info("Found cached servers", "cachedServers", cachedServers, "total # blocks", len(hashes), "longest prefix", i) for server := range cachedServers { // Update servers with their longest prefix match. // If we already found this server with longer prefix match, don't update it. @@ -149,7 +180,7 @@ func (m *Plugin) matchLongestPrefix(ctx *types.SchedulingContext, numServers int // hashPrompt divides the prompt into blocks and calculate the prefix cache for each block. // hash(0) is the hash of the model name, since different models generally don't share prefix cache. // For block i, hash(i) = hash(block i content, hash(i-1)). -func hashPrompt(ctx *types.SchedulingContext, cacheBlockSize int, maxPrefixBlocks int) []types.BlockHash { +func hashPrompt(ctx *types.SchedulingContext, cacheBlockSize int, maxPrefixBlocks int) []BlockHash { prompt := []byte(ctx.Req.Prompt) if len(prompt) < cacheBlockSize { ctx.Logger.V(logutil.DEBUG).Info("Request body too small for prefix cache", "size", len(prompt), "block size", cacheBlockSize) @@ -161,19 +192,19 @@ func hashPrompt(ctx *types.SchedulingContext, cacheBlockSize int, maxPrefixBlock } // Split the body into blocks of size cacheBlockSize. The +1 is to account for the model. // If the last block is smaller than cacheBlockSize, it will be ignored. - res := make([]types.BlockHash, 0, 1+len(prompt)/cacheBlockSize) + res := make([]BlockHash, 0, 1+len(prompt)/cacheBlockSize) // Add the model to the first block hash so that different models have different hashes even with the same body. - res = append(res, types.BlockHash(xxhash.Sum64String(ctx.Req.ResolvedTargetModel))) + res = append(res, BlockHash(xxhash.Sum64String(ctx.Req.ResolvedTargetModel))) for i := 0; i+cacheBlockSize <= len(prompt); i += cacheBlockSize { block := prompt[i : i+cacheBlockSize] prevBlockHash := res[len(res)-1] toHash := append(block, toBytes(prevBlockHash)...) - res = append(res, types.BlockHash(xxhash.Sum64(toHash))) + res = append(res, BlockHash(xxhash.Sum64(toHash))) } return res } -func toBytes(i types.BlockHash) []byte { +func toBytes(i BlockHash) []byte { bytes := make([]byte, 8) binary.LittleEndian.PutUint64(bytes, uint64(i)) return bytes diff --git a/pkg/epp/scheduling/plugins/prefix/plugin_test.go b/pkg/epp/scheduling/plugins/prefix/plugin_test.go index 47c6c7f18..9aa1dbf1c 100644 --- a/pkg/epp/scheduling/plugins/prefix/plugin_test.go +++ b/pkg/epp/scheduling/plugins/prefix/plugin_test.go @@ -30,11 +30,12 @@ func TestPrefixPlugin(t *testing.T) { } ctx := types.NewSchedulingContext(context.Background(), req1, pods) plugin.PreSchedule(ctx) - t.Logf("Hashes %+v, cached servers: %+v", ctx.PrefixHashes, ctx.PrefixCacheServers) + state := ctx.GetPluginState(types.PluginName(plugin.Name())).(SchedulingContextState) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 6, hash block size is 4, the last 2 characters are ignored. // Total hashes = 2 (the first one is for the model) - assert.Equal(t, 2, len(ctx.PrefixHashes), "number of hashes is incorrect") - assert.Equal(t, 0, len(ctx.PrefixCacheServers), "there shouldn't be any cached servers") + assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers") // Updated to use the new Score method signature scores := plugin.Score(ctx, pods) @@ -53,11 +54,12 @@ func TestPrefixPlugin(t *testing.T) { } ctx = types.NewSchedulingContext(context.Background(), req2, pods) plugin.PreSchedule(ctx) - t.Logf("Hashes %+v, cached servers: %+v", ctx.PrefixHashes, ctx.PrefixCacheServers) + state = ctx.GetPluginState(types.PluginName(plugin.Name())).(SchedulingContextState) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 6, hash block size is 4, the last 2 characters are ignored. // Total hashes = 2 (the first one is for the model) - assert.Equal(t, 2, len(ctx.PrefixHashes), "number of hashes is incorrect") - assert.Equal(t, 0, len(ctx.PrefixCacheServers), "there shouldn't be any cached servers") + assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers") // Updated to use the new Score method signature scores = plugin.Score(ctx, pods) @@ -75,11 +77,12 @@ func TestPrefixPlugin(t *testing.T) { } ctx = types.NewSchedulingContext(context.Background(), req3, pods) plugin.PreSchedule(ctx) - t.Logf("Hashes %+v, cached servers: %+v", ctx.PrefixHashes, ctx.PrefixCacheServers) + state = ctx.GetPluginState(types.PluginName(plugin.Name())).(SchedulingContextState) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 8, hash block size is 4, so 2 hashes will be calculated. // Total hashes = 3 (the first one is for the model) - assert.Equal(t, 3, len(ctx.PrefixHashes), "number of hashes is incorrect") - assert.Equal(t, 1, len(ctx.PrefixCacheServers), "pod1 should have cached the aaaa prefix") + assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix") // Updated to use the new Score method signature scores = plugin.Score(ctx, pods) @@ -96,11 +99,12 @@ func TestPrefixPlugin(t *testing.T) { } ctx = types.NewSchedulingContext(context.Background(), req4, pods) plugin.PreSchedule(ctx) - t.Logf("Hashes %+v, cached servers: %+v", ctx.PrefixHashes, ctx.PrefixCacheServers) + state = ctx.GetPluginState(types.PluginName(plugin.Name())).(SchedulingContextState) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 8, hash block size is 4, so 2 hashes will be calculated. // Total hashes = 3 (the first one is for the model) - assert.Equal(t, 3, len(ctx.PrefixHashes), "number of hashes is incorrect") - assert.Equal(t, 0, len(ctx.PrefixCacheServers), "pod1 should have cached the aaaa prefix") + assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix") // Updated to use the new Score method signature scores = plugin.Score(ctx, pods) @@ -117,11 +121,12 @@ func TestPrefixPlugin(t *testing.T) { } ctx = types.NewSchedulingContext(context.Background(), req5, pods) plugin.PreSchedule(ctx) - t.Logf("Hashes %+v, cached servers: %+v", ctx.PrefixHashes, ctx.PrefixCacheServers) + state = ctx.GetPluginState(types.PluginName(plugin.Name())).(SchedulingContextState) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 12, hash block size is 4, so 3 hashes will be calculated. // Total hashes = 4 (the first one is for the model) - assert.Equal(t, 4, len(ctx.PrefixHashes), "number of hashes is incorrect") - assert.Equal(t, 1, len(ctx.PrefixCacheServers), "pod1 should have cached the aaaa prefix") + assert.Equal(t, 4, len(state.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix") // Updated to use the new Score method signature scores = plugin.Score(ctx, pods) diff --git a/pkg/epp/scheduling/plugins/scorer/kvcache.go b/pkg/epp/scheduling/plugins/scorer/kvcache.go index 762f421bf..0bb59ebe9 100644 --- a/pkg/epp/scheduling/plugins/scorer/kvcache.go +++ b/pkg/epp/scheduling/plugins/scorer/kvcache.go @@ -21,7 +21,7 @@ import ( ) const ( - DefaultKVCacheScorerWeight = 1 + DefaultKVCacheScorerWeight = 2 ) type KVCacheScorerConfig struct { diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 6cf399938..daf27bf83 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -19,9 +19,9 @@ package types import ( "context" "fmt" + "sync" "github.com/go-logr/logr" - "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" @@ -63,12 +63,26 @@ type SchedulingContext struct { Logger logr.Logger Req *LLMRequest PodsSnapshot []Pod - // PrefixHashes is a list of prefix hashes of the request prompt broken into blocks. - PrefixHashes []BlockHash - // A map of server to its longest prefix cache match length. - PrefixCacheServers map[ServerID]int + // PluginState can be used by plugins to store state during a scheduling cycle, to communicate + // between different extension points. + PluginState map[PluginName]any + pluginStateMu *sync.RWMutex } +func (sc *SchedulingContext) GetPluginState(pluginName PluginName) any { + sc.pluginStateMu.RLock() + defer sc.pluginStateMu.RUnlock() + return sc.PluginState[pluginName] +} + +func (sc *SchedulingContext) SetPluginState(pluginName PluginName, state any) { + sc.pluginStateMu.Lock() + defer sc.pluginStateMu.Unlock() + sc.PluginState[pluginName] = state +} + +type PluginName string + func (pm *PodMetrics) String() string { if pm == nil { return "" @@ -92,10 +106,12 @@ type PodMetrics struct { func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext { logger := log.FromContext(ctx).WithValues("request", req) return &SchedulingContext{ - Context: ctx, - Logger: logger, - Req: req, - PodsSnapshot: pods, + Context: ctx, + Logger: logger, + Req: req, + PodsSnapshot: pods, + PluginState: make(map[PluginName]any), + pluginStateMu: &sync.RWMutex{}, } } @@ -111,12 +127,3 @@ func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []Pod { type Result struct { TargetPod Pod } - -// BlockHash is a hash of the block of request body. -type BlockHash uint64 - -type ServerID types.NamespacedName - -func (s ServerID) String() string { - return types.NamespacedName(s).String() -} From 4b99fdd565423754c46a40334b35afbbbb4eb0de Mon Sep 17 00:00:00 2001 From: Cong Liu Date: Thu, 8 May 2025 17:23:44 +0000 Subject: [PATCH 5/9] Clean up --- cmd/epp/main.go | 16 +++++--------- pkg/epp/scheduling/config.go | 23 ++++++-------------- pkg/epp/scheduling/plugins/prefix/plugin.go | 1 - pkg/epp/scheduling/plugins/scorer/kvcache.go | 4 ---- pkg/epp/scheduling/plugins/scorer/queue.go | 4 ---- 5 files changed, 13 insertions(+), 35 deletions(-) diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 11fa12e97..89c39fe02 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -121,7 +121,6 @@ func loadPrefixCacheConfig() prefix.Config { baseLogger := log.Log.WithName("env-config") return prefix.Config{ - Weight: envutil.GetEnvInt("PREFIX_CACHE_WEIGHT", prefix.DefaultScorerWeight, baseLogger), HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultHashBlockSize, baseLogger), MaxPrefixBlocksToMatch: envutil.GetEnvInt("PREFIX_CACHE_MAX_PREFIX_BLOCKS", prefix.DefaultMaxPrefixBlocks, baseLogger), LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY", prefix.DefaultLRUIndexerCapacity, baseLogger), @@ -192,18 +191,15 @@ func run() error { scheduler := scheduling.NewScheduler(datastore) if schedulerV2 == "true" { - queueConfig := scorer.QueueScorerConfig{ - Weight: envutil.GetEnvInt("QUEUE_SCORE_WEIGHT", scorer.DefaultQueueScorerWeight, setupLog), - } - kvCacheConfig := scorer.KVCacheScorerConfig{ - Weight: envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", scorer.DefaultKVCacheScorerWeight, setupLog), - } + queueScorerWeight := envutil.GetEnvInt("QUEUE_SCORE_WEIGHT", scorer.DefaultQueueScorerWeight, setupLog) + kvCacheScorerWeight := envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", scorer.DefaultKVCacheScorerWeight, setupLog) schedConfigOpts := []scheduling.ConfigOption{ - scheduling.WithQueuePlugin(queueConfig), - scheduling.WithKVCachePlugin(kvCacheConfig), + scheduling.AddScorer(&scorer.QueueScorer{}, queueScorerWeight), + scheduling.AddScorer(&scorer.KVCacheScorer{}, kvCacheScorerWeight), } if prefixCacheScheduling == "true" { - schedConfigOpts = append(schedConfigOpts, scheduling.WithPrefixPlugin(loadPrefixCacheConfig())) + prefixScorerWeight := envutil.GetEnvInt("PREFIX_CACHE_WEIGHT", prefix.DefaultScorerWeight, setupLog) + schedConfigOpts = append(schedConfigOpts, scheduling.AddPrefixPlugin(loadPrefixCacheConfig(), prefixScorerWeight)) } schedulerConfig := scheduling.CreateConfig(schedConfigOpts...) scheduler = scheduling.NewSchedulerWithConfig(datastore, schedulerConfig) diff --git a/pkg/epp/scheduling/config.go b/pkg/epp/scheduling/config.go index 5a1fb1658..4af3c7b39 100644 --- a/pkg/epp/scheduling/config.go +++ b/pkg/epp/scheduling/config.go @@ -20,7 +20,6 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/prefix" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" ) // NewSchedulerConfig creates a new SchedulerConfig object with the given plugins. @@ -74,25 +73,17 @@ func CreateConfig(opts ...ConfigOption) *SchedulerConfig { type ConfigOption func(*SchedulerConfig) -func WithPrefixPlugin(prefixConfig prefix.Config) ConfigOption { +func AddScorer(scorer plugins.Scorer, weight int) ConfigOption { return func(cfg *SchedulerConfig) { - prefixPlugin := prefix.New(prefixConfig) - cfg.preSchedulePlugins = append(cfg.preSchedulePlugins, prefixPlugin) - cfg.postSchedulePlugins = append(cfg.postSchedulePlugins, prefixPlugin) - cfg.scorers[prefixPlugin] = prefixConfig.Weight - } -} - -func WithQueuePlugin(queueConfig scorer.QueueScorerConfig) ConfigOption { - return func(cfg *SchedulerConfig) { - queuePlugin := &scorer.QueueScorer{} - cfg.scorers[queuePlugin] = queueConfig.Weight + cfg.scorers[scorer] = weight } } -func WithKVCachePlugin(kvCacheConfig scorer.KVCacheScorerConfig) ConfigOption { +func AddPrefixPlugin(prefixConfig prefix.Config, weight int) ConfigOption { return func(cfg *SchedulerConfig) { - kvCachePlugin := &scorer.KVCacheScorer{} - cfg.scorers[kvCachePlugin] = kvCacheConfig.Weight + prefixPlugin := prefix.New(prefixConfig) + cfg.preSchedulePlugins = append(cfg.preSchedulePlugins, prefixPlugin) + cfg.postSchedulePlugins = append(cfg.postSchedulePlugins, prefixPlugin) + cfg.scorers[prefixPlugin] = weight } } diff --git a/pkg/epp/scheduling/plugins/prefix/plugin.go b/pkg/epp/scheduling/plugins/prefix/plugin.go index 2cd5f1d4f..c7494b336 100644 --- a/pkg/epp/scheduling/plugins/prefix/plugin.go +++ b/pkg/epp/scheduling/plugins/prefix/plugin.go @@ -58,7 +58,6 @@ const ( ) type Config struct { - Weight int // The input prompt is broken into sizes of HashBlockSize to calculate block hashes . Requests // with length shorter than the block size will be ignored. HashBlockSize int diff --git a/pkg/epp/scheduling/plugins/scorer/kvcache.go b/pkg/epp/scheduling/plugins/scorer/kvcache.go index 0bb59ebe9..78ea1da5e 100644 --- a/pkg/epp/scheduling/plugins/scorer/kvcache.go +++ b/pkg/epp/scheduling/plugins/scorer/kvcache.go @@ -24,10 +24,6 @@ const ( DefaultKVCacheScorerWeight = 2 ) -type KVCacheScorerConfig struct { - Weight int -} - type KVCacheScorer struct{} func (ss *KVCacheScorer) Name() string { diff --git a/pkg/epp/scheduling/plugins/scorer/queue.go b/pkg/epp/scheduling/plugins/scorer/queue.go index 0aa8ffd09..bbe6b6961 100644 --- a/pkg/epp/scheduling/plugins/scorer/queue.go +++ b/pkg/epp/scheduling/plugins/scorer/queue.go @@ -26,10 +26,6 @@ const ( DefaultQueueScorerWeight = 1 ) -type QueueScorerConfig struct { - Weight int -} - type QueueScorer struct{} func (q *QueueScorer) Name() string { From 3274a735021060243cb7b982ea54deed928bedbc Mon Sep 17 00:00:00 2001 From: Cong Liu Date: Thu, 8 May 2025 20:41:44 +0000 Subject: [PATCH 6/9] Change to use container/list lib --- cmd/epp/main.go | 2 +- pkg/epp/scheduling/plugins/prefix/indexer.go | 61 +++++++++++-------- .../scheduling/plugins/prefix/indexer_test.go | 6 +- 3 files changed, 40 insertions(+), 29 deletions(-) diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 89c39fe02..568b06434 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -198,7 +198,7 @@ func run() error { scheduling.AddScorer(&scorer.KVCacheScorer{}, kvCacheScorerWeight), } if prefixCacheScheduling == "true" { - prefixScorerWeight := envutil.GetEnvInt("PREFIX_CACHE_WEIGHT", prefix.DefaultScorerWeight, setupLog) + prefixScorerWeight := envutil.GetEnvInt("PREFIX_CACHE_SCORE_WEIGHT", prefix.DefaultScorerWeight, setupLog) schedConfigOpts = append(schedConfigOpts, scheduling.AddPrefixPlugin(loadPrefixCacheConfig(), prefixScorerWeight)) } schedulerConfig := scheduling.CreateConfig(schedConfigOpts...) diff --git a/pkg/epp/scheduling/plugins/prefix/indexer.go b/pkg/epp/scheduling/plugins/prefix/indexer.go index 5f42fa37a..2017ba175 100644 --- a/pkg/epp/scheduling/plugins/prefix/indexer.go +++ b/pkg/epp/scheduling/plugins/prefix/indexer.go @@ -22,6 +22,8 @@ import ( "time" "unsafe" + "container/list" + "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -30,8 +32,8 @@ import ( func newIndexer(maxCacheSize int) *indexer { t := &indexer{ maxCacheSize: maxCacheSize, - table: make(map[BlockHash]map[ServerID]*node), - list: newLinkedList(), + table: make(map[BlockHash]map[ServerID]*list.Element), + ll: list.New(), } go t.ReportCacheSize(time.Second) return t @@ -42,8 +44,14 @@ func newIndexer(maxCacheSize int) *indexer { type indexer struct { mu sync.RWMutex maxCacheSize int - table map[BlockHash]map[ServerID]*node // from any prefix cache to the cache entry to find the server - list *linkedList // LRU list to keep track of the order of entries + table map[BlockHash]map[ServerID]*list.Element // from any prefix cache to the cache entry to find the server + ll *list.List // LinkedList to keep track of the order of entries +} + +// value is the value stored in the linked list. +type value struct { + server ServerID + hash BlockHash } // Get returns the set of servers that have the given prefix hash cached. @@ -68,49 +76,52 @@ func (i *indexer) Add(hashes []BlockHash, server ServerID) { } } -func (i *indexer) check(hash BlockHash, server ServerID) (*node, bool) { +func (i *indexer) check(hash BlockHash, server ServerID) (*list.Element, bool) { servers, ok := i.table[hash] if !ok { return nil, false } - n, ok := servers[server] - return n, ok + e, ok := servers[server] + return e, ok } func (i *indexer) add(hash BlockHash, server ServerID) { - node, exists := i.check(hash, server) + e, exists := i.check(hash, server) if exists { - i.list.moveToTail(node) + i.ll.MoveToBack(e) } else { i.create(hash, server) } } func (i *indexer) create(hash BlockHash, server ServerID) { - n := &node{ - hash: hash, - server: server, - } - - for i.list.size >= i.maxCacheSize { + for i.ll.Len() >= i.maxCacheSize { // Evict the least recently used entry if we've exceeded the max cache size i.evict() } if _, ok := i.table[hash]; !ok { - i.table[hash] = make(map[ServerID]*node) + i.table[hash] = make(map[ServerID]*list.Element) } - i.table[hash][server] = n - i.list.add(n) + v := &value{ + server: server, + hash: hash, + } + e := i.ll.PushBack(v) + i.table[hash][server] = e } // evict removes the least recently used entry from the cache func (i *indexer) evict() { - oldestNode := i.list.dummyHead.next - i.list.delete(oldestNode) + oldestNode := i.ll.Front() + if oldestNode == nil { + return + } + i.ll.Remove(oldestNode) - hash := oldestNode.hash - server := oldestNode.server + v := oldestNode.Value.(*value) + hash := v.hash + server := v.server // Remove from the hash map serverMap := i.table[hash] delete(serverMap, server) @@ -129,8 +140,8 @@ func (i *indexer) ReportCacheSize(interval time.Duration) { defer ticker.Stop() for range ticker.C { i.mu.RLock() - metrics.RecordPrefixCacheSize(int64(i.list.size)) - log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU", "# entries", i.list.size, "estimated size MB", i.list.size*i.estimateEntrySize()/1000000) + metrics.RecordPrefixCacheSize(int64(i.ll.Len())) + log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU", "# entries", i.ll.Len(), "estimated size MB", i.ll.Len()*i.estimateEntrySize()/1000000) i.mu.RUnlock() } } @@ -146,7 +157,7 @@ func (i *indexer) estimateEntrySize() int { // The ServerID is a NamespacedName, which contains two strings (Name and Namespace). // The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length). // So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes. - size += int(unsafe.Sizeof(node{})) + size += int(unsafe.Sizeof(value{})) // Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName). size += 2 * 63 diff --git a/pkg/epp/scheduling/plugins/prefix/indexer_test.go b/pkg/epp/scheduling/plugins/prefix/indexer_test.go index 0531351e1..596625d10 100644 --- a/pkg/epp/scheduling/plugins/prefix/indexer_test.go +++ b/pkg/epp/scheduling/plugins/prefix/indexer_test.go @@ -31,15 +31,15 @@ func TestIndexer_AddAndGet(t *testing.T) { cache.Add([]BlockHash{hash1}, server) // Retrieve the entry - assert.Equal(t, 1, cache.list.size, "Cache size should be 1 after adding an entry") + assert.Equal(t, 1, cache.ll.Len(), "Cache size should be 1 after adding an entry") servers := cache.Get(hash1) assert.Contains(t, servers, server, "Cache should contain the added server") // Add another entry to the cache, the cache size should be incremented to 2. cache.Add([]BlockHash{BlockHash(2)}, server) - assert.Equal(t, 2, cache.list.size, "Cache size should be 2 after adding an entry") + assert.Equal(t, 2, cache.ll.Len(), "Cache size should be 2 after adding an entry") // Add another entry to the cache, which should evict the first one due to max size. cache.Add([]BlockHash{BlockHash(3)}, server) - assert.Equal(t, 2, cache.list.size, "Cache size should still be 2 after adding an entry") + assert.Equal(t, 2, cache.ll.Len(), "Cache size should still be 2 after adding an entry") } From 81b24a555fe98401158e00e2eecf9c4595ca61d7 Mon Sep 17 00:00:00 2001 From: Cong Liu Date: Thu, 8 May 2025 21:18:20 +0000 Subject: [PATCH 7/9] cleanup --- cmd/epp/main.go | 18 +- pkg/epp/scheduling/config.go | 43 +-- pkg/epp/scheduling/plugins/filter/filter.go | 278 ------------------ .../scheduling/plugins/prefix/linked_list.go | 81 ----- pkg/epp/scheduling/plugins/prefix/plugin.go | 6 - pkg/epp/scheduling/plugins/scorer/kvcache.go | 2 +- 6 files changed, 21 insertions(+), 407 deletions(-) delete mode 100644 pkg/epp/scheduling/plugins/filter/filter.go delete mode 100644 pkg/epp/scheduling/plugins/prefix/linked_list.go diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 568b06434..81d902dc3 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -44,6 +44,9 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics/collectors" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/prefix" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" @@ -193,15 +196,22 @@ func run() error { if schedulerV2 == "true" { queueScorerWeight := envutil.GetEnvInt("QUEUE_SCORE_WEIGHT", scorer.DefaultQueueScorerWeight, setupLog) kvCacheScorerWeight := envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", scorer.DefaultKVCacheScorerWeight, setupLog) - schedConfigOpts := []scheduling.ConfigOption{ - scheduling.AddScorer(&scorer.QueueScorer{}, queueScorerWeight), - scheduling.AddScorer(&scorer.KVCacheScorer{}, kvCacheScorerWeight), + scorers := map[plugins.Scorer]int{ + &scorer.QueueScorer{}: queueScorerWeight, + &scorer.KVCacheScorer{}: kvCacheScorerWeight, } + schedConfigOpts := []scheduling.ConfigOption{} if prefixCacheScheduling == "true" { prefixScorerWeight := envutil.GetEnvInt("PREFIX_CACHE_SCORE_WEIGHT", prefix.DefaultScorerWeight, setupLog) schedConfigOpts = append(schedConfigOpts, scheduling.AddPrefixPlugin(loadPrefixCacheConfig(), prefixScorerWeight)) } - schedulerConfig := scheduling.CreateConfig(schedConfigOpts...) + schedulerConfig := scheduling.NewSchedulerConfig( + []plugins.PreSchedule{}, + []plugins.Filter{filter.NewSheddableRequestFilter()}, + scorers, + picker.NewMaxScorePicker(), + []plugins.PostSchedule{}, + schedConfigOpts...) scheduler = scheduling.NewSchedulerWithConfig(datastore, schedulerConfig) } serverRunner := &runserver.ExtProcServerRunner{ diff --git a/pkg/epp/scheduling/config.go b/pkg/epp/scheduling/config.go index 4af3c7b39..63d150d2c 100644 --- a/pkg/epp/scheduling/config.go +++ b/pkg/epp/scheduling/config.go @@ -18,20 +18,23 @@ package scheduling import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/prefix" ) // NewSchedulerConfig creates a new SchedulerConfig object with the given plugins. func NewSchedulerConfig(preSchedulePlugins []plugins.PreSchedule, filters []plugins.Filter, scorers map[plugins.Scorer]int, - picker plugins.Picker, postSchedulePlugins []plugins.PostSchedule) *SchedulerConfig { - return &SchedulerConfig{ + picker plugins.Picker, postSchedulePlugins []plugins.PostSchedule, opts ...ConfigOption) *SchedulerConfig { + config := &SchedulerConfig{ preSchedulePlugins: preSchedulePlugins, filters: filters, scorers: scorers, picker: picker, postSchedulePlugins: postSchedulePlugins, } + for _, opt := range opts { + opt(config) + } + return config } // SchedulerConfig provides a configuration for the scheduler which influence routing decisions. @@ -43,42 +46,8 @@ type SchedulerConfig struct { postSchedulePlugins []plugins.PostSchedule } -var defPlugin = &defaultPlugin{} - -// When the scheduler is initialized with NewScheduler function, this config will be used as default. -// it's possible to call NewSchedulerWithConfig to pass a different argument. - -// For build time plugins changes, it's recommended to change the defaultConfig variable in this file. -var defaultConfig = &SchedulerConfig{ - preSchedulePlugins: []plugins.PreSchedule{}, - filters: []plugins.Filter{defPlugin}, - scorers: map[plugins.Scorer]int{}, - picker: defPlugin, - postSchedulePlugins: []plugins.PostSchedule{}, -} - -func CreateConfig(opts ...ConfigOption) *SchedulerConfig { - config := &SchedulerConfig{ - preSchedulePlugins: []plugins.PreSchedule{}, - postSchedulePlugins: []plugins.PostSchedule{}, - scorers: map[plugins.Scorer]int{}, - filters: []plugins.Filter{&sheddableRequestFilterV2{}}, - picker: &picker.MaxScorePicker{}, - } - for _, opt := range opts { - opt(config) - } - return config -} - type ConfigOption func(*SchedulerConfig) -func AddScorer(scorer plugins.Scorer, weight int) ConfigOption { - return func(cfg *SchedulerConfig) { - cfg.scorers[scorer] = weight - } -} - func AddPrefixPlugin(prefixConfig prefix.Config, weight int) ConfigOption { return func(cfg *SchedulerConfig) { prefixPlugin := prefix.New(prefixConfig) diff --git a/pkg/epp/scheduling/plugins/filter/filter.go b/pkg/epp/scheduling/plugins/filter/filter.go deleted file mode 100644 index 86620aa9f..000000000 --- a/pkg/epp/scheduling/plugins/filter/filter.go +++ /dev/null @@ -1,278 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package filter - -import ( - "math" - "math/rand" - "time" - - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" -) - -type baseFilter struct { - name string - filter filterFunc -} - -func (f *baseFilter) Name() string { - if f == nil { - return "nil" - } - return f.name -} - -func (f *baseFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { - loggerTrace := ctx.Logger.V(logutil.TRACE) - loggerTrace.Info("Running a filter", "name", f.Name(), "podCount", len(pods)) - - return f.filter(ctx, pods) -} - -// DecisionTreeFilter applies current filterFunc, and then recursively applies next filters -// depending success or failure of the current filter. -// It can be used to construct a flow chart algorithm. -type DecisionTreeFilter struct { - Current plugins.Filter - // NextOnSuccess filter will be applied after successfully applying the current filter. - // The filtered results will be passed to the next filter. - NextOnSuccess plugins.Filter - // NextOnFailure filter will be applied if current filter fails. - // The original input will be passed to the next filter. - NextOnFailure plugins.Filter - // NextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the - // success or failure of the current filter. - // NOTE: When using NextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil. - // However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of - // NextOnSuccessOrFailure, in the success and failure scenarios, respectively. - NextOnSuccessOrFailure plugins.Filter -} - -func (f *DecisionTreeFilter) Name() string { - if f == nil { - return "nil" - } - return f.Current.Name() -} - -func (f *DecisionTreeFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { - loggerTrace := ctx.Logger.V(logutil.TRACE) - filtered := f.Current.Filter(ctx, pods) - - next := f.NextOnSuccessOrFailure - if len(filtered) > 0 { - if f.NextOnSuccess == nil && f.NextOnSuccessOrFailure == nil { - // No succeeding filters to run, return. - return filtered - } - if f.NextOnSuccess != nil { - next = f.NextOnSuccess - } - loggerTrace.Info("Filter succeeded", "filter", f.Name(), "next", next.Name(), "filteredPodCount", len(filtered)) - // On success, pass the filtered result to the next filter. - return next.Filter(ctx, filtered) - } else { - if f.NextOnFailure == nil && f.NextOnSuccessOrFailure == nil { - // No succeeding filters to run, return. - return filtered - } - if f.NextOnFailure != nil { - next = f.NextOnFailure - } - loggerTrace.Info("Filter failed", "filter", f.Name(), "next", next.Name()) - // On failure, pass the initial set of pods to the next filter. - return next.Filter(ctx, pods) - } -} - -// filterFunc filters a set of input pods to a subset. -type filterFunc func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod - -// toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc. -func toFilterFunc(pp podPredicate) filterFunc { - return func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { - filtered := []types.Pod{} - for _, pod := range pods { - pass := pp(ctx.Req, pod) - if pass { - filtered = append(filtered, pod) - } - } - - return filtered - } -} - -var LeastQueueFilter = &baseFilter{ - name: "least queuing", - filter: leastQueuingFilterFunc, -} - -// leastQueuingFilterFunc finds the max and min queue size of all pods, divides the whole range -// (max-min) by the number of pods, and finds the pods that fall into the first range. -// The intuition is that if there are multiple pods that share similar queue size in the low range, -// we should consider them all instead of the absolute minimum one. This worked better than picking -// the least one as it gives more choices for the next filter, which on aggregate gave better -// results. -// TODO: Compare this strategy with other strategies such as top K. -func leastQueuingFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { - min := math.MaxInt - max := 0 - filtered := []types.Pod{} - - for _, pod := range pods { - if pod.GetMetrics().WaitingQueueSize <= min { - min = pod.GetMetrics().WaitingQueueSize - } - if pod.GetMetrics().WaitingQueueSize >= max { - max = pod.GetMetrics().WaitingQueueSize - } - } - - for _, pod := range pods { - if pod.GetMetrics().WaitingQueueSize >= min && pod.GetMetrics().WaitingQueueSize <= min+(max-min)/len(pods) { - filtered = append(filtered, pod) - } - } - return filtered -} - -var LowQueueFilter = &baseFilter{ - name: "low queueing filter", - filter: toFilterFunc((queueThresholdPredicate(config.Conf.QueueingThresholdLoRA))), -} - -var LeastKVCacheFilter = &baseFilter{ - name: "least KV cache percent", - filter: leastKVCacheFilterFunc, -} - -// leastKVCacheFilterFunc finds the max and min KV cache of all pods, divides the whole range -// (max-min) by the number of pods, and finds the pods that fall into the first range. -// The intuition is that if there are multiple pods that share similar KV cache in the low range, we -// should consider them all instead of the absolute minimum one. This worked better than picking the -// least one as it gives more choices for the next filter, which on aggregate gave better results. -// TODO: Compare this strategy with other strategies such as top K. -func leastKVCacheFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { - min := math.MaxFloat64 - var max float64 = 0 - filtered := []types.Pod{} - - for _, pod := range pods { - if pod.GetMetrics().KVCacheUsagePercent <= min { - min = pod.GetMetrics().KVCacheUsagePercent - } - if pod.GetMetrics().KVCacheUsagePercent >= max { - max = pod.GetMetrics().KVCacheUsagePercent - } - } - - for _, pod := range pods { - if pod.GetMetrics().KVCacheUsagePercent >= min && pod.GetMetrics().KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) { - filtered = append(filtered, pod) - } - } - return filtered -} - -var LoRAAffinityFilter = &baseFilter{ - name: "affinity LoRA", - filter: loRASoftAffinityFilterFunc, -} - -// loRASoftAffinityPredicate implements a pod selection strategy that prioritizes pods -// with existing LoRA model affinity while allowing for load balancing through randomization. -// -// The function works by: -// 1. Separating pods into two groups: those with target model affinity and those with available capacity -// 2. Using a probability threshold to sometimes select from non-affinity pods to enable load balancing -// 3. Falling back to whatever group has pods if one group is empty -// -// Parameters: -// - logger: Logger interface for diagnostic output -// - req: LLM request containing the resolved target model -// - pods: Slice of pod metrics to filter -// -// Returns: -// - Filtered slice of pod metrics based on affinity and availability -// - Error if any issues occur during filtering -func loRASoftAffinityFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { - - // Pre-allocate slices with estimated capacity - filtered_affinity := make([]types.Pod, 0, len(pods)) - filtered_available := make([]types.Pod, 0, len(pods)) - - // Categorize pods based on affinity and availability - for _, pod := range pods { - _, active := pod.GetMetrics().ActiveModels[ctx.Req.ResolvedTargetModel] - _, waiting := pod.GetMetrics().WaitingModels[ctx.Req.ResolvedTargetModel] - - if active || waiting { - filtered_affinity = append(filtered_affinity, pod) - } else if len(pod.GetMetrics().ActiveModels)+len(pod.GetMetrics().WaitingModels) < pod.GetMetrics().MaxActiveModels { - filtered_available = append(filtered_available, pod) - } - } - - // Use crypto/rand for better randomization in production environments - randSource := rand.NewSource(time.Now().UnixNano()) - randGen := rand.New(randSource) - - // If both groups have pods, use probability to select which group to return - if len(filtered_affinity) > 0 && len(filtered_available) > 0 { - if randGen.Float64() < config.Conf.LoraAffinityThreshold { - return filtered_affinity - } - return filtered_available - } - - // Return whichever group has pods - if len(filtered_affinity) > 0 { - return filtered_affinity - } - - return filtered_available -} - -var HasCapacityFilter = &baseFilter{ - name: "has capacity for sheddable requests", - filter: toFilterFunc(queueThresholdPredicate(config.Conf.QueueThresholdCritical).and(kvCacheThresholdPredicate(config.Conf.KVCacheThreshold))), -} - -// podPredicate is a filter function to check whether a pod is desired. -type podPredicate func(req *types.LLMRequest, pod types.Pod) bool - -func queueThresholdPredicate(queueThreshold int) podPredicate { - return func(req *types.LLMRequest, pod types.Pod) bool { - return pod.GetMetrics().WaitingQueueSize <= queueThreshold - } -} - -func kvCacheThresholdPredicate(kvCacheThreshold float64) podPredicate { - return func(req *types.LLMRequest, pod types.Pod) bool { - return pod.GetMetrics().KVCacheUsagePercent <= kvCacheThreshold - } -} - -func (pp podPredicate) and(another podPredicate) podPredicate { - return func(req *types.LLMRequest, pod types.Pod) bool { - return pp(req, pod) && another(req, pod) - } -} diff --git a/pkg/epp/scheduling/plugins/prefix/linked_list.go b/pkg/epp/scheduling/plugins/prefix/linked_list.go deleted file mode 100644 index de6b6a163..000000000 --- a/pkg/epp/scheduling/plugins/prefix/linked_list.go +++ /dev/null @@ -1,81 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package prefix - -type linkedList struct { - dummyHead *node // The head of the linked list (dummy node). - tail *node // The tail of the linked list. - size int // The size of the linked list (excluding dummy head). -} - -// newLinkedList initializes a new linked list with a dummy head node. -// Using a dummy head simplifies the implementation by eliminating nil checks. -func newLinkedList() *linkedList { - dummy := &node{} // Create dummy head node - return &linkedList{ - dummyHead: dummy, - tail: dummy, - size: 0, - } -} - -type node struct { - prev *node - next *node - server ServerID - hash BlockHash -} - -// add adds a node to the end of the linked list. -func (ll *linkedList) add(n *node) { - ll.size++ - - n.prev = ll.tail - ll.tail.next = n - ll.tail = n -} - -// delete removes a node from the linked list. -// Note the method assumes the input node exists in the list. -func (ll *linkedList) delete(n *node) { - ll.size-- - n.prev.next = n.next - - // If it's the tail node - if n.next == nil { - ll.tail = n.prev - } else { - n.next.prev = n.prev - } -} - -// moveToTail moves an existing node to the end of the linked list (most recent). -func (ll *linkedList) moveToTail(n *node) { - if n.next == nil { - // Already the tail, no need to move. - return - } - - n.prev.next = n.next - n.next.prev = n.prev - - // Move it to the tail position - n.prev = ll.tail - n.next = nil - ll.tail.next = n - ll.tail = n -} diff --git a/pkg/epp/scheduling/plugins/prefix/plugin.go b/pkg/epp/scheduling/plugins/prefix/plugin.go index c7494b336..079687d0a 100644 --- a/pkg/epp/scheduling/plugins/prefix/plugin.go +++ b/pkg/epp/scheduling/plugins/prefix/plugin.go @@ -68,12 +68,6 @@ type Config struct { LRUIndexerCapacity int } -var DefaultConfig = Config{ - HashBlockSize: DefaultHashBlockSize, - MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, - LRUIndexerCapacity: DefaultLRUIndexerCapacity, -} - type Plugin struct { Config indexer Indexer diff --git a/pkg/epp/scheduling/plugins/scorer/kvcache.go b/pkg/epp/scheduling/plugins/scorer/kvcache.go index 78ea1da5e..dbb6079dc 100644 --- a/pkg/epp/scheduling/plugins/scorer/kvcache.go +++ b/pkg/epp/scheduling/plugins/scorer/kvcache.go @@ -21,7 +21,7 @@ import ( ) const ( - DefaultKVCacheScorerWeight = 2 + DefaultKVCacheScorerWeight = 1 ) type KVCacheScorer struct{} From 97693574d272528d7ffb8678cbdb752e90c678b4 Mon Sep 17 00:00:00 2001 From: Cong Liu Date: Fri, 9 May 2025 20:45:27 +0000 Subject: [PATCH 8/9] Add TODO --- cmd/epp/main.go | 2 +- pkg/epp/scheduling/config.go | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 81d902dc3..e674f1c20 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -207,7 +207,7 @@ func run() error { } schedulerConfig := scheduling.NewSchedulerConfig( []plugins.PreSchedule{}, - []plugins.Filter{filter.NewSheddableRequestFilter()}, + []plugins.Filter{filter.NewSheddableCapacityFilter()}, scorers, picker.NewMaxScorePicker(), []plugins.PostSchedule{}, diff --git a/pkg/epp/scheduling/config.go b/pkg/epp/scheduling/config.go index 63d150d2c..e321ca2bf 100644 --- a/pkg/epp/scheduling/config.go +++ b/pkg/epp/scheduling/config.go @@ -48,6 +48,8 @@ type SchedulerConfig struct { type ConfigOption func(*SchedulerConfig) +// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/813): Replace this +// with a more generic way to add plugins. func AddPrefixPlugin(prefixConfig prefix.Config, weight int) ConfigOption { return func(cfg *SchedulerConfig) { prefixPlugin := prefix.New(prefixConfig) From 47febd5795ae8d6f9737ac6ac8cdfc2757c03150 Mon Sep 17 00:00:00 2001 From: Cong Liu Date: Fri, 9 May 2025 21:13:13 +0000 Subject: [PATCH 9/9] make linter happy --- pkg/epp/scheduling/plugins/prefix/plugin.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/epp/scheduling/plugins/prefix/plugin.go b/pkg/epp/scheduling/plugins/prefix/plugin.go index 079687d0a..6d7f03c10 100644 --- a/pkg/epp/scheduling/plugins/prefix/plugin.go +++ b/pkg/epp/scheduling/plugins/prefix/plugin.go @@ -130,7 +130,7 @@ func (m *Plugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) { func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { state := ctx.GetPluginState(types.PluginName(m.Name())).(SchedulingContextState) total := len(state.PrefixHashes) - podScoreFunc := func(ctx *types.SchedulingContext, pod types.Pod) float64 { + podScoreFunc := func(pod types.Pod) float64 { if total == 0 { return 0 } @@ -140,7 +140,7 @@ func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types scores := make(map[types.Pod]float64, len(pods)) for _, pod := range pods { - scores[pod] = podScoreFunc(ctx, pod) + scores[pod] = podScoreFunc(pod) } return scores } @@ -191,8 +191,8 @@ func hashPrompt(ctx *types.SchedulingContext, cacheBlockSize int, maxPrefixBlock for i := 0; i+cacheBlockSize <= len(prompt); i += cacheBlockSize { block := prompt[i : i+cacheBlockSize] prevBlockHash := res[len(res)-1] - toHash := append(block, toBytes(prevBlockHash)...) - res = append(res, BlockHash(xxhash.Sum64(toHash))) + block = append(block, toBytes(prevBlockHash)...) + res = append(res, BlockHash(xxhash.Sum64(block))) } return res }