Skip to content

Commit fab0f8f

Browse files
committed
Addressed (most of) gemini's recommendations
1 parent 421b122 commit fab0f8f

File tree

3 files changed

+48
-49
lines changed

3 files changed

+48
-49
lines changed

areal/engine/vllm_remote.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -107,49 +107,44 @@ def build_distributed_weight_update_requests(
107107
) -> WeightUpdateRequests:
108108
"""Build vLLM distributed weight update requests."""
109109
# vLLM uses two-step process: set metadata, then update
110+
# vLLM uses two-step process: set metadata, then update
111+
base_payload = {
112+
"names": [pspec.name for pspec in param_specs],
113+
"dtypes": [pspec.dtype for pspec in param_specs],
114+
"shapes": [pspec.shape for pspec in param_specs],
115+
"group_name": meta.nccl_group_name,
116+
}
117+
110118
if meta.use_lora:
111-
return WeightUpdateRequests(
112-
requests=[
113-
HttpRequest(
114-
endpoint="/areal_set_update_weight_meta_lora",
115-
payload={
116-
"names": [pspec.name for pspec in param_specs],
117-
"dtypes": [pspec.dtype for pspec in param_specs],
118-
"shapes": [pspec.shape for pspec in param_specs],
119-
"lora_name": meta.lora_name,
120-
"lora_int_id": meta.lora_int_id,
121-
"lora_target_modules": meta.peft_config["target_modules"],
122-
"lora_rank": meta.peft_config["r"],
123-
"lora_alpha": meta.peft_config["lora_alpha"],
124-
"lora_bias": meta.peft_config["bias"],
125-
"base_model_name": meta.base_model_name,
126-
"group_name": meta.nccl_group_name,
127-
},
128-
),
129-
HttpRequest(
130-
endpoint="/areal_update_weights_lora_xccl",
131-
payload={},
132-
),
133-
]
134-
)
119+
lora_payload = {
120+
"lora_name": meta.lora_name,
121+
"lora_int_id": meta.lora_int_id,
122+
"lora_target_modules": meta.peft_config["target_modules"],
123+
"lora_rank": meta.peft_config["r"],
124+
"lora_alpha": meta.peft_config["lora_alpha"],
125+
"lora_bias": meta.peft_config["bias"],
126+
"base_model_name": meta.base_model_name,
127+
}
128+
payload = {**base_payload, **lora_payload}
129+
meta_endpoint = "/areal_set_update_weight_meta_lora"
130+
update_endpoint = "/areal_update_weights_lora_xccl"
135131
else:
136-
return WeightUpdateRequests(
137-
requests=[
138-
HttpRequest(
139-
endpoint="/areal_set_update_weight_meta",
140-
payload={
141-
"names": [pspec.name for pspec in param_specs],
142-
"dtypes": [pspec.dtype for pspec in param_specs],
143-
"shapes": [pspec.shape for pspec in param_specs],
144-
"group_name": meta.nccl_group_name,
145-
},
146-
),
147-
HttpRequest(
148-
endpoint="/areal_update_weights_xccl",
149-
payload={},
150-
),
151-
]
152-
)
132+
payload = base_payload
133+
meta_endpoint = "/areal_set_update_weight_meta"
134+
update_endpoint = "/areal_update_weights_xccl"
135+
136+
return WeightUpdateRequests(
137+
requests=[
138+
HttpRequest(
139+
endpoint=meta_endpoint,
140+
payload=payload,
141+
),
142+
HttpRequest(
143+
endpoint=update_endpoint,
144+
payload={},
145+
),
146+
]
147+
)
153148

154149
def build_init_weights_group_request(
155150
self, addr: str, server_idx: int, meta: WeightUpdateMeta

areal/thirdparty/vllm/vllm_worker_extension.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import traceback
2+
13
import torch
24
from vllm.logger import init_logger
35
from vllm.lora.models import LoRAModel
@@ -134,6 +136,9 @@ def update_weight_xccl(self):
134136
return False, error_msg
135137

136138
def update_weight_lora_xccl(self):
139+
# NOTE: This code relies on vLLM private APIs: _adapter_manager, _registered_adapters,
140+
# and _add_adapter/activate_adapter, which may change/ breakdown due to newer vllm versions.
141+
137142
logger.info(
138143
f"start update lora weights by xccl, lora_name={self.areal_lora_name}, lora_int_id={self.areal_lora_int_id}",
139144
flush=True,
@@ -228,8 +233,6 @@ def update_weight_lora_xccl(self):
228233
return True, "Success"
229234

230235
except Exception as e:
231-
import traceback
232-
233236
error_msg = f"Failed to update LoRA parameter via XCCL! {e}\n{traceback.format_exc()}"
234237
logger.error(error_msg)
235238
return False, error_msg

examples/lora/gsm8k_grpo_lora_vllm.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,23 +82,24 @@ def main(args):
8282
eval_rollout.initialize()
8383

8484
if config.actor.weight_update_mode == "xccl":
85-
weight_update_meta = WeightUpdateMeta.from_disk(
86-
config.saver.experiment_name,
87-
config.saver.trial_name,
88-
config.saver.fileroot,
85+
weight_update_meta = WeightUpdateMeta.from_fsdp_xccl(
86+
allocation_mode,
8987
use_lora=config.actor.use_lora,
9088
lora_name=config.gconfig.lora_name,
9189
lora_int_id=1, # hard coded for the single lora example
9290
base_model_name=config.actor.path,
9391
)
9492
elif config.actor.weight_update_mode == "disk":
95-
weight_update_meta = WeightUpdateMeta.from_fsdp_xccl(
96-
allocation_mode,
93+
weight_update_meta = WeightUpdateMeta.from_disk(
94+
config.saver.experiment_name,
95+
config.saver.trial_name,
96+
config.saver.fileroot,
9797
use_lora=config.actor.use_lora,
9898
lora_name=config.gconfig.lora_name,
9999
lora_int_id=1, # hard coded for the single lora example
100100
base_model_name=config.actor.path,
101101
)
102+
102103
else:
103104
raise ValueError(
104105
f"Invalid weight_update_mode: {config.actor.weight_update_mode}. Expected 'xccl' or 'disk'."

0 commit comments

Comments
 (0)