Skip to content

Commit bd18228

Browse files
committed
calculate the freq_cis online, no need to write/read them to/from checkpoints
1 parent b68a6d2 commit bd18228

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

run.c

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ typedef struct {
4343
float* w3; // (layer, hidden_dim, dim)
4444
// final rmsnorm
4545
float* rms_final_weight; // (dim,)
46-
// freq_cis for RoPE relatively positional embeddings
46+
// freq_cis for RoPE relatively positional embeddings (not used anymore)
4747
float* freq_cis_real; // (seq_len, head_size/2)
4848
float* freq_cis_imag; // (seq_len, head_size/2)
4949
// (optional) classifier weights for the logits, on the last layer
@@ -214,10 +214,6 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
214214
float* content_row = &(w->token_embedding_table[token * dim]);
215215
memcpy(x, content_row, dim*sizeof(*x));
216216

217-
// pluck out the "pos" row of freq_cis_real and freq_cis_imag
218-
float* freq_cis_real_row = w->freq_cis_real + pos * head_size / 2;
219-
float* freq_cis_imag_row = w->freq_cis_imag + pos * head_size / 2;
220-
221217
// forward all the layers
222218
for(int l = 0; l < p->n_layers; l++) {
223219

@@ -229,15 +225,18 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
229225
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
230226
matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
231227

232-
// RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head
233-
for (int v = 0; v < 2; v++) {
234-
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
235-
int vec_size = v == 0 ? dim : kv_dim; // the size of the vector
236-
for (int i = 0; i < vec_size; i+=2) {
228+
// RoPE relative positional encoding: complex-valued rotate q and k in each head
229+
for (int i = 0; i < dim; i+=2) {
230+
int head_dim = i % head_size;
231+
float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
232+
float val = pos * freq;
233+
float fcr = cosf(val);
234+
float fci = sinf(val);
235+
int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
236+
for (int v = 0; v < rotn; v++) {
237+
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
237238
float v0 = vec[i];
238239
float v1 = vec[i+1];
239-
float fcr = freq_cis_real_row[(i % head_size) / 2];
240-
float fci = freq_cis_imag_row[(i % head_size) / 2];
241240
vec[i] = v0 * fcr - v1 * fci;
242241
vec[i+1] = v0 * fci + v1 * fcr;
243242
}

0 commit comments

Comments
 (0)