Skip to content

Commit 4162650

Browse files
committed
[Add] examples for disaggregated prefill
Signed-off-by: ApostaC <[email protected]>
1 parent 4730522 commit 4162650

File tree

4 files changed

+286
-0
lines changed

4 files changed

+286
-0
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
local_cpu: False
2+
max_local_cpu_size: 0
3+
#local_disk:
4+
max_local_disk_size: 0
5+
remote_serde: NULL
6+
7+
enable_nixl: True
8+
nixl_role: "receiver"
9+
nixl_peer_host: "localhost"
10+
nixl_peer_port: 55555
11+
nixl_buffer_size: 1073741824 # 1GB
12+
nixl_buffer_device: "cuda"
13+
nixl_enable_gc: True
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
local_cpu: False
2+
max_local_cpu_size: 0
3+
#local_disk:
4+
max_local_disk_size: 0
5+
remote_serde: NULL
6+
7+
enable_nixl: True
8+
nixl_role: "sender"
9+
nixl_peer_host: "localhost"
10+
nixl_peer_port: 55555
11+
nixl_buffer_size: 1073741824 # 1GB
12+
nixl_buffer_device: "cuda"
13+
nixl_enable_gc: True
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/bin/bash
2+
3+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
4+
5+
if [[ $# -lt 1 ]]; then
6+
echo "Usage: $0 <prefiller | decoder | proxy> [model]"
7+
exit 1
8+
fi
9+
10+
if [[ $# -eq 1 ]]; then
11+
echo "Using default model: meta-llama/Llama-3.1-8B-Instruct"
12+
MODEL="meta-llama/Llama-3.1-8B-Instruct"
13+
else
14+
echo "Using model: $2"
15+
MODEL=$2
16+
fi
17+
18+
19+
if [[ $1 == "prefiller" ]]; then
20+
# Prefiller listens on port 8100
21+
prefill_config_file=$SCRIPT_DIR/configs/lmcache-prefiller-config.yaml
22+
23+
UCX_TLS=cuda_ipc,cuda_copy,tcp \
24+
LMCACHE_CONFIG_FILE=$prefill_config_file \
25+
LMCACHE_USE_EXPERIMENTAL=True \
26+
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
27+
VLLM_WORKER_MULTIPROC_METHOD=spawn \
28+
CUDA_VISIBLE_DEVICES=0 \
29+
vllm serve $MODEL \
30+
--port 8100 \
31+
--disable-log-requests \
32+
--enforce-eager \
33+
--kv-transfer-config \
34+
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}'
35+
36+
elif [[ $1 == "decoder" ]]; then
37+
# Decoder listens on port 8200
38+
decode_config_file=$SCRIPT_DIR/configs/lmcache-decoder-config.yaml
39+
40+
UCX_TLS=cuda_ipc,cuda_copy,tcp \
41+
LMCACHE_CONFIG_FILE=$decode_config_file \
42+
LMCACHE_USE_EXPERIMENTAL=True \
43+
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
44+
VLLM_WORKER_MULTIPROC_METHOD=spawn \
45+
CUDA_VISIBLE_DEVICES=1 \
46+
vllm serve $MODEL \
47+
--port 8200 \
48+
--disable-log-requests \
49+
--enforce-eager \
50+
--kv-transfer-config \
51+
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}'
52+
53+
elif [[ $1 == "proxy" ]]; then
54+
# Proxy listens on port 9000
55+
python3 $SCRIPT_DIR/disagg_proxy_server.py \
56+
--host localhost \
57+
--port 9000 \
58+
--prefiller-host localhost \
59+
--prefiller-port 8100 \
60+
--decoder-host localhost \
61+
--decoder-port 8200
62+
63+
else
64+
echo "Invalid role: $1"
65+
echo "Should be either prefill, decode, or proxy"
66+
exit 1
67+
fi
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import argparse
4+
import os
5+
import time
6+
from contextlib import asynccontextmanager
7+
8+
import httpx
9+
import numpy as np
10+
from fastapi import FastAPI, Request
11+
from fastapi.responses import StreamingResponse
12+
13+
14+
@asynccontextmanager
15+
async def lifespan(app: FastAPI):
16+
"""
17+
Lifespan context manager to handle startup and shutdown events.
18+
"""
19+
# Startup: Initialize clients
20+
prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1'
21+
decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1'
22+
23+
app.state.prefill_client = httpx.AsyncClient(timeout=None,
24+
base_url=prefiller_base_url)
25+
app.state.decode_client = httpx.AsyncClient(timeout=None,
26+
base_url=decoder_base_url)
27+
28+
yield
29+
30+
# Shutdown: Close clients
31+
await app.state.prefill_client.aclose()
32+
await app.state.decode_client.aclose()
33+
34+
35+
# Update FastAPI app initialization to use lifespan
36+
app = FastAPI(lifespan=lifespan)
37+
38+
39+
class StatsCalculator:
40+
41+
def __init__(self):
42+
self._stats = []
43+
self._last_log_time = time.time()
44+
45+
def add(self, value):
46+
self._stats.append(value)
47+
if time.time() - self._last_log_time > 5:
48+
self._log_stats()
49+
self._last_log_time = time.time()
50+
51+
def _log_stats(self):
52+
# Print average, median, and 99th percentile
53+
np_arr = np.array(self._stats)
54+
output_str = f"\nNum requests: {len(self._stats)}" + \
55+
"\nPrefill node TTFT stats:" + \
56+
f"\n - Average (ms): {np.mean(np_arr)}" + \
57+
f"\n - Median (ms): {np.median(np_arr)}" + \
58+
f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n"
59+
print("===============================", output_str,
60+
"===============================")
61+
62+
63+
stats_calculator = StatsCalculator()
64+
counter = 0
65+
66+
67+
def parse_args():
68+
parser = argparse.ArgumentParser()
69+
70+
parser.add_argument("--port", type=int, default=8000)
71+
parser.add_argument("--host", type=str, default="localhost")
72+
parser.add_argument("--prefiller-host", type=str, default="localhost")
73+
parser.add_argument("--prefiller-port", type=int, default=8100)
74+
parser.add_argument("--decoder-host", type=str, default="localhost")
75+
parser.add_argument("--decoder-port", type=int, default=8200)
76+
args = parser.parse_args()
77+
return args
78+
79+
80+
# Initialize variables to hold the persistent clients
81+
app.state.prefill_client = None
82+
app.state.decode_client = None
83+
84+
85+
async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
86+
req_data: dict):
87+
"""
88+
Send a request to a service using a persistent client.
89+
"""
90+
req_data = req_data.copy()
91+
req_data['max_tokens'] = 1
92+
if 'max_completion_tokens' in req_data:
93+
req_data['max_completion_tokens'] = 1
94+
95+
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
96+
response = await client.post(endpoint, json=req_data, headers=headers)
97+
response.raise_for_status()
98+
return response
99+
100+
101+
async def stream_service_response(client: httpx.AsyncClient, endpoint: str,
102+
req_data: dict):
103+
"""
104+
Asynchronously stream the response from a service using a persistent client.
105+
"""
106+
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
107+
async with client.stream("POST", endpoint, json=req_data,
108+
headers=headers) as response:
109+
response.raise_for_status()
110+
async for chunk in response.aiter_bytes():
111+
yield chunk
112+
113+
114+
@app.post("/v1/completions")
115+
async def handle_completions(request: Request):
116+
global counter, stats_calculator
117+
counter += 1
118+
119+
st = time.time()
120+
try:
121+
req_data = await request.json()
122+
123+
# Send request to prefill service, ignore the response
124+
await send_request_to_service(app.state.prefill_client, "/completions",
125+
req_data)
126+
127+
et = time.time()
128+
stats_calculator.add(et - st)
129+
130+
# Stream response from decode service
131+
async def generate_stream():
132+
async for chunk in stream_service_response(app.state.decode_client,
133+
"/completions",
134+
req_data):
135+
yield chunk
136+
137+
return StreamingResponse(generate_stream(),
138+
media_type="application/json")
139+
140+
except Exception as e:
141+
import sys
142+
import traceback
143+
exc_info = sys.exc_info()
144+
print("Error occurred in disagg prefill proxy server"
145+
" - completions endpoint")
146+
print(e)
147+
print("".join(traceback.format_exception(*exc_info)))
148+
raise
149+
150+
151+
@app.post("/v1/chat/completions")
152+
async def handle_chat_completions(request: Request):
153+
global counter, stats_calculator
154+
counter += 1
155+
156+
st = time.time()
157+
try:
158+
req_data = await request.json()
159+
160+
# Send request to prefill service, ignore the response
161+
await send_request_to_service(app.state.prefill_client,
162+
"/chat/completions", req_data)
163+
164+
et = time.time()
165+
stats_calculator.add(et - st)
166+
167+
# Stream response from decode service
168+
async def generate_stream():
169+
async for chunk in stream_service_response(app.state.decode_client,
170+
"/chat/completions",
171+
req_data):
172+
yield chunk
173+
174+
return StreamingResponse(generate_stream(),
175+
media_type="application/json")
176+
177+
except Exception as e:
178+
import sys
179+
import traceback
180+
exc_info = sys.exc_info()
181+
print("Error occurred in disagg prefill proxy server "
182+
" - chat completions endpoint")
183+
print(e)
184+
print("".join(traceback.format_exception(*exc_info)))
185+
raise
186+
187+
188+
if __name__ == '__main__':
189+
global global_args
190+
global_args = parse_args()
191+
192+
import uvicorn
193+
uvicorn.run(app, host=global_args.host, port=global_args.port)

0 commit comments

Comments
 (0)