|
43 | 43 |
|
44 | 44 | if current_platform.is_cuda_alike(): |
45 | 45 | from .fused_batched_moe import BatchedTritonExperts |
46 | | - from .fused_moe import TritonExperts, fused_experts |
| 46 | + from .fused_moe import (TritonExperts, eplb_map_to_physical_and_record, |
| 47 | + fused_experts) |
47 | 48 | if has_pplx(): |
48 | 49 | from .pplx_prepare_finalize import (PplxPrepareAndFinalize, |
49 | 50 | pplx_hidden_dim_scale_bytes) |
|
55 | 56 | fused_experts = None # type: ignore |
56 | 57 | FusedMoEPermuteExpertsUnpermute = None # type: ignore |
57 | 58 | FusedMoEPrepareAndFinalize = None # type: ignore |
| 59 | + |
| 60 | + def eplb_map_to_physical_and_record( |
| 61 | + topk_ids: torch.Tensor, expert_load_view: torch.Tensor, |
| 62 | + logical_to_physical_map: torch.Tensor, |
| 63 | + logical_replica_count: torch.Tensor, |
| 64 | + indices_type: Optional[torch.dtype]) -> torch.Tensor: |
| 65 | + # CPU fallback: no EPLB so just return as is |
| 66 | + return topk_ids |
| 67 | + |
| 68 | + |
58 | 69 | if is_rocm_aiter_moe_enabled(): |
59 | 70 | from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 |
60 | 71 | rocm_aiter_grouped_topk as grouped_topk) |
@@ -1616,55 +1627,13 @@ def select_experts( |
1616 | 1627 | assert logical_to_physical_map is not None |
1617 | 1628 | assert logical_replica_count is not None |
1618 | 1629 |
|
1619 | | - # 1. Convert the logical expert ids to physical expert ids |
1620 | | - # Directly select a random replica for each logical expert |
1621 | | - |
1622 | | - # TODO: maybe optimize this by using specified kernels, |
1623 | | - # or compute pseudo-random indices by modulo |
1624 | | - |
1625 | | - # In case `indices_type` is not `torch.long` or `torch.int`, |
1626 | | - # e.g. `torch.uint32` as required by dispatch/combine kernels |
1627 | | - topk_ids_long = topk_ids.long() |
1628 | | - replica_indices = ( |
1629 | | - torch.rand_like(topk_ids, dtype=torch.float) * |
1630 | | - logical_replica_count[topk_ids_long]).long().unsqueeze(-1) |
1631 | | - physical_ids = logical_to_physical_map[topk_ids_long].gather( |
1632 | | - -1, replica_indices).squeeze(-1) |
1633 | | - |
1634 | | - topk_ids = physical_ids |
1635 | | - |
1636 | | - # 2. Record expert load metrics. |
1637 | | - |
1638 | | - # TODO(bowen): When using `FusedMoEModularKernel`, this |
1639 | | - # can be done in a more unified way, since |
1640 | | - # `FusedMoEPrepareAndFinalize` will return the expert |
1641 | | - # token count, in some cases directly from the kernel. |
1642 | | - # However, now there are many code paths not using |
1643 | | - # the modular kernel, e.g. calling `fused_experts`, |
1644 | | - # so we decide to keep the logic here. |
1645 | | - # |
1646 | | - # If later refactor moved all the MoE kernel calls |
1647 | | - # to the modular kernel, we can move this logic there |
1648 | | - # to achieve better efficiency. |
1649 | | - |
1650 | | - # `expert_load_view`: (num_physical_experts,) |
1651 | | - |
1652 | | - topk_ids_flatten = topk_ids.flatten() |
1653 | | - |
1654 | | - # Performance optimization: |
1655 | | - # `masked_fill` is significantly faster than `masked_select` |
1656 | | - invalid_mask = topk_ids_flatten < 0 |
1657 | | - # Replace invalid expert ids with 0 (just a dummy position) |
1658 | | - # to avoid out-of-bounds errors in scatter_add_ |
1659 | | - index = topk_ids_flatten.masked_fill_(invalid_mask, 0) |
1660 | | - # `src` is the valid mask, which is 1 for valid and 0 for invalid |
1661 | | - src = ~invalid_mask |
1662 | | - |
1663 | | - expert_load_view.scatter_add_(dim=0, |
1664 | | - index=index.long(), |
1665 | | - src=src.to(expert_load_view)) |
1666 | | - |
1667 | | - topk_ids = topk_ids.to(dtype=indices_type) |
| 1630 | + topk_ids = eplb_map_to_physical_and_record( |
| 1631 | + topk_ids=topk_ids, |
| 1632 | + expert_load_view=expert_load_view, |
| 1633 | + logical_to_physical_map=logical_to_physical_map, |
| 1634 | + logical_replica_count=logical_replica_count, |
| 1635 | + indices_type=indices_type, |
| 1636 | + ) |
1668 | 1637 |
|
1669 | 1638 | assert topk_ids.dtype == indices_type or indices_type is None |
1670 | 1639 |
|
|
0 commit comments