Skip to content

Commit dfac092

Browse files
author
白永斌
committed
fix pre-commit & rename 'weight loader' to 'weight transfer'
Signed-off-by: 白永斌 <[email protected]>
1 parent ba60019 commit dfac092

File tree

6 files changed

+142
-78
lines changed

6 files changed

+142
-78
lines changed

tests/distributed/eplb_utils/test_eplb_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66

77
# Import the functions to test
88
from vllm.distributed.eplb.eplb_utils.eplb_utils import (
9-
determine_default_log2phy_map, generate_log2phy_map,
10-
get_ep_ranks_with_expert, global_idx_to_rank, idx_global_to_local,
9+
determine_default_log2phy_map,
10+
generate_log2phy_map,
11+
get_ep_ranks_with_expert,
12+
global_idx_to_rank,
13+
idx_global_to_local,
1114
idx_local_to_global)
1215

1316

@@ -64,10 +67,10 @@ def test_global_idx_to_rank_edge():
6467
# Global index = 0 (minimum valid value)
6568
assert global_idx_to_rank(global_idx=0, local_cnt=100) == 0
6669
# Global index = local_cnt*N - 1 (last index of rank N-1)
67-
assert global_idx_to_rank(global_idx=99,
68-
local_cnt=100) == 0 # 0*100 ≤99 <1*100
69-
assert global_idx_to_rank(global_idx=199,
70-
local_cnt=100) == 1 #1*100 ≤199 <2*100
70+
assert global_idx_to_rank(global_idx=99, local_cnt=100) == 0
71+
# 0*100 ≤99 <1*100
72+
assert global_idx_to_rank(global_idx=199, local_cnt=100) == 1
73+
# 1*100 ≤199 <2*100
7174

7275

7376
# ----------------------- Test get_ep_ranks_with_expert -----------------------
@@ -214,4 +217,4 @@ def mock_choice(arr):
214217
# Expert_map_all for rank2: [-1,-1,-1,-1,0,1,2] (3 local experts)
215218
# After generate_log2phy_map: all ranks get full expert values
216219
expected = torch.tensor([0, 1, 3, 4, 6, 7, 8], dtype=torch.int32)
217-
assert torch.equal(log2phy_map_rank2, expected)
220+
assert torch.equal(log2phy_map_rank2, expected)

vllm/distributed/eplb/eplb_adaptor/vllm_adaptor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55

66
class VllmEplbAdaptor:
7-
87
def __init__(self, model, **args):
98
self.model = model
109
self.expert_map_per_layer = dict()

vllm/distributed/eplb/eplb_loader/abstract_loader.py renamed to vllm/distributed/eplb/eplb_transfer/abstract_transfer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from abc import ABC, abstractmethod
55

66

7-
class BaseLoader(ABC):
7+
class BaseTransfer(ABC):
88
"""
9-
Abstract base class for a loader component responsible for managing
9+
Abstract base class for a transfer component responsible for managing
1010
expert weights and their transfer/update mechanisms in an Expert Parallel
1111
(EP) system.
1212

vllm/distributed/eplb/eplb_loader/eplb_weight_loader.py renamed to vllm/distributed/eplb/eplb_transfer/eplb_weight_transfer.py

Lines changed: 117 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,25 @@
1212
from torch.distributed import P2POp, Work, batch_isend_irecv, get_global_rank
1313

1414
from vllm.distributed.eplb.eplb_adaptor.vllm_adaptor import VllmEplbAdaptor
15-
from vllm.distributed.eplb.eplb_loader.abstract_loader import BaseLoader
15+
from vllm.distributed.eplb.eplb_transfer.abstract_transfer import BaseTransfer
1616
from vllm.distributed.eplb.eplb_utils.eplb_utils import (
17-
get_ep_ranks_with_expert, idx_local_to_global)
17+
get_ep_ranks_with_expert,
18+
idx_local_to_global,
19+
)
1820

1921

20-
class EplbWeightLoader(BaseLoader):
22+
class EplbWeightTransfer(BaseTransfer):
2123
"""
22-
A concrete implementation of BaseLoader for Expert Parallel Load
24+
A concrete implementation of BaseTransfer for Expert Parallel Load
2325
Balancing (EPLB).
2426
25-
This class is responsible for managing the transfer and update of
27+
This class is responsible for managing the transfer and update of
2628
expert weights across different ranks during expert rearrangement.
2729
"""
2830

2931
def __init__(self, eplb_adaptor: VllmEplbAdaptor):
3032
"""
31-
Initializes the EplbWeightLoader.
33+
Initializes the EplbWeightTransfer.
3234
3335
Args:
3436
eplb_adaptor: An adaptor to interact with the vLLM model's expert
@@ -54,7 +56,7 @@ def shuffle_layer(
5456
) -> None:
5557
"""
5658
Performs expert weights rearrangement for a single MoE layer.
57-
This method orchestrates the entire shuffling process including
59+
This method orchestrates the entire shuffling process including
5860
local copies and inter-rank P2P communications.
5961
6062
Args:
@@ -71,7 +73,7 @@ def shuffle_layer(
7173
expert_weights_buffer: An iterable of torch.Tensor,
7274
representing buffers for receiving expert
7375
weights. Same structure as expert_weights.
74-
ep_group: The PyTorch distributed process group for
76+
ep_group: The PyTorch distributed process group for
7577
expert parallelism.
7678
"""
7779

@@ -95,39 +97,68 @@ def shuffle_layer(
9597
dst_global = local2global(dst)
9698
if is_received_locally[dst]:
9799
continue
98-
if old_indices[src_global] == -1 or new_indices[
99-
dst_global] == -1:
100+
if old_indices[src_global] == -1 or \
101+
new_indices[dst_global] == -1:
100102
continue
101103
if old_indices[src_global] == new_indices[dst_global]:
102104
is_received_locally[dst] = True
103-
for weight, buffer in zip(expert_weights,
104-
expert_weights_buffer):
105+
for weight, buffer in \
106+
zip(expert_weights, expert_weights_buffer):
105107
buffer[dst].copy_(weight[src])
106108

107109
p2p_ops: list[P2POp] = []
108110

109111
# 2. Initiate sending of weights.
110-
p2p_ops = self.prepare_send_p2p_ops(ep_group, ep_rank, expert_weights,
111-
local2global, new_indices,
112-
num_local_experts, old_indices,
113-
p2p_ops)
112+
p2p_ops = self.prepare_send_p2p_ops(
113+
ep_group,
114+
ep_rank,
115+
expert_weights,
116+
local2global,
117+
new_indices,
118+
num_local_experts,
119+
old_indices,
120+
p2p_ops
121+
)
114122

115123
# 3. Initiate receiving of weights.
116124
experts_recv_loc, p2p_ops = self.prepare_recv_p2p_ops(
117-
ep_group, ep_rank, expert_weights_buffer, is_received_locally,
118-
local2global, new_indices, num_local_experts, old_indices, p2p_ops)
125+
ep_group,
126+
ep_rank,
127+
expert_weights_buffer,
128+
is_received_locally,
129+
local2global,
130+
new_indices,
131+
num_local_experts,
132+
old_indices,
133+
p2p_ops
134+
)
119135

120136
# 4. Execute the P2P operations. The real communication happens here.
121137
self.send_recv(p2p_ops)
122138

123139
# 5. Copy the weights from the buffer back to the original weights.
124-
self.update_weight(expert_weights, expert_weights_buffer,
125-
experts_recv_loc, is_received_locally, is_unchanged,
126-
local2global, new_indices, num_local_experts)
140+
self.update_weight(
141+
expert_weights,
142+
expert_weights_buffer,
143+
experts_recv_loc,
144+
is_received_locally,
145+
is_unchanged,
146+
local2global,
147+
new_indices,
148+
num_local_experts
149+
)
127150

128-
def update_weight(self, expert_weights, expert_weights_buffer,
129-
experts_recv_loc, is_received_locally, is_unchanged,
130-
local2global, new_indices, num_local_experts):
151+
def update_weight(
152+
self,
153+
expert_weights,
154+
expert_weights_buffer,
155+
experts_recv_loc,
156+
is_received_locally,
157+
is_unchanged,
158+
local2global,
159+
new_indices,
160+
num_local_experts
161+
):
131162
"""
132163
Updates the actual expert weights from the buffer after communication.
133164
This is part of the `shuffle_layer` process.
@@ -152,16 +183,16 @@ def update_weight(self, expert_weights, expert_weights_buffer,
152183
if is_unchanged[dst]:
153184
continue
154185
if is_received_locally[dst]:
155-
for weight, buffer in zip(expert_weights,
156-
expert_weights_buffer):
186+
for weight, buffer in \
187+
zip(expert_weights, expert_weights_buffer):
157188
weight[dst].copy_(buffer[dst])
158189
else:
159190
expert = new_indices[local2global(dst)]
160191
if expert == -1:
161192
continue
162193
src = experts_recv_loc[expert]
163-
for weight, buffer in zip(expert_weights,
164-
expert_weights_buffer):
194+
for weight, buffer in \
195+
zip(expert_weights, expert_weights_buffer):
165196
weight[dst].copy_(buffer[src])
166197

167198
@override
@@ -177,9 +208,18 @@ def send_recv(self, p2p_ops: list[P2POp]) -> None:
177208
for req in reqs:
178209
req.wait()
179210

180-
def prepare_recv_p2p_ops(self, ep_group, ep_rank, expert_weights_buffer,
181-
is_received_locally, local2global, new_indices,
182-
num_local_experts, old_indices, p2p_ops):
211+
def prepare_recv_p2p_ops(
212+
self,
213+
ep_group,
214+
ep_rank,
215+
expert_weights_buffer,
216+
is_received_locally,
217+
local2global,
218+
new_indices,
219+
num_local_experts,
220+
old_indices,
221+
p2p_ops
222+
):
183223
"""
184224
Prepares irecv operations for experts thatneed to be received
185225
from other ranks as part of the `shuffle_layer` process.
@@ -199,7 +239,7 @@ def prepare_recv_p2p_ops(self, ep_group, ep_rank, expert_weights_buffer,
199239
200240
Returns:
201241
A tuple containing:
202-
- experts_recv_loc: A dictionary mapping global expert ID to the
242+
- experts_recv_loc: A dictionary mapping global expert ID to the
203243
local buffer index where it will be received.
204244
- p2p_ops: The updated list of P2POp objects.
205245
"""
@@ -225,7 +265,8 @@ def prepare_recv_p2p_ops(self, ep_group, ep_rank, expert_weights_buffer,
225265
# Calculate the rank to recv by this rank
226266
if not ranks_to_send:
227267
raise ValueError(
228-
f"Expert {expert} is needed but no rank has it.")
268+
f"Expert {expert} is needed but no rank has it."
269+
)
229270
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
230271
recver_pos = ranks_to_recv.index(ep_rank)
231272
remainder_start = len(ranks_to_send) * num_dst_per_sender
@@ -240,13 +281,22 @@ def prepare_recv_p2p_ops(self, ep_group, ep_rank, expert_weights_buffer,
240281
torch.distributed.irecv,
241282
weight[dst],
242283
src_global,
243-
) for weight in expert_weights_buffer
284+
)
285+
for weight in expert_weights_buffer
244286
]
245287
return experts_recv_loc, p2p_ops
246288

247-
def prepare_send_p2p_ops(self, ep_group, ep_rank, expert_weights,
248-
local2global, new_indices, num_local_experts,
249-
old_indices, p2p_ops):
289+
def prepare_send_p2p_ops(
290+
self,
291+
ep_group,
292+
ep_rank,
293+
expert_weights,
294+
local2global,
295+
new_indices,
296+
num_local_experts,
297+
old_indices,
298+
p2p_ops
299+
):
250300
"""
251301
Prepares isend operations for experts that needto be sent
252302
to other ranks as part of the `shuffle_layer` process.
@@ -255,7 +305,7 @@ def prepare_send_p2p_ops(self, ep_group, ep_rank, expert_weights,
255305
ep_group: The PyTorch distributed process group.
256306
ep_rank: Current rank in the EP group.
257307
expert_weights: The actual expert weight tensors to be sent.
258-
local2global: Partial function for local
308+
local2global: Partial function for local
259309
to global index conversion.
260310
new_indices: The new global expert indices mapping.
261311
num_local_experts: Number of local experts.
@@ -302,7 +352,8 @@ def prepare_send_p2p_ops(self, ep_group, ep_rank, expert_weights,
302352
torch.distributed.isend,
303353
weight[src],
304354
dst_global,
305-
) for weight in expert_weights
355+
)
356+
for weight in expert_weights
306357
]
307358
return p2p_ops
308359

@@ -314,17 +365,22 @@ def prepare_send(self, expert_send_info, layer_id):
314365
setting up communication for a specific layer.
315366
316367
Args:
317-
expert_send_info: A list of tuples, where each tuple is
368+
expert_send_info: A list of tuples, where each tuple is
318369
(destination_rank, global_expert_id_to_send).
319370
layer_id: The ID of the MoE layer for which experts are being sent.
320371
"""
321372
for dst_rank, global_expert_id_to_send in expert_send_info:
322-
local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[
323-
layer_id][global_expert_id_to_send].item()
324-
for src_tensor in self.eplb_adaptor.expert_param_per_layer[
325-
layer_id][local_expert_id]:
373+
local_expert_id = \
374+
self.eplb_adaptor.expert_map_per_layer_cpu[layer_id][
375+
global_expert_id_to_send
376+
].item()
377+
for src_tensor in \
378+
self.eplb_adaptor.expert_param_per_layer[layer_id][
379+
local_expert_id
380+
]:
326381
self.comm_op_list.append(
327-
dist.P2POp(dist.isend, src_tensor, dst_rank))
382+
dist.P2POp(dist.isend, src_tensor, dst_rank)
383+
)
328384

329385
@override
330386
def prepare_recv(self, expert_recv_info, updated_expert_map):
@@ -354,9 +410,9 @@ def prepare_recv(self, expert_recv_info, updated_expert_map):
354410
self.recv_expert_list.append(
355411
(local_expert_to_replace, buffer_tensor_id))
356412

357-
def generate_expert_d2d_transfer_task(self, expert_send_info,
358-
expert_recv_info, updated_expert_map,
359-
layer_id):
413+
def generate_expert_d2d_transfer_task(
414+
self, expert_send_info,expert_recv_info, updated_expert_map, layer_id
415+
):
360416
"""
361417
Generates the expert data-to-data transfer tasks(send and
362418
receiveoperations). for a given layer based on the provided
@@ -407,28 +463,32 @@ def update_expert_map_and_weight(self):
407463
req.wait()
408464

409465
# update expert_map
410-
#解耦adaptor与loader
411-
self.eplb_adaptor.do_update_expert_map(self.layer_id,
412-
self.updated_expert_map)
466+
# decouple adaptor and transfer
467+
self.eplb_adaptor.do_update_expert_map(
468+
self.layer_id,
469+
self.updated_expert_map
470+
)
413471

414472
# update log2phy_map
415-
self.eplb_adaptor.do_update_log2phy_map(self.layer_id,
416-
self.updated_log2phy_map)
473+
self.eplb_adaptor.do_update_log2phy_map(
474+
self.layer_id,
475+
self.updated_log2phy_map
476+
)
417477

418478
# update expert weight
419479
buffer_tensor_id = 0
420480
for recv_expert_info in self.recv_expert_list:
421481
local_expert_to_replace, buffer_tensor_id = recv_expert_info
422-
self.eplb_adaptor.do_update_expert_weight(self.layer_id,
423-
local_expert_to_replace,
424-
buffer_tensor_id)
482+
self.eplb_adaptor.do_update_expert_weight(
483+
self.layer_id, local_expert_to_replace, buffer_tensor_id
484+
)
425485

426486
self.clear_update_data()
427487

428488
def clear_update_data(self):
429489
"""
430490
Clears the internal lists and temporary data used for the current
431-
update cycle. This prepares the loader for the next rearrangement
491+
update cycle. This prepares the transfer for the next rearrangement
432492
cycle.
433493
"""
434494
if self.comm_op_list is not None:

0 commit comments

Comments
 (0)