Skip to content

Commit 0ed0c8c

Browse files
committed
update switch channel
1 parent 25435ac commit 0ed0c8c

15 files changed

+212
-114
lines changed

docs/guide/mscclpp-torch-integration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class CustomizedComm:
129129
self._algo_large = [
130130
algo for algo in algorithms
131131
if algo.collective == "allreduce"
132-
and algo.name == "default_allreduce_nvls_with_copy"
132+
and algo.name == "default_allreduce_nvls_warp_pipeline"
133133
][0]
134134

135135
def all_reduce(self, tensor: torch.Tensor, stream=None):

examples/torch-integration/customized_comm_with_default_algo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, comm: mscclpp.CommGroup):
6161
self._algorithm_nvls_nonzero_copy = [
6262
algo
6363
for algo in algorithms
64-
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_with_copy"
64+
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_warp_pipeline"
6565
][0]
6666

6767
def all_reduce(self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM, stream: torch.cuda.Stream = None):

src/core/gpu_ipc_mem.cc

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,13 @@ UniqueGpuIpcMemHandle GpuIpcMemHandle::createMulticast([[maybe_unused]] size_t b
249249
}
250250

251251
if (handle->typeFlags == GpuIpcMemHandle::Type::None) {
252+
cuMemRelease(allocHandle);
252253
THROW(GPU, Error, ErrorCode::SystemError, "createMulticast failed: neither POSIX FD nor FABRIC handle was created");
253254
}
255+
256+
// Release the local allocation handle. The exported POSIX FD / Fabric handle keeps the
257+
// multicast object alive. Each importer will get its own handle via cuMemImportFromShareableHandle.
258+
MSCCLPP_CUTHROW(cuMemRelease(allocHandle));
254259
return handle;
255260
#else // !(CUDA_NVLS_API_AVAILABLE)
256261
THROW(GPU, Error, ErrorCode::InvalidUsage,
@@ -418,41 +423,45 @@ std::shared_ptr<void> GpuIpcMem::mapMulticast([[maybe_unused]] int numDevices, [
418423
// This will block until all devices call cuMulticastAddDevice()
419424
MSCCLPP_CUTHROW(cuMulticastBindAddr(allocHandle_, mcOffset, bufferAddr, bufferSize, 0));
420425

426+
// cuMemMap requires offset to be 0 for multicast handles, so we map the entire range
427+
// [0, mcOffset + bufferSize) and return a pointer at mcPtr + mcOffset. This only consumes
428+
// extra virtual address space for the mcOffset region; no additional physical memory is used.
429+
size_t mapSize = mcOffset + bufferSize;
421430
CUdeviceptr mcPtr;
422-
MSCCLPP_CUTHROW(cuMemAddressReserve(&mcPtr, bufferSize, minMcGran, 0U, 0));
423-
MSCCLPP_CUTHROW(cuMemMap(mcPtr, bufferSize, 0, allocHandle_, 0));
431+
MSCCLPP_CUTHROW(cuMemAddressReserve(&mcPtr, mapSize, minMcGran, 0U, 0));
432+
MSCCLPP_CUTHROW(cuMemMap(mcPtr, mapSize, 0, allocHandle_, 0));
424433

425434
CUmemAccessDesc accessDesc = {};
426435
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
427436
accessDesc.location.id = deviceId;
428437
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
429-
MSCCLPP_CUTHROW(cuMemSetAccess(mcPtr, bufferSize, &accessDesc, 1));
438+
MSCCLPP_CUTHROW(cuMemSetAccess(mcPtr, mapSize, &accessDesc, 1));
430439

431440
// Return shared_ptr with custom deleter that unmaps and unbinds
432441
CUmemGenericAllocationHandle allocHandle = allocHandle_;
433-
return std::shared_ptr<void>(
434-
reinterpret_cast<void*>(mcPtr), [self = shared_from_this(), mcOffset, bufferSize, allocHandle](void* ptr) {
435-
CUresult res;
436-
const char* errStr;
437-
438-
res = cuMemUnmap((CUdeviceptr)ptr, bufferSize);
439-
if (res != CUDA_SUCCESS) {
440-
(void)cuGetErrorString(res, &errStr);
441-
WARN(GPU, "Failed to unmap CUDA memory at pointer ", (void*)ptr, ": ", errStr);
442-
}
443-
444-
res = cuMemAddressFree((CUdeviceptr)ptr, bufferSize);
445-
if (res != CUDA_SUCCESS) {
446-
(void)cuGetErrorString(res, &errStr);
447-
WARN(GPU, "Failed to free CUDA memory at pointer ", (void*)ptr, ": ", errStr);
448-
}
449-
450-
int deviceId;
451-
CUdevice device;
452-
if (cudaGetDevice(&deviceId) == cudaSuccess && cuDeviceGet(&device, deviceId) == CUDA_SUCCESS) {
453-
(void)cuMulticastUnbind(allocHandle, device, mcOffset, bufferSize);
454-
}
455-
});
442+
return std::shared_ptr<void>(reinterpret_cast<void*>(mcPtr + mcOffset), [self = shared_from_this(), mcPtr, mapSize,
443+
mcOffset, bufferSize, allocHandle](void*) {
444+
CUresult res;
445+
const char* errStr;
446+
447+
res = cuMemUnmap(mcPtr, mapSize);
448+
if (res != CUDA_SUCCESS) {
449+
(void)cuGetErrorString(res, &errStr);
450+
WARN(GPU, "Failed to unmap CUDA memory at pointer ", (void*)mcPtr, ": ", errStr);
451+
}
452+
453+
res = cuMemAddressFree(mcPtr, mapSize);
454+
if (res != CUDA_SUCCESS) {
455+
(void)cuGetErrorString(res, &errStr);
456+
WARN(GPU, "Failed to free CUDA memory at pointer ", (void*)mcPtr, ": ", errStr);
457+
}
458+
459+
int deviceId;
460+
CUdevice device;
461+
if (cudaGetDevice(&deviceId) == cudaSuccess && cuDeviceGet(&device, deviceId) == CUDA_SUCCESS) {
462+
(void)cuMulticastUnbind(allocHandle, device, mcOffset, bufferSize);
463+
}
464+
});
456465
#else // !(CUDA_NVLS_API_AVAILABLE)
457466
THROW(GPU, Error, ErrorCode::InvalidUsage,
458467
"NVLS is not supported on this device (requires CUDA version >= 12.3 and Linux kernel version >= 5.6.0)");

src/ext/collectives/algorithm_collection_builder.cc

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
#include "allgather/allgather_fullmesh_2.hpp"
99
#include "allreduce/allreduce_allpair_packet.hpp"
1010
#include "allreduce/allreduce_fullmesh.hpp"
11-
#include "allreduce/allreduce_nvls.hpp"
11+
#include "allreduce/allreduce_nvls_zero_copy.hpp"
1212
#include "allreduce/allreduce_nvls_packet.hpp"
13-
#include "allreduce/allreduce_nvls_with_copy.hpp"
14-
#include "allreduce/allreduce_nvls_with_copy_2.hpp"
13+
#include "allreduce/allreduce_nvls_warp_pipeline.hpp"
14+
#include "allreduce/allreduce_nvls_block_pipeline.hpp"
1515
#include "allreduce/allreduce_packet.hpp"
1616
#include "allreduce/allreduce_rsag.hpp"
1717
#include "allreduce/allreduce_rsag_pipeline.hpp"
@@ -72,12 +72,14 @@ AlgorithmCollection AlgorithmCollectionBuilder::buildDefaultNativeAlgorithms(uin
7272
auto allreduceNvlsPacket =
7373
std::make_shared<AllreduceNvlsPacket>(scratchBuffer, scratchBufferSize, flagBuffer, flagBufferSize)->build();
7474
collection.registerAlgorithm(allreduceNvlsPacket->collective(), allreduceNvlsPacket->name(), allreduceNvlsPacket);
75-
auto allreduceNvlsWithCopy = std::make_shared<AllreduceNvlsWithCopy>(scratchBuffer, scratchBufferSize)->build();
76-
collection.registerAlgorithm(allreduceNvlsWithCopy->collective(), allreduceNvlsWithCopy->name(),
77-
allreduceNvlsWithCopy);
78-
auto allreduceNvlsWithCopy2 = std::make_shared<AllreduceNvlsWithCopy2>(scratchBuffer, scratchBufferSize)->build();
79-
collection.registerAlgorithm(allreduceNvlsWithCopy2->collective(), allreduceNvlsWithCopy2->name(),
80-
allreduceNvlsWithCopy2);
75+
auto allreduceNvlsWarpPipeline =
76+
std::make_shared<AllreduceNvlsWarpPipeline>(scratchBuffer, scratchBufferSize)->build();
77+
collection.registerAlgorithm(allreduceNvlsWarpPipeline->collective(), allreduceNvlsWarpPipeline->name(),
78+
allreduceNvlsWarpPipeline);
79+
auto allreduceNvlsBlockPipeline =
80+
std::make_shared<AllreduceNvlsBlockPipeline>(scratchBuffer, scratchBufferSize)->build();
81+
collection.registerAlgorithm(allreduceNvlsBlockPipeline->collective(), allreduceNvlsBlockPipeline->name(),
82+
allreduceNvlsBlockPipeline);
8183
auto allreducePkt =
8284
std::make_shared<AllreducePacket>(scratchBuffer, scratchBufferSize, flagBuffer, flagBufferSize)->build();
8385
collection.registerAlgorithm(allreducePkt->collective(), allreducePkt->name(), allreducePkt);

src/ext/collectives/allreduce/allreduce_nvls_with_copy_2.cu renamed to src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include <mscclpp/algorithm.hpp>
55

6-
#include "allreduce/allreduce_nvls_with_copy_2.hpp"
6+
#include "allreduce/allreduce_nvls_block_pipeline.hpp"
77
#include "allreduce/common.hpp"
88
#include "collective_utils.hpp"
99
#include "debug.h"
@@ -15,11 +15,12 @@ __device__ DeviceSemaphore deviceSemaphore[NUM_SEMAPHORES];
1515

1616
template <typename T>
1717
__global__ void __launch_bounds__(1024, 1)
18-
allreduceNvlsWithCopy2([[maybe_unused]] const void* src, [[maybe_unused]] void* scratch, [[maybe_unused]] void* dst,
19-
[[maybe_unused]] DeviceHandle<BaseMemoryChannel>* memoryChannels,
20-
[[maybe_unused]] DeviceHandle<SwitchChannel>* switchChannels, [[maybe_unused]] size_t size,
21-
[[maybe_unused]] size_t scratchBufferSize, [[maybe_unused]] int rank,
22-
[[maybe_unused]] int nRanksPerNode) {
18+
allreduceNvlsBlockPipeline([[maybe_unused]] const void* src, [[maybe_unused]] void* scratch,
19+
[[maybe_unused]] void* dst,
20+
[[maybe_unused]] DeviceHandle<BaseMemoryChannel>* memoryChannels,
21+
[[maybe_unused]] DeviceHandle<SwitchChannel>* switchChannels,
22+
[[maybe_unused]] size_t size, [[maybe_unused]] size_t scratchBufferSize,
23+
[[maybe_unused]] int rank, [[maybe_unused]] int nRanksPerNode) {
2324
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
2425
constexpr int alignment = 16;
2526
int nPeers = nRanksPerNode - 1;
@@ -146,7 +147,7 @@ __global__ void __launch_bounds__(1024, 1)
146147
}
147148

148149
template <ReduceOp OpType, typename T>
149-
struct NvlsWithCopy2Adapter {
150+
struct NvlsBlockPipelineAdapter {
150151
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void*,
151152
DeviceHandle<SwitchChannel>* nvlsChannels, DeviceHandle<SwitchChannel>*, size_t, size_t,
152153
size_t scratchBufferSize, int rank, int nRanksPerNode, int, size_t inputSize,
@@ -162,15 +163,15 @@ struct NvlsWithCopy2Adapter {
162163
#endif
163164
{
164165
using ChannelType = DeviceHandle<BaseMemoryChannel>;
165-
allreduceNvlsWithCopy2<T>
166-
<<<nBlocks, nThreadsPerBlock, 0, stream>>>(input, scratch, output, (ChannelType*)memoryChannels,
167-
nvlsChannels, inputSize, scratchBufferSize, rank, nRanksPerNode);
166+
allreduceNvlsBlockPipeline<T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
167+
input, scratch, output, (ChannelType*)memoryChannels, nvlsChannels, inputSize, scratchBufferSize, rank,
168+
nRanksPerNode);
168169
return cudaGetLastError();
169170
}
170171
}
171172
};
172173

173-
void AllreduceNvlsWithCopy2::initialize(std::shared_ptr<Communicator> comm) {
174+
void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr<Communicator> comm) {
174175
nSwitchChannels_ = 8;
175176
int nBaseChannels = 64;
176177
this->conns_ = setupConnections(comm);
@@ -180,14 +181,15 @@ void AllreduceNvlsWithCopy2::initialize(std::shared_ptr<Communicator> comm) {
180181
// setup base memory channels
181182
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nBaseChannels);
182183
this->memoryChannelsDeviceHandle_ = setupBaseMemoryChannelDeviceHandles(this->baseChannels_);
184+
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
183185
}
184186

185-
CommResult AllreduceNvlsWithCopy2::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input,
186-
void* output, size_t inputSize, DataType dtype, ReduceOp op,
187-
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
188-
const std::unordered_map<std::string, uintptr_t>&) {
187+
CommResult AllreduceNvlsBlockPipeline::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input,
188+
void* output, size_t inputSize, DataType dtype, ReduceOp op,
189+
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
190+
const std::unordered_map<std::string, uintptr_t>&) {
189191
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
190-
AllreduceFunc allreduce = dispatch<NvlsWithCopy2Adapter>(op, dtype);
192+
AllreduceFunc allreduce = dispatch<NvlsBlockPipelineAdapter>(op, dtype);
191193
if (!allreduce) {
192194
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
193195
return CommResult::CommInvalidArgument;
@@ -201,35 +203,35 @@ CommResult AllreduceNvlsWithCopy2::allreduceKernelFunc(const std::shared_ptr<voi
201203
ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, stream, nullptr, 0, 0,
202204
blockAndThreadNum.first, blockAndThreadNum.second);
203205
if (error != cudaSuccess) {
204-
WARN("AllreduceNvlsWithCopy failed with error: %s", cudaGetErrorString(error));
206+
WARN("AllreduceNvlsBlockPipeline failed with error: %s", cudaGetErrorString(error));
205207
return CommResult::CommUnhandledCudaError;
206208
}
207209
return CommResult::CommSuccess;
208210
}
209211

210-
AlgorithmCtxKey AllreduceNvlsWithCopy2::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
212+
AlgorithmCtxKey AllreduceNvlsBlockPipeline::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
211213
return AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
212214
}
213215

214-
std::shared_ptr<void> AllreduceNvlsWithCopy2::initAllreduceContext(std::shared_ptr<Communicator> comm, const void*,
215-
void*, size_t, DataType) {
216+
std::shared_ptr<void> AllreduceNvlsBlockPipeline::initAllreduceContext(std::shared_ptr<Communicator> comm,
217+
const void*, void*, size_t, DataType) {
216218
auto ctx = std::make_shared<AlgorithmCtx>();
217219
ctx->rank = comm->bootstrap()->getRank();
218220
ctx->workSize = comm->bootstrap()->getNranks();
219221
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
220222

221223
// setup channels
222-
ctx->nvlsConnections = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
223224
ctx->switchChannels =
224-
setupNvlsChannels(ctx->nvlsConnections, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_);
225+
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_);
225226
ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
226227
return ctx;
227228
}
228229

229-
std::shared_ptr<Algorithm> AllreduceNvlsWithCopy2::build() {
230-
auto self = std::make_shared<AllreduceNvlsWithCopy2>(reinterpret_cast<uintptr_t>(scratchBuffer_), scratchBufferSize_);
230+
std::shared_ptr<Algorithm> AllreduceNvlsBlockPipeline::build() {
231+
auto self =
232+
std::make_shared<AllreduceNvlsBlockPipeline>(reinterpret_cast<uintptr_t>(scratchBuffer_), scratchBufferSize_);
231233
return std::make_shared<NativeAlgorithm>(
232-
"default_allreduce_nvls_with_copy2", "allreduce",
234+
"default_allreduce_nvls_block_pipeline", "allreduce",
233235
[self](std::shared_ptr<Communicator> comm) { self->initialize(comm); },
234236
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
235237
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
@@ -247,4 +249,4 @@ std::shared_ptr<Algorithm> AllreduceNvlsWithCopy2::build() {
247249
}
248250

249251
} // namespace collective
250-
} // namespace mscclpp
252+
} // namespace mscclpp

src/ext/collectives/allreduce/allreduce_nvls_packet.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ struct AllreduceNvlsPacketAdapter {
7575
}
7676
};
7777

78-
void AllreduceNvlsPacket::initialize(std::shared_ptr<Communicator>) {}
78+
void AllreduceNvlsPacket::initialize(std::shared_ptr<Communicator> comm) {
79+
int nSwitchChannels = 1;
80+
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels);
81+
}
7982

8083
AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
8184
return AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
@@ -90,9 +93,8 @@ std::shared_ptr<void> AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr<
9093

9194
// setup channels
9295
int nSwitchChannels = 1;
93-
ctx->nvlsConnections = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels);
9496
ctx->switchChannels =
95-
setupNvlsChannels(ctx->nvlsConnections, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels);
97+
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels);
9698
ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
9799
return ctx;
98100
}

0 commit comments

Comments
 (0)