Skip to content

Commit f969241

Browse files
committed
wip:
1 parent f97e0ae commit f969241

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

vllm/worker/multi_step_model_runner.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212

13+
from vllm.model_executor.layers.sampler import _get_logprobs
1314
from vllm import _custom_ops as ops
1415
from vllm.distributed import get_pp_group
1516
from vllm.logger import init_logger
@@ -295,16 +296,18 @@ def execute_model(
295296
model_input.cached_outputs.append(
296297
ModelOutput(output[0], output_ready_event,
297298
output[0].sampled_token_ids, False))
298-
# make sure we dont try to serialize any GPU tensors
299-
output[0].sampled_token_ids = None
300-
output[0].sampled_token_probs = None
301-
output[0].logprobs = None
299+
302300
# Pythonize the output if CPU is ahead and the previous step is
303301
# ready.
304302
for model_output in model_input.cached_outputs:
305303
model_output.maybe_pythonize(model_input, self._copy_stream,
306304
self.pinned_sampled_token_ids)
307305

306+
# make sure we dont try to serialize any GPU tensors
307+
output[0].sampled_token_ids = None
308+
output[0].sampled_token_probs = None
309+
output[0].logprobs = None
310+
308311
model_input.current_step += 1
309312

310313
if not get_pp_group().is_last_rank:

0 commit comments

Comments
 (0)