Skip to content

Commit 947ce5f

Browse files
kfallahtechkang
authored andcommitted
[rollout,vllm] fix: Add LoRA Loading to Async vLLM (volcengine#3639)
### What does this PR do? Currently, async vLLM with AgentWorkerLoop throws an error when `update_weights` with LoRA weights. This expands support for AgentWorkerLoop with LoRAs. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent e5af9fb commit 947ce5f

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from tensordict import TensorDict
5151
from torch.distributed.device_mesh import DeviceMesh
5252
from vllm import LLM, SamplingParams
53-
from vllm.config import CompilationConfig, CompilationLevel
53+
from vllm.config import CompilationConfig, CompilationLevel, LoRAConfig
5454
from vllm.lora.request import LoRARequest
5555
from vllm.model_executor.sampling_metadata import SamplingMetadata
5656
from vllm.worker.worker_base import WorkerWrapperBase
@@ -479,10 +479,12 @@ def __init__(
479479
device_mesh: DeviceMesh,
480480
):
481481
super().__init__(config, model_config, device_mesh)
482-
483482
self.tokenizer = model_config.tokenizer
484483
self.inference_engine: WorkerWrapperBase = None
485484
self.address = self._init_zeromq()
485+
self.lora_config = (
486+
{"max_loras": 1, "max_lora_rank": model_config.lora_rank} if model_config.lora_rank > 0 else {}
487+
)
486488

487489
# https://github.com/vllm-project/vllm/issues/25171
488490
if config.layered_summon or config.expert_parallel_size > 1:
@@ -536,7 +538,6 @@ def _init_worker(self, all_kwargs: list[dict[str, Any]]):
536538
"""Initialize worker engine."""
537539
if not torch.distributed.is_initialized():
538540
initialize_global_process_group_ray()
539-
540541
all_kwargs[0]["rank"] = int(os.environ["RANK"])
541542
device_name = "NPU" if is_npu_available else "GPU"
542543
all_kwargs[0]["local_rank"] = (
@@ -545,6 +546,8 @@ def _init_worker(self, all_kwargs: list[dict[str, Any]]):
545546
else int(ray.get_runtime_context().get_accelerator_ids()[device_name][0])
546547
)
547548
self.vllm_config = all_kwargs[0]["vllm_config"]
549+
if self.lora_config:
550+
self.vllm_config.lora_config = LoRAConfig(**self.lora_config)
548551
self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config)
549552
self.inference_engine.init_worker(all_kwargs)
550553

@@ -582,11 +585,24 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
582585
Args:
583586
weights: A generator that yields the name of the weight tensor and the tensor itself.
584587
"""
585-
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
588+
peft_config, base_sync_done = kwargs.get("peft_config", None), kwargs.get("base_sync_done", False)
589+
if peft_config and base_sync_done:
590+
lora_int_id = int(time.time_ns() % 0x7FFFFFFF)
591+
lora_reqest = TensorLoRARequest(
592+
lora_name=f"{lora_int_id}",
593+
lora_int_id=lora_int_id,
594+
lora_path="simon_lora_path",
595+
peft_config=asdict(peft_config),
596+
lora_tensors=weights,
597+
)
598+
self.inference_engine.worker.add_lora(lora_reqest)
599+
logger.info(f"vLLM load weights, loaded_params: {len(weights)}")
600+
else:
601+
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
586602

587-
model = self.inference_engine.worker.model_runner.model
588-
patch_vllm_moe_model_weight_loader(model)
589-
model.load_weights(weights)
603+
model = self.inference_engine.worker.model_runner.model
604+
patch_vllm_moe_model_weight_loader(model)
605+
model.load_weights(weights)
590606

591607
def generate_sequences(self, prompts: DataProto) -> DataProto:
592608
"""Batch generate sequences in sync mode."""

0 commit comments

Comments
 (0)