diff --git a/runq.c b/runq.c index 01d8897e..f16d002f 100644 --- a/runq.c +++ b/runq.c @@ -650,11 +650,35 @@ float* forward(Transformer* transformer, int token, int pos) { matmul(s->k, &s->xq, w->wk + l, dim, kv_dim); matmul(s->v, &s->xq, w->wv + l, dim, kv_dim); + // Llama 3.1 RoPE Scaling + const float scale_factor = 8; + const float low_freq_factor = 1; + const float high_freq_factor = 4; + const float old_context_len = 8192; + + const float low_freq_wavelen = old_context_len / low_freq_factor; + const float high_freq_wavelen = old_context_len / high_freq_factor; + // END Llama 3.1 RoPE Scaling + // RoPE relative positional encoding: complex-valued rotate q and k in each head for (int i = 0; i < dim; i+=2) { int head_dim = i % head_size; // L2E Addition float freq = 1.0f / powf(rope_tf, head_dim / (float)head_size); + +// Llama 3.1 + // Check if M_PI is defined? Sorry haven't coded in C in ages lol + float wavelen = 2.0f * M_PI / freq; + if (wavelen < high_freq_wavelen) { freq = freq; } + else if (wavelen > low_freq_wavelen) { freq = freq / scale_factor; } + else { + float smooth = (old_context_len / wavelen - low_freq_factor) / \ + (high_freq_factor - low_freq_factor); + + freq = (1.0 - smooth) * freq / scale_factor + smooth * freq; + } +// END Llama 3.1 + // END L2E Addition float val = pos * freq; float fcr = cosf(val);