Skip to content

Commit 1097c78

Browse files
authored
Fix potential bugs in FastAPI frontend and add comments (vllm-project#28)
1 parent 6376304 commit 1097c78

File tree

1 file changed

+30
-5
lines changed

1 file changed

+30
-5
lines changed

cacheflow/http_frontend/fastapi_frontend.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
from cacheflow.worker.controller import DeviceID
1818
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
1919

20+
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
2021
app = FastAPI()
2122

23+
2224
class FastAPIFrontend:
2325
def __init__(
2426
self,
@@ -30,7 +32,7 @@ def __init__(
3032
dtype: str,
3133
seed: int,
3234
swap_space: int,
33-
max_batch_size: int,
35+
max_num_batched_tokens: int,
3436
num_nodes: int,
3537
num_devices_per_node: int,
3638
distributed_init_method: str,
@@ -51,7 +53,7 @@ def __init__(
5153
dtype=dtype,
5254
seed=seed,
5355
swap_space=swap_space,
54-
max_batch_size=max_batch_size,
56+
max_num_batched_tokens=max_num_batched_tokens,
5557
num_nodes=num_nodes,
5658
num_devices_per_node=num_devices_per_node,
5759
distributed_init_method=distributed_init_method,
@@ -68,12 +70,14 @@ async def server_step(self):
6870
self.is_server_running = True
6971
updated_seq_groups = await self.server.step.remote()
7072
self.is_server_running = False
73+
# Notify the waiting coroutines that there new outputs ready.
7174
for seq_group in updated_seq_groups:
7275
group_id = seq_group.group_id
7376
self.running_seq_groups[group_id] = seq_group
7477
self.sequence_group_events[group_id].set()
7578

7679
async def generate(self, request_dict: Dict):
80+
# Preprocess the request.
7781
prompt = request_dict["prompt"]
7882
sampling_params = SamplingParams.from_dict(request_dict)
7983
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
@@ -87,15 +91,27 @@ async def generate(self, request_dict: Dict):
8791
arrival_time = time.time()
8892
group_id = next(self.seq_group_counter)
8993
seq_group = SequenceGroup(group_id, seqs, arrival_time)
94+
# Create an event to notify us that there is new output from the
95+
# cacheflow server.
9096
group_event = asyncio.Event()
97+
self.running_seq_groups[group_id] = seq_group
9198
self.sequence_group_events[group_id] = group_event
99+
# Add the request into the cacheflow server's waiting queue.
92100
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
101+
# The cacheflow server does not have a background loop that keeps
102+
# processing incoming requests. Therefore, we need to keep kicking
103+
# the server to process the requests.
93104
while True:
105+
# Kick the server if the server is not running.
94106
if not self.is_server_running:
95107
await self.server_step()
96-
# Wait for new output. Add a 1s timeout to prevent dead lock.
97-
await asyncio.wait_for(group_event.wait(), timeout=1)
108+
# Wait for new output. The group_event will be set in server_step
109+
# when there is new output available for the sequence group.
110+
# Added a timeout to prevent deadlock.
111+
await asyncio.wait_for(group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
112+
# Reset the event to wait for the next output.
98113
group_event.clear()
114+
# Decode and return new outputs
99115
seq_group = self.running_seq_groups[group_id]
100116
all_outputs = []
101117
for seq in seq_group.seqs:
@@ -107,7 +123,16 @@ async def generate(self, request_dict: Dict):
107123
"error": 0,
108124
}
109125
yield (json.dumps(ret) + "\0").encode("utf-8")
126+
127+
# Once finished, release the resources of the sequence group.
110128
if seq_group.is_finished():
129+
del self.running_seq_groups[group_id]
130+
del self.sequence_group_events[group_id]
131+
# Kick the server if the server is not running. This is to
132+
# prevent that there are still requests in server's waiting
133+
# queue to be executed.
134+
if not self.is_server_running:
135+
await self.server_step()
111136
break
112137

113138

@@ -143,7 +168,7 @@ async def generate_stream(request: Request):
143168
dtype=args.dtype,
144169
seed=args.seed,
145170
swap_space=args.swap_space,
146-
max_batch_size=args.max_batch_size,
171+
max_num_batched_tokens=args.max_num_batched_tokens,
147172
num_nodes=num_nodes,
148173
num_devices_per_node=num_devices_per_node,
149174
distributed_init_method=distributed_init_method,

0 commit comments

Comments
 (0)