Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

from verl import DataProto
from verl.utils.profiler import GPUMemoryLogger
from verl.utils.ray_utils import ray_noset_visible_devices
from verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length
from verl.workers.rollout.base import BaseRollout

Expand Down Expand Up @@ -459,8 +460,7 @@ def get_zeromq_address(self):
def init_worker(self, all_kwargs: list[dict[str, Any]]):
"""Initialize worker engine."""
all_kwargs[0]["rank"] = int(os.environ["RANK"])
all_kwargs[0]["local_rank"] = 0

all_kwargs[0]["local_rank"] = 0 if not ray_noset_visible_devices() else int(os.environ.get("RAY_LOCAL_RANK", 0))
self.vllm_config = all_kwargs[0]["vllm_config"]
self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config)
self.inference_engine.init_worker(all_kwargs)
Expand Down