Skip to content
Merged
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ add_definitions(-DAVX512_FP32_WEIGHT_ONLY_NF4=true)

# add_definitions(-DDEBUG=true)
# add_definitions(-DSTEP_BY_STEP_ATTN=true)
add_definitions(-DUSE_SHM=true)
# add_definitions(-DUSE_SHM=true)
option(XFT_BUILD_TESTS "Build xfastertransformer unit tests" OFF)

# timeline event
Expand Down
36 changes: 26 additions & 10 deletions src/comm_helper/comm_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,45 @@

static ccl::communicator *pcomm;

extern "C" int init(int *rank, int *size) {
// world_color is initialized to pipeline_parallel_stages_num(pp_size)
// and will be re-assign to world_color of MPI
extern "C" int init(int *world_size, int *world_rank, int *world_color) {
ccl::init();

MPI_Init(NULL, NULL);
MPI_Comm_size(MPI_COMM_WORLD, size);
MPI_Comm_rank(MPI_COMM_WORLD, rank);
MPI_Comm_size(MPI_COMM_WORLD, world_size);
MPI_Comm_rank(MPI_COMM_WORLD, world_rank);

// world_color = world_rank / tpSize = world_rank / (world_size / ppSize)
// like: world_color = 0~7 / (8 / 4), XFT_PIPELINE_STAGES = ppSize = 4; tpSize = 2
// world_rank = 0, 1, -> world_color = ppRank = 0, 0, -> tpRank = 0, 1;
// 2, 3, 1, 1, 0, 1;
// 4, 5, 2, 2, 0, 1;
// 6, 7; 3, 3; 0, 1;
*world_color = *world_rank / (*world_size / *world_color);
MPI_Comm row_comm;
MPI_Comm_split(MPI_COMM_WORLD, *world_color, *world_rank, &row_comm);

int row_size, row_rank;
MPI_Comm_size(row_comm, &row_size);
MPI_Comm_rank(row_comm, &row_rank);

ccl::shared_ptr_class<ccl::kvs> kvs;
ccl::kvs::address_type mainAddr;

if (*rank == 0) {
if (row_rank == 0) {
kvs = ccl::create_main_kvs();
mainAddr = kvs->get_address();
MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, MPI_COMM_WORLD);
MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, row_comm);
} else {
MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, MPI_COMM_WORLD);
MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, row_comm);
kvs = ccl::create_kvs(mainAddr);
}

pcomm = new ccl::communicator(ccl::create_communicator(*size, *rank, kvs));
pcomm = new ccl::communicator(ccl::create_communicator(row_size, row_rank, kvs));

*rank = pcomm->rank();
*size = pcomm->size();
*world_size = pcomm->size();
*world_rank = pcomm->rank();

#ifdef USE_SHM
char myHostname[MPI_MAX_PROCESSOR_NAME];
Expand All @@ -53,7 +69,7 @@ extern "C" int init(int *rank, int *size) {
MPI_COMM_WORLD);

int sameHostnames = 1;
for (int i = 1; i < *size; i++) {
for (int i = 1; i < *world_size; i++) {
if (strcmp(myHostname, &all_hostnames[i * MPI_MAX_PROCESSOR_NAME]) != 0) {
sameHostnames = 0;
break;
Expand Down
12 changes: 11 additions & 1 deletion src/common/transformer_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ struct DecoderContext {
// # of splits (the same as NUMA node number in the system)
const int numSplit;

// For pipeline parallel and tensor parallel config
int ppSize = 1; // pipeline parallel stage size
int ppRank = 0; // pipeline parallel stage rank
int tpSize = 1; // tensor parallel size
int tpRank = 0; // tensor parallel rank

enum ActivationType { RELU, GELU, SWIGLU, SILU };
ActivationType actType;

Expand All @@ -105,7 +111,7 @@ struct DecoderContext {
public:
DecoderContext(int _layers, int _hiddenSize, int _attHeadNum, int _kvHeadNum, int _imSize, const std::string &act,
float epsilon, int _vocabSize, int _embeddingSize, int _maxPositions, int _maxPosEmbed, int _maxSeqLength,
int _splitIdx, int _splits, RopeParams *_ropeParamsPtr = nullptr, int numThreads = 0)
int _splitIdx, int _splits, int _ppSize = 1, int _ppRank = 0, RopeParams *_ropeParamsPtr = nullptr, int numThreads = 0)
: layers(_layers)
, hiddenSize(_hiddenSize)
, intermediateSize(_imSize)
Expand All @@ -119,6 +125,10 @@ struct DecoderContext {
, ropeParamsPtr(_ropeParamsPtr)
, splitIdx(_splitIdx)
, numSplit(_splits)
, ppSize(_ppSize)
, ppRank(_ppRank)
, tpSize(_splits)
, tpRank(_splitIdx)
, epsilon(epsilon) {
if (attHeadNum != 0) {
this->attHeadSize = hiddenSize / attHeadNum;
Expand Down
6 changes: 4 additions & 2 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ class Attention {
imBuffer.Assign(inputBuffer.Data(), inputBuffer.Rows(), inputBuffer.Cols(), inputBuffer.Stride());
inputBuffer.Assign(tmp, rows, cols, stride);
}

// TODO: refine the logic (and support large inputSeqLen when pastSeqLen > 0)
if constexpr (std::is_same_v<InT, bfloat16_t> && std::is_same_v<OutT, bfloat16_t>) {
if (pastSeqLen == 0) {
Expand All @@ -284,8 +285,9 @@ class Attention {
if (ctx->inputSeqLen >= 1024 && pastSeqLen == 0)
flashAttention(
ctx, qkvGroupMatMul, outBuffer, imBuffer, presentKey, presentValue, attnMask, pastSeqLen);
else
else {
fusedAttention(ctx, query, key, value, imBuffer, presentKey, presentValue, attnMask, pastSeqLen);
}
}
t4.release();

Expand Down Expand Up @@ -375,7 +377,7 @@ class Attention {
// to make sure it works better (the logic here is trying to make sure each head of BMM result [seq * seq] in cache)
// WARN: reserve field in context is used to make it effective for all layers, do not change it in other places
int &mBlockSize = ctx->reserved1;
if (layerId == 0) {
if (layerId % (ctx->layers / ctx->ppSize) == 0) {
// TODO: if pastSeqLen > 0 and inputSeqLen large.
if (pastSeqLen == 0) {
const int l2CacheSize = 2 * 1024 * 1024; // TODO: get it dynamically
Expand Down
5 changes: 5 additions & 0 deletions src/models/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
# ============================================================================
cmake_minimum_required(VERSION 3.15.1)

find_package(MPI REQUIRED)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If oneCCL is not present in the user's environment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是,我环境中已经有oneCCL,但是model src中报没有MPI库

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在解耦了吧,src/models里面的代码不依赖于oneCCL 和MPI

include_directories(${MPI_INCLUDE_PATH})
add_definitions(${MPI_CXX_COMPILE_FLAGS})

aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} MODEL_SRCS)

add_library(models OBJECT ${MODEL_SRCS})
add_dependencies(models utils)
target_link_libraries(models ${MPI_CXX_LIBRARIES})
Loading