diff --git a/src/kernels/rotary_embedding_kernels.cpp b/src/kernels/rotary_embedding_kernels.cpp index 99f52b08..566a6340 100644 --- a/src/kernels/rotary_embedding_kernels.cpp +++ b/src/kernels/rotary_embedding_kernels.cpp @@ -235,12 +235,11 @@ static inline void chatglm2ApplyRotaryPosEmbeding(T *query, T *key, int qStride, const int head_num = qk_shape[2] + qk_shape[4]; const int half = inv_freq_size; -#pragma omp parallel for +#pragma omp parallel for collapse(3) for (int head = 0; head < head_num; ++head) { - int off = head * dim; for (int bs = 0; bs < batch_size; ++bs) { for (int seq = 0; seq < seq_len; ++seq) { - T *pF = query + off; + T *pF = query + seq * qStride + head * dim; int pos = position_ids[seq]; float *pcos = emb_cos + pos * dim; @@ -271,7 +270,6 @@ static inline void chatglm2ApplyRotaryPosEmbeding(T *query, T *key, int qStride, xft::store_avx512(&pF[i], mask, tmp0); xft::store_avx512(&pF[i + 16], mask, tmp1); } - off += qStride; } } } @@ -295,6 +293,72 @@ void chatglm2ApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStrid query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); } + +// For ChatGLM2/3 continous batching + +template +static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride, int kStride, + int inv_freq_size, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { + int dim = inv_freq_size * 2; + const int head_num = qHeads + kHeads; + const int half = inv_freq_size; + +#pragma omp parallel for collapse(2) + for (int head = 0; head < head_num; ++head) { + for (int seq = 0; seq < totSeqLen; ++seq) { + T *pF = query + seq * qStride + head * dim; + + int pos = positionIds[seq]; + float *pcos = emb_cos + pos * dim; + float *psin = emb_sin + pos * dim; + + for (int i = 0; i < half; i += 32) { + __mmask16 mask = 0xffff; + __m512 tmp0, tmp1, pCosVec, pSinVec, qVec0, qVec1; + //TODO: can directly load/save with shuffle?? + tmp0 = _mm512_maskz_loadu_ps(mask, &pcos[i]); + tmp1 = _mm512_maskz_loadu_ps(mask, &pcos[i + 16]); + chatglm2PrepareSinCos(tmp0, tmp1, &pCosVec); + + tmp0 = _mm512_maskz_loadu_ps(mask, &psin[i]); + tmp1 = _mm512_maskz_loadu_ps(mask, &psin[i + 16]); + chatglm2PrepareSinCos(tmp0, tmp1, &pSinVec); + + tmp0 = xft::load_avx512(mask, &pF[i]); + tmp1 = xft::load_avx512(mask, &pF[i + 16]); + + chatglm2InterleaveQK(tmp0, tmp1, &qVec0, &qVec1); + + __m512 qNew0 = _mm512_fmsub_ps(qVec0, pCosVec, _mm512_mul_ps(qVec1, pSinVec)); + __m512 qNew1 = _mm512_fmadd_ps(qVec0, pSinVec, _mm512_mul_ps(qVec1, pCosVec)); + + chatglm2DeinterleaveQK(qNew0, qNew1, &tmp0, &tmp1); + + xft::store_avx512(&pF[i], mask, tmp0); + xft::store_avx512(&pF[i + 16], mask, tmp1); + } + } + } +} + +void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride, int dim, + int totSeqLen, int qHeads, int kHeads, const int *positionIds) { + chatglm2ApplyRotaryPosEmbed( + query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds); +} + +void chatglm2ApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride, + int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { + chatglm2ApplyRotaryPosEmbed( + query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds); +} + +void chatglm2ApplyRotaryPosEmbed(float16_t *query, float16_t *key, float *emb_cos, float *emb_sin, int qStride, + int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { + chatglm2ApplyRotaryPosEmbed( + query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds); +} + template static inline void qwenApplyRotaryPosEmbeding(T *query, T *key, int qStride, int kStride, float *cur_emb_cos, float *cur_emb_sin, int inv_freq_size, const float *logn, int maxSupportedSeqLength, const int *qkShape, diff --git a/src/kernels/rotary_embedding_kernels.h b/src/kernels/rotary_embedding_kernels.h index bf150351..21d13475 100644 --- a/src/kernels/rotary_embedding_kernels.h +++ b/src/kernels/rotary_embedding_kernels.h @@ -52,6 +52,15 @@ void chatglm2ApplyRotaryPosEmbeding(bfloat16_t *query, bfloat16_t *key, int qStr void chatglm2ApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, int kStride, float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds); +void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, + int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds); + +void chatglm2ApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride, + int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds); + +void chatglm2ApplyRotaryPosEmbed(float16_t *query, float16_t *key, float *emb_cos, float *emb_sin, int qStride, + int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds); + // For Qwen1.0 void qwenApplyRotaryPosEmbeding(float *query, float *key, int qStride, int kStride, float *cur_emb_cos, float *cur_emb_sin, int inv_freq_size, const float *logn, int maxSupportedSeqLength, const int *qkShape, diff --git a/src/layers/rotary_embedding_chatglm2.cpp b/src/layers/rotary_embedding_chatglm2.cpp index f8b38cce..a7e9cc24 100644 --- a/src/layers/rotary_embedding_chatglm2.cpp +++ b/src/layers/rotary_embedding_chatglm2.cpp @@ -134,18 +134,18 @@ void ChatGLM2RotaryEmbedding::forward( // For continuous batching void ChatGLM2RotaryEmbedding::forward( float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { - printf("Unsupported ChatGLM2RotaryEmbedding in cb mode !\n"); - exit(1); + xft::chatglm2ApplyRotaryPosEmbed( + query, key, emb_cos, emb_sin, qStride, kStride, inv_freq_size, totSeqLen, qHeads, kHeads, positionIds); } void ChatGLM2RotaryEmbedding::forward(bfloat16_t *query, bfloat16_t *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { - printf("Unsupported ChatGLM2RotaryEmbedding in cb mode !\n"); - exit(1); + xft::chatglm2ApplyRotaryPosEmbed( + query, key, emb_cos, emb_sin, qStride, kStride, inv_freq_size, totSeqLen, qHeads, kHeads, positionIds); } void ChatGLM2RotaryEmbedding::forward(float16_t *query, float16_t *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { - printf("Unsupported ChatGLM2RotaryEmbedding in cb mode !\n"); - exit(1); + xft::chatglm2ApplyRotaryPosEmbed( + query, key, emb_cos, emb_sin, qStride, kStride, inv_freq_size, totSeqLen, qHeads, kHeads, positionIds); } \ No newline at end of file