Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,12 @@
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel, LoRAConfig
from vllm.lora.request import LoRARequest
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.worker.worker_base import WorkerWrapperBase

try:
from vllm.worker.worker_base import WorkerWrapperBase
except ModuleNotFoundError:
# https://github.com/vllm-project/vllm/commit/6a113d9aed8221a9c234535958e70e34ab6cac5b
from vllm.v1.worker.worker_base import WorkerWrapperBase

from verl import DataProto
from verl.third_party.vllm import VLLM_SLEEP_LEVEL
Expand Down Expand Up @@ -459,10 +463,10 @@ def _monkey_patch_compute_logits(model, vocab_size: int):

def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
*args,
**kwargs,
) -> torch.Tensor:
logits = original_compute_logits(hidden_states, sampling_metadata)
logits = original_compute_logits(*args, **kwargs)
logits[..., vocab_size:] = float("-inf")
return logits

Expand Down