Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 68 additions & 4 deletions src/kernels/rotary_embedding_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,11 @@ static inline void chatglm2ApplyRotaryPosEmbeding(T *query, T *key, int qStride,
const int head_num = qk_shape[2] + qk_shape[4];
const int half = inv_freq_size;

#pragma omp parallel for
#pragma omp parallel for collapse(3)
for (int head = 0; head < head_num; ++head) {
int off = head * dim;
for (int bs = 0; bs < batch_size; ++bs) {
for (int seq = 0; seq < seq_len; ++seq) {
T *pF = query + off;
T *pF = query + seq * qStride + head * dim;

int pos = position_ids[seq];
float *pcos = emb_cos + pos * dim;
Expand Down Expand Up @@ -271,7 +270,6 @@ static inline void chatglm2ApplyRotaryPosEmbeding(T *query, T *key, int qStride,
xft::store_avx512(&pF[i], mask, tmp0);
xft::store_avx512(&pF[i + 16], mask, tmp1);
}
off += qStride;
}
}
}
Expand All @@ -295,6 +293,72 @@ void chatglm2ApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStrid
query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
}


// For ChatGLM2/3 continous batching

template <typename T>
static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int inv_freq_size, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
int dim = inv_freq_size * 2;
const int head_num = qHeads + kHeads;
const int half = inv_freq_size;

#pragma omp parallel for collapse(2)
for (int head = 0; head < head_num; ++head) {
for (int seq = 0; seq < totSeqLen; ++seq) {
T *pF = query + seq * qStride + head * dim;

int pos = positionIds[seq];
float *pcos = emb_cos + pos * dim;
float *psin = emb_sin + pos * dim;

for (int i = 0; i < half; i += 32) {
__mmask16 mask = 0xffff;
__m512 tmp0, tmp1, pCosVec, pSinVec, qVec0, qVec1;
//TODO: can directly load/save with shuffle??
tmp0 = _mm512_maskz_loadu_ps(mask, &pcos[i]);
tmp1 = _mm512_maskz_loadu_ps(mask, &pcos[i + 16]);
chatglm2PrepareSinCos(tmp0, tmp1, &pCosVec);

tmp0 = _mm512_maskz_loadu_ps(mask, &psin[i]);
tmp1 = _mm512_maskz_loadu_ps(mask, &psin[i + 16]);
chatglm2PrepareSinCos(tmp0, tmp1, &pSinVec);

tmp0 = xft::load_avx512(mask, &pF[i]);
tmp1 = xft::load_avx512(mask, &pF[i + 16]);

chatglm2InterleaveQK(tmp0, tmp1, &qVec0, &qVec1);

__m512 qNew0 = _mm512_fmsub_ps(qVec0, pCosVec, _mm512_mul_ps(qVec1, pSinVec));
__m512 qNew1 = _mm512_fmadd_ps(qVec0, pSinVec, _mm512_mul_ps(qVec1, pCosVec));

chatglm2DeinterleaveQK(qNew0, qNew1, &tmp0, &tmp1);

xft::store_avx512(&pF[i], mask, tmp0);
xft::store_avx512(&pF[i + 16], mask, tmp1);
}
}
}
}

void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride, int dim,
int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
chatglm2ApplyRotaryPosEmbed<float>(
query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds);
}

void chatglm2ApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
chatglm2ApplyRotaryPosEmbed<bfloat16_t>(
query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds);
}

void chatglm2ApplyRotaryPosEmbed(float16_t *query, float16_t *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
chatglm2ApplyRotaryPosEmbed<float16_t>(
query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds);
}

template <typename T>
static inline void qwenApplyRotaryPosEmbeding(T *query, T *key, int qStride, int kStride, float *cur_emb_cos,
float *cur_emb_sin, int inv_freq_size, const float *logn, int maxSupportedSeqLength, const int *qkShape,
Expand Down
9 changes: 9 additions & 0 deletions src/kernels/rotary_embedding_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ void chatglm2ApplyRotaryPosEmbeding(bfloat16_t *query, bfloat16_t *key, int qStr
void chatglm2ApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, int kStride, float *emb_cos,
float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds);

void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds);

void chatglm2ApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds);

void chatglm2ApplyRotaryPosEmbed(float16_t *query, float16_t *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds);

// For Qwen1.0
void qwenApplyRotaryPosEmbeding(float *query, float *key, int qStride, int kStride, float *cur_emb_cos,
float *cur_emb_sin, int inv_freq_size, const float *logn, int maxSupportedSeqLength, const int *qkShape,
Expand Down
12 changes: 6 additions & 6 deletions src/layers/rotary_embedding_chatglm2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,18 @@ void ChatGLM2RotaryEmbedding::forward(
// For continuous batching
void ChatGLM2RotaryEmbedding::forward(
float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) {
printf("Unsupported ChatGLM2RotaryEmbedding in cb mode !\n");
exit(1);
xft::chatglm2ApplyRotaryPosEmbed(
query, key, emb_cos, emb_sin, qStride, kStride, inv_freq_size, totSeqLen, qHeads, kHeads, positionIds);
}

void ChatGLM2RotaryEmbedding::forward(bfloat16_t *query, bfloat16_t *key, int totSeqLen, int qStride, int kStride,
int qHeads, int kHeads, int *positionIds) {
printf("Unsupported ChatGLM2RotaryEmbedding in cb mode !\n");
exit(1);
xft::chatglm2ApplyRotaryPosEmbed(
query, key, emb_cos, emb_sin, qStride, kStride, inv_freq_size, totSeqLen, qHeads, kHeads, positionIds);
}

void ChatGLM2RotaryEmbedding::forward(float16_t *query, float16_t *key, int totSeqLen, int qStride, int kStride,
int qHeads, int kHeads, int *positionIds) {
printf("Unsupported ChatGLM2RotaryEmbedding in cb mode !\n");
exit(1);
xft::chatglm2ApplyRotaryPosEmbed(
query, key, emb_cos, emb_sin, qStride, kStride, inv_freq_size, totSeqLen, qHeads, kHeads, positionIds);
}