Skip to content
24 changes: 11 additions & 13 deletions src/executor/execution_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ namespace mscclpp {

template <typename PacketType>
void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch,
size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan,
size_t sharedMemSize, cudaStream_t stream, uint32_t flag) {
DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize,
cudaStream_t stream, uint32_t flag) {
switch (dataType) {
case DataType::INT32:
executionKernel<int32_t, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -23,7 +23,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
break;
case DataType::UINT32:
executionKernel<uint32_t, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -33,7 +33,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
break;
case DataType::FLOAT16:
executionKernel<half, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag
rank, (half*)src, (half*)dst, (half*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -43,7 +43,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
break;
case DataType::FLOAT32:
executionKernel<float, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag
rank, (float*)src, (float*)dst, (float*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -53,7 +53,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
break;
case DataType::BFLOAT16:
executionKernel<__bfloat16, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -65,12 +65,10 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
}

template void ExecutionKernel::launchKernel<LL16Packet>(int rank, int nthreadblocks, int nthreads, void* src, void* dst,
void* scratch, size_t scratchSize, DataType dataType,
DeviceExecutionPlan* plan, size_t sharedMemSize,
cudaStream_t stream, uint32_t flag);
void* scratch, DataType dataType, DeviceExecutionPlan* plan,
size_t sharedMemSize, cudaStream_t stream, uint32_t flag);
template void ExecutionKernel::launchKernel<LL8Packet>(int rank, int nthreadblocks, int nthreads, void* src, void* dst,
void* scratch, size_t scratchSize, DataType dataType,
DeviceExecutionPlan* plan, size_t sharedMemSize,
cudaStream_t stream, uint32_t flag);
void* scratch, DataType dataType, DeviceExecutionPlan* plan,
size_t sharedMemSize, cudaStream_t stream, uint32_t flag);
} // namespace mscclpp
#endif
71 changes: 43 additions & 28 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ auto getOpType = [](const std::string& str) {
return mscclpp::OperationType::WAIT;
} else if (str == "flush") {
return mscclpp::OperationType::FLUSH;
} else if (str == "re") {
} else if (str == "reduce") {
return mscclpp::OperationType::REDUCE;
} else if (str == "rs") {
return mscclpp::OperationType::REDUCE_SEND;
Expand Down Expand Up @@ -176,37 +176,27 @@ std::vector<BufferType> ExecutionPlan::Impl::getConnectedBufferTypes(int rank) c
return std::vector<BufferType>(bufferTypes.begin(), bufferTypes.end());
}

size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const {
void ExecutionPlan::Impl::calcScratchBufferSizeAndOffset(int rank, size_t inputSize, size_t outputSize, int flag) {
size_t sizePerRank = 0;
if (this->inputChunks.at(rank) != 0)
sizePerRank = inputSize / this->inputChunks.at(rank);
sizePerRank = std::min(inputSize, this->maxMessageSize) / this->inputChunks.at(rank);
else if (this->outputChunks.at(rank) != 0)
sizePerRank = outputSize / this->outputChunks.at(rank);
sizePerRank = std::min(outputSize, this->maxMessageSize) / this->outputChunks.at(rank);
else
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);

this->scratchBufferSize = sizePerRank * this->scratchChunks.at(rank);
this->scratchBufferOffset = (this->isUsingDoubleScratchBuffer && (flag % 2) == 0) ? this->scratchBufferSize : 0;
if (this->isUsingPacket) {
return sizePerRank * this->scratchChunks.at(rank) * 2 /* data + flag*/ * 2 /*double buffer*/;
this->scratchBufferSize *= 2; /* data + flag */
}
return sizePerRank * this->scratchChunks.at(rank);
}

size_t ExecutionPlan::Impl::getMaxScratchBufferSize(int rank) const {
if (this->maxMessageSize == std::numeric_limits<uint64_t>::max()) {
return std::numeric_limits<size_t>::max();
if (this->isUsingDoubleScratchBuffer) {
this->scratchBufferSize *= 2; /* double buffer */
}
size_t sizePerChunk = 0;
if (this->inputChunks.at(rank) != 0)
sizePerChunk = maxMessageSize / this->inputChunks.at(rank);
else if (this->outputChunks.at(rank) != 0)
sizePerChunk = maxMessageSize / this->outputChunks.at(rank);
else
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);

return this->getScratchBufferSize(rank, sizePerChunk * this->inputChunks.at(rank),
sizePerChunk * this->outputChunks.at(rank));
}

size_t ExecutionPlan::Impl::getScratchBufferSize() const { return this->scratchBufferSize; }

std::vector<Operation> ExecutionPlan::Impl::getOperations(int rank, int threadblock) const {
return this->operations.at(rank)[threadblock];
}
Expand All @@ -215,8 +205,9 @@ int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->oper

int ExecutionPlan::Impl::getNThreadsPerBlock() const { return this->nThreadsPerBlock; }

void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset,
size_t constDstOffset) {
void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, size_t constSrcOffset,
size_t constDstOffset, int selfRank, size_t inputBufferSize,
size_t outputBufferSize, int flag) {
std::ifstream file(this->planPath);
json obj = json::parse(file);
if (this->name != obj["name"]) {
Expand All @@ -230,6 +221,7 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize,
this->inputSize = inputSize;
this->outputSize = outputSize;
this->nThreadsPerBlock = obj.value("num_threads_per_block", 1024);
this->isUsingDoubleScratchBuffer = obj["use_double_scratch_buffer"];
this->minMessageSize = obj.value("min_message_size", 0);
this->maxMessageSize = obj.value("max_message_size", std::numeric_limits<uint64_t>::max());
this->isInPlace = obj["inplace"];
Expand All @@ -243,11 +235,13 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize,
this->chunkGroups[rank] = gpu["chunkGroups"];
}
this->setupChannels(gpus);
this->setupOperations(gpus, contsSrcOffset, constDstOffset);
this->calcScratchBufferSizeAndOffset(selfRank, inputBufferSize, outputBufferSize, flag);
this->setupOperations(gpus, constSrcOffset, constDstOffset);
}

void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset,
size_t constDstOffset) {
void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t constSrcOffset,
size_t constDstOffset, int selfRank, size_t inputBufferSize,
size_t outputBufferSize, int flag) {
std::ifstream file(this->planPath);
json obj = json::parse(file);
if (this->name != obj["name"]) {
Expand All @@ -257,6 +251,7 @@ void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t output
if (protocol == "LL") {
this->isUsingPacket = true;
}
this->isUsingDoubleScratchBuffer = obj["use_double_scratch_buffer"];
const auto& gpus = obj["gpus"];

for (const auto& gpu : gpus) {
Expand All @@ -269,7 +264,8 @@ void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t output

this->inputSize = inputSize;
this->outputSize = outputSize;
this->setupOperations(gpus, contsSrcOffset, constDstOffset);
this->calcScratchBufferSizeAndOffset(selfRank, inputBufferSize, outputBufferSize, flag);
this->setupOperations(gpus, constSrcOffset, constDstOffset);
}

void ExecutionPlan::Impl::parseChannels(
Expand Down Expand Up @@ -373,6 +369,15 @@ void ExecutionPlan::Impl::setupChannels(const json& gpus) {
}
}

void ExecutionPlan::Impl::checkChannelsPerOperation(int channels) {
if (channels > MAX_CHANNEL_PER_OPERATION) {
throw Error("Executor plan has " + std::to_string(channels) +
" channels per operation, exceeding executor support (" +
std::to_string(MAX_CHANNEL_PER_OPERATION) + ")",
ErrorCode::ExecutorError);
}
}

void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffset, size_t constDstOffset) {
auto getConstOffset = [&](BufferType type) -> size_t {
switch (type) {
Expand All @@ -381,7 +386,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
case BufferType::OUTPUT:
return constDstOffset;
case BufferType::SCRATCH:
return 0;
return this->scratchBufferOffset;
default:
throw Error("Invalid buffer type", ErrorCode::ExecutorError);
}
Expand Down Expand Up @@ -424,6 +429,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
chunkIndexes.push_back((uint32_t)op["srcoff"]);
} else {
operation.nInputs = op["i_cids"].size();
checkChannelsPerOperation(operation.nInputs);
for (int i = 0; i < operation.nInputs; i++) {
BufferType srcBufferType = convertToBufferType(op["i_buff"]["src"]);
BufferType dstBufferType = convertToBufferType(op["i_buff"]["dst"]);
Expand All @@ -440,6 +446,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
// will have either srcs or i_cids
if (op.contains("srcs")) {
operation.nInputs = op["srcs"].size();
checkChannelsPerOperation(operation.nInputs);
operation.inputBufferType = convertToBufferType(op["srcs"][0]["buff"]);
for (int i = 0; i < operation.nInputs; i++) {
operation.inputOffsets[i] =
Expand All @@ -450,6 +457,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
}
if (op.contains("o_cids")) {
operation.nOutputs = op["o_cids"].size();
checkChannelsPerOperation(operation.nOutputs);
for (int i = 0; i < operation.nOutputs; i++) {
if (operation.channelType == mscclpp::ChannelType::NVLS) {
BufferType dstBufferType = convertToBufferType(op["dstbuff"]);
Expand All @@ -471,6 +479,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
// will have either dsts or o_cids
if (op.contains("dsts")) {
operation.nOutputs = op["dsts"].size();
checkChannelsPerOperation(operation.nOutputs);
operation.outputBufferType = convertToBufferType(op["dsts"][0]["buff"]);
for (int i = 0; i < operation.nOutputs; i++) {
operation.outputOffsets[i] =
Expand All @@ -484,13 +493,19 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
}
if (op.contains("srcoff")) {
operation.srcOffset = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["srcoff"]);
if (operation.srcBufferType == BufferType::SCRATCH) {
operation.srcOffset += this->scratchBufferOffset;
}
chunkIndexes.push_back((uint32_t)op["srcoff"]);
}
if (op.contains("dstbuff")) {
operation.dstBufferType = convertToBufferType(op["dstbuff"]);
}
if (op.contains("dstoff")) {
operation.dstOffset = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["dstoff"]);
if (operation.dstBufferType == BufferType::SCRATCH) {
operation.dstOffset += this->scratchBufferOffset;
}
chunkIndexes.push_back((uint32_t)op["dstoff"]);
}
if (op.contains("cnt")) {
Expand Down
Loading