Skip to content

Commit 80df391

Browse files
authored
[Kernel] Add GPU kernels and enable LLaMA model. (#372)
1 parent 24242ff commit 80df391

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1336
-302
lines changed

CMakeLists.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ project(xfastertransformer LANGUAGES C CXX)
2020
option(WITH_GPU "Build with GPU" OFF)
2121
if(WITH_GPU)
2222
message(STATUS "Notice: Building with GPU.")
23-
add_definitions(-DGPU=true)
23+
add_definitions(-DXFT_GPU=true)
2424
# Get compiler version
2525
execute_process(COMMAND ${CMAKE_CXX_COMPILER} --version
2626
OUTPUT_VARIABLE ICPX_VERSION
@@ -35,10 +35,6 @@ else()
3535
message(STATUS "Notice: GCC version: ${GCC_VERSION}")
3636
endif()
3737

38-
if(NOT CMAKE_BUILD_TYPE)
39-
set(CMAKE_BUILD_TYPE Release)
40-
endif()
41-
4238
set(CMAKE_CXX_STANDARD 17)
4339
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mavx512bw -mavx512vl -fPIC")
4440
if(WITH_GPU)
@@ -73,11 +69,15 @@ if(GCC_VERSION VERSION_GREATER_EQUAL "10.1")
7369
endif()
7470
endif()
7571

72+
if(NOT CMAKE_BUILD_TYPE)
73+
set(CMAKE_BUILD_TYPE Release)
74+
endif()
75+
7676
if(CMAKE_BUILD_TYPE MATCHES "Debug")
7777
message(STATUS "Notice: Using Debug mode.")
7878
set(CMAKE_C_FLAGS "-O0 -g")
7979
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g")
80-
add_definitions(-DDEBUG=true)
80+
add_definitions(-DXFT_DEBUG=true)
8181
add_definitions(-DSTEP_BY_STEP_ATTN=true)
8282
else()
8383
message(STATUS "Notice: Using Release mode.")

requirements-gpu.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-f https://download.pytorch.org/whl/torch_stable.html
2+
cmake==3.26.1
3+
sentencepiece==0.1.99
4+
torch==2.3.0+cpu.cxx11.abi
5+
transformers==4.40.0
6+
accelerate==0.23.0
7+
protobuf
8+
tiktoken

src/common/allocator.h

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@
1515
#pragma once
1616
#include <cstdio>
1717
#include <cstdlib>
18-
#include <sys/mman.h>
18+
#include <cstring>
1919
#include "environment.h"
20+
#include <sys/mman.h>
21+
22+
#ifdef XFT_GPU
23+
#include <CL/sycl.hpp>
24+
#endif
2025

2126
namespace xft {
2227

@@ -26,10 +31,22 @@ static inline bool is_thp_alloc(size_t nbytes) {
2631
return (Env::getInstance().getTHPEnabled() && (nbytes >= g_thp_threshold));
2732
}
2833

29-
static inline void *alloc(size_t nbytes, size_t alignment = 64) {
34+
static inline void *alloc(size_t nbytes, void *device = nullptr, size_t alignment = 64) {
3035
if (nbytes == 0) { return nullptr; }
3136

32-
void *data;
37+
void *data = nullptr;
38+
39+
#ifdef XFT_GPU
40+
if (device != nullptr) {
41+
sycl::queue *gpu_queue = static_cast<sycl::queue *>(device);
42+
data = sycl::malloc_device<char>(nbytes, *gpu_queue);
43+
if (data == nullptr) {
44+
printf("Unable to allocate buffer with size of %zu in GPU.\n", nbytes);
45+
exit(-1);
46+
}
47+
return data;
48+
}
49+
#endif
3350

3451
int err = posix_memalign(&data, alignment, nbytes);
3552
if (err != 0) {
@@ -47,4 +64,40 @@ static inline void *alloc(size_t nbytes, size_t alignment = 64) {
4764

4865
return data;
4966
}
67+
68+
static inline void dealloc(void *data, void *device = nullptr) {
69+
#ifdef XFT_GPU
70+
if (device != nullptr) {
71+
sycl::free(data, *static_cast<sycl::queue *>(device));
72+
return;
73+
}
74+
#endif
75+
76+
free(data);
77+
}
78+
79+
static inline void memcopy(void *dst, const void *src, size_t size, void *device = nullptr) {
80+
#ifdef XFT_GPU
81+
if (device != nullptr) {
82+
sycl::queue *gpu_queue = static_cast<sycl::queue *>(device);
83+
gpu_queue->memcpy(dst, src, size).wait();
84+
return;
85+
}
86+
#endif
87+
88+
memcpy(dst, src, size);
89+
}
90+
91+
static inline void memsetv(void *dst, int ch, size_t size, void *device = nullptr) {
92+
#ifdef XFT_GPU
93+
if (device != nullptr) {
94+
sycl::queue *gpu_queue = static_cast<sycl::queue *>(device);
95+
gpu_queue->memset(dst, ch, size).wait();
96+
return;
97+
}
98+
#endif
99+
100+
memset(dst, ch, size);
101+
}
102+
50103
} // namespace xft

src/common/sequence.h

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <queue>
2020
#include <unordered_map>
2121

22+
#include "allocator.h"
2223
#include "environment.h"
2324
#include "sampling_params.h"
2425

@@ -67,7 +68,7 @@ class SequenceIDManager {
6768
// The SequenceMeta is one sequence of batch inputs and includes the generated tokens.
6869
class SequenceMeta {
6970
public:
70-
SequenceMeta(std::vector<int32_t> &_promptTokens)
71+
SequenceMeta(const std::vector<int32_t> &_promptTokens)
7172
: sequenceID(SequenceIDManager::getInstance().createSequenceID())
7273
, inputSeqLen(_promptTokens.size())
7374
, pastSeqLen(0)
@@ -81,6 +82,16 @@ class SequenceMeta {
8182
, promptTokens(_inputSeqLen, 0)
8283
, step(0) {}
8384

85+
SequenceMeta(int32_t _sequenceID, const std::vector<int32_t> &_promptTokens)
86+
: sequenceID(_sequenceID)
87+
, inputSeqLen(_promptTokens.size())
88+
, pastSeqLen(0)
89+
, promptTokens(_promptTokens)
90+
, step(0) {}
91+
92+
SequenceMeta(int32_t _sequenceID, int32_t _inputSeqLen)
93+
: sequenceID(_sequenceID), inputSeqLen(_inputSeqLen), pastSeqLen(0), promptTokens(_inputSeqLen, 0), step(0) {}
94+
8495
~SequenceMeta() {}
8596

8697
int32_t getSequenceID() const { return sequenceID; }
@@ -175,7 +186,8 @@ class SequenceGroupMeta {
175186
groupID = sequences[0].getSequenceID();
176187
}
177188

178-
SequenceGroupMeta(std::vector<int32_t> &_inputTokens, SamplingMeta &samplingMeta_) : samplingMeta(samplingMeta_) {
189+
SequenceGroupMeta(const std::vector<int32_t> &_inputTokens, SamplingMeta &samplingMeta_)
190+
: samplingMeta(samplingMeta_) {
179191
sequences.reserve(samplingMeta.config.numBeams);
180192
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
181193
sequences.emplace_back(SequenceMeta(_inputTokens));
@@ -191,7 +203,7 @@ class SequenceGroupMeta {
191203
groupID = sequences[0].getSequenceID();
192204
}
193205

194-
SequenceGroupMeta(std::vector<int32_t> &_inputTokens) {
206+
SequenceGroupMeta(const std::vector<int32_t> &_inputTokens) {
195207
sequences.reserve(samplingMeta.config.numBeams);
196208
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
197209
sequences.emplace_back(SequenceMeta(_inputTokens));
@@ -207,6 +219,40 @@ class SequenceGroupMeta {
207219
groupID = sequences[0].getSequenceID();
208220
}
209221

222+
SequenceGroupMeta(int32_t _sequenceID, const std::vector<int32_t> &_inputTokens, SamplingMeta &samplingMeta_)
223+
: samplingMeta(samplingMeta_) {
224+
sequences.reserve(samplingMeta.config.numBeams);
225+
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
226+
sequences.emplace_back(SequenceMeta(_sequenceID, _inputTokens));
227+
}
228+
groupID = sequences[0].getSequenceID();
229+
}
230+
231+
SequenceGroupMeta(int32_t _sequenceID, int32_t _inputSeqLen, SamplingMeta &samplingMeta_)
232+
: samplingMeta(samplingMeta_) {
233+
sequences.reserve(samplingMeta.config.numBeams);
234+
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
235+
sequences.emplace_back(SequenceMeta(_sequenceID, _inputSeqLen));
236+
}
237+
groupID = sequences[0].getSequenceID();
238+
}
239+
240+
SequenceGroupMeta(int32_t _sequenceID, const std::vector<int32_t> &_inputTokens) {
241+
sequences.reserve(samplingMeta.config.numBeams);
242+
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
243+
sequences.emplace_back(SequenceMeta(_sequenceID, _inputTokens));
244+
}
245+
groupID = sequences[0].getSequenceID();
246+
}
247+
248+
SequenceGroupMeta(int32_t _sequenceID, int32_t _inputSeqLen) {
249+
sequences.reserve(samplingMeta.config.numBeams);
250+
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
251+
sequences.emplace_back(SequenceMeta(_sequenceID, _inputSeqLen));
252+
}
253+
groupID = sequences[0].getSequenceID();
254+
}
255+
210256
int32_t getGroupID() { return groupID; }
211257

212258
int32_t getGroupSize() { return samplingMeta.config.numBeams; }
@@ -272,6 +318,31 @@ class SequencePool {
272318
return group;
273319
}
274320

321+
SequenceGroupMeta *newGroupMeta(
322+
int32_t sequenceID, std::vector<int32_t> &inputTokens, SamplingMeta &samplingMeta_) {
323+
auto *group = new SequenceGroupMeta(sequenceID, inputTokens, samplingMeta_);
324+
this->add(group);
325+
return group;
326+
}
327+
328+
SequenceGroupMeta *newGroupMeta(int32_t sequenceID, int32_t inputSeqLen, SamplingMeta &samplingMeta_) {
329+
auto *group = new SequenceGroupMeta(sequenceID, inputSeqLen, samplingMeta_);
330+
this->add(group);
331+
return group;
332+
}
333+
334+
SequenceGroupMeta *newGroupMeta(int32_t sequenceID, std::vector<int32_t> &inputTokens) {
335+
auto *group = new SequenceGroupMeta(sequenceID, inputTokens);
336+
this->add(group);
337+
return group;
338+
}
339+
340+
SequenceGroupMeta *newGroupMeta(int32_t sequenceID, int32_t inputSeqLen) {
341+
auto *group = new SequenceGroupMeta(sequenceID, inputSeqLen);
342+
this->add(group);
343+
return group;
344+
}
345+
275346
bool add(SequenceGroupMeta *sequenceGroup, bool force = false) {
276347
int32_t groupID = sequenceGroup->getGroupID();
277348
bool isSuccess = false;

src/common/transformer_ctx.h

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ struct DecoderContext {
112112
xft::Matrix<float> qkvMatMul; // query, key, value
113113
xft::Matrix<float> imOut; // intermediate output
114114

115-
MMHelper *mmHelper;
115+
MMHelper *mmHelper = nullptr;
116+
void *device = nullptr;
116117

117118
std::string configPath;
118119
INIReader configReader;
@@ -130,7 +131,7 @@ struct DecoderContext {
130131
public:
131132
DecoderContext(int _layers, int _hiddenSize, int _headSize, int _attHeadNum, int _kvHeadNum, int _imSize, const std::string &act,
132133
float epsilon, int _vocabSize, int _embeddingSize, int _maxPositions, int _maxPosEmbed, int _maxSeqLength,
133-
int _splitIdx, int _splits, int _ppSize = 1, int _ppRank = 0, RopeParams *_ropeParamsPtr = nullptr,
134+
int _splitIdx, int _splits, MMHelper *mmHelper, void *device = nullptr, int _ppSize = 1, int _ppRank = 0, RopeParams *_ropeParamsPtr = nullptr,
134135
bool _useLogN = true, bool _useNTK = true, int numThreads = 0)
135136
: layers(_layers)
136137
, hiddenSize(_hiddenSize)
@@ -170,9 +171,12 @@ struct DecoderContext {
170171
}
171172
}
172173

174+
this->mmHelper = mmHelper;
175+
this->device = device;
176+
173177
this->rawBufSize = 4 * 32 * intermediateSize + 4 * attHeadNum * 32 * 32; // assume bs=4, seq=32
174-
this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize);
175-
memset(this->rawBuffer, 0, sizeof(float) * rawBufSize);
178+
this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize, this->device);
179+
xft::memsetv(this->rawBuffer, 0, sizeof(float) * rawBufSize, this->device);
176180

177181
if (act == "relu") {
178182
this->actType = RELU;
@@ -240,8 +244,12 @@ struct DecoderContext {
240244
bool cached(const std::string &name) { return SimpleMemPool::instance().cached(name); }
241245

242246
template <typename T>
243-
T *getBuffer(const std::string &name, size_t size, size_t alignment = 64) {
244-
return (T *)SimpleMemPool::instance().getBuffer(name, sizeof(T) * size, alignment);
247+
T *getBuffer(const std::string &name, size_t size, void *device = nullptr, size_t alignment = 64) {
248+
return (T *)SimpleMemPool::instance().getBuffer(name, sizeof(T) * size, device, alignment);
249+
}
250+
251+
void freeBuffer(const std::string &name) {
252+
SimpleMemPool::instance().freeBuffer(name);
245253
}
246254

247255
void dump() {
@@ -286,10 +294,10 @@ struct DecoderContext {
286294
uint64_t total = size1 + size2 + size3;
287295
if (total > this->rawBufSize) {
288296
this->rawBufSize = total;
289-
free(this->rawBuffer);
297+
if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device);
290298

291-
this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize);
292-
memset(this->rawBuffer, 0, sizeof(float) * rawBufSize);
299+
this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize, this->device);
300+
xft::memsetv(this->rawBuffer, 0, sizeof(float) * rawBufSize, this->device);
293301
}
294302

295303
// Assign the buffer
@@ -312,5 +320,9 @@ struct DecoderContext {
312320
return rawBufSize - size1 - size2;
313321
}
314322

315-
~DecoderContext() { free(this->rawBuffer); }
323+
~DecoderContext() {
324+
#ifndef XFT_GPU
325+
if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device);
326+
#endif
327+
}
316328
};

src/kernels/attention_kernels.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void crossAttention(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bflo
6666
small_sgemm_bf16bf16f32_b(true, m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)baseB, ldb, C, ldc, blkIndices,
6767
cacheBlkStride, cacheBlkSize);
6868

69-
#ifdef DEBUG
69+
#ifdef XFT_DEBUG
7070
if (b == 0 && i == 0) {
7171
printf("Q * K, first head:\n");
7272
auto p = C;
@@ -78,7 +78,7 @@ void crossAttention(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bflo
7878
// Softmax(Q * K)
7979
small_softmax_f32(C, scale, n);
8080

81-
#ifdef DEBUG
81+
#ifdef XFT_DEBUG
8282
if (b == 0 && i == 0) {
8383
printf("Softmax(Q * K), first head:\n");
8484
auto p = C;
@@ -100,7 +100,7 @@ void crossAttention(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bflo
100100
small_sgemm_f32bf16bf16_b(false, m, n, k, C, lda, (XDNN_BF16 *)baseB, ldb, (XDNN_BF16 *)baseC, ldc,
101101
blkIndices, cacheBlkStride, cacheBlkSize);
102102

103-
#ifdef DEBUG
103+
#ifdef XFT_DEBUG
104104
if (b == 0 && i == 0) {
105105
printf("Softmax(Q * K) * V, first head:\n");
106106
auto p = C;

0 commit comments

Comments
 (0)