diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 1e2425a7d3c..03f083e4e93 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -52,8 +52,12 @@ from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationLevel, LoRAConfig from vllm.lora.request import LoRARequest -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.worker.worker_base import WorkerWrapperBase + +try: + from vllm.worker.worker_base import WorkerWrapperBase +except ModuleNotFoundError: + # https://github.com/vllm-project/vllm/commit/6a113d9aed8221a9c234535958e70e34ab6cac5b + from vllm.v1.worker.worker_base import WorkerWrapperBase from verl import DataProto from verl.third_party.vllm import VLLM_SLEEP_LEVEL @@ -459,10 +463,10 @@ def _monkey_patch_compute_logits(model, vocab_size: int): def compute_logits( self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, + *args, + **kwargs, ) -> torch.Tensor: - logits = original_compute_logits(hidden_states, sampling_metadata) + logits = original_compute_logits(*args, **kwargs) logits[..., vocab_size:] = float("-inf") return logits