diff --git a/tests/alltoall.py b/tests/alltoall.py new file mode 100644 index 00000000..d4847336 --- /dev/null +++ b/tests/alltoall.py @@ -0,0 +1,146 @@ +import torch +import torch.distributed as dist +from typing import List, Tuple, Optional, Union + +from deep_ep import Buffer, EventOverlap + +# Communication buffer (will allocate at runtime) +_buffer: Optional[Buffer] = None + +# Set the number of SMs to use +# NOTES: this is a static variable +# Buffer.set_num_sms(24) + + +# You may call this function at the framework initialization +def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer: + global _buffer + + # NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests + num_nvl_bytes, num_rdma_bytes = 0, 0 + for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())): + num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes) + num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes) + + # Allocate a buffer if not existed or not enough buffer size + if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes: + _buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes) + return _buffer + + +def get_hidden_bytes(x: torch.Tensor) -> int: + t = x[0] if isinstance(x, tuple) else x + return t.size(1) * max(t.element_size(), 2) + + +def dispatch_forward( + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + previous_event: Optional[EventOverlap] = None, + async_finish: bool = False, + allocate_on_comm_stream: bool = False +) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]: + # NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency + # of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please + # refer to the docs of `Buffer.dispatch` + global _buffer + + # Calculate layout before actual dispatch + num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = _buffer.get_dispatch_layout( + topk_idx, + num_experts, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream + ) + + # Do MoE dispatch + # NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph + # For more advanced usages, please refer to the docs of the `dispatch` function + recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = _buffer.dispatch( + x, + topk_idx=topk_idx, + topk_weights=topk_weights, + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream + ) + + # For event management, please refer to the docs of the `EventOverlap` class + return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event + + +def dispatch_backward( + grad_recv_x: torch.Tensor, + grad_recv_topk_weights: torch.Tensor, + handle: Tuple, + previous_event: Optional[EventOverlap] = None, + async_finish: bool = False, + allocate_on_comm_stream: bool = False +) -> Tuple[torch.Tensor, torch.Tensor, EventOverlap]: + global _buffer + + # The backward process of MoE dispatch is actually a combine + # For more advanced usages, please refer to the docs of the `combine` function + combined_grad_x, combined_grad_recv_topk_weights, event = _buffer.combine( + grad_recv_x, + handle, + topk_weights=grad_recv_topk_weights, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream + ) + + # For event management, please refer to the docs of the `EventOverlap` class + return combined_grad_x, combined_grad_recv_topk_weights, event + + +def combine_forward( + x: torch.Tensor, + handle: Tuple, + previous_event: Optional[EventOverlap] = None, + async_finish: bool = False, + allocate_on_comm_stream: bool = False +) -> Tuple[torch.Tensor, EventOverlap]: + global _buffer + + # Do MoE combine + # For more advanced usages, please refer to the docs of the `combine` function + combined_x, _, event = _buffer.combine( + x, + handle, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream) + + # For event management, please refer to the docs of the `EventOverlap` class + return combined_x, event + + +def combine_backward( + grad_combined_x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + handle: Tuple, + previous_event: Optional[EventOverlap] = None, + async_finish: bool = False, + allocate_on_comm_stream: bool = False +) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], EventOverlap]: + global _buffer + + # The backward process of MoE combine is actually a dispatch + # For more advanced usages, please refer to the docs of the `dispatch` function + grad_x, _, _, _, _, event = _buffer.dispatch( + grad_combined_x, + handle=handle, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream + ) + + # For event management, please refer to the docs of the `EventOverlap` class + return grad_x, event diff --git a/tests/run_test_internode.sh b/tests/run_test_internode.sh new file mode 100644 index 00000000..09fc7ccf --- /dev/null +++ b/tests/run_test_internode.sh @@ -0,0 +1,68 @@ +#!/bin/bash + +WORK_ROOT=/root/paddlejob/workspace/env_run/liuyiqun +export PYTHONPATH=${WORK_ROOT}/env/virtualenvs_cuda12.8/torch_py310_yiqun +export PATH=${PYTHONPATH}/bin:${PATH} + +export PYTHONPATH=${WORK_ROOT}/PaPerf:$PYTHONPATH + +#export NVSHMEM_DIR=$ROOT_DIR/third-party/nvshmem +#export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH" + +START_RANK=46 +END_RANK=54 + +if [[ ${PADDLE_TRAINER_ID} -lt $START_RANK ]]; then + exit 0 +fi + +if [[ ${PADDLE_TRAINER_ID} -ge $END_RANK ]]; then + exit 0 +fi + +rank=$(($PADDLE_TRAINER_ID - $START_RANK)) +nnodes=$(($END_RANK - $START_RANK)) +echo "rank: ${rank}, nnodes: ${nnodes}" + +python -c "import torch; print(torch.__version__)" + +#master=`cat /root/paddlejob/workspace/hostfile | head -n 1 | awk '{print $1}'` +export MASTER_ADDR="10.95.238.87" # 46 +#master="10.95.238.99" # 48 +#master="10.95.237.154" # 32 +#master="10.95.244.212" # 8 +export MASTER_PORT=8367 +export WORLD_SIZE=$nnodes +export RANK=$rank + +export NCCL_DEBUG=WARN +#export NVSHMEM_DEBUG=DEBUG +#export NVSHMEM_DEBUG=TRACE + +# 保证集群稳定性的配置,跟性能无关 +export NCCL_IB_QPS_PER_CONNECTION=8 +#export NCCL_IB_TIMEOUT=22 +export NCCL_IB_GID_INDEX=3 +export NCCL_NVLS_ENABLE=0 +# 开启AR功能 +#export NCCL_IB_ADAPTIVE_ROUTING=1 + +export NCCL_IB_GID_INDEX=3 +export NVSHMEM_IB_GID_INDEX=3 +export NVSHMEM_IB_TRAFFIC_CLASS=162 + +#export NVSHMEM_IB_ENABLE_IBGDA=true +#export NVSHMEM_DISABLE_P2P=1 +export NVSHMEM_BOOTSTRAP=UID +export NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME=xgbe0 +#export NVSHMEM_BOOTSTRAP_UID_SOCK_FAMILY=AF_INET + +#export NVSHMEM_DEBUG=INFO + +export PATH=/opt/nvidia/nsight-systems/2025.1.1/bin:$PATH +#nsys_args="nsys profile --stats true -w true -t cuda,nvtx,cudnn,cublas --capture-range=cudaProfilerApi -x true --force-overwrite true -o test_simple_kernel_${WORLD_SIZE}.torch" +#nsys_args="nsys profile --stats true -w true -t cuda,nvtx --nic-metrics=true --capture-range=cudaProfilerApi -x true --force-overwrite true -o test_internode_${WORLD_SIZE}nodes_rank${RANK}.torch" + +rm -rf core.* + +${nsys_args} python test_internode_latency.py diff --git a/tests/run_test_intranode.sh b/tests/run_test_intranode.sh new file mode 100644 index 00000000..b373201b --- /dev/null +++ b/tests/run_test_intranode.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +WORK_ROOT=/root/paddlejob/workspace/env_run/liuyiqun +export PYTHONPATH=${WORK_ROOT}/env/virtualenvs_cuda12.8/torch_py310_yiqun +export PATH=${PYTHONPATH}/bin:${PATH} + +export MASTER_ADDR=10.54.98.83 +export MASTER_PORT=8364 +export WORLD_SIZE=1 +export RANK=$PADDLE_TRAINER_ID + +python test_intranode.py diff --git a/tests/test_internode.py b/tests/test_internode.py index 73c3bbd6..f0bfffe1 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -1,37 +1,66 @@ import os +import sys import time import torch import torch.distributed as dist # noinspection PyUnresolvedReferences import deep_ep +import utils from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back # Test compatibility with low latency functions import test_low_latency +try: + from paperf import profile_torch + has_paperf = True +except ImportError: + has_paperf = False -def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup): +def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup, use_random_input, dump_input, dump_output): # Settings num_tokens, hidden, num_topk_groups, num_topk, num_experts = 4096, 7168, min(num_nodes, 4), 8, (256 // num_ranks) * num_ranks assert num_experts % num_ranks == 0 and num_local_ranks == 8 if local_rank == 0: print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True) - # Random data - x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank - x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - x_e4m3 = per_token_cast_to_fp8(x) - scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 - group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1) - group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices - masked_scores = create_grouped_scores(scores, group_idx, num_nodes) - topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1] - topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank - topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda') + if use_random_input: + # Random data + x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank + x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + x_e4m3 = per_token_cast_to_fp8(x) + + scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 + group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1) + group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices + masked_scores = create_grouped_scores(scores, group_idx, num_nodes) + + topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1] + topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank + topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda') + + if dump_input: + utils.dump(x, 'x', local_rank) + utils.dump(x_pure_rand, 'x_pure_rand', local_rank) + utils.dump(x_e4m3, 'x_e4m3', local_rank) + + utils.dump(topk_idx, 'topk_idx', local_rank) + utils.dump(topk_weights, 'topk_weights', local_rank) + utils.dump(topk_weights_pure_rand, 'topk_weights_pure_rand', local_rank) + else: + x = utils.load("x", local_rank) + x_pure_rand = utils.load("x_pure_rand", local_rank) + x_e4m3 = utils.load("x_e4m3", local_rank, "tuple") + + topk_idx = utils.load("topk_idx", local_rank) + topk_weights = utils.load("topk_weights", local_rank) + topk_weights_pure_rand = utils.load("topk_weights_pure_rand", local_rank) + rank_idx = topk_idx // (num_experts // num_ranks) rank_idx.masked_fill_(topk_idx == -1, -1) inplace_unique(rank_idx, num_ranks) + rdma_rank_idx = rank_idx // num_local_ranks rdma_rank_idx.masked_fill_(rank_idx == -1, -1) inplace_unique(rdma_rank_idx, num_nodes) @@ -42,44 +71,79 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in inplace_unique(rdma_idx, num_nodes) num_rdma_token_sent = rdma_idx.ne(-1).sum().item() - # Expert meta - num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda') - for i in range(num_experts): - num_tokens_per_expert[i] = (topk_idx == i).sum() - gbl_num_tokens_per_expert = num_tokens_per_expert.clone() - dist.all_reduce(gbl_num_tokens_per_expert, group=group) - - # Rank layout meta - num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda') - num_tokens_per_rdma_rank = torch.empty((num_nodes, ), dtype=torch.int, device='cuda') - token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda') - for i in range(num_ranks): - num_tokens_per_rank[i] = (rank_idx == i).sum() - token_sel = (rank_idx == i).max(dim=-1)[0] - count = token_sel.sum().item() - tokens = torch.sort(token_sel.to(torch.int), descending=True)[1] - tokens[:count] = torch.sort(tokens[:count])[0] - token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda') - for i in range(num_nodes): - num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum() - token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int) - is_token_in_rank = token_idx_in_rank >= 0 - gbl_num_tokens_per_rank = num_tokens_per_rank.clone() - dist.all_reduce(gbl_num_tokens_per_rank, group=group) + current_node = rank // num_local_ranks + mask_rdma_only = (rdma_idx != current_node) & (rdma_idx != -1) + num_rdma_only_token_sent = mask_rdma_only.sum().item() + print(f"-- [local_rank={local_rank}, rank={rank}] num_rdma_token_sent: {num_rdma_token_sent}, num_rdma_token_sent_rdma_only: {num_rdma_only_token_sent}") + + if use_random_input: + # Expert meta + num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda') + for i in range(num_experts): + num_tokens_per_expert[i] = (topk_idx == i).sum() + gbl_num_tokens_per_expert = num_tokens_per_expert.clone() + dist.all_reduce(gbl_num_tokens_per_expert, group=group) + + # Rank layout meta + num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda') + num_tokens_per_rdma_rank = torch.empty((num_nodes, ), dtype=torch.int, device='cuda') + token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda') + for i in range(num_ranks): + num_tokens_per_rank[i] = (rank_idx == i).sum() + token_sel = (rank_idx == i).max(dim=-1)[0] + count = token_sel.sum().item() + tokens = torch.sort(token_sel.to(torch.int), descending=True)[1] + tokens[:count] = torch.sort(tokens[:count])[0] + token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda') + for i in range(num_nodes): + num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum() + token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int) + is_token_in_rank = token_idx_in_rank >= 0 + gbl_num_tokens_per_rank = num_tokens_per_rank.clone() + dist.all_reduce(gbl_num_tokens_per_rank, group=group) + + if dump_input: + utils.dump(num_tokens_per_rank, 'num_tokens_per_rank', local_rank) + utils.dump(num_tokens_per_rdma_rank, 'num_tokens_per_rdma_rank', local_rank) + utils.dump(is_token_in_rank, 'is_token_in_rank', local_rank) + utils.dump(num_tokens_per_expert, 'num_tokens_per_expert', local_rank) + utils.dump(gbl_num_tokens_per_rank, 'gbl_num_tokens_per_rank', local_rank) + utils.dump(gbl_num_tokens_per_expert, 'gbl_num_tokens_per_expert', local_rank) + else: + num_tokens_per_rank = utils.load('num_tokens_per_rank', local_rank) + num_tokens_per_rdma_rank = utils.load('num_tokens_per_rdma_rank', local_rank) + is_token_in_rank = utils.load('is_token_in_rank', local_rank) + num_tokens_per_expert = utils.load('num_tokens_per_expert', local_rank) + gbl_num_tokens_per_rank = utils.load('gbl_num_tokens_per_rank', local_rank) + gbl_num_tokens_per_expert = utils.load('gbl_num_tokens_per_expert', local_rank) + + ############################################################################################################ + # get_dispatch_layout + ############################################################################################################ ref_num_tokens_per_rank, ref_num_tokens_per_rdma_rank, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \ buffer.get_dispatch_layout(topk_idx, num_experts) + + if dump_output: + utils.dump(ref_num_tokens_per_rank, 'ref_num_tokens_per_rank', local_rank) + utils.dump(ref_num_tokens_per_rdma_rank, 'ref_num_tokens_per_rdma_rank', local_rank) + utils.dump(ref_num_tokens_per_expert, 'ref_num_tokens_per_expert', local_rank) + utils.dump(ref_is_token_in_rank, 'ref_is_token_in_rank', local_rank) + assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank) assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank) assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert) assert torch.allclose(ref_is_token_in_rank, is_token_in_rank) - t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0] + + t = bench(group, lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0] if local_rank == 0: print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True) print() group.barrier() time.sleep(1) + ############################################################################################################ + # Config rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512) config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size) @@ -98,20 +162,46 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): for async_mode in (False, True): for current_x in (x_pure_rand, x, x_e4m3): for with_topk in (False, True): + dtype_str = "FP8" if isinstance(current_x, tuple) else "BF16" + dump_prefix = f'{dtype_str}_{"with" if with_topk else "without"}_top-k_async_{async_mode}_previous_{previous_mode}_' if local_rank == 0: - print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='') - dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank, - 'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode} + print(f'[testing] Running with {dtype_str}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='') + + dispatch_args = { + 'x': current_x, + 'num_tokens_per_rank': num_tokens_per_rank, + 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, + 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, + 'config': config, + 'async_finish': async_mode + } if with_topk: - dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights}) + dispatch_args.update({ + 'topk_idx': topk_idx, + 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights + }) + if previous_mode: dispatch_args.update({'previous_event': buffer.capture()}) + recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args) event.current_stream_wait() if async_mode else () + + if dump_output: + utils.dump(recv_x, f'{dump_prefix}recv_x', local_rank) + utils.dump(recv_topk_idx, f'{dump_prefix}recv_topk_idx', local_rank) + utils.dump(recv_topk_weights, f'{dump_prefix}recv_topk_weights', local_rank) + utils.dump(recv_num_tokens_per_expert_list, f'{dump_prefix}recv_num_tokens_per_expert_list', local_rank) + recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x # Checks recv_gbl_rank_prefix_sum = handle[-4] + + if dump_output: + utils.dump(recv_gbl_rank_prefix_sum, f"{dump_prefix}recv_gbl_rank_prefix_sum", local_rank) + assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}' assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list if current_x is not x_pure_rand: @@ -133,20 +223,40 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode} if previous_mode: dispatch_args.update({'previous_event': buffer.capture()}) + recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) event.current_stream_wait() if async_mode else () + + if dump_output: + utils.dump(recv_x, f'{dump_prefix}recv_x_wo_topk', local_rank) + recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x if current_x is not x_pure_rand: check_data(recv_x, recv_gbl_rank_prefix_sum) # Test combine + if not use_random_input: + recv_x = utils.load(f"{dump_prefix}recv_x_combine_input", local_rank) + combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode} + if with_topk: combine_args.update({'topk_weights': recv_topk_weights}) + + if dump_input: + utils.dump(recv_x, f'{dump_prefix}recv_x_combine_input', local_rank) + if with_topk: + utils.dump(recv_topk_weights, f'{dump_prefix}recv_topk_weights_input', local_rank) if previous_mode: dispatch_args.update({'previous_event': buffer.capture()}) + combined_x, combined_topk_weights, event = buffer.combine(**combine_args) event.current_stream_wait() if async_mode else () + + if dump_output: + utils.dump(combined_x, f'{dump_prefix}combined_x', local_rank) + utils.dump(combined_topk_weights, f'{dump_prefix}combined_topk_weights', local_rank) + check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1) ref_x = x_pure_rand if current_x is x_pure_rand else x assert calc_diff(check_x, ref_x) < 5e-6 @@ -157,62 +267,161 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): # For later tuning dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2 + dispatch_bf16_rdma_only_send_bytes = num_rdma_only_token_sent * hidden * 2 dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2 combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes + combine_bf16_rdma_only_recv_bytes = dispatch_bf16_rdma_only_send_bytes if local_rank == 0: print(' passed', flush=True) if local_rank == 0: print() + def print_tensor_info(t, name): + #print(f"-- {name}: data_ptr={t.untyped_storage().data_ptr()}, shape={t.size()}, dtype={t.dtype}") + print(f"-- {name}: shape={t.size()}, dtype={t.dtype}") + + profile = False + profile = profile and has_paperf + + if profile: + profile_torch.switch_profile(0, 0, 1) + # Tune dispatch performance best_dispatch_results = None fp8_factor = (1 + 4 / 128) / 2 for current_x in (x_e4m3, x): - best_time, best_results = 1e10, None + dtype_str = "FP8" if isinstance(current_x, tuple) else "BF16" + if profile: + profile_torch.push_record_event(f"Tune_Dispatch_{dtype_str}") + + best_time, best_cpu_time, best_results = 1e10, 1e10, None + rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes + rdma_only_send_bytes = (dispatch_bf16_rdma_only_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_only_send_bytes nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes + for nvl_chunk_size in range(4, 33, 4): for rdma_chunk_size in range(4, 33, 4): + config_str = f"sms={num_sms},nvl={nvl_chunk_size},{nvl_buffer_size},rdma={rdma_chunk_size},{rdma_buffer_size}" + if profile: + profile_torch.push_record_event(f"Dispatch_{dtype_str}_Config({config_str})") + config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) tune_args = {'x': current_x, 'handle': handle, 'config': config} - t = bench(lambda: buffer.dispatch(**tune_args))[0] + result_times = bench(group, lambda: buffer.dispatch(**tune_args)) + t = result_times[0] + cpu_t = result_times[3] + + if profile: + profile_torch.pop_record_event() + if t < best_time: best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size) + best_cpu_time = cpu_t + if local_rank == 0: - print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ') + rdma_send_GBs = rdma_send_bytes / 1e9 / t + rdma_only_send_GBs = rdma_only_send_bytes / 1e9 / t + nvl_recv_GBs = nvl_recv_bytes / 1e9 / t + print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_GBs:.2f} GB/s (RDMA + NVL), {rdma_only_send_GBs:.2f} GB/s (RDMA), {nvl_recv_GBs:.2f} GB/s (NVL) (time: {t:.5f} s, cpu_time: {cpu_t:.5f} s)') + + if profile: + profile_torch.pop_record_event() + if local_rank == 0: - print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)') + rdma_send_GBs = rdma_send_bytes / 1e9 / best_time + rdma_only_send_GBs = rdma_only_send_bytes / 1e9 / best_time + nvl_recv_GBs = nvl_recv_bytes / 1e9 / best_time + print(f'[tuning] Best dispatch ({dtype_str}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_GBs:.2f} GB/s (RDMA + NVL), {rdma_only_send_GBs:.2f} (RDMA), {nvl_recv_GBs:.2f} GB/s (NVL) (time: {best_time:.5f} s, cpu_time: {best_cpu_time:.5f} s)') print() if isinstance(current_x, tuple): + if profile: + profile_torch.push_record_event("Gather_Best_Config") + # Gather FP8 the best config from rank 0 best_dispatch_results = torch.tensor([best_results[0], best_results[1], best_results[2]], dtype=torch.int32, device='cuda') all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())] dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group) best_dispatch_results = all_best_fp8_results_list[0].tolist() + + if profile: + profile_torch.pop_record_event() + + group.barrier() + print(f"===============================================================") + + config_str = f"sms={best_dispatch_results[0]},nvl={best_dispatch_results[1]},{nvl_buffer_size},rdma={best_dispatch_results[2]},{rdma_buffer_size}" + if profile: + profile_torch.push_record_event(f"Best_Dispatch_BF16_Config({config_str})") + dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size, best_dispatch_results[2], rdma_buffer_size) + #dispatch_config = deep_ep.Config(24, 20, 512, 32, 128) dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert, 'config': dispatch_config if dispatch_config is not None else config} - recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) + + #if local_rank == 0: + # print_tensor_info(x, "x") + # print_tensor_info(num_tokens_per_rank, "num_tokens_per_rank") + # print_tensor_info(num_tokens_per_rdma_rank, "num_tokens_per_rdma_rank") + # print_tensor_info(is_token_in_rank, "is_token_in_rank") + # print_tensor_info(num_tokens_per_expert, "num_tokens_per_expert") + # print(f"-- dispatch_args: {dispatch_args}") + # print(f"-- dispatch_config: {best_dispatch_results[0]}, {best_dispatch_results[1]}, {nvl_buffer_size}, {best_dispatch_results[2]}, {rdma_buffer_size}") + + for i in range(1): + recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) + + if profile: + profile_torch.pop_record_event() + + #if local_rank == 0: + # print_tensor_info(recv_x, "recv_x") + + if profile: + profile_torch.push_record_event(f"Tune_Combine_BF16") # Tune combine performance - best_time, best_results = 1e10, None + best_time, best_cpu_time, best_results = 1e10, 1e10, None for nvl_chunk_size in range(1, 5, 1): for rdma_chunk_size in range(8, 33, 4): + config_str = f"sms={num_sms},nvl={nvl_chunk_size},{nvl_buffer_size},rdma={rdma_chunk_size},{rdma_buffer_size}" + if profile: + profile_torch.push_record_event(f"Combine_BF16_Config({config_str})") + config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) tune_args = {'x': recv_x, 'handle': handle, 'config': config} - t = bench(lambda: buffer.combine(**tune_args))[0] + result_times = bench(group, lambda: buffer.combine(**tune_args)) + t = result_times[0] + cpu_t = result_times[3] + + if profile: + profile_torch.pop_record_event() + if local_rank == 0: - print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ') + combine_bf16_rdma_recv_GBs = combine_bf16_rdma_recv_bytes / 1e9 / t + combine_bf16_rdma_only_recv_GBs = combine_bf16_rdma_only_recv_bytes / 1e9 / t + combine_bf16_nvl_send_GBs = combine_bf16_nvl_send_bytes / 1e9 / t + print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_GBs:.2f} GB/s (RDMA + NVL), {combine_bf16_rdma_only_recv_GBs:.2f} GB/s (RDMA), {combine_bf16_nvl_send_GBs:.2f} GB/s (NVL) (time: {t:.5f} s, cpu_time: {cpu_t:.5f} s)') if t < best_time: best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size) + best_cpu_time = cpu_t + + if profile: + profile_torch.pop_record_event() + + if profile: + profile_torch.switch_profile(1, 0, 1) if local_rank == 0: - print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)') + combine_bf16_rdma_recv_GBs = combine_bf16_rdma_recv_bytes / 1e9 / best_time + combine_bf16_rdma_only_recv_GBs = combine_bf16_rdma_only_recv_bytes / 1e9 / best_time + combine_bf16_nvl_send_GBs = combine_bf16_nvl_send_bytes / 1e9 / best_time + print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_GBs:.2f} GB/s (RDMA + NVL), {combine_bf16_rdma_only_recv_GBs:.2f} GB/s (RDMA), {combine_bf16_nvl_send_GBs:.2f} GB/s (NVL) (time: {best_time:.5f} s, cpu_time: {best_cpu_time:.5f} s)') print() @@ -229,8 +438,11 @@ def test_loop(local_rank: int, num_local_ranks: int): assert num_local_ranks == 8 and num_ranks > 8 torch.manual_seed(rank) + use_random_input = True + dump_input = False + dump_output = False for i in (24, ): - test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group) + test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group, use_random_input, dump_input, dump_output) if local_rank == 0: print() diff --git a/tests/test_internode_latency.py b/tests/test_internode_latency.py new file mode 100644 index 00000000..275ce8ba --- /dev/null +++ b/tests/test_internode_latency.py @@ -0,0 +1,159 @@ +import os +import sys +import time +import numpy as np + +import torch +import torch.distributed as dist + +# noinspection PyUnresolvedReferences +import alltoall +import utils +from utils import init_dist, create_grouped_scores + +try: + from paperf import profile_torch + has_paperf = True +except ImportError: + has_paperf = False + + +def print_tensor_info(t, name): + #print(f"-- {name}: data_ptr={t.untyped_storage().data_ptr()}, shape={t.size()}, dtype={t.dtype}") + print(f"-- {name}: shape={t.size()}, dtype={t.dtype}") + + +def test_main(local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, group: dist.ProcessGroup, use_random_input, dump_input): + # Settings + num_tokens = 4096 + hidden = 7168 + num_topk_groups = min(num_nodes, 4) + num_topk = 8 + num_experts = (256 // num_ranks) * num_ranks + + assert num_experts % num_ranks == 0 and num_local_ranks == 8 + if local_rank == 0: + print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True) + + if use_random_input: + # Random data + x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank + x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + #x_e4m3 = per_token_cast_to_fp8(x) + + scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 + group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1) + group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices + masked_scores = create_grouped_scores(scores, group_idx, num_nodes) + + topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1] + topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank + topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda') + + if dump_input: + utils.dump(x, 'x', local_rank) + utils.dump(x_pure_rand, 'x_pure_rand', local_rank) + #utils.dump(x_e4m3, 'x_e4m3', local_rank) + + utils.dump(topk_idx, 'topk_idx', local_rank) + utils.dump(topk_weights, 'topk_weights', local_rank) + utils.dump(topk_weights_pure_rand, 'topk_weights_pure_rand', local_rank) + else: + x = utils.load("x", local_rank) + x_pure_rand = utils.load("x_pure_rand", local_rank) + #x_e4m3 = utils.load("x_e4m3", local_rank, "tuple") + + topk_idx = utils.load("topk_idx", local_rank) + topk_weights = utils.load("topk_weights", local_rank) + topk_weights_pure_rand = utils.load("topk_weights_pure_rand", local_rank) + + + profile = False + profile = profile and has_paperf + + # test bfloat16 + buffer = alltoall.get_buffer(group, alltoall.get_hidden_bytes(x)) + + if profile: + profile_torch.switch_profile(0, 0, 1) + + num_warmups = 100 + num_tests = 1000 + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] + + for i in range(num_warmups + num_tests): + if i == num_warmups: + group.barrier() + torch.cuda.synchronize() + cpu_start = time.time() + + if i >= num_warmups: + # Record + batch_start = time.time() + start_events[i - num_warmups].record() + + #group.barrier() + recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, dispatch_event = alltoall.dispatch_forward( + x=x, + topk_idx=topk_idx, + topk_weights=topk_weights, + num_experts=num_experts, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False + ) + + combined_x, event = alltoall.combine_forward( + x=recv_x, + handle=handle, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False + ) + + end_events[i - num_warmups].record() + batch_time = time.time() - batch_start + if local_rank == 0: + print(f"-- {i - num_warmups}-th running, cpu_time: {batch_time:.5f} s") + torch.cuda.synchronize() + group.barrier() + + cpu_runtime = time.time() - cpu_start + avg_cpu_time = cpu_runtime / num_tests + + gpu_times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:] + avg_gpu_time = np.average(gpu_times) + + print(f"-- rank: {rank}, avg_cpu_time: {avg_cpu_time:.5f} s, avg_gpu_time: {avg_gpu_time:.5f} s") + + torch.cuda.synchronize() + group.barrier() + + avg_cpu_time_all_ranks = [None, ] * num_ranks + avg_gpu_time_all_ranks = [None, ] * num_ranks + dist.all_gather_object(avg_cpu_time_all_ranks, avg_cpu_time, group=group) + dist.all_gather_object(avg_gpu_time_all_ranks, avg_gpu_time, group=group) + if rank == 0: + avg_cpu_time = np.average(np.array(avg_cpu_time_all_ranks)) + avg_gpu_time = np.average(np.array(avg_gpu_time_all_ranks)) + print(f"-- avg_cpu_time_of_all_ranks: {avg_cpu_time:.5f} s, avg_gpu_time_of_all_ranks: {avg_gpu_time:.5f} s") + + +def test_loop(local_rank: int, num_local_ranks: int): + num_nodes = int(os.getenv('WORLD_SIZE', 1)) + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + + assert num_local_ranks == 8 and num_ranks > 8 + torch.manual_seed(rank) + + use_random_input = True + dump_input = False + + test_main(local_rank, num_local_ranks, num_ranks, num_nodes, rank, group, use_random_input, dump_input) + + +if __name__ == '__main__': + num_processes = 8 + torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes) diff --git a/tests/test_internode_perf.py b/tests/test_internode_perf.py new file mode 100644 index 00000000..4ed0ff07 --- /dev/null +++ b/tests/test_internode_perf.py @@ -0,0 +1,242 @@ +import os +import sys +import time +import numpy as np + +import torch +import torch.distributed as dist + +import deep_ep +import utils +from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back + +import alltoall + +try: + from paperf import profile_torch + has_paperf = True +except ImportError: + has_paperf = False + +profile = True +profile = profile and has_paperf + + +def init_random_tensors(rank, num_nodes, num_tokens, hidden, num_topk_groups, num_topk, num_experts, dump_input=False): + # Random data + x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank + x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + x_e4m3 = per_token_cast_to_fp8(x) + + scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 + group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1) + group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices + masked_scores = create_grouped_scores(scores, group_idx, num_nodes) + + topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1] + topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank + topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda') + + if dump_input: + utils.dump(x, 'x', local_rank) + utils.dump(x_pure_rand, 'x_pure_rand', local_rank) + utils.dump(x_e4m3, 'x_e4m3', local_rank) + + utils.dump(topk_idx, 'topk_idx', local_rank) + utils.dump(topk_weights, 'topk_weights', local_rank) + utils.dump(topk_weights_pure_rand, 'topk_weights_pure_rand', local_rank) + + return x, x_pure_rand, x_e4m3, topk_idx, topk_weights, topk_weights_pure_rand + + +def load_dumped_tensors(rank, num_tokens, hidden, num_topk_groups, num_topk, num_experts): + def _load_tensor(rank, name, idx, typehint="tensor"): + dump_dir = "/root/paddlejob/workspace/env_run/liuyiqun/outputs/ds_8nodes" + filename = f"{dump_dir}/{idx}_{name}_rank{rank}.npy" + if typehint == "tensor": + x_np = np.load(filename) + if x_np.dtype == np.uint16: + x = torch.tensor(x_np, device='cuda').view(torch.bfloat16) + elif x_np.dtype in [np.float32, np.int32, np.int64, np.bool, np.int8]: + x = torch.tensor(x_np, device='cuda') + else: + assert False, f'{name}: {x_np.dtype}' + return x + else: + assert False, f'invalid typehint: {typehint}' + + input_tensors = [] + for i in range(20): + x = _load_tensor(rank, "dispatch_x", i + 1) + topk_idx = _load_tensor(rank, "topk_idx", i + 1) + topk_weights = _load_tensor(rank, "topk_weights", i + 1) + input_tensors.append({"x": x, "topk_idx": topk_idx, "topk_weights": topk_weights}) + + return input_tensors + + +def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup, use_random_input, dump_input): + # Settings + num_tokens = 4096 + hidden = 7168 + num_topk_groups = min(num_nodes, 4) + num_topk = 8 + num_experts = (256 // num_ranks) * num_ranks + + assert num_experts % num_ranks == 0 and num_local_ranks == 8 + if local_rank == 0: + print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True) + + if use_random_input: + if profile: + profile_torch.push_record_event(f"init_random_tensors") + + x, x_pure_rand, x_e4m3, topk_idx, topk_weights, topk_weights_pure_rand = init_random_tensors(rank, num_nodes, num_tokens, hidden, num_topk_groups, num_topk, num_experts, dump_input) + + if profile: + profile_torch.pop_record_event() + else: + if profile: + profile_torch.push_record_event(f"load_dumped_tensors") + + input_tensors = load_dumped_tensors(rank, num_tokens, hidden, num_topk_groups, num_topk, num_experts) + + x = input_tensors[0]["x"] + topk_idx = input_tensors[0]["topk_idx"] + topk_weights = input_tensors[0]["topk_weights"] + + if profile: + profile_torch.pop_record_event() + + if buffer is None: + #buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=False) + buffer = alltoall.get_buffer(group, hidden * 2) + + if profile: + profile_torch.push_record_event(f"get_dispatch_layout") + + num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, _ = \ + buffer.get_dispatch_layout(topk_idx, num_experts) + + t = bench(group, lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0] + if local_rank == 0: + print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True) + print() + + if profile: + profile_torch.pop_record_event() + + if profile: + profile_torch.push_record_event(f"barrier") + + torch.distributed.barrier() + #group.barrier() + + if profile: + profile_torch.pop_record_event() + + time.sleep(1) + + # Config + # rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512) + rdma_buffer_size = 128 + nvl_buffer_size = 288 + + current_x = x + handle = None + + nvl_chunk_size = 20 + rdma_chunk_size = 28 + + config_str = f"sms={num_sms},nvl={nvl_chunk_size},{nvl_buffer_size},rdma={rdma_chunk_size},{rdma_buffer_size}" + if profile: + profile_torch.push_record_event(f"Dispatch_Config({config_str})") + + #dispatch_config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) + dispatch_config = deep_ep.Buffer.get_dispatch_config(group.size()) + + dispatch_num_nvl_bytes = dispatch_config.get_nvl_buffer_size_hint(hidden * 2, group.size()) + dispatch_num_rdma_bytes = dispatch_config.get_rdma_buffer_size_hint(hidden * 2, group.size()) + + if handle is not None: + dispatch_args = {'x': current_x, 'handle': handle, 'config': config} + else: + dispatch_args = { + 'x': current_x, + 'num_tokens_per_rank': num_tokens_per_rank, + 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, + 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, + 'topk_idx': topk_idx, + 'topk_weights': topk_weights, + 'config': dispatch_config + } + result_times = bench(group, lambda: buffer.dispatch(**dispatch_args)) + gpu_time = result_times[0] + cpu_time = result_times[3] + + if profile: + profile_torch.pop_record_event() + + if local_rank == 0: + print(f'[rank={rank}] Dispatch: SMs {num_sms}, nvl_chunk_size {nvl_chunk_size}, nvl_buffer_size {nvl_buffer_size}, rdma_chunk_size {rdma_chunk_size}, rdma_buffer_size {rdma_buffer_size}, num_nvl_bytes {dispatch_num_nvl_bytes}, num_rdma_bytes {dispatch_num_rdma_bytes}; gpu_time: {gpu_time:.5f} s, cpu_time: {cpu_time:.5f} s') + + #group.barrier() + + recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) + + nvl_chunk_size = 1 + rdma_chunk_size = 20 + config_str = f"sms={num_sms},nvl={nvl_chunk_size},{nvl_buffer_size},rdma={rdma_chunk_size},{rdma_buffer_size}" + if profile: + profile_torch.push_record_event(f"Combine_Config({config_str})") + + #combine_config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) + combine_config = deep_ep.Buffer.get_combine_config(group.size()) + + combine_num_nvl_bytes = combine_config.get_nvl_buffer_size_hint(hidden * 2, group.size()) + combine_num_rdma_bytes = combine_config.get_rdma_buffer_size_hint(hidden * 2, group.size()) + + combine_args = { + 'x': recv_x, + 'handle': handle, + 'config': combine_config + } + result_times = bench(group, lambda: buffer.combine(**combine_args)) + gpu_time = result_times[0] + cpu_time = result_times[3] + + if profile: + profile_torch.pop_record_event() + + if local_rank == 0: + print(f'[rank={rank}] Combine: SMs {num_sms}, nvl_chunk_size {nvl_chunk_size}, nvl_buffer_size {nvl_buffer_size}, rdma_chunk_size {rdma_chunk_size}, rdma_buffer_size {rdma_buffer_size}, num_nvl_bytes {combine_num_nvl_bytes}, num_rdma_bytes {combine_num_rdma_bytes}; gpu_time: {gpu_time:.5f} s, cpu_time: {cpu_time:.5f} s') + + +def test_loop(local_rank: int, num_local_ranks: int): + num_nodes = int(os.getenv('WORLD_SIZE', 1)) + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + + if profile: + profile_torch.switch_profile(0, 0, 1) + + #buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=False) + + assert num_local_ranks == 8 and num_ranks > 8 + torch.manual_seed(rank) + + use_random_input = False + dump_input = False + for i in (20, ): + test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, None, group, use_random_input, dump_input) + + if profile: + profile_torch.switch_profile(1, 0, 1) + + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == '__main__': + num_processes = 8 + torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes) diff --git a/tests/utils.py b/tests/utils.py index a5743663..afd81b84 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,6 @@ import os import sys +import time import numpy as np import torch import torch.distributed as dist @@ -12,6 +13,7 @@ def init_dist(local_rank: int, num_local_ranks: int): port = int(os.getenv('MASTER_PORT', '8361')) num_nodes = int(os.getenv('WORLD_SIZE', 1)) node_rank = int(os.getenv('RANK', 0)) + print(f"num_nodes: {num_nodes}, node_rank: {node_rank}") assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 dist.init_process_group( @@ -71,7 +73,7 @@ def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_gro return (scores * mask).view(num_tokens, num_experts) -def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): +def bench(group, fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): # Flush L2 cache with 256 MB data torch.cuda.synchronize() cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') @@ -86,6 +88,11 @@ def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): # Testing start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] + + group.barrier() + torch.cuda.synchronize() + + cpu_start = time.time() for i in range(num_tests): # Record start_events[i].record() @@ -94,9 +101,10 @@ def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): if post_fn is not None: post_fn() torch.cuda.synchronize() + cpu_runtime = time.time() - cpu_start times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:] - return np.average(times), np.min(times), np.max(times) + return np.average(times), np.min(times), np.max(times), cpu_runtime / num_tests class empty_suppress: @@ -190,3 +198,64 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: def hash_tensor(t: torch.Tensor): return t.view(torch.int64).sum().item() + + +def dump(x, name, local_rank): + #print(f'wsm debug {name} {x}') + #dtype2str = {torch.float32: "_orgi_fp32", torch.int: "_orgi_int", torch.int64: "_orgi_int64", torch.bool: "_orgi_bool", torch.bfloat16: "_orgi_bf16", torch.float8_e4m3fn: "_orgi_fp8"} + name = "/root/paddlejob/workspace/env_run/liuyiqun/outputs/torch_dump/" + name + print(name) + if isinstance(x, torch.Tensor): + if x.dtype==torch.float32 or x.dtype==torch.int or x.dtype==torch.int64 or x.dtype==torch.bool: + y = x.cpu().numpy() + elif x.dtype==torch.bfloat16: + y = x.view(torch.uint16).cpu().numpy() + elif x.dtype==torch.float8_e4m3fn: + y = x.view(torch.uint8).cpu().numpy() + else: + assert False, f'{name}: {x.dtype} {x}' + #name += dtype2str[x.dtype] + np.save(f"{name}_rank{local_rank}.npy",y) + elif isinstance(x, tuple): + y, y_scale = x + assert y.dtype==torch.float8_e4m3fn + assert y_scale.dtype==torch.float32 + y_dump = y.view(torch.uint8).cpu().numpy() + y_scale_dump = y_scale.cpu().numpy() + # np.save(f"{name}{dtype2str[y.dtype]}_value_rank{local_rank}.npy", y_dump) + # np.save(f"{name}{dtype2str[y_scale.dtype]}_scale_rank{local_rank}.npy", y_scale_dump) + + np.save(f"{name}_value_rank{local_rank}.npy", y_dump) + np.save(f"{name}_scale_rank{local_rank}.npy", y_scale_dump) + elif isinstance(x, list): + y = np.asarray(x) + np.save(f"{name}_rank{local_rank}.npy", y) + elif x is None: + np.save(f"{name}_rank{local_rank}.npy", np.zeros(5)) + else: + assert False, f'{name}: {x}' + + +def load(name, local_rank, typehint="tensor"): + dump_dir = '/root/paddlejob/workspace/env_run/liuyiqun/outputs/torch_dump' + name = dump_dir + "/" + name + print(f"[local_rank={local_rank}] load {name}") + # orig_dtype = retrive_dtype(name) + # name += dtype2str[x.dtype] + if typehint == "tensor": + x_np = np.load(f'{name}_rank{local_rank}.npy') + if x_np.dtype == np.uint16: + x = torch.tensor(x_np, device='cuda').view(torch.bfloat16) + elif x_np.dtype == np.uint8: + x = torch.tensor(x_np, device='cuda').view(torch.float8_e4m3fn) + else: + x = torch.tensor(x_np, device='cuda') + return x + elif typehint == "tuple": + y_np = np.load(f'{name}_value_rank{local_rank}.npy') + y_scale_np = np.load(f'{name}_scale_rank{local_rank}.npy') + y = torch.tensor(y_np, device='cuda').view(torch.float8_e4m3fn) + y_scale = torch.tensor(y_scale_np, device='cuda') + return (y, y_scale) + else: + assert False, f'invalid typehint: {typehint}' diff --git a/tests_paddle/compare.py b/tests_paddle/compare.py new file mode 100644 index 00000000..9afab597 --- /dev/null +++ b/tests_paddle/compare.py @@ -0,0 +1,61 @@ +import os +import time +import numpy as np +from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to_fp8, per_token_cast_back + + +def load(name, local_rank, typehint="tensor"): + if typehint=="tensor": + x_np = np.load(f'{name}_rank{local_rank}.npy') + return x_np + elif typehint=="tuple": + y_np = np.load(f'{name}_value_rank{local_rank}.npy') + y_scale_np = np.load(f'{name}_scale_rank{local_rank}.npy') + return (y_np, y_scale_np) + else: + assert False, f'invalid typehint: {typehint}' + + +def load_cmp(name, local_rank, typehint="tensor"): + if typehint == "tensor": + ref = load("torch_dump/"+name, local_rank, typehint) + x = load("paddle_dump/"+name, local_rank, typehint) + np.testing.assert_array_equal(x, ref, err_msg=f"{name} missmatch", strict=True) + elif typehint == "tuple": + ref1, ref2 = load("torch_dump/"+name, local_rank, typehint) + x1, x2 = load("paddle_dump/"+name, local_rank, typehint) + np.testing.assert_array_equal(x1, ref1, err_msg=f"{name} missmatch", strict=True) + np.testing.assert_array_equal(x2, ref2, err_msg=f"{name} missmatch", strict=True) + else: + assert False, f'invalid typehint: {typehint}' + + +def test_main(local_rank: int): + load_cmp("ref_num_tokens_per_rank", local_rank) + load_cmp("ref_num_tokens_per_expert", local_rank) + load_cmp("ref_is_token_in_rank", local_rank) + + for previous_mode in (False, True): + for async_mode in (False, True): + for current_x_type in ("hack", ("hack", "hack")): + for with_topk in (False, True): + if local_rank == 0: + print(f'[testing] Running with {"FP8" if isinstance(current_x_type, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='') + dump_prefix = f'{"FP8" if isinstance(current_x_type, tuple) else "BF16"}_{"with" if with_topk else "without"}_top-k_async_{async_mode}_previous_{previous_mode}_' + + load_cmp(f"{dump_prefix}recv_x", local_rank, "tuple" if isinstance(current_x_type, tuple) else "tensor") + load_cmp(f"{dump_prefix}recv_topk_idx", local_rank) + load_cmp(f"{dump_prefix}recv_topk_weights", local_rank) + load_cmp(f"{dump_prefix}recv_num_tokens_per_expert_list", local_rank) + + load_cmp(f"{dump_prefix}rank_prefix_matrix", local_rank) + if not with_topk: + load_cmp(f"{dump_prefix}recv_x_wo_topk", local_rank, "tuple" if isinstance(current_x_type, tuple) else "tensor") + + load_cmp(f"{dump_prefix}combined_x", local_rank) + load_cmp(f"{dump_prefix}combined_topk_weights", local_rank) + + +if __name__ == '__main__': + for i in range(8): + test_main(i) diff --git a/tests_paddle/fused_a2a.py b/tests_paddle/fused_a2a.py new file mode 100644 index 00000000..ce1ad1be --- /dev/null +++ b/tests_paddle/fused_a2a.py @@ -0,0 +1,384 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 DeepSeek +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import paddle.distributed.communication.deep_ep as deep_ep + + HAVE_DEEP_EP = True +except ImportError: + HAVE_DEEP_EP = False + +import paddle +from paddle.autograd import PyLayer +from paddle.distributed.communication.group import Group + +from paperf import profile_paddle + +import sys +import numpy as np + +_buffer = None + + +def print_tensor_info(t, name): + print(f"-- {name}: data_ptr={t.data_ptr():#x}, shape={t.shape}, dtype={t.dtype}, md5sum: {t._md5sum()}") + + +def get_hidden_bytes(x: paddle.Tensor) -> int: + """Calculate the number of hidden bytes for a tensor. + + Args: + x (paddle.Tensor): Input tensor + + Returns: + int: Number of hidden bytes + """ + return x.shape[1] * max(x.element_size(), 2) + + +def get_buffer(group: Group, hidden_bytes: int): + """Get or create a buffer for all-to-all communication. + + Args: + group (paddle.distributed.ProcessGroup): Process group for communication + hidden_bytes (int): Number of hidden bytes needed + + Returns: + Buffer: Communication buffer + """ + global _buffer + num_nvl_bytes, num_rdma_bytes = 0, 0 + for config in ( + deep_ep.Buffer.get_dispatch_config(group.world_size), + deep_ep.Buffer.get_combine_config(group.world_size), + ): + # Split long line for PEP8 compliance + num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.world_size), num_nvl_bytes) + num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.world_size), num_rdma_bytes) + + # Allocate buffer if not existed or not enough buffer + # NOTES: the adaptive routing configuration of the network **must be off** + if ( + _buffer is None + or _buffer.group != group + or _buffer.num_nvl_bytes < num_nvl_bytes + or _buffer.num_rdma_bytes < num_rdma_bytes + ): + print(f"-- group.world_size: {group.world_size}, num_nvl_bytes: {num_nvl_bytes}, num_rdma_bytes: {num_rdma_bytes}") + _buffer = deep_ep.Buffer(group, num_nvl_bytes, num_rdma_bytes) + return _buffer + + +def fused_dispatch_forward_func( + x, + token_indices, + token_probs, + num_experts, + group, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, +): + """Forward pass of fused dispatch.""" + + # Calculate layout before actual dispatch + if isinstance(x, tuple): + timer_name_suffix = "_tuple" + buffer = get_buffer(group, get_hidden_bytes(x[0])) + else: + timer_name_suffix = "" + buffer = get_buffer(group, get_hidden_bytes(x)) + + profile_paddle.push_record_event(f"dispatch_forward-get_dispatch_layout{timer_name_suffix}") + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + previous_event_, + ) = buffer.get_dispatch_layout( + token_indices, + num_experts, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + profile_paddle.pop_record_event() + + assert token_probs.dtype == paddle.float32 + + # Do MoE dispatch + # NOTES: the CPU will wait for GPU's signal to arrive, + # so this is not compatible with CUDA graph + profile_paddle.push_record_event(f"dispatch_forward-dispatch{timer_name_suffix}") + (recv_x, recv_token_indices, recv_token_probs, num_recv_tokens_per_expert_list, handle, event, time_stamp) = buffer.dispatch( + x, + topk_idx=token_indices, + topk_weights=token_probs, + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + profile_paddle.pop_record_event() + + return recv_x, recv_token_indices, recv_token_probs, num_recv_tokens_per_expert_list, handle, event + + +def fused_dispatch_backward_func( + grad_output, + grad_token_probs, + group, + handle, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, +): + """Backward pass of fused dispatch.""" + buffer = get_buffer(group, get_hidden_bytes(grad_output)) + + profile_paddle.push_record_event("dispatch_backward-combine") + grad_x, grad_token_probs, event = buffer.combine( + grad_output.contiguous(), + handle, + topk_weights=grad_token_probs.cast(paddle.float32), + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + profile_paddle.pop_record_event() + return grad_x, None, grad_token_probs + + +def fused_combine_forward_func( + x, group, handle, previous_event=None, async_finish=False, allocate_on_comm_stream=False +): + """Forward pass of fused combine.""" + buffer = get_buffer(group, get_hidden_bytes(x)) + + profile_paddle.push_record_event("combine_forward-combine") + combined_x, _, event = buffer.combine( + x, + handle=handle, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + profile_paddle.pop_record_event() + return combined_x + + +def fused_combine_backward_func( + grad_output, group, handle, previous_event=None, async_finish=False, allocate_on_comm_stream=False +): + """Backward pass of fused combine.""" + if isinstance(grad_output, tuple): + buffer = get_buffer(group, get_hidden_bytes(grad_output[0])) + + profile_paddle.push_record_event("combine_backward-dispatch_tuple") + grad_x, _, _, _, _, event, _ = buffer.dispatch( + (grad_output[0].contiguous(), grad_output[1].contiguous()), + handle=handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + profile_paddle.pop_record_event() + else: + buffer = get_buffer(group, get_hidden_bytes(grad_output)) + + profile_paddle.push_record_event("combine_backward-dispatch") + grad_x, _, _, _, _, event, _ = buffer.dispatch( + grad_output.contiguous(), + handle=handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + profile_paddle.pop_record_event() + return grad_x + + +class FusedDispatch(PyLayer): + """Fused dispatch operation for MoE routing combining computation and communication.""" + + @staticmethod + def forward(ctx, x, token_indices, token_probs, num_experts, group, previous_event=None): + """Forward pass of fused dispatch.""" + recv_x, recv_token_probs, states, event = fused_dispatch_forward_func( + x, token_indices, token_probs, num_experts, group, previous_event + ) + + ctx.group = group + ctx.handle = states["handle"] + ctx.event = event + + return recv_x, recv_token_probs, states + + @staticmethod + def backward(ctx, grad_output, grad_token_probs): + """Backward pass of fused dispatch.""" + return fused_dispatch_backward_func(grad_output, grad_token_probs, ctx.group, ctx.handle) + + +class FusedCombine(PyLayer): + """Fused combine operation for MoE output combining computation and communication.""" + + @staticmethod + def forward(ctx, x, group, states, previous_event=None): + """Forward pass of fused combine.""" + combined_x = fused_combine_forward_func(x, group, states, previous_event) + + ctx.handle = states["handle"] + ctx.group = group + ctx.previous_event = previous_event + + return combined_x + + @staticmethod + def backward(ctx, grad_output): + """Backward pass of fused combine.""" + return fused_combine_backward_func(grad_output, ctx.group, ctx.handle, ctx.previous_event) + + +if HAVE_DEEP_EP: + + def fused_dispatch(x, token_indices, token_probs, num_experts, group: Group, previous_event=None): + """Perform fused dispatch operation if deep_ep is available. + + Args: + x: Input tensor [num_tokens, hidden_size] + token_indices: Token routing indices [num_tokens, topk] + token_probs: Token routing probabilities [num_tokens, topk] + num_experts: Number of experts + group: Process group + previous_event: Previous CUDA event + + Returns: + Result of FusedDispatch + """ + return FusedDispatch.apply(x.contiguous(), token_indices, token_probs, num_experts, group, previous_event) + + def fused_combine(x, group, handle, previous_event=None): + """Perform fused combine operation if deep_ep is available. + + Args: + x: Input tensor + group: Process group + handle: Communication handle + previous_event: Previous CUDA event + + Returns: + Result of FusedCombine + """ + states = dict() + states["handle"] = handle + return FusedCombine.apply(x, group, states, previous_event) + +else: + fused_dispatch = None + fused_combine = None + + +class DispatchNode: + def __init__(self, name="dispatch"): + self.name = name + + def reset_statue(self): + self.handle = None + + def forward( + self, + x, + token_indices, + token_probs, + num_experts, + group, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ): + """Forward pass of fused dispatch.""" + profile_paddle.push_record_event(f"DispatchNode_forward") + recv_x, recv_token_probs, states, event = fused_dispatch_forward_func( + x, + token_indices, + token_probs, + num_experts, + group, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + self.group = group + self.handle = states["handle"] + self.event = event + profile_paddle.pop_record_event() + + return recv_x, recv_token_probs, states + + def backward(self, grad_output, grad_token_probs, previous_event=None, async_finish=False): + """Backward pass of fused dispatch.""" + profile_paddle.push_record_event(f"DispatchNode_backward") + out = fused_dispatch_backward_func( + grad_output, + grad_token_probs, + self.group, + self.handle, + previous_event=previous_event, + async_finish=async_finish, + ) + self.reset_statue() + profile_paddle.pop_record_event() + return out + + +class CombineNode: + def __init__(self, name="combine"): + self.name = name + + def reset_statue(self): + self.handle = None + + def forward(self, x, group, handle, previous_event=None, async_finish=False): + """Forward pass of fused combine.""" + profile_paddle.push_record_event(f"CombineNode_forward") + states = dict() + states["handle"] = handle + combined_x = fused_combine_forward_func( + x, group, states, previous_event=previous_event, async_finish=async_finish + ) + + self.handle = handle + self.group = group + self.previous_event = previous_event + profile_paddle.pop_record_event() + + return combined_x + + def backward(self, grad_output, previous_event=None, async_finish=False): + """Backward pass of fused combine.""" + profile_paddle.push_record_event(f"CombineNode_backward") + out = fused_combine_backward_func( + grad_output, self.group, self.handle, previous_event=previous_event, async_finish=async_finish + ) + self.reset_statue() + profile_paddle.pop_record_event() + return out diff --git a/tests_paddle/run_test_internode.sh b/tests_paddle/run_test_internode.sh new file mode 100644 index 00000000..d1506f82 --- /dev/null +++ b/tests_paddle/run_test_internode.sh @@ -0,0 +1,94 @@ +#!/bin/bash + +#bash kill_process.sh + +WORK_ROOT=/root/paddlejob/workspace/env_run/liuyiqun +export PYTHONPATH=${WORK_ROOT}/env/virtualenvs_cuda12.8/paddle_py310_yiqun +export PATH=${PYTHONPATH}/bin:${PATH} + +#export NVSHMEM_DIR=${WORK_ROOT}/Paddle/build_paddle/third_party_cuda12.3_gcc12.2.0_py3.10/install/nvshmem +#export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH" +#export LD_LIBRARY_PATH=/root/paddlejob/workspace/env_run/liuyiqun/env/virtualenvs_cuda12.3/paddle_py310_yiqun/lib/python3.10/site-packages/paddle/libs:$LD_LIBRARY_PATH + +export PYTHONPATH=${WORK_ROOT}/PaPerf:$PYTHONPATH +#export LD_LIBRARY_PATH=/root/paddlejob/workspace/env_run/liuyiqun/DeepEP/tests_paddle:$LD_LIBRARY_PATH +#export LD_LIBRARY_PATH=/root/paddlejob/workspace/env_run/liuyiqun/Paddle/paddle/fluid/distributed/collective/deep_ep/kernels/build:$LD_LIBRARY_PATH + + +# 屏蔽平台预设的环境变量,因为框架采用兼容升级,检测到这些配置会使用原方式启动 +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT + +START_RANK=54 +END_RANK=62 + +if [[ ${PADDLE_TRAINER_ID} -lt $START_RANK ]]; then + exit 0 +fi + +if [[ ${PADDLE_TRAINER_ID} -ge $END_RANK ]]; then + exit 0 +fi + +export WORLD_SIZE=$(($END_RANK - $START_RANK)) + +RANK=$(($PADDLE_TRAINER_ID - ${START_RANK})) +NNODES=${WORLD_SIZE} +echo "rank: ${RANK}, nnodes: ${NNODES}" + +for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do + unset ${name} +done + +python -c "import paddle; print(paddle.version.commit)" + +#MASTER_ADDR="10.95.238.87" # 46 +MASTER_ADDR="10.95.238.158" # 54 +MASTER_PORT=58978 + +export FLAGS_eager_communication_connection=1 + +export NCCL_DEBUG=WARN +#export NVSHMEM_DEBUG=DEBUG +#export NVSHMEM_DEBUG=TRACE + +# 保证集群稳定性的配置,跟性能无关 +export NCCL_IB_QPS_PER_CONNECTION=8 +#export NCCL_IB_TIMEOUT=22 +export NCCL_IB_GID_INDEX=3 +export NCCL_NVLS_ENABLE=0 +# 开启AR功能 +#export NCCL_IB_ADAPTIVE_ROUTING=1 + +export NCCL_IB_GID_INDEX=3 +export NVSHMEM_IB_GID_INDEX=3 +export NVSHMEM_IB_TRAFFIC_CLASS=162 + +#export NVSHMEM_IB_ENABLE_IBGDA=true +#export NVSHMEM_DISABLE_P2P=0 +export NVSHMEM_BOOTSTRAP=UID +export NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME=xgbe0 +#export NVSHMEM_BOOTSTRAP_UID_SOCK_FAMILY=AF_INET +#export NVSHMEM_IB_ENABLE_RELAXED_ORDERING=0 + +#export NVSHMEM_CUMEM_GRANULARITY=2M + +#export NVSHMEM_NVTX=common +#export NVSHMEM_DEBUG=INFO +#export NVSHMEM_INFO=1 +#export GLOG_vmodule=deep_ep=4 + +#export FLAGS_use_nvml=False + +export PATH=/opt/nvidia/nsight-systems/2025.1.1/bin:/usr/local/NVIDIA-Nsight-Compute-2025.1/bin:$PATH +#nsys_args="nsys profile --stats true -w true -t cuda,nvshmem,nvtx -r um_cpu_page_faults_sum --nic-metrics=true --capture-range=cudaProfilerApi -x true --force-overwrite true -o test_internode_${WORLD_SIZE}nodes_rank${RANK}.paddle" +#nsys_args="nsys profile --stats true -w true -t cuda,nvtx --gpu-metrics-devices=all --nic-metrics=true --gpuctxsw=true --capture-range=cudaProfilerApi -x true --force-overwrite true -o test_internode_latency.bf16_${SUFFIX}_rank${rank}" + +rm -rf core.* +rm -rf log + +#${nsys_args} python -m paddle.distributed.launch --master=${MASTER_ADDR}:${MASTER_PORT} --nnodes=${NNODES} --rank ${RANK} test_internode.py +${nsys_args} python -m paddle.distributed.launch --master=${MASTER_ADDR}:${MASTER_PORT} --nnodes=${NNODES} --rank ${RANK} test_internode_latency.py diff --git a/tests_paddle/run_test_intranode.sh b/tests_paddle/run_test_intranode.sh new file mode 100644 index 00000000..4bcbc44f --- /dev/null +++ b/tests_paddle/run_test_intranode.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +WORK_ROOT=/root/paddlejob/workspace/env_run/liuyiqun +export PYTHONPATH=${WORK_ROOT}/env/virtualenvs_cuda12.3/paddle_py310_yiqun +export PATH=${PYTHONPATH}/bin:${PATH} + +# 屏蔽平台预设的环境变量,因为框架采用兼容升级,检测到这些配置会使用原方式启动 +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT + +#nnodes=$PADDLE_TRAINERS_NUM +#rank=$PADDLE_TRAINER_ID +nnodes=1 +rank=0 + +python -m paddle.distributed.launch --nnodes=$nnodes --rank $rank test_intranode.py diff --git a/tests_paddle/test_internode.py b/tests_paddle/test_internode.py new file mode 100644 index 00000000..026dcc08 --- /dev/null +++ b/tests_paddle/test_internode.py @@ -0,0 +1,479 @@ +import os +import sys +import time + +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +import paddle.distributed.communication.deep_ep as deep_ep +from paddle.distributed.communication.group import Group +from paddle.base.core import Config +import numpy as np + +# noinspection PyUnresolvedReferences +import utils +from utils import bench, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back + +# Test compatibility with low latency functions +#import test_low_latency +try: + from paperf import profile_paddle + has_paperf = True +except ImportError: + has_paperf = False + + +def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: Group, use_random_input: bool, dump_input: bool, dump_output: bool, tune_performance: bool): + # Settings + num_tokens, hidden, num_topk_groups, num_topk, num_experts = 4096, 7168, min(num_nodes, 4), 8, (256 // num_ranks) * num_ranks + assert num_experts % num_ranks == 0 and num_local_ranks == 8 + if local_rank == 0: + print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True) + + if use_random_input: + # Random data + x = paddle.ones(shape=[num_tokens, hidden], dtype=paddle.bfloat16) * rank + x_pure_rand = paddle.randn(shape=[num_tokens, hidden], dtype=paddle.bfloat16) + x_e4m3 = per_token_cast_to_fp8(x) + scores = paddle.randn(shape=[num_tokens, num_experts], dtype=paddle.float32).abs() + 1 + + group_scores = scores.view([num_tokens, num_nodes, -1]).amax(axis=-1) + group_idx = paddle.topk(group_scores, num_topk_groups, axis=-1, sorted=False)[1] + masked_scores = create_grouped_scores(scores, group_idx, num_nodes) + topk_idx = paddle.topk(masked_scores, num_topk, axis=-1, largest=True, sorted=False)[1] + + topk_weights = paddle.ones(shape=[num_tokens, num_topk], dtype=paddle.float32) * rank + topk_weights_pure_rand = paddle.randn(shape=[num_tokens, num_topk], dtype=paddle.float32) + + if dump_input: + utils.dump(topk_idx, "topk_idx", local_rank) + else: + x = utils.load("x", local_rank) + x_pure_rand = utils.load("x_pure_rand", local_rank) + x_e4m3 = utils.load("x_e4m3", local_rank, "tuple") + + topk_idx = utils.load("topk_idx", local_rank) + topk_weights = utils.load("topk_weights", local_rank) + topk_weights_pure_rand = utils.load("topk_weights_pure_rand", local_rank) + + rank_idx = topk_idx // (num_experts // num_ranks) + rank_idx.masked_fill_(topk_idx == -1, -1) + inplace_unique(rank_idx, num_ranks) + + rdma_rank_idx = rank_idx // num_local_ranks + rdma_rank_idx.masked_fill_(rank_idx == -1, -1) + inplace_unique(rdma_rank_idx, num_nodes) + + # RDMA dispatch counts + rdma_idx = topk_idx // (num_experts // num_nodes) + rdma_idx.masked_fill_(topk_idx == -1, -1) + inplace_unique(rdma_idx, num_nodes) + num_rdma_token_sent = paddle.not_equal(rdma_idx, paddle.full_like(rdma_idx, -1)).sum().item() + + current_node = rank // num_local_ranks + mask_rdma_only = (rdma_idx != current_node) & (rdma_idx != -1) + num_rdma_only_token_sent = mask_rdma_only.sum().item() + print(f"-- [local_rank={local_rank}, rank={rank}] num_rdma_token_sent: {num_rdma_token_sent}, num_rdma_token_sent_rdma_only: {num_rdma_only_token_sent}") + + if use_random_input: + # Expert meta + num_tokens_per_expert = paddle.zeros(shape=[num_experts, ], dtype=paddle.int32) + for i in range(num_experts): + num_tokens_per_expert[i] = (topk_idx == i).sum() + gbl_num_tokens_per_expert = num_tokens_per_expert.clone() + dist.all_reduce(gbl_num_tokens_per_expert, group=group) + + # Rank layout meta + num_tokens_per_rank = paddle.empty([num_ranks, ], dtype=paddle.int32) + num_tokens_per_rdma_rank = paddle.empty([num_nodes, ], dtype=paddle.int32) + token_idx_in_rank = paddle.full([num_ranks, num_tokens], -1, dtype=paddle.int64) + for i in range(num_ranks): + num_tokens_per_rank[i] = (rank_idx == i).sum() + token_sel = (rank_idx == i).cast(paddle.int32).max(axis=-1) + count = token_sel.sum().item() + tokens = paddle.argsort(token_sel.cast(paddle.int32), descending=True) + tokens[:count] = paddle.sort(tokens[:count]) + token_idx_in_rank[i][tokens[:count]] = paddle.arange(count, dtype=paddle.int64) + for i in range(num_nodes): + num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum() + token_idx_in_rank = token_idx_in_rank.t().contiguous().cast(paddle.int32) + is_token_in_rank = token_idx_in_rank >= 0 + gbl_num_tokens_per_rank = num_tokens_per_rank.clone() + dist.all_reduce(gbl_num_tokens_per_rank, group=group) + else: + num_tokens_per_rank = utils.load('num_tokens_per_rank', local_rank) + num_tokens_per_rdma_rank = utils.load('num_tokens_per_rdma_rank', local_rank) + is_token_in_rank = utils.load('is_token_in_rank', local_rank) + num_tokens_per_expert = utils.load('num_tokens_per_expert', local_rank) + gbl_num_tokens_per_rank = utils.load('gbl_num_tokens_per_rank', local_rank) + gbl_num_tokens_per_expert = utils.load('gbl_num_tokens_per_expert', local_rank) + + ############################################################################################################ + # get_dispatch_layout + ############################################################################################################ + + ref_num_tokens_per_rank, ref_num_tokens_per_rdma_rank, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \ + buffer.get_dispatch_layout(topk_idx, num_experts) + + if dump_output: + utils.dump(ref_num_tokens_per_rank, 'ref_num_tokens_per_rank', local_rank) + utils.dump(ref_num_tokens_per_rdma_rank, 'ref_num_tokens_per_rdma_rank', local_rank) + utils.dump(ref_num_tokens_per_expert, 'ref_num_tokens_per_expert', local_rank) + utils.dump(ref_is_token_in_rank, 'ref_is_token_in_rank', local_rank) + + assert paddle.allclose(ref_num_tokens_per_rank, num_tokens_per_rank) + assert paddle.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank) + assert paddle.allclose(ref_num_tokens_per_expert, num_tokens_per_expert) + assert paddle.allclose(ref_is_token_in_rank, is_token_in_rank) + + t = bench(group, lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0] + if local_rank == 0: + print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True) + print() + paddle.distributed.barrier(group) + time.sleep(1) + + ############################################################################################################ + + # Config + rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512) + config = Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size) + + # Test dispatch + # noinspection PyShadowingNames + def check_data(check_x, recv_gbl_rank_prefix_sum): + assert paddle.allclose(check_x.amin(axis=1), check_x.amax(axis=1)) + check_start = 0 + for i in range(num_ranks): + check_end = recv_gbl_rank_prefix_sum[i].item() + assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0 + check_start = check_end + + for previous_mode in (False, True): + for async_mode in (False, True): + for current_x in (x_pure_rand, x, x_e4m3): + for with_topk in (False, True): + dtype_str = "FP8" if isinstance(current_x, tuple) else "BF16" + dump_prefix = f'{dtype_str}_{"with" if with_topk else "without"}_top-k_async_{async_mode}_previous_{previous_mode}_' + if local_rank == 0: + print(f'[testing] Running with {dtype_str}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='\n') + + + dispatch_args = { + 'x': current_x, + 'num_tokens_per_rank': num_tokens_per_rank, + 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, + 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, + 'config': config, + 'async_finish': async_mode + } + if with_topk: + dispatch_args.update({ + 'topk_idx': topk_idx, + 'topk_weights': topk_weights_pure_rand if not isinstance(current_x, tuple) else topk_weights + }) + + if previous_mode: + dispatch_args.update({'previous_event': buffer.capture()}) + + recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + + if dump_output: + utils.dump(recv_x, f'{dump_prefix}recv_x', local_rank) + utils.dump(recv_topk_idx, f'{dump_prefix}recv_topk_idx', local_rank) + utils.dump(recv_topk_weights, f'{dump_prefix}recv_topk_weights', local_rank) + utils.dump(recv_num_tokens_per_expert_list, f'{dump_prefix}recv_num_tokens_per_expert_list', local_rank) + + recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x + + # Checks + recv_gbl_rank_prefix_sum = handle[-4] + + if dump_output: + utils.dump(recv_gbl_rank_prefix_sum, f"{dump_prefix}recv_gbl_rank_prefix_sum", local_rank) + + assert gbl_num_tokens_per_rank[rank].item() == recv_x.shape[0], f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.shape[0]}' + assert gbl_num_tokens_per_expert.view([num_ranks, -1])[rank].tolist() == recv_num_tokens_per_expert_list + if current_x is not x_pure_rand: + pass + # check_data(recv_x, recv_gbl_rank_prefix_sum) + if with_topk: + # Check `topk_idx` + assert (recv_topk_idx.equal(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel() + for i, count in enumerate(recv_num_tokens_per_expert_list): + assert recv_topk_idx.equal(i).sum().item() == count + + if use_random_input: + # Check `topk_weights` + if current_x is not x_pure_rand: + recv_topk_weights[recv_topk_idx.equal(-1)] = recv_topk_weights.amax(axis=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.equal(-1)] + # check_data(recv_topk_weights, recv_gbl_rank_prefix_sum) + + # Test cached dispatch (must without top-k staffs) + # NOTES: handle must be refreshed + if not with_topk: + dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode} + if previous_mode: + dispatch_args.update({'previous_event': buffer.capture()}) + + recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + + if dump_output: + utils.dump(recv_x, f'{dump_prefix}recv_x_wo_topk', local_rank) + + recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x + + if use_random_input: + if current_x is not x_pure_rand: + pass + # check_data(recv_x, recv_gbl_rank_prefix_sum) + + # Test combine + if not use_random_input: + recv_x = utils.load(f"{dump_prefix}recv_x_combine_input", local_rank) + + combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode} + + if with_topk: + if not use_random_input: + recv_topk_weights = utils.load(f"{dump_prefix}recv_topk_weights_input", local_rank) + combine_args.update({'topk_weights': recv_topk_weights}) + + if dump_input: + utils.dump(recv_x, f'{dump_prefix}recv_x_combine_input', local_rank) + if with_topk: + utils.dump(recv_topk_weights, f'{dump_prefix}recv_topk_weights_input', local_rank) + if previous_mode: + dispatch_args.update({'previous_event': buffer.capture()}) + + combined_x, combined_topk_weights, event = buffer.combine(**combine_args) + event.current_stream_wait() if async_mode else () + + if dump_output: + utils.dump(combined_x, f"{dump_prefix}combined_x", local_rank) + utils.dump(combined_topk_weights, f"{dump_prefix}combined_topk_weights", local_rank) + + # check_x = combined_x.cast(paddle.float32) / is_token_in_rank.sum(axis=1).unsqueeze(1) + if use_random_input: + ref_x = x_pure_rand if current_x is x_pure_rand else x + # assert calc_diff(check_x, ref_x) < 5e-6 + if with_topk: + pass + # check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1)) + # ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights + # assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 + + # For later tuning + dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2 + dispatch_bf16_rdma_only_send_bytes = num_rdma_only_token_sent * hidden * 2 + dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2 + combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes + combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes + combine_bf16_rdma_only_recv_bytes = dispatch_bf16_rdma_only_send_bytes + + if local_rank == 0: + print(' passed', flush=True) + + if local_rank == 0: + print() + + def print_tensor_info(t, name): + print(f"-- {name}: data_ptr={t.data_ptr()}, shape={t.shape}, dtype={t.dtype}") + + profile = False + profile = profile and has_paperf + + if profile: + profile_paddle.switch_profile(0, 0, 1) + + # Tune dispatch performance + best_dispatch_results = None + fp8_factor = (1 + 4 / 128) / 2 + for current_x in (x_e4m3, x): + dtype_str = "FP8" if isinstance(current_x, tuple) else "BF16" + if profile: + profile_paddle.push_record_event(f"Tune_Dispatch_{dtype_str}") + + best_time, best_cpu_time, best_results = 1e10, 1e10, None + + rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes + rdma_only_send_bytes = (dispatch_bf16_rdma_only_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_only_send_bytes + nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes + + for nvl_chunk_size in range(4, 33, 4): + for rdma_chunk_size in range(4, 33, 4): + config_str = f"sms={num_sms},nvl={nvl_chunk_size},{nvl_buffer_size},rdma={rdma_chunk_size},{rdma_buffer_size}" + if profile: + profile_paddle.push_record_event(f"Dispatch_{dtype_str}_Config({config_str})") + + config = Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) + tune_args = {'x': current_x, 'handle': handle, 'config': config} + result_times = bench(group, lambda: buffer.dispatch(**tune_args)) + t = result_times[0] + cpu_t = result_times[3] + + if profile: + profile_paddle.pop_record_event() + + if t < best_time: + best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size) + best_cpu_time = cpu_t + + if local_rank == 0: + rdma_send_GBs = rdma_send_bytes / 1e9 / t + rdma_only_send_GBs = rdma_only_send_bytes / 1e9 / t + nvl_recv_GBs = nvl_recv_bytes / 1e9 / t + print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_GBs:.2f} GB/s (RDMA + NVL), {rdma_only_send_GBs:.2f} GB/s (RDMA), {nvl_recv_GBs:.2f} GB/s (NVL) (time: {t:.5f} s, cpu_time: {cpu_t:.5f} s)') + + if profile: + profile_paddle.pop_record_event() + + if local_rank == 0: + rdma_send_GBs = rdma_send_bytes / 1e9 / best_time + rdma_only_send_GBs = rdma_only_send_bytes / 1e9 / best_time + nvl_recv_GBs = nvl_recv_bytes / 1e9 / best_time + print(f'[tuning] Best dispatch ({dtype_str}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_GBs:.2f} GB/s (RDMA + NVL), {rdma_only_send_GBs:.2f} (RDMA), {nvl_recv_GBs:.2f} GB/s (NVL) (time: {best_time:.5f} s, cpu_time: {best_cpu_time:.5f} s)') + print() + + if isinstance(current_x, tuple): + if profile: + profile_paddle.push_record_event("Gather_Best_Config") + + # Gather FP8 the best config from rank 0 + best_dispatch_results = paddle.to_tensor([best_results[0], best_results[1], best_results[2]], dtype=paddle.int32) + all_best_fp8_results_list = [paddle.zeros_like(best_dispatch_results) for _ in range(paddle.distributed.get_world_size(group))] + dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group) + best_dispatch_results = all_best_fp8_results_list[0].tolist() + + if profile: + profile_paddle.pop_record_event() + + paddle.distributed.barrier(group) + print(f"========================================================================") + + config_str = f"sms={best_dispatch_results[0]},nvl={best_dispatch_results[1]},{nvl_buffer_size},rdma={best_dispatch_results[2]},{rdma_buffer_size}" + if profile: + profile_paddle.push_record_event(f"Best_Dispatch_BF16_Config({config_str})") + + dispatch_config = Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size, best_dispatch_results[2], rdma_buffer_size) + #dispatch_config = Config(24, 20, 512, 32, 128) + + dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, + 'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert, + 'config': dispatch_config if dispatch_config is not None else config} + + #if local_rank == 0: + # print_tensor_info(x, "x") + # print_tensor_info(num_tokens_per_rank, "num_tokens_per_rank") + # print_tensor_info(num_tokens_per_rdma_rank, "num_tokens_per_rdma_rank") + # print_tensor_info(is_token_in_rank, "is_token_in_rank") + # print_tensor_info(num_tokens_per_expert, "num_tokens_per_expert") + # print(f"-- dispatch_args: {dispatch_args}") + # print(f"-- dispatch_config: {best_dispatch_results[0]}, {best_dispatch_results[1]}, {nvl_buffer_size}, {best_dispatch_results[2]}, {rdma_buffer_size}") + + for i in range(1): + recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) + + if profile: + profile_paddle.pop_record_event() + + #if local_rank == 0: + # print_tensor_info(recv_x, "recv_x") + + if profile: + profile_paddle.push_record_event(f"Tune_Combine_BF16") + + # Tune combine performance + best_time, best_cpu_time, best_results = 1e10, 1e10, None + for nvl_chunk_size in range(1, 5, 1): + for rdma_chunk_size in range(8, 33, 4): + config_str = f"sms={num_sms},nvl={nvl_chunk_size},{nvl_buffer_size},rdma={rdma_chunk_size},{rdma_buffer_size}" + if profile: + profile_paddle.push_record_event(f"Combine_BF16_Config({config_str})") + + config = Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) + tune_args = {'x': recv_x, 'handle': handle, 'config': config} + result_times = bench(group, lambda: buffer.combine(**tune_args)) + t = result_times[0] + cpu_t = result_times[3] + + if profile: + profile_paddle.pop_record_event() + + if local_rank == 0: + combine_bf16_rdma_recv_GBs = combine_bf16_rdma_recv_bytes / 1e9 / t + combine_bf16_rdma_only_recv_GBs = combine_bf16_rdma_only_recv_bytes / 1e9 / t + combine_bf16_nvl_send_GBs = combine_bf16_nvl_send_bytes / 1e9 / t + print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_GBs:.2f} GB/s (RDMA + NVL), {combine_bf16_rdma_only_recv_GBs:.2f} GB/s (RDMA), {combine_bf16_nvl_send_GBs:.2f} GB/s (NVL) (time: {t:.5f} s, cpu_time: {cpu_t:.5f} s)') + if t < best_time: + best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size) + best_cpu_time = cpu_t + + if profile: + profile_paddle.pop_record_event() + + if profile: + profile_paddle.switch_profile(1, 0, 1) + + if local_rank == 0: + combine_bf16_rdma_recv_GBs = combine_bf16_rdma_recv_bytes / 1e9 / best_time + combine_bf16_rdma_only_recv_GBs = combine_bf16_rdma_only_recv_bytes / 1e9 / best_time + combine_bf16_nvl_send_GBs = combine_bf16_nvl_send_bytes / 1e9 / best_time + print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_GBs:.2f} GB/s (RDMA + NVL), {combine_bf16_rdma_only_recv_GBs:.2f} GB/s (RDMA), {combine_bf16_nvl_send_GBs:.2f} GB/s (NVL) (time: {best_time:.5f} s, cpu_time: {best_cpu_time:.5f} s)') + print() + + +# noinspection PyUnboundLocalVariable +def test_loop(num_local_ranks): + # Please make sure AR (Adaptive Routing) is turned off when running normal internode kernels, + # rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + test_ll_compatibility = False + if test_ll_compatibility: + ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 + + hcg = fleet.get_hybrid_communicate_group() + ep_group = hcg.get_model_parallel_group() + + buffer = deep_ep.Buffer(ep_group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility, + num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1)) + + num_ranks = dist.get_world_size(ep_group) + rank = dist.get_rank(ep_group) + + num_nodes = int(num_ranks / 8) + local_rank = rank % 8 + print(f'local_rank:{local_rank}, num_local_ranks:{num_local_ranks}, num_ranks:{num_ranks}, rank:{rank}') + + assert num_local_ranks == 8 and num_ranks > 8 + paddle.seed(rank) + + use_random_input = True + dump_input = False + dump_output = False + tune_performance = True + + for i in (24, ): + test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, ep_group, use_random_input, dump_input, dump_output, tune_performance) + if local_rank == 0: + print() + + # Test compatibility with low latency functions + if test_ll_compatibility: + buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) + test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1) + + +if __name__ == '__main__': + num_processes = 8 + #torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes) + world_size = int(os.getenv('WORLD_SIZE', 1)) + mp_degree = world_size * num_processes + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "mp_degree": mp_degree, + } + fleet.init(is_collective=True, strategy=strategy) + test_loop(num_processes) diff --git a/tests_paddle/test_internode_latency.py b/tests_paddle/test_internode_latency.py new file mode 100644 index 00000000..29343bf0 --- /dev/null +++ b/tests_paddle/test_internode_latency.py @@ -0,0 +1,252 @@ +import os +import sys +import time +import numpy as np + +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from paddle.distributed.communication.group import Group +import random + +import fused_a2a +import utils + +try: + from paperf import profile_paddle + has_paperf = True +except ImportError: + has_paperf = False + + +profile = False +profile = profile and has_paperf + + +def print_tensor_info(t, name): + #print(f"-- {name}: data_ptr={t.untyped_storage().data_ptr()}, shape={t.size()}, dtype={t.dtype}") + print(f"-- {name}: shape={t.size()}, dtype={t.dtype}") + + +def init_random_tensors(rank, num_nodes, num_tokens, hidden, num_topk_groups, num_topk, num_experts, dump_input=False): + # Random data + x = paddle.ones(shape=[num_tokens, hidden], dtype=paddle.bfloat16) * rank + x_pure_rand = paddle.randn(shape=[num_tokens, hidden], dtype=paddle.bfloat16) + x_e4m3 = utils.per_token_cast_to_fp8(x) + + scores = paddle.randn(shape=[num_tokens, num_experts], dtype=paddle.float32).abs() + 1 + group_scores = scores.view([num_tokens, num_nodes, -1]).amax(axis=-1) + group_idx = paddle.topk(group_scores, num_topk_groups, axis=-1, sorted=False)[1] + masked_scores = utils.create_grouped_scores(scores, group_idx, num_nodes) + + topk_idx = paddle.topk(masked_scores, num_topk, axis=-1, largest=True, sorted=False)[1] + topk_weights = paddle.ones(shape=[num_tokens, num_topk], dtype=paddle.float32) * rank + topk_weights_pure_rand = paddle.randn(shape=[num_tokens, num_topk], dtype=paddle.float32) + + if dump_input: + utils.dump(x, 'x', rank) + utils.dump(x_pure_rand, 'x_pure_rand', rank) + utils.dump(x_e4m3, 'x_e4m3', rank) + + utils.dump(topk_idx, 'topk_idx', rank) + utils.dump(topk_weights, 'topk_weights', rank) + utils.dump(topk_weights_pure_rand, 'topk_weights_pure_rand', rank) + + return x, x_pure_rand, x_e4m3, topk_idx, topk_weights, topk_weights_pure_rand + + +def load_dumped_tensors(rank, num_tokens, hidden, num_topk_groups, num_topk, num_experts): + # x = utils.load("x", local_rank) + # x_pure_rand = utils.load("x_pure_rand", local_rank) + # #x_e4m3 = utils.load("x_e4m3", local_rank, "tuple") + + # topk_idx = utils.load("topk_idx", local_rank) + # topk_weights = utils.load("topk_weights", local_rank) + # topk_weights_pure_rand = utils.load("topk_weights_pure_rand", local_rank) + + def _load_tensor(rank, name, idx, typehint="tensor"): + dump_dir = "/root/paddlejob/workspace/env_run/liuyiqun/outputs/ds_8nodes" + filename = f"{dump_dir}/{idx}_{name}_rank{rank}.npy" + if typehint == "tensor": + x_np = np.load(filename) + if x_np.dtype == np.uint16: + x = paddle.to_tensor(x_np).view(paddle.bfloat16) + elif x_np.dtype in [np.float32, np.int32, np.int64, np.int8]: + x = paddle.to_tensor(x_np) + else: + assert False, f'{name}: {x_np.dtype}' + return x + else: + assert False, f'invalid typehint: {typehint}' + + input_tensors = [] + for i in range(20): + x = _load_tensor(rank, "dispatch_x", i + 1) + topk_idx = _load_tensor(rank, "topk_idx", i + 1) + topk_weights = _load_tensor(rank, "topk_weights", i + 1) + input_tensors.append({"x": x, "topk_idx": topk_idx, "topk_weights": topk_weights}) + + return input_tensors + + +def test_main(local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, group: Group, use_random_input, dump_input): + # Settings + num_tokens = 4096 + hidden = 7168 + num_topk_groups = min(num_nodes, 4) + num_topk = 8 + num_experts = (256 // num_ranks) * num_ranks + + assert num_experts % num_ranks == 0 and num_local_ranks == 8 + if local_rank == 0: + print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True) + + if profile: + profile_paddle.push_record_event("init_input_tensors") + if use_random_input: + x, x_pure_rand, x_e4m3, topk_idx, topk_weights, topk_weights_pure_rand = init_random_tensors(rank, num_nodes, num_tokens, hidden, num_topk_groups, num_topk, num_experts, dump_input) + else: + input_tensors = load_dumped_tensors(rank, num_tokens, hidden, num_topk_groups, num_topk, num_experts) + if profile: + profile_paddle.pop_record_event() + + #paddle.distributed.barrier() + + # test bfloat16 + buffer = fused_a2a.get_buffer(group, hidden * 2) + + #paddle.distributed.barrier() + + num_warmups = 100 + num_tests = 1000 + + start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)] + end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)] + + for i in range(num_warmups + num_tests + 1): + #if i == 1: + # paddle.distributed.barrier() + + if profile: + profile_paddle.push_record_event(f"test_{i}") + + if not use_random_input: + inputs_i = input_tensors[i % 20] + x = inputs_i["x"] + topk_idx = inputs_i["topk_idx"] + topk_weights = inputs_i["topk_weights"] + + if i == num_warmups: + paddle.distributed.barrier(group) + paddle.device.synchronize() + cpu_start = time.time() + + if i >= num_warmups and i < num_warmups + num_tests: + # Record + batch_start = time.time() + start_events[i - num_warmups].record() + + recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, dispatch_event = fused_a2a.fused_dispatch_forward_func( + x=x, + token_indices=topk_idx, + token_probs=topk_weights, + num_experts=num_experts, + group=group, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False + ) + + # combined_x, event = fused_a2a.fused_combine_forward_func( + # x=recv_x, + # handle=handle, + # previous_event=None, + # async_finish=False, + # allocate_on_comm_stream=False + # ) + + if i >= num_warmups and i < num_warmups + num_tests: + end_events[i - num_warmups].record() + + if i < num_warmups + 20: + random_time = random.uniform(0.02, 0.08) + time.sleep(random_time) + else: + time.sleep(0.02) + + if profile: + profile_paddle.pop_record_event() + if i >= num_warmups and i < num_warmups + num_tests: + batch_cpu_time = time.time() - batch_start + #if local_rank == 0: + # print(f"-- {i - num_warmups}-th running, cpu_time: {batch_cpu_time:.5f} s") + + paddle.distributed.barrier(group) + paddle.device.synchronize() + + cpu_runtime = time.time() - cpu_start + avg_cpu_time = cpu_runtime / num_tests + + gpu_times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:] + avg_gpu_time = np.average(gpu_times) + max_gpu_time = np.max(gpu_times) + min_gpu_time = np.min(gpu_times) + + print(f"-- rank: {rank}, avg_cpu_time: {avg_cpu_time:.5f} s; gpu_time: avg={avg_gpu_time:.5f} s, max={max_gpu_time:.5f} s, min={min_gpu_time:.5f} s") + + paddle.distributed.barrier(group) + paddle.device.synchronize() + + avg_cpu_time_all_ranks = [] + avg_gpu_time_all_ranks = [] + max_gpu_time_all_ranks = [] + min_gpu_time_all_ranks = [] + dist.all_gather_object(avg_cpu_time_all_ranks, avg_cpu_time, group=group) + dist.all_gather_object(avg_gpu_time_all_ranks, avg_gpu_time, group=group) + dist.all_gather_object(max_gpu_time_all_ranks, max_gpu_time, group=group) + dist.all_gather_object(min_gpu_time_all_ranks, min_gpu_time, group=group) + if rank == 0: + avg_cpu_time = np.average(np.array(avg_cpu_time_all_ranks)) + avg_gpu_time = np.average(np.array(avg_gpu_time_all_ranks)) + max_gpu_time = np.average(np.array(max_gpu_time_all_ranks)) + min_gpu_time = np.average(np.array(min_gpu_time_all_ranks)) + print(f"-- avg_cpu_time_of_all_ranks: {avg_cpu_time:.5f} s; gpu_time_of_all_ranks: avg={avg_gpu_time:.5f} s, max={max_gpu_time:.5f} s, min={min_gpu_time:.5f} s") + + +def test_loop(num_local_ranks: int): + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + + num_ranks = dist.get_world_size(group) + rank = dist.get_rank(group) + + num_nodes = int(num_ranks / 8) + local_rank = rank % 8 + print(f'local_rank:{local_rank}, num_local_ranks:{num_local_ranks}, num_ranks:{num_ranks}, rank:{rank}') + + assert num_local_ranks == 8 and num_ranks > 8 + paddle.seed(rank) + + use_random_input = False + dump_input = False + + print(f"-- profile: {profile}") + if profile: + profile_paddle.switch_profile(0, 0, 1) + + test_main(local_rank, num_local_ranks, num_ranks, num_nodes, rank, group, use_random_input, dump_input) + + if profile: + profile_paddle.switch_profile(1, 0, 1) + +if __name__ == '__main__': + num_processes = 8 + #torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes) + world_size = int(os.getenv('WORLD_SIZE', 1)) + mp_degree = world_size * num_processes + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "mp_degree": mp_degree, + } + fleet.init(is_collective=True, strategy=strategy) + test_loop(num_processes) diff --git a/tests_paddle/test_intranode.py b/tests_paddle/test_intranode.py new file mode 100644 index 00000000..6f34b378 --- /dev/null +++ b/tests_paddle/test_intranode.py @@ -0,0 +1,325 @@ +import os +import time +import paddle +import paddle.distributed.fleet as fleet +import paddle.distributed as dist +# from paddle.distributed import deep_ep +import paddle.distributed.communication.deep_ep as deep_ep +from paddle.distributed.communication.group import Group +from paddle.base.core import Config +import numpy as np + +# noinspection PyUnresolvedReferences +# import deep_ep +# from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to_fp8, per_token_cast_back + +# Test compatibility with low latency functions +# import test_low_latency +try: + from paperf import profile_paddle + has_paperf = True +except ImportError: + has_paperf = False + + +def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): + # Flush L2 cache with 256 MB data + paddle.device.cuda.synchronize() + cache = paddle.empty([int(256e6 // 4)], dtype=paddle.int32) + + # Warmup + for _ in range(num_warmups): + fn() + + # Flush L2 + cache.zero_() + + # Testing + start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)] + end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)] + for i in range(num_tests): + # Record + start_events[i].record() + fn() + end_events[i].record() + if post_fn is not None: + post_fn() + paddle.device.cuda.synchronize() + + times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:] + return np.average(times), np.min(times), np.max(times) + +def per_token_cast_to_fp8(x: paddle.Tensor): + assert x.dim() == 2 and x.shape[1] % 128 == 0 + # m, n = x.shape + m = x.shape[0] + n = x.shape[1] + x_view = x.view([m, -1, 128]) + x_amax = x_view.abs().cast(paddle.float32).amax(axis=2).view([m, -1]).clip(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).cast(paddle.float8_e4m3fn).view([m, n]), (x_amax / 448.0).view([m, -1]) + + +def per_token_cast_back(x_fp8: paddle.Tensor, x_scales: paddle.Tensor): + x_fp32 = x_fp8.cast(paddle.float32).view([x_fp8.shape[0], -1, 128]) + x_scales = x_scales.view([x_fp8.shape[0], -1, 1]) + return (x_fp32 * x_scales).view(x_fp8.shape).cast(paddle.bfloat16) + +def inplace_unique(x: paddle.Tensor, num_slots: int): + assert x.dim() == 2 + mask = x < 0 + x_padded = x.masked_fill(mask, num_slots) + bin_count = paddle.zeros([x.shape[0], num_slots + 1], dtype=x.dtype).to(x.place) + # bin_count.scatter_add_(1, x_padded, paddle.ones_like(x_padded)) + bin_count.put_along_axis_(axis=1, indices=x_padded, values=paddle.ones_like(x_padded), reduce='add', include_self=True) + + bin_count = bin_count[:, :num_slots] + sorted_bin_count = paddle.sort(bin_count, axis=-1, descending=True) + sorted_bin_idx = paddle.argsort(bin_count, axis=-1, descending=True) + sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1) + sorted_bin_idx = paddle.sort(sorted_bin_idx, descending=True, axis=-1) + x[:, :].fill_(-1) + valid_len = min(num_slots, x.shape[1]) + x[:, :valid_len] = sorted_bin_idx[:, :valid_len] + +def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: Group): + # Settings + num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks + assert num_experts % num_ranks == 0 and num_local_ranks == 8 + if local_rank == 0: + print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True) + + # Random data + x = paddle.ones(shape=[num_tokens, hidden], dtype=paddle.bfloat16) * rank + x_pure_rand = paddle.randn(shape=[num_tokens, hidden], dtype=paddle.bfloat16) + x_e4m3 = per_token_cast_to_fp8(x) + scores = paddle.randn([num_tokens, num_experts], dtype=paddle.float32).abs() + 1 + topk_idx = paddle.topk(scores, num_topk, axis=-1, largest=True, sorted=False)[1] + topk_weights = paddle.ones([num_tokens, num_topk], dtype=paddle.float32) * rank + topk_weights_pure_rand = paddle.randn([num_tokens, num_topk], dtype=paddle.float32) + rank_idx = topk_idx // (num_experts // num_ranks) + rank_idx.masked_fill_(topk_idx == -1, -1) + inplace_unique(rank_idx, num_ranks) + + # Expert meta + num_tokens_per_expert = paddle.zeros([num_experts, ], dtype=paddle.int32) + for i in range(num_experts): + num_tokens_per_expert[i] = (topk_idx == i).sum() + gbl_num_tokens_per_expert = num_tokens_per_expert.clone() + dist.all_reduce(gbl_num_tokens_per_expert, group=group) + + # Rank layout meta + num_tokens_per_rank = paddle.empty([num_ranks, ], dtype=paddle.int32) + token_idx_in_rank = paddle.full((num_ranks, num_tokens), -1, dtype=paddle.int64) + for i in range(num_ranks): + num_tokens_per_rank[i] = (rank_idx == i).sum() + token_sel = (rank_idx == i).cast(paddle.int32).max(axis=-1) + count = token_sel.sum().item() + tokens = paddle.argsort(token_sel.cast(paddle.int32), descending=True) + tokens[:count] = paddle.sort(tokens[:count]) + token_idx_in_rank[i][tokens[:count]] = paddle.arange(count, dtype=paddle.int64) + token_idx_in_rank = token_idx_in_rank.t().contiguous().cast(paddle.int32) + is_token_in_rank = token_idx_in_rank >= 0 + gbl_num_tokens_per_rank = num_tokens_per_rank.clone() + dist.all_reduce(gbl_num_tokens_per_rank, group=group) + + ref_num_tokens_per_rank, _, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \ + buffer.get_dispatch_layout(topk_idx, num_experts) + + assert paddle.allclose(ref_num_tokens_per_rank, num_tokens_per_rank) + assert paddle.allclose(ref_num_tokens_per_expert, num_tokens_per_expert) + assert paddle.allclose(ref_is_token_in_rank, is_token_in_rank) + + t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0] + if local_rank == 0: + print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True) + print() + paddle.distributed.barrier(group) + time.sleep(1) + + # Config + nvl_buffer_size = 256 + config = Config(num_sms, 8, nvl_buffer_size) + + # Test dispatch + # noinspection PyShadowingNames + def check_data(check_x, rank_prefix_matrix): + assert paddle.allclose(check_x.amin(axis=1), check_x.amax(axis=1)) + check_start = 0 + for i in range(num_ranks): + check_end = rank_prefix_matrix[i][rank].item() + assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0 + check_start = check_end + + wsm_a = paddle.randn([num_tokens, hidden], dtype=paddle.bfloat16) + wsm_b = paddle.randn([hidden, num_tokens], dtype=paddle.bfloat16) + for previous_mode in (False, True): + for async_mode in (False, True): + for current_x in (x_pure_rand, x, x_e4m3): + paddle.base.core.nvprof_nvtx_push("gemm") + wsm_out = paddle.matmul(wsm_a, wsm_b) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("wsm_all2all") + wsm_all2all_out = paddle.empty_like(wsm_out) + paddle.distributed.alltoall_single(wsm_all2all_out, wsm_out, group=group, sync_op=False) + paddle.base.core.nvprof_nvtx_pop() + + for with_topk in (False, True): + if local_rank == 0: + print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='') + dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode} + if with_topk: + dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights}) + if previous_mode: + dispatch_args.update({'previous_event': buffer.capture()}) + recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x + + # Checks + rank_prefix_matrix = handle[0] + assert gbl_num_tokens_per_rank[rank].item() == recv_x.shape[0], f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.shape[0]}' + assert gbl_num_tokens_per_expert.view([num_ranks, -1])[rank].tolist() == recv_num_tokens_per_expert_list + if current_x is not x_pure_rand: + pass + # check_data(recv_x, rank_prefix_matrix) + if with_topk: + # Check `topk_idx` + assert (recv_topk_idx.equal(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel() + for i, count in enumerate(recv_num_tokens_per_expert_list): + assert recv_topk_idx.equal(i).sum().item() == count + + # Check `topk_weights` + if current_x is not x_pure_rand: + recv_topk_weights[recv_topk_idx.equal(-1)] = recv_topk_weights.amax(axis=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.equal(-1)] + # check_data(recv_topk_weights, rank_prefix_matrix) + + # Test cached dispatch (must without top-k staffs) + # NOTES: handle must be refreshed + if not with_topk: + dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode} + if previous_mode: + dispatch_args.update({'previous_event': buffer.capture()}) + recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x + if current_x is not x_pure_rand: + pass + # check_data(recv_x, rank_prefix_matrix) + + # Test combine + combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode} + if with_topk: + combine_args.update({'topk_weights': recv_topk_weights}) + if previous_mode: + dispatch_args.update({'previous_event': buffer.capture()}) + combined_x, combined_topk_weights, event = buffer.combine(**combine_args) + event.current_stream_wait() if async_mode else () + # check_x = combined_x.cast(paddle.float32) / is_token_in_rank.sum(axis=1).unsqueeze(1) + ref_x = x_pure_rand if current_x is x_pure_rand else x + # assert calc_diff(check_x, ref_x) < 5e-6 + if with_topk: + pass + # check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(axis=1).unsqueeze(1)) + # ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights + # assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 + + # For later tuning + dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2 + combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes + + if local_rank == 0: + print(' passed', flush=True) + if local_rank == 0: + print() + + # Tune dispatch performance + best_dispatch_results = None + fp8_factor = (1 + 4 / 128) / 2 + for current_x in (x_e4m3, x): + best_time, best_results = 1e10, None + nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes + for nvl_chunk_size in range(4, 33, 4): + config = Config(num_sms, nvl_chunk_size, nvl_buffer_size) + tune_args = {'x': current_x, 'handle': handle, 'config': config} + t = bench(lambda: buffer.dispatch(**tune_args))[0] + if t < best_time: + best_time, best_results = t, (num_sms, nvl_chunk_size) + if local_rank == 0: + print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ') + if local_rank == 0: + print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)') + print() + + if isinstance(current_x, tuple): + # Gather FP8 the best config from rank 0 + best_dispatch_results = paddle.to_tensor([best_results[0], best_results[1]], dtype=paddle.int32) + all_best_fp8_results_list = [paddle.zeros_like(best_dispatch_results) for _ in range(paddle.distributed.get_world_size(group))] + dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group) + best_dispatch_results = all_best_fp8_results_list[0].tolist() + dispatch_config = Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size) + + dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, + 'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert, + 'config': dispatch_config if dispatch_config is not None else config} + + for i in range(1): + recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) + + # Tune combine performance + best_time, best_results = 1e10, None + for nvl_chunk_size in range(1, 5, 1): + config = Config(num_sms, nvl_chunk_size, nvl_buffer_size) + tune_args = {'x': recv_x, 'handle': handle, 'config': config} + t = bench(lambda: buffer.combine(**tune_args))[0] + if local_rank == 0: + print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ') + if t < best_time: + best_time, best_results = t, (num_sms, nvl_chunk_size) + + if local_rank == 0: + print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)') + print() + + +# noinspection PyUnboundLocalVariable +def test_loop(): + # rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + test_ll_compatibility, num_rdma_bytes = False, 0 + if test_ll_compatibility: + ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts) + + hcg = fleet.get_hybrid_communicate_group() + ep_group = hcg.get_model_parallel_group() + buffer = deep_ep.Buffer(ep_group, int(1e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility, + num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1)) + + local_rank = dist.get_rank(ep_group) + num_local_ranks = dist.get_world_size(ep_group) + rank = local_rank + num_ranks = num_local_ranks + print(f'local_rank:{local_rank}, num_local_ranks:{num_local_ranks}, num_ranks:{num_ranks}, rank:{rank}') + + paddle.seed(rank) + + for i in (24, ): + test_main(i, local_rank, num_local_ranks, num_ranks, rank, buffer, ep_group) + if local_rank == 0: + print() + + # Test compatibility with low latency functions + if test_ll_compatibility: + buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) + test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1) + + +if __name__ == '__main__': + mp_degree = 8 + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "mp_degree": mp_degree, + } + fleet.init(is_collective=True, strategy=strategy) + test_loop() diff --git a/tests_paddle/test_low_latency.py b/tests_paddle/test_low_latency.py new file mode 100644 index 00000000..0b50c163 --- /dev/null +++ b/tests_paddle/test_low_latency.py @@ -0,0 +1,160 @@ +import random +import torch +import torch.distributed as dist +from functools import partial + +import deep_ep +from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back + + +def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, + rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, seed: int = 0): + torch.manual_seed(seed + rank) + random.seed(seed + rank) + + assert num_experts % num_ranks == 0 + num_local_experts = num_experts // num_ranks + + # NOTES: the integers greater than 256 exceeds the BF16 precision limit + rank_offset = 128 + assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)' + + x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset) + x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1) + scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 + topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] + topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs() + + # Randomly mask some positions + for i in range(10): + topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1 + + # Check dispatch correctness + do_check = True + hash_value, num_times = 0, 0 + for return_recv_hook in (False, True): + num_times += 1 + for i in range((num_times % 2) + 1): + packed_recv_x, packed_recv_count, handle, event, hook = \ + buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, + async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) + hook() if return_recv_hook else event.current_stream_wait() + packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) + simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) + all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda') + dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) + for i in range(num_local_experts if do_check else 0): + expert_id = rank * num_local_experts + i + recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) + recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i] + + # Check expert indices + int_mask = (2 ** 32) - 1 + num_valid_tokens = recv_count.item() + assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()' + assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}' + + # Check received data + recv_x = recv_x[:num_valid_tokens] + recv_x_amin = recv_x[:, :-128].amin(dim=-1) + recv_src_info = recv_src_info[:num_valid_tokens] + assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) + assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 + for j in range(num_ranks): + begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item() + assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() + assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0 + hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens]) + hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens]) + + # Check combine correctness + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, + async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) + hook() if return_recv_hook else event.current_stream_wait() + if do_check: + diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) + assert torch.isnan(combined_x).sum().item() == 0 + assert diff < 1e-5, f'Error: diff={diff}' + hash_value ^= hash_tensor(combined_x) + + def create_test_cast_with_outliers(num_outliers): + tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + tmp /= tmp.abs().amax(dim=1).view(-1, 1) + assert tmp.abs().amax().item() <= 1 + + # Create some amax outliers + for i in range(num_outliers): + tmp[random.randint(0, num_tokens - 1)] *= 1e3 + return tmp + + # noinspection PyShadowingNames + def large_gemm_with_hook(hook): + mat_0 = torch.randn((8192, 8192), dtype=torch.float) + mat_1 = torch.randn((8192, 8192), dtype=torch.float) + mat_0 @ mat_1 + hook() + + # noinspection PyShadowingNames + def test_func(return_recv_hook): + recv_x, recv_count, handle, event, hook = \ + buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, + async_finish=False, return_recv_hook=return_recv_hook) + large_gemm_with_hook(hook) if return_recv_hook else None + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, + return_recv_hook=return_recv_hook) + large_gemm_with_hook(hook) if return_recv_hook else None + + # Calculate bandwidth + num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 + num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 + for i in range(num_tokens): + num_selections = (topk_idx[i] != -1).sum().item() + num_dispatch_comm_bytes += num_fp8_bytes * num_selections + num_combine_comm_bytes += num_bf16_bytes * num_selections + + # Dispatch + combine testing + avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False)) + print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' + f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True) + + # Separate profiling + for return_recv_hook in (False, True): + group.barrier() + dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook), + kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True, + suppress_kineto_output=True) + if not return_recv_hook: + print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' + f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us') + else: + print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | ' + f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} us') + + return hash_value + + +# noinspection PyUnboundLocalVariable +def test_loop(local_rank: int, num_local_ranks: int): + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288 + + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts) + if local_rank == 0: + print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True) + buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, + num_qps_per_rank=num_experts // num_ranks) + test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1) + + do_pressure_test = False + for seed in range(int(1e9) if do_pressure_test else 0): + if local_rank == 0: + print(f'Testing with seed {seed} ...', flush=True) + ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) + for i in range(20): + assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}' + + +if __name__ == '__main__': + # TODO: you may modify NUMA binding for less CPU overhead + num_processes = 8 + torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes) diff --git a/tests_paddle/utils.py b/tests_paddle/utils.py new file mode 100644 index 00000000..ceccd26c --- /dev/null +++ b/tests_paddle/utils.py @@ -0,0 +1,293 @@ +import os +import sys +import time +import numpy as np +import paddle +import paddle.distributed as dist +from typing import Optional + + +def init_dist(num_local_ranks: int): + # NOTES: you may rewrite this function with your own cluster settings + ip = os.getenv('MASTER_ADDR', '127.0.0.1') + port = int(os.getenv('MASTER_PORT', '8361')) + num_nodes = int(os.getenv('WORLD_SIZE', 1)) + node_rank = int(os.getenv('RANK', 0)) + print(f"num_nodes: {num_nodes}, node_rank: {node_rank}") + assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 + + dist.init_parallel_env() + #dist.init_process_group( + # backend='nccl', + # init_method=f'tcp://{ip}:{port}', + # world_size=num_nodes * num_local_ranks, + # rank=node_rank * num_local_ranks + local_rank + #) + + rank = dist.get_rank() + local_rank = rank % num_local_ranks + + paddle.set_default_dtype(paddle.bfloat16) + #paddle.set_default_device('cuda') + #paddle.set_device(f"cuda:{local_rank}") + + return local_rank, rank, dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes))) + + +#def calc_diff(x: torch.Tensor, y: torch.Tensor): +# x, y = x.double() + 1, y.double() + 1 +# denominator = (x * x + y * y).sum() +# sim = 2 * (x * y).sum() / denominator +# return (1 - sim).item() + + +def per_token_cast_to_fp8(x: paddle.Tensor): + assert x.dim() == 2 and x.shape[1] % 128 == 0 + # m, n = x.shape + m = x.shape[0] + n = x.shape[1] + x_view = x.view([m, -1, 128]) + x_amax = x_view.abs().cast(paddle.float32).amax(axis=2).view([m, -1]).clip(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).cast(paddle.float8_e4m3fn).view([m, n]), (x_amax / 448.0).view([m, -1]) + + +def per_token_cast_back(x_fp8: paddle.Tensor, x_scales: paddle.Tensor): + x_fp32 = x_fp8.cast(paddle.float32).view([x_fp8.shape[0], -1, 128]) + x_scales = x_scales.view([x_fp8.shape[0], -1, 1]) + return (x_fp32 * x_scales).view(x_fp8.shape).cast(paddle.bfloat16) + + +def inplace_unique(x: paddle.Tensor, num_slots: int): + assert x.dim() == 2 + mask = x < 0 + x_padded = x.masked_fill(mask, num_slots) + bin_count = paddle.zeros([x.shape[0], num_slots + 1], dtype=x.dtype).to(x.place) + # bin_count.scatter_add_(1, x_padded, paddle.ones_like(x_padded)) + bin_count.put_along_axis_(axis=1, indices=x_padded, values=paddle.ones_like(x_padded), reduce='add', include_self=True) + + bin_count = bin_count[:, :num_slots] + sorted_bin_count = paddle.sort(bin_count, axis=-1, descending=True) + sorted_bin_idx = paddle.argsort(bin_count, axis=-1, descending=True) + sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1) + sorted_bin_idx = paddle.sort(sorted_bin_idx, descending=True, axis=-1) + x[:, :].fill_(-1) + valid_len = min(num_slots, x.shape[1]) + x[:, :valid_len] = sorted_bin_idx[:, :valid_len] + + +def create_grouped_scores(scores: paddle.Tensor, group_idx: paddle.Tensor, num_groups: int): + num_tokens, num_experts = scores.shape + scores = scores.view([num_tokens, num_groups, -1]) + mask = paddle.zeros([num_tokens, num_groups], dtype=paddle.int64) + #mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores) + mask = mask.put_along_axis_(axis=1, indices=group_idx, values=1).cast(paddle.float32) + mask = mask.unsqueeze(-1).expand_as(scores) + return (scores * mask).view([num_tokens, num_experts]) + + +def bench(group, fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): + # Flush L2 cache with 256 MB data + paddle.device.synchronize() + cache = paddle.empty([int(256e6 // 4)], dtype=paddle.int32) + + # Warmup + for _ in range(num_warmups): + fn() + + # Flush L2 + cache.zero_() + + # Testing + start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)] + end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)] + + paddle.distributed.barrier(group) + paddle.device.synchronize() + + cpu_start = time.time() + for i in range(num_tests): + # Record + start_events[i].record() + fn() + end_events[i].record() + if post_fn is not None: + post_fn() + paddle.device.synchronize() + cpu_runtime = time.time() - cpu_start + + times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:] + return np.average(times), np.min(times), np.max(times), cpu_runtime / num_tests + + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, 'w') + self.errnull_file = open(os.devnull, 'w') + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, + trace_path: Optional[str] = None, barrier_comm_profiling: bool = False): + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress + with suppress(): + schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) as prof: + for i in range(2): + # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead + if barrier_comm_profiling: + lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + lhs @ rhs + dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) + for _ in range(num_tests): + fn() + prof.step() + + # Parse the profiling table + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tupled = isinstance(kernel_names, tuple) + prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') + kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names + assert all([isinstance(name, str) for name in kernel_names]) + for name in kernel_names: + assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' + + # Save chrome traces + if trace_path is not None: + prof.export_chrome_trace(trace_path) + + # Return average kernel times + units = {'ms': 1e3, 'us': 1e6} + kernel_times = [] + for name in kernel_names: + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + for unit, scale in units.items(): + if unit in time_str: + kernel_times.append(float(time_str.replace(unit, '')) / scale) + break + break + return tuple(kernel_times) if is_tupled else kernel_times[0] + + +dtype2str = { + paddle.float32: "_orgi_fp32", + paddle.int32: "_orgi_int", + paddle.int64: "_orgi_int64", + paddle.bool: "_orgi_bool", + paddle.bfloat16: "_orgi_bf16", + paddle.float8_e4m3fn: "_orgi_fp8" +} + + +def dump(x, name, local_rank): + dump_dir = "/root/paddlejob/workspace/env_run/liuyiqun/outputs/paddle_dump" + name = dump_dir + "/" + name + if isinstance(x, paddle.Tensor): + if x.dtype == paddle.float32 or x.dtype == paddle.int32 or x.dtype == paddle.int64 or x.dtype == paddle.bool: + y = x.numpy() + elif x.dtype == paddle.bfloat16: + y = x.view('uint16').numpy() + elif x.dtype == paddle.float8_e4m3fn: + y = x.view(paddle.uint8).numpy() + else: + assert False, f'{name}: {x.dtype} {x}' + # name += dtype2str[x.dtype] + np.save(f"{name}_rank{local_rank}.npy",y) + elif isinstance(x, tuple): + y, y_scale = x + assert y.dtype == paddle.float8_e4m3fn + assert y_scale.dtype == paddle.float32 + y_dump = y.view(paddle.uint8).numpy() + y_scale_dump = y_scale.numpy() + # np.save(f"{name}{dtype2str[y.dtype]}_value_rank{local_rank}.npy", y_dump) + # np.save(f"{name}{dtype2str[y_scale.dtype]}_scale_rank{local_rank}.npy", y_scale_dump) + np.save(f"{name}_value_rank{local_rank}.npy", y_dump) + np.save(f"{name}_scale_rank{local_rank}.npy", y_scale_dump) + elif isinstance(x, list): + y = np.asarray(x) + np.save(f"{name}_rank{local_rank}.npy", y) + elif x is None: + np.save(f"{name}_rank{local_rank}.npy", np.zeros(5)) + else: + assert False, f'{name}: {x}' + + +def retrive_dtype(name): + if "_orgi_fp32" in name: + return paddle.float32 + elif "_orgi_int" in name: + return paddle.int32 + elif "_orgi_int64" in name: + return paddle.int64 + elif "_orgi_bool" in name: + return paddle.bool + elif "_orgi_bf16" in name: + return paddle.bfloat16 + elif "_orgi_fp8" in name: + return paddle.float8_e4m3fn + else: + assert False, f"{name} with wrong dtype" + + +def load(name, local_rank, typehint="tensor"): + dump_dir = '/root/paddlejob/workspace/env_run/liuyiqun/outputs/torch_dump' + name = dump_dir + "/" + name + print(f"[local_rank={local_rank}] load {name}") + # orig_dtype = retrive_dtype(name) + # name += dtype2str[x.dtype] + if typehint == "tensor": + x_np = np.load(f'{name}_rank{local_rank}.npy') + if x_np.dtype == np.uint16: + x = paddle.to_tensor(x_np).view(paddle.bfloat16) + elif x_np.dtype == np.uint8: + x = paddle.to_tensor(x_np).view(paddle.float8_e4m3fn) + else: + x = paddle.to_tensor(x_np) + return x + elif typehint == "tuple": + y_np = np.load(f'{name}_value_rank{local_rank}.npy') + y_scale_np = np.load(f'{name}_scale_rank{local_rank}.npy') + y = paddle.to_tensor(y_np).view(paddle.float8_e4m3fn) + y_scale = paddle.to_tensor(y_scale_np) + return (y, y_scale) + else: + assert False, f'invalid typehint: {typehint}'