diff --git a/roll/distributed/strategy/vllm_strategy.py b/roll/distributed/strategy/vllm_strategy.py index 01b5ff33..bd2959f4 100644 --- a/roll/distributed/strategy/vllm_strategy.py +++ b/roll/distributed/strategy/vllm_strategy.py @@ -12,10 +12,13 @@ import torch.distributed as dist from torch.nn.utils.rnn import pad_sequence from transformers import set_seed +import vllm from vllm import RequestOutput, SamplingParams from vllm.lora.request import LoRARequest from vllm.utils import random_uuid +from packaging.version import Version + from mcore_adapter.models.converter.convert_utils import RecvBucketManager from roll.distributed.executor.worker import Worker from roll.distributed.scheduler.protocol import DataProto @@ -138,9 +141,18 @@ def generate(self, batch: DataProto, generation_config) -> torch.Tensor: if "multi_modal_data" in batch.non_tensor_batch: vllm_input_args["prompts"] = batch.non_tensor_batch["multi_modal_data"] else: - vllm_input_args["prompt_token_ids"] = gather_unpadded_input_ids( - input_ids=input_ids, attention_mask=attention_mask - ) + # fix llm generate params for vllm 0.10.2 + if Version("0.10.2") >= Version(vllm.__version__): + from vllm.inputs.data import TokensPrompt + prompt_token_ids_list = gather_unpadded_input_ids( + input_ids=input_ids, attention_mask=attention_mask + ) + prompts = [TokensPrompt(prompt_token_ids=p_ids) for p_ids in prompt_token_ids_list] + vllm_input_args["prompts"] = prompts + else: + vllm_input_args["prompt_token_ids"] = gather_unpadded_input_ids( + input_ids=input_ids, attention_mask=attention_mask + ) lora_requests = None if self.is_lora: