From 4555ff3d312a2bebc8993db8e93b0e403118e63c Mon Sep 17 00:00:00 2001 From: benyi Date: Fri, 10 Oct 2025 20:43:59 +0800 Subject: [PATCH] fix(vllm_strategy): llm.generate params 'prompt_token_ids' become part of 'prompts' --- roll/distributed/strategy/vllm_strategy.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/roll/distributed/strategy/vllm_strategy.py b/roll/distributed/strategy/vllm_strategy.py index 01b5ff33..bd2959f4 100644 --- a/roll/distributed/strategy/vllm_strategy.py +++ b/roll/distributed/strategy/vllm_strategy.py @@ -12,10 +12,13 @@ import torch.distributed as dist from torch.nn.utils.rnn import pad_sequence from transformers import set_seed +import vllm from vllm import RequestOutput, SamplingParams from vllm.lora.request import LoRARequest from vllm.utils import random_uuid +from packaging.version import Version + from mcore_adapter.models.converter.convert_utils import RecvBucketManager from roll.distributed.executor.worker import Worker from roll.distributed.scheduler.protocol import DataProto @@ -138,9 +141,18 @@ def generate(self, batch: DataProto, generation_config) -> torch.Tensor: if "multi_modal_data" in batch.non_tensor_batch: vllm_input_args["prompts"] = batch.non_tensor_batch["multi_modal_data"] else: - vllm_input_args["prompt_token_ids"] = gather_unpadded_input_ids( - input_ids=input_ids, attention_mask=attention_mask - ) + # fix llm generate params for vllm 0.10.2 + if Version("0.10.2") >= Version(vllm.__version__): + from vllm.inputs.data import TokensPrompt + prompt_token_ids_list = gather_unpadded_input_ids( + input_ids=input_ids, attention_mask=attention_mask + ) + prompts = [TokensPrompt(prompt_token_ids=p_ids) for p_ids in prompt_token_ids_list] + vllm_input_args["prompts"] = prompts + else: + vllm_input_args["prompt_token_ids"] = gather_unpadded_input_ids( + input_ids=input_ids, attention_mask=attention_mask + ) lora_requests = None if self.is_lora: