diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index b23259f8c..5c65332b8 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -318,7 +318,7 @@ func loadPrefixCacheConfig() prefix.Config { return prefix.Config{ HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultHashBlockSize, baseLogger), MaxPrefixBlocksToMatch: envutil.GetEnvInt("PREFIX_CACHE_MAX_PREFIX_BLOCKS", prefix.DefaultMaxPrefixBlocks, baseLogger), - LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY", prefix.DefaultLRUIndexerCapacity, baseLogger), + LRUCapacityPerServer: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY_PER_SERVER", prefix.DefaultLRUCapacityPerServer, baseLogger), } } diff --git a/go.mod b/go.mod index 192773cc6..0c02daccc 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/go-logr/logr v1.4.3 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 + github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/onsi/ginkgo/v2 v2.23.4 github.com/onsi/gomega v1.37.0 github.com/prometheus/client_golang v1.22.0 diff --git a/go.sum b/go.sum index 7733d5555..2d45d351f 100644 --- a/go.sum +++ b/go.sum @@ -95,6 +95,8 @@ github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5T github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674/go.mod h1:r4w70xmWCQKmi1ONH4KIaBptdivuRPyosB9RmPlGEwA= github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 h1:TmHmbvxPmaegwhDubVz0lICL0J5Ka2vwTzhoePEXsGE= github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0/go.mod h1:qztMSjm835F2bXf+5HKAPIS5qsmQDqZna/PgVt4rWtI= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4= github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go index 4859357d8..716c9f265 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go @@ -20,154 +20,130 @@ import ( "context" "sync" "time" - "unsafe" - - "container/list" + lru "github.com/hashicorp/golang-lru/v2" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -func newIndexer(maxCacheSize int) *indexer { - t := &indexer{ - maxCacheSize: maxCacheSize, - table: make(map[BlockHash]map[ServerID]*list.Element), - ll: list.New(), - } - go t.ReportCacheSize(time.Second) - return t -} - // An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that -// prefix cached . +// prefix cached. type indexer struct { - mu sync.RWMutex - maxCacheSize int - table map[BlockHash]map[ServerID]*list.Element // from any prefix cache to the cache entry to find the server - ll *list.List // LinkedList to keep track of the order of entries + mu sync.RWMutex + hashToPods map[BlockHash]podSet // the lookup data structure to find pods that have the BlockHash cached + podToLRU map[ServerID]*lru.Cache[BlockHash, struct{}] // key is pod namespacedName, value is an LRU cache + maxLRUSize int } -// value is the value stored in the linked list. -type value struct { - server ServerID - hash BlockHash -} - -// Get returns the set of servers that have the given prefix hash cached. -func (i *indexer) Get(hash BlockHash) map[ServerID]bool { - i.mu.RLock() - defer i.mu.RUnlock() - res := map[ServerID]bool{} - for server := range i.table[hash] { - res[server] = true +// newIndexer initializes an indexer with size limits and starts cache size reporting. +func newIndexer(maxLRUSize int) *indexer { + ix := &indexer{ + hashToPods: make(map[BlockHash]podSet), + podToLRU: make(map[ServerID]*lru.Cache[BlockHash, struct{}]), + maxLRUSize: maxLRUSize, } - return res + + go ix.ReportLRUSize(time.Second) + return ix } -// Add adds a list of prefix hashes of a single request to the server the request was sent to. -// The intuition is that this server is likely to have the prefix cached, so next time a request -// sharing the longest prefix should be sent to the same server to take advantage of the cache hit. -func (i *indexer) Add(hashes []BlockHash, server ServerID) { +// Add adds a list of prefix hashes to the cache, tied to the server. +func (i *indexer) Add(hashes []BlockHash, pod ServerID) { i.mu.Lock() - defer i.mu.Unlock() - for _, hash := range hashes { - i.add(hash, server) + // Check if the LRU pod exist + lruForPod, exists := i.podToLRU[pod] + if !exists { + newLRU, _ := lru.NewWithEvict[BlockHash, struct{}](i.maxLRUSize, i.makeEvictionFn(pod)) + i.podToLRU[pod] = newLRU + lruForPod = newLRU } -} -func (i *indexer) check(hash BlockHash, server ServerID) (*list.Element, bool) { - servers, ok := i.table[hash] - if !ok { - return nil, false + i.mu.Unlock() + + // Add to LRU (may evict) + for _, hash := range hashes { + lruForPod.Add(hash, struct{}{}) } - e, ok := servers[server] - return e, ok -} -func (i *indexer) add(hash BlockHash, server ServerID) { - e, exists := i.check(hash, server) - if exists { - i.ll.MoveToBack(e) - } else { - i.create(hash, server) + // Update hashToPods once under lock + i.mu.Lock() + for _, hash := range hashes { + pods := i.hashToPods[hash] + if pods == nil { + pods = make(podSet) + } + pods[pod] = struct{}{} + i.hashToPods[hash] = pods } + + i.mu.Unlock() } -func (i *indexer) create(hash BlockHash, server ServerID) { - for i.ll.Len() >= i.maxCacheSize { - // Evict the least recently used entry if we've exceeded the max cache size - i.evict() - } +// Get returns a set of servers that have the given prefix hash cached. +func (i *indexer) Get(hash BlockHash) podSet { + i.mu.RLock() + defer i.mu.RUnlock() - if _, ok := i.table[hash]; !ok { - i.table[hash] = make(map[ServerID]*list.Element) - } - v := &value{ - server: server, - hash: hash, + res := podSet{} + pods, ok := i.hashToPods[hash] + if !ok { + return res } - e := i.ll.PushBack(v) - i.table[hash][server] = e + + return pods } -// evict removes the least recently used entry from the cache -func (i *indexer) evict() { - oldestNode := i.ll.Front() - if oldestNode == nil { - return +// makeEvictionFn returns a per-pod LRU eviction callback that removes the pod from hashToPods on eviction. +func (i *indexer) makeEvictionFn(pod ServerID) func(BlockHash, struct{}) { + return func(hash BlockHash, _ struct{}) { + i.mu.Lock() + defer i.mu.Unlock() + // Remove the pod from the hash→pods map + if podSet, ok := i.hashToPods[hash]; ok { + delete(podSet, pod) + if len(podSet) == 0 { + delete(i.hashToPods, hash) + } + } } - i.ll.Remove(oldestNode) - - v := oldestNode.Value.(*value) - hash := v.hash - server := v.server - // Remove from the hash map - serverMap := i.table[hash] - delete(serverMap, server) - - // If this was the last server for this hash, remove the hash entry entirely - if len(serverMap) == 0 { - delete(i.table, hash) - } - - log.FromContext(context.TODO()).V(logutil.TRACE).Info("Evicted LRU entry", "hash", hash, "server", server) } -// ReportCacheSize starts a goroutine that periodically reports the cache size metric -func (i *indexer) ReportCacheSize(interval time.Duration) { +// ReportLRUSize starts a goroutine that periodically reports the LRU cache size metric. +func (i *indexer) ReportLRUSize(interval time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() for range ticker.C { i.mu.RLock() - metrics.RecordPrefixCacheSize(int64(i.ll.Len())) - log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU", "# entries", i.ll.Len(), "estimated size MB", i.ll.Len()*i.estimateEntrySize()/1000000) + totalEntries := 0 + maxPodEntries := 0 + maxPodName := ServerID{} + + for pod, lruCache := range i.podToLRU { + size := lruCache.Len() + totalEntries += size + if size > maxPodEntries { + maxPodEntries = size + maxPodName = pod + } + } + + numPods := len(i.podToLRU) + avg := 0.0 + if numPods > 0 { + avg = float64(totalEntries) / float64(numPods) + } + + metrics.RecordPrefixCacheSize(int64(totalEntries)) + log.FromContext(context.TODO()).V(logutil.TRACE).Info("Prefix cache state", + "total entries", totalEntries, + "# pods", numPods, + "avg entries per pod", avg, + "pod with max cache", maxPodName, + "max pod size", maxPodEntries, + "global max LRU cache capacity per pod", i.maxLRUSize, + ) + i.mu.RUnlock() } } - -// estimateEntrySize estimates the memory size of a cache entry in bytes. -func (i *indexer) estimateEntrySize() int { - size := 0 - - // Estimate the size of a node in the linked list. - // First get the size of the node struct via unsafe.Sizeof. - // The prev and next pointers are 8 bytes each on a 64-bit system. - // The BlockHash is a uint64, which is 8 bytes. - // The ServerID is a NamespacedName, which contains two strings (Name and Namespace). - // The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length). - // So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes. - size += int(unsafe.Sizeof(value{})) - // Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName). - size += 2 * 63 - - // Estimate the size of an entry in the hash map. Note the overhead of the map headers and buckets are ignored. - size += 8 // Size of the BlockHash (uint64). - size += 2 * 16 // Size of the ServerID string headers (NamespacedName). - size += 2 * 63 // Size of the Name and Namespace strings in ServerID. - size += 8 // Size of the pointer to the node in the hash map. - - // Based on the above estimates, the estimated size of an entry is: - // (48 + 2*63) + (8 + 2*16 + 2*63 + 8) = 348 bytes. - return size -} diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go index 596625d10..240985033 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go @@ -22,24 +22,23 @@ import ( ) func TestIndexer_AddAndGet(t *testing.T) { - cache := newIndexer(2) + i := newIndexer(2) hash1 := BlockHash(1) server := ServerID{Namespace: "default", Name: "server1"} - // Add an entry to the cache - cache.Add([]BlockHash{hash1}, server) + i.Add([]BlockHash{hash1}, server) // Retrieve the entry - assert.Equal(t, 1, cache.ll.Len(), "Cache size should be 1 after adding an entry") - servers := cache.Get(hash1) + assert.Equal(t, 1, i.podToLRU[server].Len(), "Cache size should be 1 after adding an entry") + servers := i.Get(hash1) assert.Contains(t, servers, server, "Cache should contain the added server") // Add another entry to the cache, the cache size should be incremented to 2. - cache.Add([]BlockHash{BlockHash(2)}, server) - assert.Equal(t, 2, cache.ll.Len(), "Cache size should be 2 after adding an entry") + i.Add([]BlockHash{BlockHash(2)}, server) + assert.Equal(t, 2, i.podToLRU[server].Len(), "Cache size should be 2 after adding an entry") // Add another entry to the cache, which should evict the first one due to max size. - cache.Add([]BlockHash{BlockHash(3)}, server) - assert.Equal(t, 2, cache.ll.Len(), "Cache size should still be 2 after adding an entry") + i.Add([]BlockHash{BlockHash(3)}, server) + assert.Equal(t, 2, i.podToLRU[server].Len(), "Cache size should still be 2 after adding an entry") } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index 01903bce3..0d40746f3 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -32,11 +32,6 @@ import ( const ( DefaultScorerWeight = 1 - // Attempt to return DefaultNumServersToMatch servers with their longest prefix match length. - // Why not just return the server with longest prefix match? - // It may not be the optimal choice, e.g., it may have a high queue depth. - // We optimistically search more than one to give more candidates for the scheduler to choose. - DefaultNumServersToMatch = 2 // vLLM default token block size is 16, and a good guess of average characters per token is 4. DefaultHashBlockSize = 64 // The maximum number of blocks to match. Two long requests with the same prefix up to this @@ -44,20 +39,16 @@ const ( // This parameter provides a trade-off between cache size, prefix matching speed and matching // accuracy. Use a small value if most requests are short to reduce cache size and speed up the // matching process. Use a large value if most requests are long to increase the matching accuracy. - DefaultMaxPrefixBlocks = 128 - // The indexer is an approximation to the actual prefix cache state on the model servers. + DefaultMaxPrefixBlocks = 256 + // The indexer is an approximation to the actual prefix LRU cache state on the model servers per server (pod). // A small capacity ensures a high accuracy of cache hit on the model server, but it will // increase the chance of false negatives. A high capacity does the opposite. // To properly size this, consider the sum of the total number of cache entries on all model - // servers. Consider the llama3 8B model on 3 H100 80GB GPUs. The size of the model weight is - // about 16GB. Assume 50% of the remaining HBM is used for caching prefixes, we have 32GB. Each - // token is about 128KB in size, so we can cache 250K tokens. Using the default block size of 16 - // in vLLM, we will have 250K / 16 = 15.6K blocks. In total we have 15.6K * 3 = 46.8K blocks, or - // roughly 50K. - // How much memory space does it require to hold the 50K block hashes? - // According to the estimates in indexer.estimateEntrySize(), the size of each entry is - // approximately 348 bytes. So in total we have 50K * 348 = 17.4MB. - DefaultLRUIndexerCapacity = 50000 + // servers. Consider the llama3 8B model on a H100 80GB GPUs. The size of the model weight is + // about 16GB. The remaining HBM used for caching prefixes is 64GB. Each + // token is about 128KB in size, so we can cache 500K tokens. Using the default block size of 16 + // in vLLM, we will have 250K / 16 = 31.25K blocks. + DefaultLRUCapacityPerServer = 31250 ) type Config struct { @@ -67,8 +58,8 @@ type Config struct { // MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will // be ignored. MaxPrefixBlocksToMatch int - // Max (approximate) size of the LRU indexer in number of entries. - LRUIndexerCapacity int + // Max capacity size of the LRU indexer in number of entries per server (pod). + LRUCapacityPerServer int } type Plugin struct { @@ -76,8 +67,11 @@ type Plugin struct { indexer Indexer } +// podSet holds an pods servers that may have a specific prefix hash. +type podSet map[ServerID]struct{} + type Indexer interface { - Get(hash BlockHash) map[ServerID]bool + Get(hash BlockHash) podSet Add(hashes []BlockHash, server ServerID) } @@ -121,9 +115,18 @@ var _ framework.PostCycle = &Plugin{} // New initializes a new prefix Plugin and returns its pointer. func New(config Config) *Plugin { + capacity := config.LRUCapacityPerServer + if capacity <= 0 { + capacity = DefaultLRUCapacityPerServer + log.FromContext(context.TODO()).V(logutil.DEFAULT).Info( + "LRUCapacityPerServer is not positive, using default value", + "defaultCapacity", DefaultLRUCapacityPerServer, + ) + } + m := &Plugin{ Config: config, - indexer: newIndexer(config.LRUIndexerCapacity), + indexer: newIndexer(capacity), } return m } @@ -138,14 +141,11 @@ func (m *Plugin) Score(ctx context.Context, request *types.LLMRequest, cycleStat loggerTrace := log.FromContext(ctx).V(logutil.TRACE) // pre score step, hashing prompt and find longest prefix match. hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch) - numServers := DefaultNumServersToMatch - if numServers > len(pods) { - numServers = len(pods) - } state := &schedulingContextState{ PrefixHashes: hashes, - PrefixCacheServers: m.matchLongestPrefix(ctx, hashes, numServers), + PrefixCacheServers: m.matchLongestPrefix(ctx, hashes), } + cycleState.Write(types.StateKey(m.Name()), state) loggerTrace.Info(fmt.Sprintf("cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes) // calculate the scores of pods @@ -174,29 +174,31 @@ func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, re log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state") return } + m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName)) + total := len(state.PrefixHashes) matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)] metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize) } // matchLongestPrefix returns a map of servers and length of prefix that each server caches. -func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash, numServers int) map[ServerID]int { +func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int { loggerTrace := log.FromContext(ctx).V(logutil.TRACE) res := make(map[ServerID]int) // Use a greedy strategy to search from the longest prefix. // NOTE: It's possible to further optimize this with a binary search. - for i := len(hashes) - 1; i >= 0 && len(res) < numServers; i-- { + for i := 0; i < len(hashes); i++ { hash := hashes[i] cachedServers := m.indexer.Get(hash) - if len(cachedServers) > 0 { + if len(cachedServers) == 0 { + break + } else { loggerTrace.Info("Found cached servers", "cachedServers", cachedServers, "total # blocks", len(hashes), "longest prefix", i) for server := range cachedServers { // Update servers with their longest prefix match. - // If we already found this server with longer prefix match, don't update it. - if _, ok := res[server]; !ok { - res[server] = i + 1 - } + res[server]++ + } } } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index de6c68bbd..db1feacf4 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -18,6 +18,10 @@ package prefix import ( "context" + "fmt" + "math" + "math/rand" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -27,10 +31,11 @@ import ( ) func TestPrefixPlugin(t *testing.T) { + config := Config{ HashBlockSize: 4, MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, - LRUIndexerCapacity: DefaultLRUIndexerCapacity, + LRUCapacityPerServer: DefaultLRUCapacityPerServer, } plugin := New(config) @@ -136,3 +141,61 @@ func TestPrefixPlugin(t *testing.T) { plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPod: pod1}) } + +// TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length. +func BenchmarkPrefixPluginStress(b *testing.B) { + blockSize := 4 + maxPrefixBlocks := 50000 + config := Config{ + HashBlockSize: blockSize, + MaxPrefixBlocksToMatch: maxPrefixBlocks, + LRUCapacityPerServer: DefaultLRUCapacityPerServer, + } + + plugin := New(config) + types.NewCycleState() + var promptLen []int + for i := 1; i <= 1024; i++ { + promptLen = append(promptLen, i) + } + promptLen = append(promptLen, 2048, 4096, 8192, 10000, 20000, 50000) + + for _, i := range promptLen { + // Generate increasing-length random prompts + prompt := randomPrompt(4 + i) + pod := &types.PodMetrics{ + Pod: &backend.Pod{ + NamespacedName: k8stypes.NamespacedName{ + Name: fmt.Sprintf("random-pod-%d", i), + }, + }, + } + + pods := []types.Pod{pod} + req := &types.LLMRequest{ + TargetModel: "model-stress", + Prompt: prompt, + } + + // First cycle: simulate scheduling and insert prefix info into the cache + cycleState := types.NewCycleState() + plugin.Score(context.Background(), req, cycleState, pods) + plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPod: pod}) + + // Second cycle: validate internal state + state, err := plugin.getPrefixState(cycleState) + assert.NoError(b, err) + expectedHashes := int(math.Min(float64(maxPrefixBlocks+1), float64(len(req.Prompt)/blockSize+1))) // the extra one is for the model. + assert.Equal(b, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect") + } +} + +// randomPrompt generates a pseudo-random string of length n using lowercase letters. +func randomPrompt(n int) string { + runes := []rune("abcdefghijklmnopqrstuvwxyz") + var sb strings.Builder + for i := 0; i < n; i++ { + sb.WriteRune(runes[rand.Intn(len(runes))]) + } + return sb.String() +} diff --git a/site-src/guides/epp-configuration/prefix-aware.md b/site-src/guides/epp-configuration/prefix-aware.md index a0ad8b51a..43e2ef064 100644 --- a/site-src/guides/epp-configuration/prefix-aware.md +++ b/site-src/guides/epp-configuration/prefix-aware.md @@ -4,7 +4,7 @@ The [prefix cache plugin](https://github.com/kubernetes-sigs/gateway-api-inferen takes advantage of the prefix caching (e.g., [vllm APC](https://docs.vllm.ai/en/latest/features/automatic_prefix_caching.html)) feature of model servers, and optimizes request scheduling by placing requests sharing the longest prefixes to the same server as much as possible, while balancing the server load by considering kv-cache -and queue depth. +and queue depth. ## Enable the prefix cache plugin @@ -32,16 +32,18 @@ extremely long inputs. 128 (or 128*64=8192 characters, or roughly 2048 tokens). This is useful to tradeoff prefix match accuracy for performance. -* `PREFIX_CACHE_LRU_CAPACITY`: Maximum capacity the prefix LRU indexer in number of block hashes. Below +* `PREFIX_CACHE_LRU_CAPACITY_PER_SERVER`: Maximum capacity the prefix LRU cache in number of block hashes per server (pod). Below shows a detailed analysis on how to estimate this. + + The prefix cache plugin estimates the prefix cache indexes in model server HBMs. In the perfect scenario, EPP has the exact same prefix cache entries per model server as their HBM cache entries. If the EPP cache is smaller than HBM cache, a positive EPP cache match is more accurate, but there are more false cache misses. If the EPP cache is larger than the HBM cache, then there are more false cache hits. Therefore **the EPP prefix cache indexer size should be as close as possible to the HBM cache size.** - NOTE: EPP builds prefix cache based on characters, while model server maintains prefix cache entries + NOTE: EPP builds prefix cache based on characters, while model server maintains prefix cache entries in tokens, a conversion between character <-> token is needed. Below are the formulas to estimate the EPP prefix indexer size: @@ -63,8 +65,7 @@ shows a detailed analysis on how to estimate this. max_kv_tokens_per_server = (80GB - 16GB) / 128KB = 500,000 # assume avg_chars_per_token = 4, prefix_indexer_hash_block_size = 64 (default) # each entry is about 358KB, so the memory footrpint is abut 11 MB per server - lru_indexer_capacity_per_server = 500,000*4/64 = 31250 - lru_indexer_capacity_total = 3 * 31250 = 93750 + lru_indexer_capacity_per_server = 500,000*4/64 = 31250 ``` See the [Use Helm section](#helm) to install an inferencepool with the environment variables. @@ -83,7 +84,7 @@ $ helm install triton-llama3-8b-instruct \ --set provider.name=[none|gke] \ --set inferenceExtension.env.EXPERIMENTAL_USE_SCHEDULER_V2=true \ --set inferenceExtension.env.ENABLE_PREFIX_CACHE_SCHEDULING=true \ - --set inferenceExtension.env.PREFIX_CACHE_LRU_CAPACITY=93750 \ + --set inferenceExtension.env.PREFIX_CACHE_LRU_CAPACITY_PER_SERVER=31250 \ --set inferenceExtension.env.PREFIX_CACHE_MAX_PREFIX_BLOCKS=1024 \ oci://us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension/charts/inferencepool --version v0 ```