Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions tests/alltoall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import torch
import torch.distributed as dist
from typing import List, Tuple, Optional, Union

from deep_ep import Buffer, EventOverlap

# Communication buffer (will allocate at runtime)
_buffer: Optional[Buffer] = None

# Set the number of SMs to use
# NOTES: this is a static variable
# Buffer.set_num_sms(24)


# You may call this function at the framework initialization
def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer:
global _buffer

# NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests
num_nvl_bytes, num_rdma_bytes = 0, 0
for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())):
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)

# Allocate a buffer if not existed or not enough buffer size
if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes:
_buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
return _buffer


def get_hidden_bytes(x: torch.Tensor) -> int:
t = x[0] if isinstance(x, tuple) else x
return t.size(1) * max(t.element_size(), 2)


def dispatch_forward(
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
num_experts: int,
previous_event: Optional[EventOverlap] = None,
async_finish: bool = False,
allocate_on_comm_stream: bool = False
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]:
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
# refer to the docs of `Buffer.dispatch`
global _buffer

# Calculate layout before actual dispatch
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = _buffer.get_dispatch_layout(
topk_idx,
num_experts,
previous_event=previous_event,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream
)

# Do MoE dispatch
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
# For more advanced usages, please refer to the docs of the `dispatch` function
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = _buffer.dispatch(
x,
topk_idx=topk_idx,
topk_weights=topk_weights,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=num_tokens_per_expert,
previous_event=previous_event,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream
)

# For event management, please refer to the docs of the `EventOverlap` class
return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event


def dispatch_backward(
grad_recv_x: torch.Tensor,
grad_recv_topk_weights: torch.Tensor,
handle: Tuple,
previous_event: Optional[EventOverlap] = None,
async_finish: bool = False,
allocate_on_comm_stream: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, EventOverlap]:
global _buffer

# The backward process of MoE dispatch is actually a combine
# For more advanced usages, please refer to the docs of the `combine` function
combined_grad_x, combined_grad_recv_topk_weights, event = _buffer.combine(
grad_recv_x,
handle,
topk_weights=grad_recv_topk_weights,
previous_event=previous_event,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream
)

# For event management, please refer to the docs of the `EventOverlap` class
return combined_grad_x, combined_grad_recv_topk_weights, event


def combine_forward(
x: torch.Tensor,
handle: Tuple,
previous_event: Optional[EventOverlap] = None,
async_finish: bool = False,
allocate_on_comm_stream: bool = False
) -> Tuple[torch.Tensor, EventOverlap]:
global _buffer

# Do MoE combine
# For more advanced usages, please refer to the docs of the `combine` function
combined_x, _, event = _buffer.combine(
x,
handle,
async_finish=async_finish,
previous_event=previous_event,
allocate_on_comm_stream=allocate_on_comm_stream)

# For event management, please refer to the docs of the `EventOverlap` class
return combined_x, event


def combine_backward(
grad_combined_x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
handle: Tuple,
previous_event: Optional[EventOverlap] = None,
async_finish: bool = False,
allocate_on_comm_stream: bool = False
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], EventOverlap]:
global _buffer

# The backward process of MoE combine is actually a dispatch
# For more advanced usages, please refer to the docs of the `dispatch` function
grad_x, _, _, _, _, event = _buffer.dispatch(
grad_combined_x,
handle=handle,
async_finish=async_finish,
previous_event=previous_event,
allocate_on_comm_stream=allocate_on_comm_stream
)

# For event management, please refer to the docs of the `EventOverlap` class
return grad_x, event
68 changes: 68 additions & 0 deletions tests/run_test_internode.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/bin/bash

WORK_ROOT=/root/paddlejob/workspace/env_run/liuyiqun
export PYTHONPATH=${WORK_ROOT}/env/virtualenvs_cuda12.8/torch_py310_yiqun
export PATH=${PYTHONPATH}/bin:${PATH}

export PYTHONPATH=${WORK_ROOT}/PaPerf:$PYTHONPATH

#export NVSHMEM_DIR=$ROOT_DIR/third-party/nvshmem
#export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH"

START_RANK=46
END_RANK=54

if [[ ${PADDLE_TRAINER_ID} -lt $START_RANK ]]; then
exit 0
fi

if [[ ${PADDLE_TRAINER_ID} -ge $END_RANK ]]; then
exit 0
fi

rank=$(($PADDLE_TRAINER_ID - $START_RANK))
nnodes=$(($END_RANK - $START_RANK))
echo "rank: ${rank}, nnodes: ${nnodes}"

python -c "import torch; print(torch.__version__)"

#master=`cat /root/paddlejob/workspace/hostfile | head -n 1 | awk '{print $1}'`
export MASTER_ADDR="10.95.238.87" # 46
#master="10.95.238.99" # 48
#master="10.95.237.154" # 32
#master="10.95.244.212" # 8
export MASTER_PORT=8367
export WORLD_SIZE=$nnodes
export RANK=$rank

export NCCL_DEBUG=WARN
#export NVSHMEM_DEBUG=DEBUG
#export NVSHMEM_DEBUG=TRACE

# 保证集群稳定性的配置,跟性能无关
export NCCL_IB_QPS_PER_CONNECTION=8
#export NCCL_IB_TIMEOUT=22
export NCCL_IB_GID_INDEX=3
export NCCL_NVLS_ENABLE=0
# 开启AR功能
#export NCCL_IB_ADAPTIVE_ROUTING=1

export NCCL_IB_GID_INDEX=3
export NVSHMEM_IB_GID_INDEX=3
export NVSHMEM_IB_TRAFFIC_CLASS=162

#export NVSHMEM_IB_ENABLE_IBGDA=true
#export NVSHMEM_DISABLE_P2P=1
export NVSHMEM_BOOTSTRAP=UID
export NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME=xgbe0
#export NVSHMEM_BOOTSTRAP_UID_SOCK_FAMILY=AF_INET

#export NVSHMEM_DEBUG=INFO

export PATH=/opt/nvidia/nsight-systems/2025.1.1/bin:$PATH
#nsys_args="nsys profile --stats true -w true -t cuda,nvtx,cudnn,cublas --capture-range=cudaProfilerApi -x true --force-overwrite true -o test_simple_kernel_${WORLD_SIZE}.torch"
#nsys_args="nsys profile --stats true -w true -t cuda,nvtx --nic-metrics=true --capture-range=cudaProfilerApi -x true --force-overwrite true -o test_internode_${WORLD_SIZE}nodes_rank${RANK}.torch"

rm -rf core.*

${nsys_args} python test_internode_latency.py
12 changes: 12 additions & 0 deletions tests/run_test_intranode.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash

WORK_ROOT=/root/paddlejob/workspace/env_run/liuyiqun
export PYTHONPATH=${WORK_ROOT}/env/virtualenvs_cuda12.8/torch_py310_yiqun
export PATH=${PYTHONPATH}/bin:${PATH}

export MASTER_ADDR=10.54.98.83
export MASTER_PORT=8364
export WORLD_SIZE=1
export RANK=$PADDLE_TRAINER_ID

python test_intranode.py
Loading