Skip to content

Commit da31b53

Browse files
JenZhaoywang96
andauthored
[Bugfix] V1 Memory Profiling: V0 Sampler Integration without Rejection Sampler (#13594)
Signed-off-by: Jennifer Zhao <[email protected]> Co-authored-by: Roger Wang <[email protected]>
1 parent bb78fb3 commit da31b53

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
3232
KVCacheSpec)
3333
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
34+
from vllm.v1.sample.metadata import SamplingMetadata
3435
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
3536
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
3637
from vllm.v1.utils import bind_kv_cache
@@ -1305,11 +1306,34 @@ def profile_run(self) -> None:
13051306
if get_pp_group().is_last_rank:
13061307
hidden_states = hidden_states[logit_indices]
13071308
logits = self.model.compute_logits(hidden_states, None)
1308-
# TODO(woosuk): Consider the memory usage of the sampler.
1309+
dummy_tensors = lambda v: torch.full(
1310+
(num_reqs, ), v, device=self.device)
1311+
dummy_metadata = SamplingMetadata(
1312+
temperature=dummy_tensors(0.5),
1313+
all_greedy=False,
1314+
all_random=False,
1315+
spec_token_ids=None,
1316+
top_p=dummy_tensors(0.9),
1317+
top_k=dummy_tensors(logits.size(1) - 1),
1318+
min_p=None,
1319+
generators={},
1320+
max_num_logprobs=None,
1321+
no_penalties=True,
1322+
prompt_token_ids=torch.ones_like(logits, dtype=torch.int64),
1323+
frequency_penalties=dummy_tensors(0.1),
1324+
presence_penalties=dummy_tensors(0.1),
1325+
repetition_penalties=dummy_tensors(0.1),
1326+
output_token_ids=[[] for _ in range(num_reqs)],
1327+
min_tokens={},
1328+
logit_bias=[None for _ in range(num_reqs)])
1329+
sampler_output = self.model.sample(
1330+
logits=logits, sampling_metadata=dummy_metadata)
13091331
else:
13101332
logits = None
1333+
sampler_output = None
1334+
dummy_metadata = None
13111335
torch.cuda.synchronize()
1312-
del hidden_states, logits
1336+
del hidden_states, logits, sampler_output, dummy_metadata
13131337
self.encoder_cache.clear()
13141338
gc.collect()
13151339

0 commit comments

Comments
 (0)