Skip to content

Commit ead8aec

Browse files
authored
[kernel] Add GPU compiler. (#228)
1 parent 084dc20 commit ead8aec

File tree

12 files changed

+75
-54
lines changed

12 files changed

+75
-54
lines changed

CMakeLists.txt

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,24 @@
1616
cmake_minimum_required(VERSION 3.15.1)
1717
project(xfastertransformer LANGUAGES C CXX)
1818

19-
# Get gcc version
20-
execute_process(COMMAND ${CMAKE_CXX_COMPILER} -dumpfullversion
21-
OUTPUT_VARIABLE GCC_VERSION
22-
OUTPUT_STRIP_TRAILING_WHITESPACE)
23-
message(STATUS "Notice: GCC version: ${GCC_VERSION}")
19+
# Enable GPU
20+
option(WITH_GPU "Build with GPU" OFF)
21+
if(WITH_GPU)
22+
message(STATUS "Notice: Building with GPU.")
23+
add_definitions(-DGPU=true)
24+
# Get compiler version
25+
execute_process(COMMAND ${CMAKE_CXX_COMPILER} --version
26+
OUTPUT_VARIABLE ICPX_VERSION
27+
OUTPUT_STRIP_TRAILING_WHITESPACE)
28+
message(STATUS "Notice: ICPX version: ${ICPX_VERSION}")
29+
else()
30+
message(STATUS "Notice: Building with CPU.")
31+
# Get compiler version
32+
execute_process(COMMAND ${CMAKE_CXX_COMPILER} -dumpfullversion
33+
OUTPUT_VARIABLE GCC_VERSION
34+
OUTPUT_STRIP_TRAILING_WHITESPACE)
35+
message(STATUS "Notice: GCC version: ${GCC_VERSION}")
36+
endif()
2437

2538
if(NOT CMAKE_BUILD_TYPE)
2639
set(CMAKE_BUILD_TYPE Release)
@@ -29,6 +42,9 @@ endif()
2942
set(CMAKE_CXX_STANDARD 17)
3043
set(CMAKE_CXX_FLAGS
3144
"${CMAKE_CXX_FLAGS} -fopenmp -mavx512f -mavx512bw -mavx512vl -fPIC -D_GLIBCXX_USE_CXX11_ABI=0")
45+
if(WITH_GPU)
46+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -fsycl-device-code-split=per_kernel -lOpenCL")
47+
endif()
3248

3349
# GCC>=10.1 should support avx512bf16, but need to double check as some versions have issues
3450
if(GCC_VERSION VERSION_GREATER_EQUAL "10.1")

cmake/onednn.cmake

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,18 @@ project(dependency NONE)
2424

2525
include(ExternalProject)
2626

27+
set(ONEDNN_BUILD_OPTIONS -DONEDNN_LIBRARY_TYPE=STATIC -DONEDNN_BUILD_TESTS=OFF -DONEDNN_BUILD_EXAMPLES=OFF)
28+
if(WITH_GPU)
29+
set(ONEDNN_BUILD_OPTIONS "${ONEDNN_BUILD_OPTIONS} -DONEDNN_GPU_RUNTIME=SYCL")
30+
endif()
31+
2732
# cmake-format: off
2833
ExternalProject_Add(onednn
2934
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
30-
GIT_TAG v3.2
35+
GIT_TAG v3.3.3
3136
SOURCE_DIR ${CMAKE_SOURCE_DIR}/3rdparty/onednn
3237
BINARY_DIR ${CMAKE_SOURCE_DIR}/3rdparty/onednn
33-
CONFIGURE_COMMAND ${CMAKE_COMMAND} -E make_directory "build" && ${CMAKE_COMMAND} -E chdir "build" ${CMAKE_COMMAND} -DONEDNN_LIBRARY_TYPE=STATIC -DONEDNN_BUILD_TESTS=OFF -DONEDNN_BUILD_EXAMPLES=OFF ..
38+
CONFIGURE_COMMAND ${CMAKE_COMMAND} -E make_directory "build" && ${CMAKE_COMMAND} -E chdir "build" ${CMAKE_COMMAND} ${ONEDNN_BUILD_OPTIONS} ..
3439
BUILD_COMMAND ${CMAKE_COMMAND} -E chdir "build" make -j all
3540
INSTALL_COMMAND ""
3641
TEST_COMMAND ""

examples/cpp/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,8 @@ else()
3232
target_link_libraries(example PRIVATE xfastertransformer_static)
3333
endif()
3434
target_link_libraries(example PRIVATE sentencepiece -lstdc++fs)
35+
if(WITH_GPU)
36+
target_link_libraries(example PRIVATE -fsycl -fsycl-device-code-split=per_kernel -lOpenCL)
37+
endif()
3538

3639
add_dependencies(example cmdline sentencepiece_lib)

include/dtype.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ enum DataType {
3434
};
3535

3636
enum DeviceKind {
37-
CPU = 0,
38-
GPU,
37+
iCPU = 0,
38+
iGPU,
3939
};
4040
} // namespace xft

src/common/my_types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
typedef int8_t s8;
2323
typedef uint8_t u8;
2424

25-
typedef struct {
25+
typedef struct w8a8 {
2626
int8_t s8;
2727
operator int8_t() { return s8; }
2828
} w8a8_t;

src/models/chatglm2.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
#include "INIReader.h"
1919
#include "chatglm2.h"
2020

21-
template <typename WeiT, typename NormT>
22-
ChatGLM2<WeiT, NormT>::ChatGLM2(const std::string &modelPath, const std::string &modelType)
23-
: CommonDecoder<Attention<WeiT, ChatGLM2RotaryEmbedding, NormT, float, float, float, true>,
24-
ChatGLM2MLP<WeiT, float, float, float, NormT, true>>(modelPath, modelType) {
21+
template <typename WeiT>
22+
ChatGLM2<WeiT>::ChatGLM2(const std::string &modelPath, const std::string &modelType)
23+
: CommonDecoder<Attention<WeiT, ChatGLM2RotaryEmbedding, RmsNorm, float, float, float, true>,
24+
ChatGLM2MLP<WeiT, float, float, float, RmsNorm, true>>(modelPath, modelType) {
2525
this->positionIds = nullptr;
2626
this->posBufSize = 0;
2727

@@ -36,15 +36,15 @@ ChatGLM2<WeiT, NormT>::ChatGLM2(const std::string &modelPath, const std::string
3636
setFinalLnWeight(modelPath);
3737
}
3838

39-
template <typename WeiT, typename NormT>
40-
ChatGLM2<WeiT, NormT>::~ChatGLM2() {
39+
template <typename WeiT>
40+
ChatGLM2<WeiT>::~ChatGLM2() {
4141
delete embedding;
4242

4343
if (positionIds) { free(positionIds); }
4444
}
4545

46-
template <typename WeiT, typename NormT>
47-
void ChatGLM2<WeiT, NormT>::setEmbeddingWeights(const std::string &modelPath) {
46+
template <typename WeiT>
47+
void ChatGLM2<WeiT>::setEmbeddingWeights(const std::string &modelPath) {
4848
int vocabSize = embedding->getVocabSize();
4949
int hiddenSize = embedding->getHiddenSize();
5050

@@ -57,8 +57,8 @@ void ChatGLM2<WeiT, NormT>::setEmbeddingWeights(const std::string &modelPath) {
5757
free(tokenEmb);
5858
}
5959

60-
template <typename WeiT, typename NormT>
61-
void ChatGLM2<WeiT, NormT>::setFinalLnWeight(const std::string &modelPath) {
60+
template <typename WeiT>
61+
void ChatGLM2<WeiT>::setFinalLnWeight(const std::string &modelPath) {
6262
int hiddenSize = embedding->getHiddenSize();
6363

6464
float *gamma = (float *)malloc(hiddenSize * sizeof(float));
@@ -85,8 +85,8 @@ void ChatGLM2<WeiT, NormT>::setFinalLnWeight(const std::string &modelPath) {
8585
// attention_mask = (attention_mask < 0.5).bool()
8686
//
8787
// return attention_mask
88-
template <typename WeiT, typename NormT>
89-
void ChatGLM2<WeiT, NormT>::prepareAttnMask(int *ids, int step) {
88+
template <typename WeiT>
89+
void ChatGLM2<WeiT>::prepareAttnMask(int *ids, int step) {
9090
DecoderContext *ctx = this->getContext();
9191
int seqLen = ctx->inputSeqLen;
9292
int sizeRequired = ctx->batchSize * seqLen * seqLen;
@@ -127,13 +127,13 @@ void ChatGLM2<WeiT, NormT>::prepareAttnMask(int *ids, int step) {
127127
}
128128
}
129129

130-
template <typename WeiT, typename NormT>
131-
void ChatGLM2<WeiT, NormT>::embeddingForward(int *ids, float *output, int batchSize, int seqLen) {
130+
template <typename WeiT>
131+
void ChatGLM2<WeiT>::embeddingForward(int *ids, float *output, int batchSize, int seqLen) {
132132
embedding->forward(ids, output, batchSize, seqLen);
133133
}
134134

135-
template <typename WeiT, typename NormT>
136-
void ChatGLM2<WeiT, NormT>::lastLayerNormForward(float *input, float *output, int rows) {
135+
template <typename WeiT>
136+
void ChatGLM2<WeiT>::lastLayerNormForward(float *input, float *output, int rows) {
137137
finalLN.forward(input, output, rows);
138138
}
139139

@@ -147,8 +147,8 @@ void ChatGLM2<WeiT, NormT>::lastLayerNormForward(float *input, float *output, in
147147
// batch_size, seq_length = input_ids.shape
148148
// position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
149149
// return position_ids
150-
template <typename WeiT, typename NormT>
151-
int *ChatGLM2<WeiT, NormT>::getPositionIds(int *ids, int batchSize, int seqLen, int step) {
150+
template <typename WeiT>
151+
int *ChatGLM2<WeiT>::getPositionIds(int *ids, int batchSize, int seqLen, int step) {
152152
// Prepare buffer
153153
int sizeNeeded = (batchSize * seqLen + 63) / 64 * 64; // position_ids + block_position_ids
154154
if (posBufSize < sizeNeeded) {

src/models/chatglm2.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
#include "rotary_embedding_chatglm2.h"
2323
#include "token_embedding.h"
2424

25-
template <typename WeiT, typename NormT = RmsNorm>
26-
class ChatGLM2 : public CommonDecoder<Attention<WeiT, ChatGLM2RotaryEmbedding, NormT, float, float, float, true>,
27-
ChatGLM2MLP<WeiT, float, float, float, NormT, true>> {
25+
template <typename WeiT>
26+
class ChatGLM2 : public CommonDecoder<Attention<WeiT, ChatGLM2RotaryEmbedding, RmsNorm, float, float, float, true>,
27+
ChatGLM2MLP<WeiT, float, float, float, RmsNorm, true>> {
2828
public:
2929
ChatGLM2(const std::string &modelPath, const std::string &modelType = "chatglm2");
3030
~ChatGLM2();
@@ -40,7 +40,7 @@ class ChatGLM2 : public CommonDecoder<Attention<WeiT, ChatGLM2RotaryEmbedding, N
4040

4141
private:
4242
TokenEmbedding<float16_t> *embedding;
43-
NormT finalLN;
43+
RmsNorm finalLN;
4444

4545
// Record last block positions
4646
std::vector<int> lastBlockPositions;

src/models/chatglm3.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
#include "chatglm2.h"
1818

1919
// ChatGLM3 and ChatGLM2 have the same structure, so ChatGLM3 utilizes the implementation of ChatGLM2.
20-
template <typename WeiT, typename NormT = RmsNorm>
21-
class ChatGLM3 : public ChatGLM2<WeiT, NormT> {
20+
template <typename WeiT>
21+
class ChatGLM3 : public ChatGLM2<WeiT> {
2222
public:
23-
ChatGLM3(const std::string &modelPath) : ChatGLM2<WeiT, NormT>(modelPath, "chatglm3") {}
23+
ChatGLM3(const std::string &modelPath) : ChatGLM2<WeiT>(modelPath, "chatglm3") {}
2424
};
2525

2626
template class ChatGLM3<float>;

src/models/common_decoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ class CommonDecoder : public AbstractDecoder {
614614
this->context.reset(new DecoderContext(layers, hiddenSize, attHeadNum, kvHeadNum, imSize, act, epsilon,
615615
vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, tpRank, tpSize, ppSize, ppRank,
616616
ropeParamsPtr));
617-
this->context->mmHelper = new MMHelper(xft::DeviceKind::CPU, 0);
617+
this->context->mmHelper = new MMHelper(xft::DeviceKind::iCPU, 0);
618618
}
619619

620620
return this->context.get();

src/models/models.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -377,12 +377,12 @@ AutoModel::AutoModel(std::string modelPath, xft::DataType datatype) : Model() {
377377
}
378378
} else if (modeltype == "chatglm2") {
379379
switch (datatype) {
380-
case xft::DataType::fp16: setDecoder(new ChatGLM2<float16_t, RmsNorm>(modelPath)); break;
381-
case xft::DataType::bf16: setDecoder(new ChatGLM2<bfloat16_t, RmsNorm>(modelPath)); break;
382-
case xft::DataType::int8: setDecoder(new ChatGLM2<int8_t, RmsNorm>(modelPath)); break;
383-
case xft::DataType::w8a8: setDecoder(new ChatGLM2<w8a8_t, RmsNorm>(modelPath)); break;
384-
case xft::DataType::int4: setDecoder(new ChatGLM2<uint4x2_t, RmsNorm>(modelPath)); break;
385-
case xft::DataType::nf4: setDecoder(new ChatGLM2<nf4x2_t, RmsNorm>(modelPath)); break;
380+
case xft::DataType::fp16: setDecoder(new ChatGLM2<float16_t>(modelPath)); break;
381+
case xft::DataType::bf16: setDecoder(new ChatGLM2<bfloat16_t>(modelPath)); break;
382+
case xft::DataType::int8: setDecoder(new ChatGLM2<int8_t>(modelPath)); break;
383+
case xft::DataType::w8a8: setDecoder(new ChatGLM2<w8a8_t>(modelPath)); break;
384+
case xft::DataType::int4: setDecoder(new ChatGLM2<uint4x2_t>(modelPath)); break;
385+
case xft::DataType::nf4: setDecoder(new ChatGLM2<nf4x2_t>(modelPath)); break;
386386
case xft::DataType::bf16_fp16:
387387
setDecoder(new HybridModel<ChatGLM2, bfloat16_t, float16_t>(modelPath));
388388
break;
@@ -399,12 +399,12 @@ AutoModel::AutoModel(std::string modelPath, xft::DataType datatype) : Model() {
399399
}
400400
} else if (modeltype == "chatglm3") {
401401
switch (datatype) {
402-
case xft::DataType::fp16: setDecoder(new ChatGLM3<float16_t, RmsNorm>(modelPath)); break;
403-
case xft::DataType::bf16: setDecoder(new ChatGLM3<bfloat16_t, RmsNorm>(modelPath)); break;
404-
case xft::DataType::int8: setDecoder(new ChatGLM3<int8_t, RmsNorm>(modelPath)); break;
405-
case xft::DataType::w8a8: setDecoder(new ChatGLM3<w8a8_t, RmsNorm>(modelPath)); break;
406-
case xft::DataType::int4: setDecoder(new ChatGLM3<uint4x2_t, RmsNorm>(modelPath)); break;
407-
case xft::DataType::nf4: setDecoder(new ChatGLM3<nf4x2_t, RmsNorm>(modelPath)); break;
402+
case xft::DataType::fp16: setDecoder(new ChatGLM3<float16_t>(modelPath)); break;
403+
case xft::DataType::bf16: setDecoder(new ChatGLM3<bfloat16_t>(modelPath)); break;
404+
case xft::DataType::int8: setDecoder(new ChatGLM3<int8_t>(modelPath)); break;
405+
case xft::DataType::w8a8: setDecoder(new ChatGLM3<w8a8_t>(modelPath)); break;
406+
case xft::DataType::int4: setDecoder(new ChatGLM3<uint4x2_t>(modelPath)); break;
407+
case xft::DataType::nf4: setDecoder(new ChatGLM3<nf4x2_t>(modelPath)); break;
408408
case xft::DataType::bf16_fp16:
409409
setDecoder(new HybridModel<ChatGLM3, bfloat16_t, float16_t>(modelPath));
410410
break;

0 commit comments

Comments
 (0)