Skip to content

Commit 510b581

Browse files
authored
1 parent 2f72666 commit 510b581

2 files changed

Lines changed: 31 additions & 12 deletions

File tree

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ const (
5555
// token is about 128KB in size, so we can cache 500K tokens. Using the default block size of 16
5656
// in vLLM, we will have 250K / 16 = 31.25K blocks.
5757
DefaultLRUCapacityPerServer = 31250
58+
// In P/D disaggregation mode, the prefill and decode are usually represented as two different scheduling profiles to pick
59+
// the prefill and decode endpoints. This constant defines the prefill profile name to ensure that the index is updated
60+
// for the prefill endpoint and not only for the primary endpoint that will initially handle the request.
61+
// This is hardcoded for now until we land on a canonical approach for plugins to identify prefill and decode endpoints
62+
// (See https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/2080)
63+
Experimental_DefaultPrefillProfile = "prefill"
5864

5965
PrefixCachePluginType = "prefix-cache-scorer"
6066
)
@@ -269,10 +275,10 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
269275
func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) {
270276
primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]
271277
targetPod := primaryProfileResult.TargetPods[0] // get the first pod of the primary profile
278+
servers := []Server{p.makeServer(targetPod)}
272279

273-
gpuBlocks := p.config.LRUCapacityPerServer
274-
if p.config.AutoTune && targetPod.GetMetrics().CacheNumGPUBlocks > 0 {
275-
gpuBlocks = targetPod.GetMetrics().CacheNumGPUBlocks
280+
if pr, exists := schedulingResult.ProfileResults[Experimental_DefaultPrefillProfile]; exists && len(pr.TargetPods) > 0 {
281+
servers = append(servers, p.makeServer(pr.TargetPods[0]))
276282
}
277283

278284
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
@@ -288,10 +294,9 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
288294
// WaitGroup is added to the Plugin struct to allow waiting in tests.
289295
p.wg.Add(1)
290296
go func() {
291-
p.indexer.Add(state.PrefixHashes, Server{
292-
ServerID(targetPod.GetPod().NamespacedName),
293-
gpuBlocks,
294-
})
297+
for _, s := range servers {
298+
p.indexer.Add(state.PrefixHashes, s)
299+
}
295300
p.wg.Done()
296301
}()
297302

@@ -302,6 +307,17 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
302307
metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize)
303308
}
304309

310+
func (p *Plugin) makeServer(targetPod types.Pod) Server {
311+
gpuBlocks := p.config.LRUCapacityPerServer
312+
if p.config.AutoTune && targetPod.GetMetrics().CacheNumGPUBlocks > 0 {
313+
gpuBlocks = targetPod.GetMetrics().CacheNumGPUBlocks
314+
}
315+
return Server{
316+
ServerID(targetPod.GetPod().NamespacedName),
317+
gpuBlocks,
318+
}
319+
}
320+
305321
// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
306322
func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int {
307323
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ func TestPrefixPluginCompletion(t *testing.T) {
4949

5050
pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: backendmetrics.NewMetricsState()}
5151
pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: backendmetrics.NewMetricsState()}
52-
pods := []types.Pod{pod1, pod2}
52+
pod3 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}, MetricsState: backendmetrics.NewMetricsState()}
53+
pods := []types.Pod{pod1, pod2, pod3}
5354

5455
// First request.
5556
req1 := &types.LLMRequest{
@@ -72,11 +73,12 @@ func TestPrefixPluginCompletion(t *testing.T) {
7273
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
7374
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
7475

75-
// Simulate pod1 was picked.
76+
// Simulate pod1 was picked and pod3 was picked as a prefill node.
7677
schedulingResult := &types.SchedulingResult{
7778
PrimaryProfileName: "default",
7879
ProfileResults: map[string]*types.ProfileRunResult{
79-
"default": {TargetPods: []types.Pod{pod1}},
80+
"default": {TargetPods: []types.Pod{pod1}},
81+
Experimental_DefaultPrefillProfile: {TargetPods: []types.Pod{pod3}},
8082
},
8183
}
8284
plugin.PreRequest(context.Background(), req1, schedulingResult)
@@ -131,8 +133,9 @@ func TestPrefixPluginCompletion(t *testing.T) {
131133
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
132134
// Total hashes = 2 (the first one is for the prefix with model)
133135
assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect")
134-
assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix")
136+
assert.Equal(t, 2, len(state.PrefixCacheServers), "pod1 and pod3 should have cached the aaaa prefix")
135137
assert.Equal(t, 0.5, scores[pod1], "score should be 0.5 - the model and the first prefix block match")
138+
assert.Equal(t, 0.5, scores[pod3], "score should be 0.5 - the model and the first prefix block match on the prefill node")
136139
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
137140

138141
schedulingResult = &types.SchedulingResult{
@@ -191,7 +194,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
191194
// Input size is 12, hash block size is 4, so 3 hashes will be calculated.
192195
// Total hashes = 3 (the first one is for the prefix with model)
193196
assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect")
194-
assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix")
197+
assert.Equal(t, 2, len(state.PrefixCacheServers), "pod1 and pod3 should have cached the aaaa prefix")
195198
assert.Equal(t, 2./3, scores[pod1], "score should be 2./3 - the model and the first 2 prefix blocks match")
196199
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
197200

0 commit comments

Comments
 (0)