@@ -1202,41 +1202,98 @@ def _dummy_run(
12021202 self ,
12031203 num_tokens : int ,
12041204 ) -> torch .Tensor :
1205- model = self .model
1206- if self .is_multimodal_model :
1207- input_ids = None
1208- inputs_embeds = self .inputs_embeds [:num_tokens ]
1209- else :
1210- input_ids = self .input_ids [:num_tokens ]
1211- inputs_embeds = None
1212- if self .uses_mrope :
1213- positions = self .mrope_positions [:, :num_tokens ]
1214- else :
1215- positions = self .positions [:num_tokens ]
12161205
1217- if get_pp_group (). is_first_rank :
1218- intermediate_tensors = None
1219- else :
1220- if self .intermediate_tensors is None :
1221- self .intermediate_tensors = (
1222- self . model . make_empty_intermediate_tensors (
1223- batch_size = self . max_num_tokens ,
1224- dtype = self . model_config . dtype ,
1225- device = self . device ))
1226- intermediate_tensors = IntermediateTensors ({
1227- k : v [: num_tokens ]
1228- for k , v in self . intermediate_tensors . items ()
1229- } )
1206+ # Set num_scheduled_tokens based on num_tokens and max_num_seqs
1207+ # for dummy run with LoRA so that the num_reqs collectively
1208+ # has num_tokens in total.
1209+ assert num_tokens <= self .scheduler_config . max_num_batched_tokens
1210+ max_num_reqs = self .scheduler_config . max_num_seqs
1211+ num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
1212+ min_tokens_per_req = num_tokens // num_reqs
1213+ num_scheduled_tokens_list = [ min_tokens_per_req ] * num_reqs
1214+ num_scheduled_tokens_list [ - 1 ] += num_tokens % num_reqs
1215+ assert sum ( num_scheduled_tokens_list ) == num_tokens
1216+ assert len ( num_scheduled_tokens_list ) == num_reqs
1217+ num_scheduled_tokens = np . array ( num_scheduled_tokens_list ,
1218+ dtype = np . int32 )
12301219
1231- with set_forward_context (None , self .vllm_config ,
1232- num_tokens = num_tokens ):
1233- hidden_states = model (
1234- input_ids = input_ids ,
1235- positions = positions ,
1236- intermediate_tensors = intermediate_tensors ,
1237- inputs_embeds = inputs_embeds ,
1238- )
1239- return hidden_states
1220+ with self .maybe_dummy_run_with_lora (self .lora_config ,
1221+ num_scheduled_tokens ):
1222+ model = self .model
1223+ if self .is_multimodal_model :
1224+ input_ids = None
1225+ inputs_embeds = self .inputs_embeds [:num_tokens ]
1226+ else :
1227+ input_ids = self .input_ids [:num_tokens ]
1228+ inputs_embeds = None
1229+ if self .uses_mrope :
1230+ positions = self .mrope_positions [:, :num_tokens ]
1231+ else :
1232+ positions = self .positions [:num_tokens ]
1233+
1234+ if get_pp_group ().is_first_rank :
1235+ intermediate_tensors = None
1236+ else :
1237+ if self .intermediate_tensors is None :
1238+ self .intermediate_tensors = (
1239+ self .model .make_empty_intermediate_tensors (
1240+ batch_size = self .max_num_tokens ,
1241+ dtype = self .model_config .dtype ,
1242+ device = self .device ))
1243+ intermediate_tensors = IntermediateTensors ({
1244+ k : v [:num_tokens ]
1245+ for k , v in self .intermediate_tensors .items ()
1246+ })
1247+
1248+ with set_forward_context (None ,
1249+ self .vllm_config ,
1250+ num_tokens = num_tokens ):
1251+ hidden_states = model (
1252+ input_ids = input_ids ,
1253+ positions = positions ,
1254+ intermediate_tensors = intermediate_tensors ,
1255+ inputs_embeds = inputs_embeds ,
1256+ )
1257+
1258+ logit_indices = np .cumsum (num_scheduled_tokens ) - 1
1259+ return hidden_states [logit_indices ]
1260+
1261+ @torch .inference_mode ()
1262+ def _dummy_sampler_run (
1263+ self ,
1264+ hidden_states : torch .Tensor ,
1265+ ) -> torch .Tensor :
1266+
1267+ logits = self .model .compute_logits (hidden_states , None )
1268+ num_reqs = logits .size (0 )
1269+
1270+ dummy_tensors = lambda v : torch .full (
1271+ (num_reqs , ), v , device = self .device )
1272+
1273+ dummy_metadata = SamplingMetadata (
1274+ temperature = dummy_tensors (0.5 ),
1275+ all_greedy = False ,
1276+ all_random = False ,
1277+ top_p = dummy_tensors (0.9 ),
1278+ top_k = dummy_tensors (logits .size (1 ) - 1 ),
1279+ min_p = None ,
1280+ generators = {},
1281+ max_num_logprobs = None ,
1282+ no_penalties = True ,
1283+ prompt_token_ids = None ,
1284+ frequency_penalties = dummy_tensors (0.1 ),
1285+ presence_penalties = dummy_tensors (0.1 ),
1286+ repetition_penalties = dummy_tensors (0.1 ),
1287+ output_token_ids = [[] for _ in range (num_reqs )],
1288+ min_tokens = {},
1289+ logit_bias = [None for _ in range (num_reqs )],
1290+ allowed_token_ids_mask = None ,
1291+ bad_words_token_ids = {},
1292+ )
1293+ sampler_output = self .model .sample (logits = logits ,
1294+ sampling_metadata = dummy_metadata )
1295+
1296+ return sampler_output
12401297
12411298 def profile_run (self ) -> None :
12421299 # Profile with multimodal encoder & encoder cache.
@@ -1332,60 +1389,14 @@ def profile_run(self) -> None:
13321389 # Cache the dummy encoder outputs.
13331390 self .encoder_cache ["tmp" ] = dict (enumerate (dummy_encoder_outputs ))
13341391
1335- # For profile, have maximum num_reqs and that collectively have
1336- # maximum num_tokens.
1337- num_reqs = self .scheduler_config .max_num_seqs
1338- num_tokens = self .max_num_tokens
1339- min_tokens_per_req = num_tokens // num_reqs
1340-
1341- num_scheduled_tokens_list = [min_tokens_per_req ] * num_reqs
1342- num_scheduled_tokens_list [- 1 ] += num_tokens % num_reqs
1343- assert sum (num_scheduled_tokens_list ) == num_tokens
1344- assert len (num_scheduled_tokens_list ) == num_reqs
1345-
1346- num_scheduled_tokens = np .array (num_scheduled_tokens_list ,
1347- dtype = np .int32 )
1348- logit_indices = np .cumsum (num_scheduled_tokens ) - 1
1349-
1350- with self .maybe_profile_with_lora (self .lora_config ,
1351- num_scheduled_tokens ):
1352- # Trigger compilation for general shape.
1353- hidden_states = self ._dummy_run (self .max_num_tokens )
1354- if get_pp_group ().is_last_rank :
1355- hidden_states = hidden_states [logit_indices ]
1356- logits = self .model .compute_logits (hidden_states , None )
1357- dummy_tensors = lambda v : torch .full (
1358- (num_reqs , ), v , device = self .device )
1359- dummy_metadata = SamplingMetadata (
1360- temperature = dummy_tensors (0.5 ),
1361- all_greedy = False ,
1362- all_random = False ,
1363- top_p = dummy_tensors (0.9 ),
1364- top_k = dummy_tensors (logits .size (1 ) - 1 ),
1365- min_p = None ,
1366- generators = {},
1367- max_num_logprobs = None ,
1368- no_penalties = True ,
1369- prompt_token_ids = torch .ones_like (logits ,
1370- dtype = torch .int64 ),
1371- frequency_penalties = dummy_tensors (0.1 ),
1372- presence_penalties = dummy_tensors (0.1 ),
1373- repetition_penalties = dummy_tensors (0.1 ),
1374- output_token_ids = [[] for _ in range (num_reqs )],
1375- min_tokens = {},
1376- logit_bias = [None for _ in range (num_reqs )],
1377- allowed_token_ids_mask = None ,
1378- bad_words_token_ids = {},
1379- )
1380- sampler_output = self .model .sample (
1381- logits = logits , sampling_metadata = dummy_metadata )
1382- else :
1383- logits = None
1384- sampler_output = None
1385- dummy_metadata = None
1386- torch .cuda .synchronize ()
1387- del hidden_states , logits , sampler_output , dummy_metadata
1388- self .encoder_cache .clear ()
1392+ hidden_states = self ._dummy_run (self .max_num_tokens )
1393+ if get_pp_group ().is_last_rank :
1394+ sampler_output = self ._dummy_sampler_run (hidden_states )
1395+ else :
1396+ sampler_output = None
1397+ torch .cuda .synchronize ()
1398+ del hidden_states , sampler_output
1399+ self .encoder_cache .clear ()
13891400 gc .collect ()
13901401
13911402 def capture_model (self ) -> None :
0 commit comments