Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions roll/distributed/strategy/vllm_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
import torch.distributed as dist
from torch.nn.utils.rnn import pad_sequence
from transformers import set_seed
import vllm
from vllm import RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.utils import random_uuid

from packaging.version import Version

from mcore_adapter.models.converter.convert_utils import RecvBucketManager
from roll.distributed.executor.worker import Worker
from roll.distributed.scheduler.protocol import DataProto
Expand Down Expand Up @@ -138,9 +141,18 @@ def generate(self, batch: DataProto, generation_config) -> torch.Tensor:
if "multi_modal_data" in batch.non_tensor_batch:
vllm_input_args["prompts"] = batch.non_tensor_batch["multi_modal_data"]
else:
vllm_input_args["prompt_token_ids"] = gather_unpadded_input_ids(
input_ids=input_ids, attention_mask=attention_mask
)
# fix llm generate params for vllm 0.10.2
if Version("0.10.2") >= Version(vllm.__version__):
from vllm.inputs.data import TokensPrompt
prompt_token_ids_list = gather_unpadded_input_ids(
input_ids=input_ids, attention_mask=attention_mask
)
prompts = [TokensPrompt(prompt_token_ids=p_ids) for p_ids in prompt_token_ids_list]
vllm_input_args["prompts"] = prompts
else:
vllm_input_args["prompt_token_ids"] = gather_unpadded_input_ids(
input_ids=input_ids, attention_mask=attention_mask
)

lora_requests = None
if self.is_lora:
Expand Down