@@ -302,8 +302,7 @@ def dummy_run(self,
302302 break
303303
304304 def generate_token_ids (self ,
305- sampled_token_ids : Union [torch .Tensor ,
306- list [np .ndarray ]],
305+ sampled_token_ids : torch .Tensor | list [list [int ]],
307306 sampling_metadata : SamplingMetadata = None ,
308307 scheduler_output : SchedulerOutput = None ,
309308 spec_decode_metadata : SpecDecodeMetadata = None ,
@@ -380,7 +379,6 @@ def generate_token_ids(self,
380379 common_attn_metadata .query_start_loc = \
381380 query_start_loc_pcp_full [:num_reqs + 1 ]
382381 if self .speculative_config .disable_padded_drafter_batch :
383- assert isinstance (sampled_token_ids , list )
384382 # NOTE: Currently, MTP-fullgraph is incompatibility with pcp
385383 token_indices_to_sample = None
386384 common_attn_metadata , token_indices = \
@@ -439,7 +437,7 @@ def _get_attn_metadata(self, attn_metadata):
439437 def _prepare_inputs (
440438 self ,
441439 common_attn_metadata : CommonAttentionMetadata ,
442- sampled_token_ids : list [np . ndarray ],
440+ sampled_token_ids : list [list [ int ] ],
443441 num_draft_tokens : list [int ],
444442 ) -> tuple [CommonAttentionMetadata , torch .Tensor ]:
445443 """
@@ -897,7 +895,7 @@ def _prepare_input_kernel(self, out_ptr: torch.Tensor,
897895
898896 def prepare_next_token_ids_cpu (
899897 self ,
900- sampled_token_ids : list [np . ndarray ],
898+ sampled_token_ids : list [list [ int ] ],
901899 requests : dict [str , CachedRequestState ],
902900 gpu_input_batch : InputBatch ,
903901 num_scheduled_tokens : dict [str , int ],
@@ -912,7 +910,7 @@ def prepare_next_token_ids_cpu(
912910 req_ids = gpu_input_batch .req_ids
913911 next_token_ids : list [int ] = []
914912 for i , token_ids in enumerate (sampled_token_ids ):
915- if token_ids . shape [ 0 ] > 0 :
913+ if token_ids :
916914 # Common case.
917915 next_token_id = token_ids [- 1 ]
918916 else :
@@ -923,7 +921,7 @@ def prepare_next_token_ids_cpu(
923921 seq_len = req_state .num_computed_tokens + num_scheduled_tokens [
924922 req_id ]
925923 next_token_id = req_state .get_token_id (seq_len )
926- next_token_ids .append (next_token_id . item () )
924+ next_token_ids .append (next_token_id )
927925 next_token_ids = torch .tensor (next_token_ids ,
928926 dtype = torch .int32 ,
929927 device = self .input_ids .device )
0 commit comments