Skip to content

Commit eb92af3

Browse files
committed
Add test for internode latency.
1 parent 89e33c0 commit eb92af3

5 files changed

Lines changed: 331 additions & 202 deletions

File tree

tests/alltoall.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import torch
2+
import torch.distributed as dist
3+
from typing import List, Tuple, Optional, Union
4+
5+
from deep_ep import Buffer, EventOverlap
6+
7+
# Communication buffer (will allocate at runtime)
8+
_buffer: Optional[Buffer] = None
9+
10+
# Set the number of SMs to use
11+
# NOTES: this is a static variable
12+
# Buffer.set_num_sms(24)
13+
14+
15+
# You may call this function at the framework initialization
16+
def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer:
17+
global _buffer
18+
19+
# NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests
20+
num_nvl_bytes, num_rdma_bytes = 0, 0
21+
for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())):
22+
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)
23+
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)
24+
25+
# Allocate a buffer if not existed or not enough buffer size
26+
if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes:
27+
_buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
28+
return _buffer
29+
30+
31+
def get_hidden_bytes(x: torch.Tensor) -> int:
32+
t = x[0] if isinstance(x, tuple) else x
33+
return t.size(1) * max(t.element_size(), 2)
34+
35+
36+
def dispatch_forward(
37+
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
38+
topk_idx: torch.Tensor,
39+
topk_weights: torch.Tensor,
40+
num_experts: int,
41+
previous_event: Optional[EventOverlap] = None,
42+
async_finish: bool = False,
43+
allocate_on_comm_stream: bool = False
44+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]:
45+
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
46+
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
47+
# refer to the docs of `Buffer.dispatch`
48+
global _buffer
49+
50+
# Calculate layout before actual dispatch
51+
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = _buffer.get_dispatch_layout(
52+
topk_idx,
53+
num_experts,
54+
previous_event=previous_event,
55+
async_finish=async_finish,
56+
allocate_on_comm_stream=allocate_on_comm_stream
57+
)
58+
59+
# Do MoE dispatch
60+
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
61+
# For more advanced usages, please refer to the docs of the `dispatch` function
62+
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = _buffer.dispatch(
63+
x,
64+
topk_idx=topk_idx,
65+
topk_weights=topk_weights,
66+
num_tokens_per_rank=num_tokens_per_rank,
67+
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
68+
is_token_in_rank=is_token_in_rank,
69+
num_tokens_per_expert=num_tokens_per_expert,
70+
previous_event=previous_event,
71+
async_finish=async_finish,
72+
allocate_on_comm_stream=allocate_on_comm_stream
73+
)
74+
75+
# For event management, please refer to the docs of the `EventOverlap` class
76+
return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event
77+
78+
79+
def dispatch_backward(
80+
grad_recv_x: torch.Tensor,
81+
grad_recv_topk_weights: torch.Tensor,
82+
handle: Tuple,
83+
previous_event: Optional[EventOverlap] = None,
84+
async_finish: bool = False,
85+
allocate_on_comm_stream: bool = False
86+
) -> Tuple[torch.Tensor, torch.Tensor, EventOverlap]:
87+
global _buffer
88+
89+
# The backward process of MoE dispatch is actually a combine
90+
# For more advanced usages, please refer to the docs of the `combine` function
91+
combined_grad_x, combined_grad_recv_topk_weights, event = _buffer.combine(
92+
grad_recv_x,
93+
handle,
94+
topk_weights=grad_recv_topk_weights,
95+
previous_event=previous_event,
96+
async_finish=async_finish,
97+
allocate_on_comm_stream=allocate_on_comm_stream
98+
)
99+
100+
# For event management, please refer to the docs of the `EventOverlap` class
101+
return combined_grad_x, combined_grad_recv_topk_weights, event
102+
103+
104+
def combine_forward(
105+
x: torch.Tensor,
106+
handle: Tuple,
107+
previous_event: Optional[EventOverlap] = None,
108+
async_finish: bool = False,
109+
allocate_on_comm_stream: bool = False
110+
) -> Tuple[torch.Tensor, EventOverlap]:
111+
global _buffer
112+
113+
# Do MoE combine
114+
# For more advanced usages, please refer to the docs of the `combine` function
115+
combined_x, _, event = _buffer.combine(
116+
x,
117+
handle,
118+
async_finish=async_finish,
119+
previous_event=previous_event,
120+
allocate_on_comm_stream=allocate_on_comm_stream)
121+
122+
# For event management, please refer to the docs of the `EventOverlap` class
123+
return combined_x, event
124+
125+
126+
def combine_backward(
127+
grad_combined_x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
128+
handle: Tuple,
129+
previous_event: Optional[EventOverlap] = None,
130+
async_finish: bool = False,
131+
allocate_on_comm_stream: bool = False
132+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], EventOverlap]:
133+
global _buffer
134+
135+
# The backward process of MoE combine is actually a dispatch
136+
# For more advanced usages, please refer to the docs of the `dispatch` function
137+
grad_x, _, _, _, _, event = _buffer.dispatch(
138+
grad_combined_x,
139+
handle=handle,
140+
async_finish=async_finish,
141+
previous_event=previous_event,
142+
allocate_on_comm_stream=allocate_on_comm_stream
143+
)
144+
145+
# For event management, please refer to the docs of the `EventOverlap` class
146+
return grad_x, event

tests/run_test_internode.sh

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,37 @@ WORK_ROOT=/root/paddlejob/workspace/env_run/liuyiqun
44
export PYTHONPATH=${WORK_ROOT}/env/virtualenvs_cuda12.8/torch_py310_yiqun
55
export PATH=${PYTHONPATH}/bin:${PATH}
66

7-
python -c "import torch; print(torch.__version__)"
8-
97
export PYTHONPATH=${WORK_ROOT}/PaPerf:$PYTHONPATH
108

119
#export NVSHMEM_DIR=$ROOT_DIR/third-party/nvshmem
1210
#export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH"
1311

14-
export MASTER_ADDR=10.54.95.204
15-
export MASTER_PORT=8367
16-
export WORLD_SIZE=8
17-
18-
START_NODE=0
19-
END_NODE=$((${START_NODE} + ${WORLD_SIZE}))
20-
export RANK=$(($PADDLE_TRAINER_ID - ${START_NODE}))
21-
22-
if [ ${PADDLE_TRAINER_ID} -lt ${START_NODE} ]; then
23-
echo "$PADDLE_TRAINER_ID exit"
24-
exit
25-
elif [ ${PADDLE_TRAINER_ID} -ge ${END_NODE} ]; then
26-
echo "$PADDLE_TRAINER_ID exit"
27-
exit
12+
START_RANK=46
13+
END_RANK=54
14+
15+
if [[ ${PADDLE_TRAINER_ID} -lt $START_RANK ]]; then
16+
exit 0
17+
fi
18+
19+
if [[ ${PADDLE_TRAINER_ID} -ge $END_RANK ]]; then
20+
exit 0
2821
fi
2922

23+
rank=$(($PADDLE_TRAINER_ID - $START_RANK))
24+
nnodes=$(($END_RANK - $START_RANK))
25+
echo "rank: ${rank}, nnodes: ${nnodes}"
26+
27+
python -c "import torch; print(torch.__version__)"
28+
29+
#master=`cat /root/paddlejob/workspace/hostfile | head -n 1 | awk '{print $1}'`
30+
export MASTER_ADDR="10.95.238.87" # 46
31+
#master="10.95.238.99" # 48
32+
#master="10.95.237.154" # 32
33+
#master="10.95.244.212" # 8
34+
export MASTER_PORT=8367
35+
export WORLD_SIZE=$nnodes
36+
export RANK=$rank
37+
3038
export NCCL_DEBUG=WARN
3139
#export NVSHMEM_DEBUG=DEBUG
3240
#export NVSHMEM_DEBUG=TRACE
@@ -46,7 +54,7 @@ export NVSHMEM_IB_TRAFFIC_CLASS=162
4654
#export NVSHMEM_IB_ENABLE_IBGDA=true
4755
#export NVSHMEM_DISABLE_P2P=1
4856
export NVSHMEM_BOOTSTRAP=UID
49-
export NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME==xgbe0
57+
export NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME=xgbe0
5058
#export NVSHMEM_BOOTSTRAP_UID_SOCK_FAMILY=AF_INET
5159

5260
#export NVSHMEM_DEBUG=INFO
@@ -57,5 +65,4 @@ export PATH=/opt/nvidia/nsight-systems/2025.1.1/bin:$PATH
5765

5866
rm -rf core.*
5967

60-
${nsys_args} python test_internode.py
61-
#${nsys_args} python test_simple.py
68+
${nsys_args} python test_internode_latency.py

tests/test_internode_latency.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import os
2+
import sys
3+
import time
4+
import numpy as np
5+
6+
import torch
7+
import torch.distributed as dist
8+
9+
# noinspection PyUnresolvedReferences
10+
import alltoall
11+
import utils
12+
from utils import init_dist, create_grouped_scores
13+
14+
try:
15+
from paperf import profile_torch
16+
has_paperf = True
17+
except ImportError:
18+
has_paperf = False
19+
20+
21+
def print_tensor_info(t, name):
22+
#print(f"-- {name}: data_ptr={t.untyped_storage().data_ptr()}, shape={t.size()}, dtype={t.dtype}")
23+
print(f"-- {name}: shape={t.size()}, dtype={t.dtype}")
24+
25+
26+
def test_main(local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, group: dist.ProcessGroup, use_random_input, dump_input):
27+
# Settings
28+
num_tokens = 4096
29+
hidden = 7168
30+
num_topk_groups = min(num_nodes, 4)
31+
num_topk = 8
32+
num_experts = (256 // num_ranks) * num_ranks
33+
34+
assert num_experts % num_ranks == 0 and num_local_ranks == 8
35+
if local_rank == 0:
36+
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True)
37+
38+
if use_random_input:
39+
# Random data
40+
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
41+
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
42+
#x_e4m3 = per_token_cast_to_fp8(x)
43+
44+
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
45+
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
46+
group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices
47+
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
48+
49+
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1]
50+
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
51+
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
52+
53+
if dump_input:
54+
utils.dump(x, 'x', local_rank)
55+
utils.dump(x_pure_rand, 'x_pure_rand', local_rank)
56+
#utils.dump(x_e4m3, 'x_e4m3', local_rank)
57+
58+
utils.dump(topk_idx, 'topk_idx', local_rank)
59+
utils.dump(topk_weights, 'topk_weights', local_rank)
60+
utils.dump(topk_weights_pure_rand, 'topk_weights_pure_rand', local_rank)
61+
else:
62+
x = utils.load("x", local_rank)
63+
x_pure_rand = utils.load("x_pure_rand", local_rank)
64+
#x_e4m3 = utils.load("x_e4m3", local_rank, "tuple")
65+
66+
topk_idx = utils.load("topk_idx", local_rank)
67+
topk_weights = utils.load("topk_weights", local_rank)
68+
topk_weights_pure_rand = utils.load("topk_weights_pure_rand", local_rank)
69+
70+
71+
profile = False
72+
profile = profile and has_paperf
73+
74+
# test bfloat16
75+
buffer = alltoall.get_buffer(group, alltoall.get_hidden_bytes(x))
76+
77+
if profile:
78+
profile_torch.switch_profile(0, 0, 1)
79+
80+
num_warmups = 100
81+
num_tests = 1000
82+
83+
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
84+
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
85+
86+
for i in range(num_warmups + num_tests):
87+
if i == num_warmups:
88+
group.barrier()
89+
torch.cuda.synchronize()
90+
cpu_start = time.time()
91+
92+
if i >= num_warmups:
93+
# Record
94+
batch_start = time.time()
95+
start_events[i - num_warmups].record()
96+
97+
#group.barrier()
98+
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, dispatch_event = alltoall.dispatch_forward(
99+
x=x,
100+
topk_idx=topk_idx,
101+
topk_weights=topk_weights,
102+
num_experts=num_experts,
103+
previous_event=None,
104+
async_finish=False,
105+
allocate_on_comm_stream=False
106+
)
107+
108+
combined_x, event = alltoall.combine_forward(
109+
x=recv_x,
110+
handle=handle,
111+
previous_event=None,
112+
async_finish=False,
113+
allocate_on_comm_stream=False
114+
)
115+
116+
end_events[i - num_warmups].record()
117+
batch_time = time.time() - batch_start
118+
if local_rank == 0:
119+
print(f"-- {i - num_warmups}-th running, cpu_time: {batch_time:.5f} s")
120+
torch.cuda.synchronize()
121+
group.barrier()
122+
123+
cpu_runtime = time.time() - cpu_start
124+
avg_cpu_time = cpu_runtime / num_tests
125+
126+
gpu_times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:]
127+
avg_gpu_time = np.average(gpu_times)
128+
129+
print(f"-- rank: {rank}, avg_cpu_time: {avg_cpu_time:.5f} s, avg_gpu_time: {avg_gpu_time:.5f} s")
130+
131+
torch.cuda.synchronize()
132+
group.barrier()
133+
134+
avg_cpu_time_all_ranks = [None, ] * num_ranks
135+
avg_gpu_time_all_ranks = [None, ] * num_ranks
136+
dist.all_gather_object(avg_cpu_time_all_ranks, avg_cpu_time, group=group)
137+
dist.all_gather_object(avg_gpu_time_all_ranks, avg_gpu_time, group=group)
138+
if rank == 0:
139+
avg_cpu_time = np.average(np.array(avg_cpu_time_all_ranks))
140+
avg_gpu_time = np.average(np.array(avg_gpu_time_all_ranks))
141+
print(f"-- avg_cpu_time_of_all_ranks: {avg_cpu_time:.5f} s, avg_gpu_time_of_all_ranks: {avg_gpu_time:.5f} s")
142+
143+
144+
def test_loop(local_rank: int, num_local_ranks: int):
145+
num_nodes = int(os.getenv('WORLD_SIZE', 1))
146+
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
147+
148+
assert num_local_ranks == 8 and num_ranks > 8
149+
torch.manual_seed(rank)
150+
151+
use_random_input = True
152+
dump_input = False
153+
154+
test_main(local_rank, num_local_ranks, num_ranks, num_nodes, rank, group, use_random_input, dump_input)
155+
156+
157+
if __name__ == '__main__':
158+
num_processes = 8
159+
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)

0 commit comments

Comments
 (0)