@@ -1179,43 +1179,6 @@ def _dummy_run(
11791179 )
11801180 return hidden_states
11811181
1182- @torch .inference_mode ()
1183- def _dummy_sampler_run (
1184- self ,
1185- hidden_states : torch .Tensor ,
1186- ) -> torch .Tensor :
1187-
1188- logits = self .model .compute_logits (hidden_states , None )
1189- num_reqs = logits .size (0 )
1190-
1191- dummy_tensors = lambda v : torch .full (
1192- (num_reqs , ), v , device = self .device )
1193-
1194- dummy_metadata = SamplingMetadata (
1195- temperature = dummy_tensors (0.5 ),
1196- all_greedy = False ,
1197- all_random = False ,
1198- spec_token_ids = None ,
1199- top_p = dummy_tensors (0.9 ),
1200- top_k = dummy_tensors (logits .size (1 ) - 1 ),
1201- min_p = None ,
1202- generators = {},
1203- max_num_logprobs = None ,
1204- no_penalties = True ,
1205- prompt_token_ids = None ,
1206- frequency_penalties = dummy_tensors (0.1 ),
1207- presence_penalties = dummy_tensors (0.1 ),
1208- repetition_penalties = dummy_tensors (0.1 ),
1209- output_token_ids = [[] for _ in range (num_reqs )],
1210- min_tokens = {},
1211- logit_bias = [None for _ in range (num_reqs )],
1212- allowed_token_ids_mask = None ,
1213- )
1214- sampler_output = self .model .sample (logits = logits ,
1215- sampling_metadata = dummy_metadata )
1216-
1217- return sampler_output
1218-
12191182 def profile_run (self ) -> None :
12201183 # use an empty tensor instead of `None`` to force Dynamo to pass
12211184 # it by reference, rather by specializing on the value `None`.
@@ -1343,11 +1306,38 @@ def profile_run(self) -> None:
13431306 dummy_kv_caches )
13441307 if get_pp_group ().is_last_rank :
13451308 hidden_states = hidden_states [logit_indices ]
1346- sampler_output = self ._dummy_sampler_run (hidden_states )
1309+ logits = self .model .compute_logits (hidden_states , None )
1310+ dummy_tensors = lambda v : torch .full (
1311+ (num_reqs , ), v , device = self .device )
1312+ dummy_metadata = SamplingMetadata (
1313+ temperature = dummy_tensors (0.5 ),
1314+ all_greedy = False ,
1315+ all_random = False ,
1316+ spec_token_ids = None ,
1317+ top_p = dummy_tensors (0.9 ),
1318+ top_k = dummy_tensors (logits .size (1 ) - 1 ),
1319+ min_p = None ,
1320+ generators = {},
1321+ max_num_logprobs = None ,
1322+ no_penalties = True ,
1323+ prompt_token_ids = torch .ones_like (logits ,
1324+ dtype = torch .int64 ),
1325+ frequency_penalties = dummy_tensors (0.1 ),
1326+ presence_penalties = dummy_tensors (0.1 ),
1327+ repetition_penalties = dummy_tensors (0.1 ),
1328+ output_token_ids = [[] for _ in range (num_reqs )],
1329+ min_tokens = {},
1330+ logit_bias = [None for _ in range (num_reqs )],
1331+ allowed_token_ids_mask = None ,
1332+ )
1333+ sampler_output = self .model .sample (
1334+ logits = logits , sampling_metadata = dummy_metadata )
13471335 else :
1336+ logits = None
13481337 sampler_output = None
1338+ dummy_metadata = None
13491339 torch .cuda .synchronize ()
1350- del hidden_states , sampler_output
1340+ del hidden_states , logits , sampler_output , dummy_metadata
13511341 self .encoder_cache .clear ()
13521342 gc .collect ()
13531343
0 commit comments