diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 1e2425a7d3c..b48132378a3 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -547,7 +547,8 @@ def _init_worker(self, all_kwargs: list[dict[str, Any]]): ) self.vllm_config = all_kwargs[0]["vllm_config"] if self.lora_config: - self.vllm_config.lora_config = LoRAConfig(**self.lora_config) + lora_dtype = getattr(torch, self.config.dtype) + self.vllm_config.lora_config = LoRAConfig(lora_dtype=lora_dtype, **self.lora_config) self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config) self.inference_engine.init_worker(all_kwargs)