@@ -21,73 +21,74 @@ void rotary_embedding_impl(
2121 constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num ();
2222
2323 const int embed_dim = rot_dim / 2 ;
24- TORCH_CHECK (embed_dim % VEC_ELEM_NUM == 0 );
24+ bool flag = (embed_dim % VEC_ELEM_NUM == 0 );
25+ const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM;
2526
26- #pragma omp parallel for
27- for (int token_idx = 0 ; token_idx < num_tokens; ++token_idx) {
28- int64_t pos = positions[token_idx];
29- const scalar_t * cache_ptr = cos_sin_cache + pos * rot_dim;
27+ auto compute_loop = [&](const int64_t token_head, const scalar_t * cache_ptr,
28+ scalar_t * qk) {
29+ int j = 0 ;
30+ for (; j < loop_upper; j += VEC_ELEM_NUM) {
31+ const int rot_offset = j;
32+ const int x_index = rot_offset;
33+ const int y_index = embed_dim + rot_offset;
3034
31- for (int i = 0 ; i < num_heads; ++i) {
32- const int head_idx = i;
33- const int64_t token_head =
34- token_idx * query_stride + head_idx * head_size;
35- for (int j = 0 ; j < embed_dim; j += VEC_ELEM_NUM) {
36- const int rot_offset = j;
37- const int x_index = rot_offset;
38- const int y_index = embed_dim + rot_offset;
35+ const int64_t out_x = token_head + x_index;
36+ const int64_t out_y = token_head + y_index;
3937
40- const int64_t out_x = token_head + x_index;
41- const int64_t out_y = token_head + y_index;
38+ const scalar_vec_t cos (cache_ptr + x_index) ;
39+ const scalar_vec_t sin (cache_ptr + y_index) ;
4240
43- const scalar_vec_t cos (cache_ptr + x_index );
44- const scalar_vec_t sin (cache_ptr + y_index );
41+ const scalar_vec_t q_x (qk + out_x );
42+ const scalar_vec_t q_y (qk + out_y );
4543
46- const scalar_vec_t q_x (query + out_x );
47- const scalar_vec_t q_y (query + out_y );
44+ vec_op::FP32Vec8 fp32_cos (cos );
45+ vec_op::FP32Vec8 fp32_sin (sin );
4846
49- vec_op::FP32Vec8 fp32_cos (cos );
50- vec_op::FP32Vec8 fp32_sin (sin );
47+ vec_op::FP32Vec8 fp32_q_x (q_x );
48+ vec_op::FP32Vec8 fp32_q_y (q_y );
5149
52- vec_op::FP32Vec8 fp32_q_x (q_x) ;
53- vec_op::FP32Vec8 fp32_q_y (q_y );
50+ auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin ;
51+ scalar_vec_t (out1). save (qk + out_x );
5452
55- auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
56- scalar_vec_t (out1).save (query + out_x);
57-
58- auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
59- scalar_vec_t (out2).save (query + out_y);
60- }
53+ auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
54+ scalar_vec_t (out2).save (qk + out_y);
6155 }
62-
63- for (int i = 0 ; i < num_kv_heads; ++i) {
64- const int head_idx = i;
65- const int64_t token_head = token_idx * key_stride + head_idx * head_size;
66- for (int j = 0 ; j < embed_dim; j += VEC_ELEM_NUM) {
67- const int rot_offset = j;
68- const int x_index = rot_offset;
69- const int y_index = embed_dim + rot_offset;
56+ if (!flag) {
57+ for (; j < embed_dim; ++j) {
58+ const int x_index = j;
59+ const int y_index = embed_dim + j;
7060
7161 const int64_t out_x = token_head + x_index;
7262 const int64_t out_y = token_head + y_index;
7363
74- const scalar_vec_t cos (cache_ptr + x_index) ;
75- const scalar_vec_t sin (cache_ptr + y_index) ;
64+ const float fp32_cos = cache_ptr[ x_index] ;
65+ const float fp32_sin = cache_ptr[ y_index] ;
7666
77- const scalar_vec_t k_x (key + out_x) ;
78- const scalar_vec_t k_y (key + out_y) ;
67+ const float fp32_q_x = qk[ out_x] ;
68+ const float fp32_q_y = qk[ out_y] ;
7969
80- vec_op::FP32Vec8 fp32_cos (cos);
81- vec_op::FP32Vec8 fp32_sin (sin);
70+ qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
71+ qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
72+ }
73+ }
74+ };
8275
83- vec_op::FP32Vec8 fp32_k_x (k_x);
84- vec_op::FP32Vec8 fp32_k_y (k_y);
76+ #pragma omp parallel for
77+ for (int token_idx = 0 ; token_idx < num_tokens; ++token_idx) {
78+ int64_t pos = positions[token_idx];
79+ const scalar_t * cache_ptr = cos_sin_cache + pos * rot_dim;
8580
86- auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin;
87- scalar_vec_t (out1).save (key + out_x);
88- auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin;
89- scalar_vec_t (out2).save (key + out_y);
90- }
81+ for (int i = 0 ; i < num_heads; ++i) {
82+ const int head_idx = i;
83+ const int64_t token_head =
84+ token_idx * query_stride + head_idx * head_size;
85+ compute_loop (token_head, cache_ptr, query);
86+ }
87+
88+ for (int i = 0 ; i < num_kv_heads; ++i) {
89+ const int head_idx = i;
90+ const int64_t token_head = token_idx * key_stride + head_idx * head_size;
91+ compute_loop (token_head, cache_ptr, key);
9192 }
9293 }
9394}
0 commit comments