@@ -43,7 +43,7 @@ typedef struct {
4343 float * w3 ; // (layer, hidden_dim, dim)
4444 // final rmsnorm
4545 float * rms_final_weight ; // (dim,)
46- // freq_cis for RoPE relatively positional embeddings
46+ // freq_cis for RoPE relatively positional embeddings (not used anymore)
4747 float * freq_cis_real ; // (seq_len, head_size/2)
4848 float * freq_cis_imag ; // (seq_len, head_size/2)
4949 // (optional) classifier weights for the logits, on the last layer
@@ -214,10 +214,6 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
214214 float * content_row = & (w -> token_embedding_table [token * dim ]);
215215 memcpy (x , content_row , dim * sizeof (* x ));
216216
217- // pluck out the "pos" row of freq_cis_real and freq_cis_imag
218- float * freq_cis_real_row = w -> freq_cis_real + pos * head_size / 2 ;
219- float * freq_cis_imag_row = w -> freq_cis_imag + pos * head_size / 2 ;
220-
221217 // forward all the layers
222218 for (int l = 0 ; l < p -> n_layers ; l ++ ) {
223219
@@ -229,15 +225,18 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
229225 matmul (s -> k , s -> xb , w -> wk + l * dim * kv_dim , dim , kv_dim );
230226 matmul (s -> v , s -> xb , w -> wv + l * dim * kv_dim , dim , kv_dim );
231227
232- // RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head
233- for (int v = 0 ; v < 2 ; v ++ ) {
234- float * vec = v == 0 ? s -> q : s -> k ; // the vector to rotate (query or key)
235- int vec_size = v == 0 ? dim : kv_dim ; // the size of the vector
236- for (int i = 0 ; i < vec_size ; i += 2 ) {
228+ // RoPE relative positional encoding: complex-valued rotate q and k in each head
229+ for (int i = 0 ; i < dim ; i += 2 ) {
230+ int head_dim = i % head_size ;
231+ float freq = 1.0f / powf (10000.0f , head_dim / (float )head_size );
232+ float val = pos * freq ;
233+ float fcr = cosf (val );
234+ float fci = sinf (val );
235+ int rotn = i < kv_dim ? 2 : 1 ; // how many vectors? 2 = q & k, 1 = q only
236+ for (int v = 0 ; v < rotn ; v ++ ) {
237+ float * vec = v == 0 ? s -> q : s -> k ; // the vector to rotate (query or key)
237238 float v0 = vec [i ];
238239 float v1 = vec [i + 1 ];
239- float fcr = freq_cis_real_row [(i % head_size ) / 2 ];
240- float fci = freq_cis_imag_row [(i % head_size ) / 2 ];
241240 vec [i ] = v0 * fcr - v1 * fci ;
242241 vec [i + 1 ] = v0 * fci + v1 * fcr ;
243242 }
0 commit comments