Skip to content

Commit 2275784

Browse files
authored
Revert "[V1][Core] Fix memory issue with logits & sampling" (#13775)
1 parent befc402 commit 2275784

File tree

2 files changed

+29
-49
lines changed

2 files changed

+29
-49
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,43 +1179,6 @@ def _dummy_run(
11791179
)
11801180
return hidden_states
11811181

1182-
@torch.inference_mode()
1183-
def _dummy_sampler_run(
1184-
self,
1185-
hidden_states: torch.Tensor,
1186-
) -> torch.Tensor:
1187-
1188-
logits = self.model.compute_logits(hidden_states, None)
1189-
num_reqs = logits.size(0)
1190-
1191-
dummy_tensors = lambda v: torch.full(
1192-
(num_reqs, ), v, device=self.device)
1193-
1194-
dummy_metadata = SamplingMetadata(
1195-
temperature=dummy_tensors(0.5),
1196-
all_greedy=False,
1197-
all_random=False,
1198-
spec_token_ids=None,
1199-
top_p=dummy_tensors(0.9),
1200-
top_k=dummy_tensors(logits.size(1) - 1),
1201-
min_p=None,
1202-
generators={},
1203-
max_num_logprobs=None,
1204-
no_penalties=True,
1205-
prompt_token_ids=None,
1206-
frequency_penalties=dummy_tensors(0.1),
1207-
presence_penalties=dummy_tensors(0.1),
1208-
repetition_penalties=dummy_tensors(0.1),
1209-
output_token_ids=[[] for _ in range(num_reqs)],
1210-
min_tokens={},
1211-
logit_bias=[None for _ in range(num_reqs)],
1212-
allowed_token_ids_mask=None,
1213-
)
1214-
sampler_output = self.model.sample(logits=logits,
1215-
sampling_metadata=dummy_metadata)
1216-
1217-
return sampler_output
1218-
12191182
def profile_run(self) -> None:
12201183
# use an empty tensor instead of `None`` to force Dynamo to pass
12211184
# it by reference, rather by specializing on the value `None`.
@@ -1343,11 +1306,38 @@ def profile_run(self) -> None:
13431306
dummy_kv_caches)
13441307
if get_pp_group().is_last_rank:
13451308
hidden_states = hidden_states[logit_indices]
1346-
sampler_output = self._dummy_sampler_run(hidden_states)
1309+
logits = self.model.compute_logits(hidden_states, None)
1310+
dummy_tensors = lambda v: torch.full(
1311+
(num_reqs, ), v, device=self.device)
1312+
dummy_metadata = SamplingMetadata(
1313+
temperature=dummy_tensors(0.5),
1314+
all_greedy=False,
1315+
all_random=False,
1316+
spec_token_ids=None,
1317+
top_p=dummy_tensors(0.9),
1318+
top_k=dummy_tensors(logits.size(1) - 1),
1319+
min_p=None,
1320+
generators={},
1321+
max_num_logprobs=None,
1322+
no_penalties=True,
1323+
prompt_token_ids=torch.ones_like(logits,
1324+
dtype=torch.int64),
1325+
frequency_penalties=dummy_tensors(0.1),
1326+
presence_penalties=dummy_tensors(0.1),
1327+
repetition_penalties=dummy_tensors(0.1),
1328+
output_token_ids=[[] for _ in range(num_reqs)],
1329+
min_tokens={},
1330+
logit_bias=[None for _ in range(num_reqs)],
1331+
allowed_token_ids_mask=None,
1332+
)
1333+
sampler_output = self.model.sample(
1334+
logits=logits, sampling_metadata=dummy_metadata)
13471335
else:
1336+
logits = None
13481337
sampler_output = None
1338+
dummy_metadata = None
13491339
torch.cuda.synchronize()
1350-
del hidden_states, sampler_output
1340+
del hidden_states, logits, sampler_output, dummy_metadata
13511341
self.encoder_cache.clear()
13521342
gc.collect()
13531343

vllm/v1/worker/gpu_worker.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -211,16 +211,6 @@ def compile_or_warm_up_model(self) -> None:
211211
self.model_runner._dummy_run(size)
212212
if not self.model_config.enforce_eager:
213213
self.model_runner.capture_model()
214-
215-
# Warm up sampler and preallocate memory buffer for logits and other
216-
# sampling related tensors of max possible shape to avoid memory
217-
# fragmentation issue.
218-
# NOTE: This is called after `capture_model` on purpose to prevent
219-
# memory buffers from being cleared by `torch.cuda.empty_cache`.
220-
self.model_runner._dummy_sampler_run(
221-
hidden_states=self.model_runner._dummy_run(
222-
num_tokens=self.scheduler_config.max_num_seqs))
223-
224214
# Reset the seed to ensure that the random state is not affected by
225215
# the model initialization and profiling.
226216
set_random_seed(self.model_config.seed)

0 commit comments

Comments
 (0)