|
31 | 31 | from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, |
32 | 32 | KVCacheSpec) |
33 | 33 | from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput |
| 34 | +from vllm.v1.sample.metadata import SamplingMetadata |
34 | 35 | from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID |
35 | 36 | from vllm.v1.spec_decode.ngram_proposer import NgramProposer |
36 | 37 | from vllm.v1.utils import bind_kv_cache |
@@ -1305,11 +1306,34 @@ def profile_run(self) -> None: |
1305 | 1306 | if get_pp_group().is_last_rank: |
1306 | 1307 | hidden_states = hidden_states[logit_indices] |
1307 | 1308 | logits = self.model.compute_logits(hidden_states, None) |
1308 | | - # TODO(woosuk): Consider the memory usage of the sampler. |
| 1309 | + dummy_tensors = lambda v: torch.full( |
| 1310 | + (num_reqs, ), v, device=self.device) |
| 1311 | + dummy_metadata = SamplingMetadata( |
| 1312 | + temperature=dummy_tensors(0.5), |
| 1313 | + all_greedy=False, |
| 1314 | + all_random=False, |
| 1315 | + spec_token_ids=None, |
| 1316 | + top_p=dummy_tensors(0.9), |
| 1317 | + top_k=dummy_tensors(logits.size(1) - 1), |
| 1318 | + min_p=None, |
| 1319 | + generators={}, |
| 1320 | + max_num_logprobs=None, |
| 1321 | + no_penalties=True, |
| 1322 | + prompt_token_ids=torch.ones_like(logits, dtype=torch.int64), |
| 1323 | + frequency_penalties=dummy_tensors(0.1), |
| 1324 | + presence_penalties=dummy_tensors(0.1), |
| 1325 | + repetition_penalties=dummy_tensors(0.1), |
| 1326 | + output_token_ids=[[] for _ in range(num_reqs)], |
| 1327 | + min_tokens={}, |
| 1328 | + logit_bias=[None for _ in range(num_reqs)]) |
| 1329 | + sampler_output = self.model.sample( |
| 1330 | + logits=logits, sampling_metadata=dummy_metadata) |
1309 | 1331 | else: |
1310 | 1332 | logits = None |
| 1333 | + sampler_output = None |
| 1334 | + dummy_metadata = None |
1311 | 1335 | torch.cuda.synchronize() |
1312 | | - del hidden_states, logits |
| 1336 | + del hidden_states, logits, sampler_output, dummy_metadata |
1313 | 1337 | self.encoder_cache.clear() |
1314 | 1338 | gc.collect() |
1315 | 1339 |
|
|
0 commit comments