Skip to content

Commit a5f5674

Browse files
YaoJiayiAkshat-Tripathi
authored andcommitted
[Feature] Support KV cache offloading and disagg prefill with LMCache connector. (vllm-project#12953)
1 parent 77c117b commit a5f5674

File tree

5 files changed

+310
-2
lines changed

5 files changed

+310
-2
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
This file demonstrates the example usage of cpu offloading
4+
with LMCache.
5+
6+
Note that `pip install lmcache` is needed to run this example.
7+
Learn more about LMCache in https://github.com/LMCache/LMCache.
8+
"""
9+
import os
10+
import time
11+
12+
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
13+
from lmcache.integration.vllm.utils import ENGINE_NAME
14+
15+
from vllm import LLM, SamplingParams
16+
from vllm.config import KVTransferConfig
17+
18+
# LMCache-related environment variables
19+
# Use experimental features in LMCache
20+
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
21+
# LMCache is set to use 256 tokens per chunk
22+
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
23+
# Enable local CPU backend in LMCache
24+
os.environ["LMCACHE_LOCAL_CPU"] = "True"
25+
# Set local CPU memory limit to 5.0 GB
26+
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
27+
28+
# This example script runs two requests with a shared prefix.
29+
shared_prompt = "Hello, how are you?" * 1000
30+
first_prompt = [
31+
shared_prompt + "Hello, my name is",
32+
]
33+
second_prompt = [
34+
shared_prompt + "Tell me a very long story",
35+
]
36+
37+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
38+
39+
ktc = KVTransferConfig.from_cli(
40+
'{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}')
41+
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
42+
# memory. Reduce the value if your GPU has less memory.
43+
# Note that LMCache is not compatible with chunked prefill for now.
44+
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
45+
kv_transfer_config=ktc,
46+
max_model_len=8000,
47+
enable_chunked_prefill=False,
48+
gpu_memory_utilization=0.8)
49+
50+
outputs = llm.generate(first_prompt, sampling_params)
51+
for output in outputs:
52+
generated_text = output.outputs[0].text
53+
print(f"Generated text: {generated_text!r}")
54+
print("First request done.")
55+
56+
time.sleep(1)
57+
58+
outputs = llm.generate(second_prompt, sampling_params)
59+
for output in outputs:
60+
generated_text = output.outputs[0].text
61+
print(f"Generated text: {generated_text!r}")
62+
print("Second request done.")
63+
64+
# Clean up lmcache backend
65+
LMCacheEngineBuilder.destroy(ENGINE_NAME)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
This file demonstrates the example usage of disaggregated prefilling
4+
with LMCache.
5+
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
6+
and launch an additional LMCache server.
7+
KV cache is transferred in the following manner:
8+
VLLM prefill node -> LMCache server -> VLLM decode node.
9+
10+
Note that `pip install lmcache` is needed to run this example.
11+
Learn more about LMCache in https://github.com/LMCache/LMCache.
12+
"""
13+
import os
14+
import subprocess
15+
import time
16+
from multiprocessing import Event, Process
17+
18+
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
19+
from lmcache.integration.vllm.utils import ENGINE_NAME
20+
21+
from vllm import LLM, SamplingParams
22+
from vllm.config import KVTransferConfig
23+
24+
# LMCache-related environment variables
25+
# The port to start LMCache server
26+
port = 8100
27+
# Use experimental features in LMCache
28+
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
29+
# LMCache is set to use 256 tokens per chunk
30+
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
31+
# Disable local CPU backend in LMCache
32+
os.environ["LMCACHE_LOCAL_CPU"] = "False"
33+
# Set local CPU memory buffer limit to 5.0 GB
34+
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
35+
# Set the remote URL for LMCache server
36+
os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}"
37+
# Set the serializer/deserializer between vllm and LMCache server
38+
# `naive` indicates using raw bytes of the tensor without any compression
39+
os.environ["LMCACHE_REMOTE_SERDE"] = "naive"
40+
41+
42+
def run_prefill(prefill_done, prompts):
43+
# We use GPU 0 for prefill node.
44+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
45+
46+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
47+
48+
ktc = KVTransferConfig.from_cli(
49+
'{"kv_connector":"LMCacheConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
50+
)
51+
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
52+
# memory. Reduce the value if your GPU has less memory.
53+
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
54+
kv_transfer_config=ktc,
55+
max_model_len=8000,
56+
gpu_memory_utilization=0.8,
57+
enforce_eager=True)
58+
59+
#llm.generate(prompts, sampling_params)
60+
outputs = llm.generate(prompts, sampling_params)
61+
for output in outputs:
62+
generated_text = output.outputs[0].text
63+
print(f"Generated text: {generated_text!r}")
64+
print("Prefill node is finished.")
65+
prefill_done.set()
66+
67+
# Clean up lmcache backend
68+
LMCacheEngineBuilder.destroy(ENGINE_NAME)
69+
70+
71+
def run_decode(prefill_done, prompts, timeout=1):
72+
# We use GPU 1 for decode node.
73+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
74+
75+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
76+
77+
ktc = KVTransferConfig.from_cli(
78+
'{"kv_connector":"LMCacheConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
79+
)
80+
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
81+
# of memory. Reduce the value if your GPU has less memory.
82+
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
83+
kv_transfer_config=ktc,
84+
max_model_len=8000,
85+
gpu_memory_utilization=0.8,
86+
enforce_eager=True)
87+
88+
print("Waiting for prefill node to finish...")
89+
prefill_done.wait()
90+
time.sleep(timeout)
91+
92+
outputs = llm.generate(prompts, sampling_params)
93+
for output in outputs:
94+
generated_text = output.outputs[0].text
95+
print(f"Generated text: {generated_text!r}")
96+
97+
# Clean up lmcache backend
98+
LMCacheEngineBuilder.destroy(ENGINE_NAME)
99+
100+
101+
def run_lmcache_server(port):
102+
server_proc = subprocess.Popen([
103+
"python", "-m", "lmcache.experimental.server", "localhost",
104+
str(port)
105+
])
106+
return server_proc
107+
108+
109+
if __name__ == "__main__":
110+
111+
prompts = [
112+
"Hello, how are you?" * 1000,
113+
]
114+
115+
prefill_done = Event()
116+
prefill_process = Process(target=run_prefill, args=(prefill_done, prompts))
117+
decode_process = Process(target=run_decode, args=(prefill_done, prompts))
118+
lmcache_server_process = run_lmcache_server(port)
119+
120+
# Start prefill node
121+
prefill_process.start()
122+
123+
# Start decode node
124+
decode_process.start()
125+
126+
# Clean up the processes
127+
decode_process.join()
128+
prefill_process.terminate()
129+
lmcache_server_process.terminate()
130+
lmcache_server_process.wait()

vllm/distributed/kv_transfer/kv_connector/factory.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,8 @@ def create_connector(cls, rank: int, local_rank: int,
4848
"MooncakeConnector",
4949
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
5050
"SimpleConnector")
51+
52+
KVConnectorFactory.register_connector(
53+
"LMCacheConnector",
54+
"vllm.distributed.kv_transfer.kv_connector.lmcache_connector",
55+
"LMCacheConnector")
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
LMCache KV Cache Connector for Distributed Machine Learning Inference
4+
5+
The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker
6+
(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache;
7+
(2) offload and share KV caches.
8+
"""
9+
10+
from typing import TYPE_CHECKING, List, Tuple, Union
11+
12+
import torch
13+
14+
from vllm.config import VllmConfig
15+
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
16+
from vllm.logger import init_logger
17+
from vllm.sequence import IntermediateTensors
18+
19+
if TYPE_CHECKING:
20+
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
21+
22+
logger = init_logger(__name__)
23+
24+
25+
class LMCacheConnector(KVConnectorBase):
26+
27+
def __init__(
28+
self,
29+
rank: int,
30+
local_rank: int,
31+
config: VllmConfig,
32+
):
33+
34+
self.transfer_config = config.kv_transfer_config
35+
self.vllm_config = config
36+
37+
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
38+
from lmcache.integration.vllm.utils import ENGINE_NAME
39+
from lmcache.integration.vllm.vllm_adapter import (
40+
RetrieveStatus, StoreStatus, init_lmcache_engine,
41+
lmcache_retrieve_kv, lmcache_should_store, lmcache_store_kv)
42+
logger.info("Initializing LMCacheConfig under kv_transfer_config %s",
43+
self.transfer_config)
44+
45+
# TODO (Jiayi): Find model_config, parallel_config, and cache_config
46+
self.engine = init_lmcache_engine(config.model_config,
47+
config.parallel_config,
48+
config.cache_config)
49+
self.lmcache_engine_name = ENGINE_NAME
50+
self.lmcache_engine_builder = LMCacheEngineBuilder
51+
52+
self.model_config = config.model_config
53+
self.parallel_config = config.parallel_config
54+
self.cache_config = config.cache_config
55+
self.lmcache_retrieve_kv = lmcache_retrieve_kv
56+
self.lmcache_store_kv = lmcache_store_kv
57+
self.lmcache_should_store = lmcache_should_store
58+
self.store_status = StoreStatus
59+
self.retrieve_status = RetrieveStatus
60+
61+
def recv_kv_caches_and_hidden_states(
62+
self, model_executable: torch.nn.Module,
63+
model_input: "ModelInputForGPUWithSamplingMetadata",
64+
kv_caches: List[torch.Tensor]
65+
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
66+
"ModelInputForGPUWithSamplingMetadata"]:
67+
68+
hidden_or_intermediate_states = None
69+
70+
# TODO (Jiayi): Need to support chunked prefill
71+
retrieve_status = self.retrieve_status.PREFILL
72+
73+
model_input, bypass_model_exec = self.lmcache_retrieve_kv(
74+
model_executable, model_input, self.cache_config, kv_caches,
75+
retrieve_status)
76+
77+
return hidden_or_intermediate_states, bypass_model_exec, model_input
78+
79+
def send_kv_caches_and_hidden_states(
80+
self,
81+
model_executable: torch.nn.Module,
82+
model_input: "ModelInputForGPUWithSamplingMetadata",
83+
kv_caches: List[torch.Tensor],
84+
hidden_or_intermediate_states: Union[torch.Tensor,
85+
IntermediateTensors],
86+
) -> None:
87+
num_reqs = 0
88+
seq_group_list = model_input.sampling_metadata.seq_groups
89+
assert seq_group_list is not None
90+
for seq_group in seq_group_list:
91+
seq_ids = seq_group.seq_ids
92+
for seq_id in seq_ids:
93+
num_reqs += 1
94+
95+
# TODO (Jiayi): Only normal prefill is supported for now
96+
store_status = self.lmcache_should_store(model_input)
97+
self.lmcache_store_kv(
98+
self.model_config,
99+
self.parallel_config,
100+
self.cache_config,
101+
model_executable,
102+
model_input,
103+
kv_caches,
104+
store_status,
105+
)
106+
107+
def close(self):
108+
self.lmcache_engine_builder.destroy(self.lmcache_engine_name)

vllm/distributed/parallel_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -962,8 +962,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
962962
return
963963

964964
if all([
965-
vllm_config.kv_transfer_config.need_kv_parallel_group, _KV_TRANSFER
966-
is None
965+
vllm_config.kv_transfer_config.is_kv_transfer_instance,
966+
_KV_TRANSFER is None
967967
]):
968968
_KV_TRANSFER = kv_transfer.KVTransferAgent(
969969
rank=get_world_group().rank,

0 commit comments

Comments
 (0)