1717from cacheflow .worker .controller import DeviceID
1818from cacheflow .utils import Counter , get_gpu_memory , get_cpu_memory
1919
20+ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
2021app = FastAPI ()
2122
23+
2224class FastAPIFrontend :
2325 def __init__ (
2426 self ,
@@ -30,7 +32,7 @@ def __init__(
3032 dtype : str ,
3133 seed : int ,
3234 swap_space : int ,
33- max_batch_size : int ,
35+ max_num_batched_tokens : int ,
3436 num_nodes : int ,
3537 num_devices_per_node : int ,
3638 distributed_init_method : str ,
@@ -51,7 +53,7 @@ def __init__(
5153 dtype = dtype ,
5254 seed = seed ,
5355 swap_space = swap_space ,
54- max_batch_size = max_batch_size ,
56+ max_num_batched_tokens = max_num_batched_tokens ,
5557 num_nodes = num_nodes ,
5658 num_devices_per_node = num_devices_per_node ,
5759 distributed_init_method = distributed_init_method ,
@@ -68,12 +70,14 @@ async def server_step(self):
6870 self .is_server_running = True
6971 updated_seq_groups = await self .server .step .remote ()
7072 self .is_server_running = False
73+ # Notify the waiting coroutines that there new outputs ready.
7174 for seq_group in updated_seq_groups :
7275 group_id = seq_group .group_id
7376 self .running_seq_groups [group_id ] = seq_group
7477 self .sequence_group_events [group_id ].set ()
7578
7679 async def generate (self , request_dict : Dict ):
80+ # Preprocess the request.
7781 prompt = request_dict ["prompt" ]
7882 sampling_params = SamplingParams .from_dict (request_dict )
7983 sampling_params .stop_token_ids .add (self .tokenizer .eos_token_id )
@@ -87,15 +91,27 @@ async def generate(self, request_dict: Dict):
8791 arrival_time = time .time ()
8892 group_id = next (self .seq_group_counter )
8993 seq_group = SequenceGroup (group_id , seqs , arrival_time )
94+ # Create an event to notify us that there is new output from the
95+ # cacheflow server.
9096 group_event = asyncio .Event ()
97+ self .running_seq_groups [group_id ] = seq_group
9198 self .sequence_group_events [group_id ] = group_event
99+ # Add the request into the cacheflow server's waiting queue.
92100 await self .server .add_sequence_groups .remote ([(seq_group , sampling_params )])
101+ # The cacheflow server does not have a background loop that keeps
102+ # processing incoming requests. Therefore, we need to keep kicking
103+ # the server to process the requests.
93104 while True :
105+ # Kick the server if the server is not running.
94106 if not self .is_server_running :
95107 await self .server_step ()
96- # Wait for new output. Add a 1s timeout to prevent dead lock.
97- await asyncio .wait_for (group_event .wait (), timeout = 1 )
108+ # Wait for new output. The group_event will be set in server_step
109+ # when there is new output available for the sequence group.
110+ # Added a timeout to prevent deadlock.
111+ await asyncio .wait_for (group_event .wait (), timeout = TIMEOUT_TO_PREVENT_DEADLOCK )
112+ # Reset the event to wait for the next output.
98113 group_event .clear ()
114+ # Decode and return new outputs
99115 seq_group = self .running_seq_groups [group_id ]
100116 all_outputs = []
101117 for seq in seq_group .seqs :
@@ -107,7 +123,16 @@ async def generate(self, request_dict: Dict):
107123 "error" : 0 ,
108124 }
109125 yield (json .dumps (ret ) + "\0 " ).encode ("utf-8" )
126+
127+ # Once finished, release the resources of the sequence group.
110128 if seq_group .is_finished ():
129+ del self .running_seq_groups [group_id ]
130+ del self .sequence_group_events [group_id ]
131+ # Kick the server if the server is not running. This is to
132+ # prevent that there are still requests in server's waiting
133+ # queue to be executed.
134+ if not self .is_server_running :
135+ await self .server_step ()
111136 break
112137
113138
@@ -143,7 +168,7 @@ async def generate_stream(request: Request):
143168 dtype = args .dtype ,
144169 seed = args .seed ,
145170 swap_space = args .swap_space ,
146- max_batch_size = args .max_batch_size ,
171+ max_num_batched_tokens = args .max_num_batched_tokens ,
147172 num_nodes = num_nodes ,
148173 num_devices_per_node = num_devices_per_node ,
149174 distributed_init_method = distributed_init_method ,
0 commit comments