1212from torch .distributed import P2POp , Work , batch_isend_irecv , get_global_rank
1313
1414from vllm .distributed .eplb .eplb_adaptor .vllm_adaptor import VllmEplbAdaptor
15- from vllm .distributed .eplb .eplb_loader . abstract_loader import BaseLoader
15+ from vllm .distributed .eplb .eplb_transfer . abstract_transfer import BaseTransfer
1616from vllm .distributed .eplb .eplb_utils .eplb_utils import (
17- get_ep_ranks_with_expert , idx_local_to_global )
17+ get_ep_ranks_with_expert ,
18+ idx_local_to_global ,
19+ )
1820
1921
20- class EplbWeightLoader ( BaseLoader ):
22+ class EplbWeightTransfer ( BaseTransfer ):
2123 """
22- A concrete implementation of BaseLoader for Expert Parallel Load
24+ A concrete implementation of BaseTransfer for Expert Parallel Load
2325 Balancing (EPLB).
2426
25- This class is responsible for managing the transfer and update of
27+ This class is responsible for managing the transfer and update of
2628 expert weights across different ranks during expert rearrangement.
2729 """
2830
2931 def __init__ (self , eplb_adaptor : VllmEplbAdaptor ):
3032 """
31- Initializes the EplbWeightLoader .
33+ Initializes the EplbWeightTransfer .
3234
3335 Args:
3436 eplb_adaptor: An adaptor to interact with the vLLM model's expert
@@ -54,7 +56,7 @@ def shuffle_layer(
5456 ) -> None :
5557 """
5658 Performs expert weights rearrangement for a single MoE layer.
57- This method orchestrates the entire shuffling process including
59+ This method orchestrates the entire shuffling process including
5860 local copies and inter-rank P2P communications.
5961
6062 Args:
@@ -71,7 +73,7 @@ def shuffle_layer(
7173 expert_weights_buffer: An iterable of torch.Tensor,
7274 representing buffers for receiving expert
7375 weights. Same structure as expert_weights.
74- ep_group: The PyTorch distributed process group for
76+ ep_group: The PyTorch distributed process group for
7577 expert parallelism.
7678 """
7779
@@ -95,39 +97,68 @@ def shuffle_layer(
9597 dst_global = local2global (dst )
9698 if is_received_locally [dst ]:
9799 continue
98- if old_indices [src_global ] == - 1 or new_indices [
99- dst_global ] == - 1 :
100+ if old_indices [src_global ] == - 1 or \
101+ new_indices [ dst_global ] == - 1 :
100102 continue
101103 if old_indices [src_global ] == new_indices [dst_global ]:
102104 is_received_locally [dst ] = True
103- for weight , buffer in zip ( expert_weights ,
104- expert_weights_buffer ):
105+ for weight , buffer in \
106+ zip ( expert_weights , expert_weights_buffer ):
105107 buffer [dst ].copy_ (weight [src ])
106108
107109 p2p_ops : list [P2POp ] = []
108110
109111 # 2. Initiate sending of weights.
110- p2p_ops = self .prepare_send_p2p_ops (ep_group , ep_rank , expert_weights ,
111- local2global , new_indices ,
112- num_local_experts , old_indices ,
113- p2p_ops )
112+ p2p_ops = self .prepare_send_p2p_ops (
113+ ep_group ,
114+ ep_rank ,
115+ expert_weights ,
116+ local2global ,
117+ new_indices ,
118+ num_local_experts ,
119+ old_indices ,
120+ p2p_ops
121+ )
114122
115123 # 3. Initiate receiving of weights.
116124 experts_recv_loc , p2p_ops = self .prepare_recv_p2p_ops (
117- ep_group , ep_rank , expert_weights_buffer , is_received_locally ,
118- local2global , new_indices , num_local_experts , old_indices , p2p_ops )
125+ ep_group ,
126+ ep_rank ,
127+ expert_weights_buffer ,
128+ is_received_locally ,
129+ local2global ,
130+ new_indices ,
131+ num_local_experts ,
132+ old_indices ,
133+ p2p_ops
134+ )
119135
120136 # 4. Execute the P2P operations. The real communication happens here.
121137 self .send_recv (p2p_ops )
122138
123139 # 5. Copy the weights from the buffer back to the original weights.
124- self .update_weight (expert_weights , expert_weights_buffer ,
125- experts_recv_loc , is_received_locally , is_unchanged ,
126- local2global , new_indices , num_local_experts )
140+ self .update_weight (
141+ expert_weights ,
142+ expert_weights_buffer ,
143+ experts_recv_loc ,
144+ is_received_locally ,
145+ is_unchanged ,
146+ local2global ,
147+ new_indices ,
148+ num_local_experts
149+ )
127150
128- def update_weight (self , expert_weights , expert_weights_buffer ,
129- experts_recv_loc , is_received_locally , is_unchanged ,
130- local2global , new_indices , num_local_experts ):
151+ def update_weight (
152+ self ,
153+ expert_weights ,
154+ expert_weights_buffer ,
155+ experts_recv_loc ,
156+ is_received_locally ,
157+ is_unchanged ,
158+ local2global ,
159+ new_indices ,
160+ num_local_experts
161+ ):
131162 """
132163 Updates the actual expert weights from the buffer after communication.
133164 This is part of the `shuffle_layer` process.
@@ -152,16 +183,16 @@ def update_weight(self, expert_weights, expert_weights_buffer,
152183 if is_unchanged [dst ]:
153184 continue
154185 if is_received_locally [dst ]:
155- for weight , buffer in zip ( expert_weights ,
156- expert_weights_buffer ):
186+ for weight , buffer in \
187+ zip ( expert_weights , expert_weights_buffer ):
157188 weight [dst ].copy_ (buffer [dst ])
158189 else :
159190 expert = new_indices [local2global (dst )]
160191 if expert == - 1 :
161192 continue
162193 src = experts_recv_loc [expert ]
163- for weight , buffer in zip ( expert_weights ,
164- expert_weights_buffer ):
194+ for weight , buffer in \
195+ zip ( expert_weights , expert_weights_buffer ):
165196 weight [dst ].copy_ (buffer [src ])
166197
167198 @override
@@ -177,9 +208,18 @@ def send_recv(self, p2p_ops: list[P2POp]) -> None:
177208 for req in reqs :
178209 req .wait ()
179210
180- def prepare_recv_p2p_ops (self , ep_group , ep_rank , expert_weights_buffer ,
181- is_received_locally , local2global , new_indices ,
182- num_local_experts , old_indices , p2p_ops ):
211+ def prepare_recv_p2p_ops (
212+ self ,
213+ ep_group ,
214+ ep_rank ,
215+ expert_weights_buffer ,
216+ is_received_locally ,
217+ local2global ,
218+ new_indices ,
219+ num_local_experts ,
220+ old_indices ,
221+ p2p_ops
222+ ):
183223 """
184224 Prepares irecv operations for experts thatneed to be received
185225 from other ranks as part of the `shuffle_layer` process.
@@ -199,7 +239,7 @@ def prepare_recv_p2p_ops(self, ep_group, ep_rank, expert_weights_buffer,
199239
200240 Returns:
201241 A tuple containing:
202- - experts_recv_loc: A dictionary mapping global expert ID to the
242+ - experts_recv_loc: A dictionary mapping global expert ID to the
203243 local buffer index where it will be received.
204244 - p2p_ops: The updated list of P2POp objects.
205245 """
@@ -225,7 +265,8 @@ def prepare_recv_p2p_ops(self, ep_group, ep_rank, expert_weights_buffer,
225265 # Calculate the rank to recv by this rank
226266 if not ranks_to_send :
227267 raise ValueError (
228- f"Expert { expert } is needed but no rank has it." )
268+ f"Expert { expert } is needed but no rank has it."
269+ )
229270 num_dst_per_sender = len (ranks_to_recv ) // len (ranks_to_send )
230271 recver_pos = ranks_to_recv .index (ep_rank )
231272 remainder_start = len (ranks_to_send ) * num_dst_per_sender
@@ -240,13 +281,22 @@ def prepare_recv_p2p_ops(self, ep_group, ep_rank, expert_weights_buffer,
240281 torch .distributed .irecv ,
241282 weight [dst ],
242283 src_global ,
243- ) for weight in expert_weights_buffer
284+ )
285+ for weight in expert_weights_buffer
244286 ]
245287 return experts_recv_loc , p2p_ops
246288
247- def prepare_send_p2p_ops (self , ep_group , ep_rank , expert_weights ,
248- local2global , new_indices , num_local_experts ,
249- old_indices , p2p_ops ):
289+ def prepare_send_p2p_ops (
290+ self ,
291+ ep_group ,
292+ ep_rank ,
293+ expert_weights ,
294+ local2global ,
295+ new_indices ,
296+ num_local_experts ,
297+ old_indices ,
298+ p2p_ops
299+ ):
250300 """
251301 Prepares isend operations for experts that needto be sent
252302 to other ranks as part of the `shuffle_layer` process.
@@ -255,7 +305,7 @@ def prepare_send_p2p_ops(self, ep_group, ep_rank, expert_weights,
255305 ep_group: The PyTorch distributed process group.
256306 ep_rank: Current rank in the EP group.
257307 expert_weights: The actual expert weight tensors to be sent.
258- local2global: Partial function for local
308+ local2global: Partial function for local
259309 to global index conversion.
260310 new_indices: The new global expert indices mapping.
261311 num_local_experts: Number of local experts.
@@ -302,7 +352,8 @@ def prepare_send_p2p_ops(self, ep_group, ep_rank, expert_weights,
302352 torch .distributed .isend ,
303353 weight [src ],
304354 dst_global ,
305- ) for weight in expert_weights
355+ )
356+ for weight in expert_weights
306357 ]
307358 return p2p_ops
308359
@@ -314,17 +365,22 @@ def prepare_send(self, expert_send_info, layer_id):
314365 setting up communication for a specific layer.
315366
316367 Args:
317- expert_send_info: A list of tuples, where each tuple is
368+ expert_send_info: A list of tuples, where each tuple is
318369 (destination_rank, global_expert_id_to_send).
319370 layer_id: The ID of the MoE layer for which experts are being sent.
320371 """
321372 for dst_rank , global_expert_id_to_send in expert_send_info :
322- local_expert_id = self .eplb_adaptor .expert_map_per_layer_cpu [
323- layer_id ][global_expert_id_to_send ].item ()
324- for src_tensor in self .eplb_adaptor .expert_param_per_layer [
325- layer_id ][local_expert_id ]:
373+ local_expert_id = \
374+ self .eplb_adaptor .expert_map_per_layer_cpu [layer_id ][
375+ global_expert_id_to_send
376+ ].item ()
377+ for src_tensor in \
378+ self .eplb_adaptor .expert_param_per_layer [layer_id ][
379+ local_expert_id
380+ ]:
326381 self .comm_op_list .append (
327- dist .P2POp (dist .isend , src_tensor , dst_rank ))
382+ dist .P2POp (dist .isend , src_tensor , dst_rank )
383+ )
328384
329385 @override
330386 def prepare_recv (self , expert_recv_info , updated_expert_map ):
@@ -354,9 +410,9 @@ def prepare_recv(self, expert_recv_info, updated_expert_map):
354410 self .recv_expert_list .append (
355411 (local_expert_to_replace , buffer_tensor_id ))
356412
357- def generate_expert_d2d_transfer_task (self , expert_send_info ,
358- expert_recv_info , updated_expert_map ,
359- layer_id ):
413+ def generate_expert_d2d_transfer_task (
414+ self , expert_send_info , expert_recv_info , updated_expert_map , layer_id
415+ ):
360416 """
361417 Generates the expert data-to-data transfer tasks(send and
362418 receiveoperations). for a given layer based on the provided
@@ -407,28 +463,32 @@ def update_expert_map_and_weight(self):
407463 req .wait ()
408464
409465 # update expert_map
410- #解耦adaptor与loader
411- self .eplb_adaptor .do_update_expert_map (self .layer_id ,
412- self .updated_expert_map )
466+ # decouple adaptor and transfer
467+ self .eplb_adaptor .do_update_expert_map (
468+ self .layer_id ,
469+ self .updated_expert_map
470+ )
413471
414472 # update log2phy_map
415- self .eplb_adaptor .do_update_log2phy_map (self .layer_id ,
416- self .updated_log2phy_map )
473+ self .eplb_adaptor .do_update_log2phy_map (
474+ self .layer_id ,
475+ self .updated_log2phy_map
476+ )
417477
418478 # update expert weight
419479 buffer_tensor_id = 0
420480 for recv_expert_info in self .recv_expert_list :
421481 local_expert_to_replace , buffer_tensor_id = recv_expert_info
422- self .eplb_adaptor .do_update_expert_weight (self . layer_id ,
423- local_expert_to_replace ,
424- buffer_tensor_id )
482+ self .eplb_adaptor .do_update_expert_weight (
483+ self . layer_id , local_expert_to_replace , buffer_tensor_id
484+ )
425485
426486 self .clear_update_data ()
427487
428488 def clear_update_data (self ):
429489 """
430490 Clears the internal lists and temporary data used for the current
431- update cycle. This prepares the loader for the next rearrangement
491+ update cycle. This prepares the transfer for the next rearrangement
432492 cycle.
433493 """
434494 if self .comm_op_list is not None :
0 commit comments