|
| 1 | +import torch |
| 2 | +import torch.distributed as dist |
| 3 | +from typing import List, Tuple, Optional, Union |
| 4 | + |
| 5 | +from deep_ep import Buffer, EventOverlap |
| 6 | + |
| 7 | +# Communication buffer (will allocate at runtime) |
| 8 | +_buffer: Optional[Buffer] = None |
| 9 | + |
| 10 | +# Set the number of SMs to use |
| 11 | +# NOTES: this is a static variable |
| 12 | +# Buffer.set_num_sms(24) |
| 13 | + |
| 14 | + |
| 15 | +# You may call this function at the framework initialization |
| 16 | +def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer: |
| 17 | + global _buffer |
| 18 | + |
| 19 | + # NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests |
| 20 | + num_nvl_bytes, num_rdma_bytes = 0, 0 |
| 21 | + for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())): |
| 22 | + num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes) |
| 23 | + num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes) |
| 24 | + |
| 25 | + # Allocate a buffer if not existed or not enough buffer size |
| 26 | + if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes: |
| 27 | + _buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes) |
| 28 | + return _buffer |
| 29 | + |
| 30 | + |
| 31 | +def get_hidden_bytes(x: torch.Tensor) -> int: |
| 32 | + t = x[0] if isinstance(x, tuple) else x |
| 33 | + return t.size(1) * max(t.element_size(), 2) |
| 34 | + |
| 35 | + |
| 36 | +def dispatch_forward( |
| 37 | + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], |
| 38 | + topk_idx: torch.Tensor, |
| 39 | + topk_weights: torch.Tensor, |
| 40 | + num_experts: int, |
| 41 | + previous_event: Optional[EventOverlap] = None, |
| 42 | + async_finish: bool = False, |
| 43 | + allocate_on_comm_stream: bool = False |
| 44 | +) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]: |
| 45 | + # NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency |
| 46 | + # of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please |
| 47 | + # refer to the docs of `Buffer.dispatch` |
| 48 | + global _buffer |
| 49 | + |
| 50 | + # Calculate layout before actual dispatch |
| 51 | + num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = _buffer.get_dispatch_layout( |
| 52 | + topk_idx, |
| 53 | + num_experts, |
| 54 | + previous_event=previous_event, |
| 55 | + async_finish=async_finish, |
| 56 | + allocate_on_comm_stream=allocate_on_comm_stream |
| 57 | + ) |
| 58 | + |
| 59 | + # Do MoE dispatch |
| 60 | + # NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph |
| 61 | + # For more advanced usages, please refer to the docs of the `dispatch` function |
| 62 | + recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = _buffer.dispatch( |
| 63 | + x, |
| 64 | + topk_idx=topk_idx, |
| 65 | + topk_weights=topk_weights, |
| 66 | + num_tokens_per_rank=num_tokens_per_rank, |
| 67 | + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, |
| 68 | + is_token_in_rank=is_token_in_rank, |
| 69 | + num_tokens_per_expert=num_tokens_per_expert, |
| 70 | + previous_event=previous_event, |
| 71 | + async_finish=async_finish, |
| 72 | + allocate_on_comm_stream=allocate_on_comm_stream |
| 73 | + ) |
| 74 | + |
| 75 | + # For event management, please refer to the docs of the `EventOverlap` class |
| 76 | + return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event |
| 77 | + |
| 78 | + |
| 79 | +def dispatch_backward( |
| 80 | + grad_recv_x: torch.Tensor, |
| 81 | + grad_recv_topk_weights: torch.Tensor, |
| 82 | + handle: Tuple, |
| 83 | + previous_event: Optional[EventOverlap] = None, |
| 84 | + async_finish: bool = False, |
| 85 | + allocate_on_comm_stream: bool = False |
| 86 | +) -> Tuple[torch.Tensor, torch.Tensor, EventOverlap]: |
| 87 | + global _buffer |
| 88 | + |
| 89 | + # The backward process of MoE dispatch is actually a combine |
| 90 | + # For more advanced usages, please refer to the docs of the `combine` function |
| 91 | + combined_grad_x, combined_grad_recv_topk_weights, event = _buffer.combine( |
| 92 | + grad_recv_x, |
| 93 | + handle, |
| 94 | + topk_weights=grad_recv_topk_weights, |
| 95 | + previous_event=previous_event, |
| 96 | + async_finish=async_finish, |
| 97 | + allocate_on_comm_stream=allocate_on_comm_stream |
| 98 | + ) |
| 99 | + |
| 100 | + # For event management, please refer to the docs of the `EventOverlap` class |
| 101 | + return combined_grad_x, combined_grad_recv_topk_weights, event |
| 102 | + |
| 103 | + |
| 104 | +def combine_forward( |
| 105 | + x: torch.Tensor, |
| 106 | + handle: Tuple, |
| 107 | + previous_event: Optional[EventOverlap] = None, |
| 108 | + async_finish: bool = False, |
| 109 | + allocate_on_comm_stream: bool = False |
| 110 | +) -> Tuple[torch.Tensor, EventOverlap]: |
| 111 | + global _buffer |
| 112 | + |
| 113 | + # Do MoE combine |
| 114 | + # For more advanced usages, please refer to the docs of the `combine` function |
| 115 | + combined_x, _, event = _buffer.combine( |
| 116 | + x, |
| 117 | + handle, |
| 118 | + async_finish=async_finish, |
| 119 | + previous_event=previous_event, |
| 120 | + allocate_on_comm_stream=allocate_on_comm_stream) |
| 121 | + |
| 122 | + # For event management, please refer to the docs of the `EventOverlap` class |
| 123 | + return combined_x, event |
| 124 | + |
| 125 | + |
| 126 | +def combine_backward( |
| 127 | + grad_combined_x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], |
| 128 | + handle: Tuple, |
| 129 | + previous_event: Optional[EventOverlap] = None, |
| 130 | + async_finish: bool = False, |
| 131 | + allocate_on_comm_stream: bool = False |
| 132 | +) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], EventOverlap]: |
| 133 | + global _buffer |
| 134 | + |
| 135 | + # The backward process of MoE combine is actually a dispatch |
| 136 | + # For more advanced usages, please refer to the docs of the `dispatch` function |
| 137 | + grad_x, _, _, _, _, event = _buffer.dispatch( |
| 138 | + grad_combined_x, |
| 139 | + handle=handle, |
| 140 | + async_finish=async_finish, |
| 141 | + previous_event=previous_event, |
| 142 | + allocate_on_comm_stream=allocate_on_comm_stream |
| 143 | + ) |
| 144 | + |
| 145 | + # For event management, please refer to the docs of the `EventOverlap` class |
| 146 | + return grad_x, event |
0 commit comments