@@ -107,49 +107,44 @@ def build_distributed_weight_update_requests(
107107 ) -> WeightUpdateRequests :
108108 """Build vLLM distributed weight update requests."""
109109 # vLLM uses two-step process: set metadata, then update
110+ # vLLM uses two-step process: set metadata, then update
111+ base_payload = {
112+ "names" : [pspec .name for pspec in param_specs ],
113+ "dtypes" : [pspec .dtype for pspec in param_specs ],
114+ "shapes" : [pspec .shape for pspec in param_specs ],
115+ "group_name" : meta .nccl_group_name ,
116+ }
117+
110118 if meta .use_lora :
111- return WeightUpdateRequests (
112- requests = [
113- HttpRequest (
114- endpoint = "/areal_set_update_weight_meta_lora" ,
115- payload = {
116- "names" : [pspec .name for pspec in param_specs ],
117- "dtypes" : [pspec .dtype for pspec in param_specs ],
118- "shapes" : [pspec .shape for pspec in param_specs ],
119- "lora_name" : meta .lora_name ,
120- "lora_int_id" : meta .lora_int_id ,
121- "lora_target_modules" : meta .peft_config ["target_modules" ],
122- "lora_rank" : meta .peft_config ["r" ],
123- "lora_alpha" : meta .peft_config ["lora_alpha" ],
124- "lora_bias" : meta .peft_config ["bias" ],
125- "base_model_name" : meta .base_model_name ,
126- "group_name" : meta .nccl_group_name ,
127- },
128- ),
129- HttpRequest (
130- endpoint = "/areal_update_weights_lora_xccl" ,
131- payload = {},
132- ),
133- ]
134- )
119+ lora_payload = {
120+ "lora_name" : meta .lora_name ,
121+ "lora_int_id" : meta .lora_int_id ,
122+ "lora_target_modules" : meta .peft_config ["target_modules" ],
123+ "lora_rank" : meta .peft_config ["r" ],
124+ "lora_alpha" : meta .peft_config ["lora_alpha" ],
125+ "lora_bias" : meta .peft_config ["bias" ],
126+ "base_model_name" : meta .base_model_name ,
127+ }
128+ payload = {** base_payload , ** lora_payload }
129+ meta_endpoint = "/areal_set_update_weight_meta_lora"
130+ update_endpoint = "/areal_update_weights_lora_xccl"
135131 else :
136- return WeightUpdateRequests (
137- requests = [
138- HttpRequest (
139- endpoint = "/areal_set_update_weight_meta" ,
140- payload = {
141- "names" : [pspec .name for pspec in param_specs ],
142- "dtypes" : [pspec .dtype for pspec in param_specs ],
143- "shapes" : [pspec .shape for pspec in param_specs ],
144- "group_name" : meta .nccl_group_name ,
145- },
146- ),
147- HttpRequest (
148- endpoint = "/areal_update_weights_xccl" ,
149- payload = {},
150- ),
151- ]
152- )
132+ payload = base_payload
133+ meta_endpoint = "/areal_set_update_weight_meta"
134+ update_endpoint = "/areal_update_weights_xccl"
135+
136+ return WeightUpdateRequests (
137+ requests = [
138+ HttpRequest (
139+ endpoint = meta_endpoint ,
140+ payload = payload ,
141+ ),
142+ HttpRequest (
143+ endpoint = update_endpoint ,
144+ payload = {},
145+ ),
146+ ]
147+ )
153148
154149 def build_init_weights_group_request (
155150 self , addr : str , server_idx : int , meta : WeightUpdateMeta
0 commit comments