@@ -42,8 +42,8 @@ class SchedulingBudget:
4242 """
4343 token_budget : int
4444 max_num_seqs : int
45- _requeset_ids_num_batched_tokens : Set [int ] = field (default_factory = set )
46- _requeset_ids_num_curr_seqs : Set [int ] = field (default_factory = set )
45+ _requeset_ids_num_batched_tokens : Set [str ] = field (default_factory = set )
46+ _requeset_ids_num_curr_seqs : Set [str ] = field (default_factory = set )
4747 _num_batched_tokens : int = 0
4848 _num_curr_seqs : int = 0
4949
@@ -133,7 +133,7 @@ def is_empty(self) -> bool:
133133 return (not self .scheduled_seq_groups and not self .blocks_to_swap_in
134134 and not self .blocks_to_swap_out and not self .blocks_to_copy )
135135
136- def _sort_by_lora_ids (self ) -> bool :
136+ def _sort_by_lora_ids (self ):
137137 self .scheduled_seq_groups = sorted (
138138 self .scheduled_seq_groups ,
139139 key = lambda g : (g .seq_group .lora_int_id , g .seq_group .request_id ))
@@ -337,7 +337,8 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
337337 self .free_seq (seq )
338338
339339 def has_unfinished_seqs (self ) -> bool :
340- return self .waiting or self .running or self .swapped
340+ return len (self .waiting ) != 0 or len (self .running ) != 0 or len (
341+ self .swapped ) != 0
341342
342343 def get_num_unfinished_seq_groups (self ) -> int :
343344 return len (self .waiting ) + len (self .running ) + len (self .swapped )
@@ -404,7 +405,7 @@ def _schedule_running(
404405 budget .subtract_num_seqs (seq_group .request_id ,
405406 num_running_seqs )
406407 if curr_loras is not None and seq_group .lora_int_id > 0 :
407- curr_loras .pop (seq_group .lora_int_id )
408+ curr_loras .remove (seq_group .lora_int_id )
408409
409410 if running_queue :
410411 # Preempt the lowest-priority sequence groups.
@@ -496,7 +497,7 @@ def _schedule_swapped(
496497 now = time .time ()
497498 swapped_queue = policy .sort_by_priority (now , swapped_queue )
498499
499- leftover_swapped = deque ()
500+ leftover_swapped : Deque [ SequenceGroup ] = deque ()
500501 while swapped_queue :
501502 seq_group = swapped_queue [0 ]
502503
@@ -507,7 +508,9 @@ def _schedule_swapped(
507508 lora_int_id = 0
508509 if self .lora_enabled :
509510 lora_int_id = seq_group .lora_int_id
510- if (lora_int_id > 0 and lora_int_id not in curr_loras
511+ assert curr_loras is not None
512+ assert self .lora_config is not None
513+ if (lora_int_id > 0 and (lora_int_id not in curr_loras )
511514 and len (curr_loras ) >= self .lora_config .max_loras ):
512515 # We don't have a space for another LoRA, so
513516 # we ignore this request for now.
@@ -593,7 +596,7 @@ def _schedule_prefills(
593596 # Copy the queue so that the input queue is not modified.
594597 waiting_queue = deque ([s for s in waiting_queue ])
595598
596- leftover_waiting_sequences = deque ()
599+ leftover_waiting_sequences : Deque [ SequenceGroup ] = deque ()
597600 while self ._passed_delay (time .time ()) and waiting_queue :
598601 seq_group = waiting_queue [0 ]
599602
@@ -635,6 +638,8 @@ def _schedule_prefills(
635638 lora_int_id = 0
636639 if self .lora_enabled :
637640 lora_int_id = seq_group .lora_int_id
641+ assert curr_loras is not None
642+ assert self .lora_config is not None
638643 if (self .lora_enabled and lora_int_id > 0
639644 and lora_int_id not in curr_loras
640645 and len (curr_loras ) >= self .lora_config .max_loras ):
@@ -780,7 +785,7 @@ def _schedule_chunked_prefill(self):
780785 token_budget = self .scheduler_config .max_num_batched_tokens ,
781786 max_num_seqs = self .scheduler_config .max_num_seqs ,
782787 )
783- curr_loras = set ()
788+ curr_loras : Set [ int ] = set ()
784789
785790 remaining_waiting , prefills = (self .waiting ,
786791 SchedulerPrefillOutputs .create_empty ())
@@ -1087,7 +1092,7 @@ def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
10871092
10881093 def _get_num_new_tokens (self , seq_group : SequenceGroup ,
10891094 status : SequenceStatus , enable_chunking : bool ,
1090- budget : SchedulingBudget ) -> Tuple [ int , bool ] :
1095+ budget : SchedulingBudget ) -> int :
10911096 """Get the next new tokens to compute for a given sequence group
10921097 that's in a given `status`.
10931098
0 commit comments