From 8a0ad84b9ee94fad175e5687fb8774503efbd23b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 23 Jul 2024 23:20:52 -0700 Subject: [PATCH] Llama 3.1 RoPE Beware - my C is very rusty (haven't done C in like ages lol) - I might have transcribed it incorrectly from https://github.com/unslothai/unsloth/blob/main/unsloth/models/llama.py#L1116 --- runq.c | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/runq.c b/runq.c index 01d8897e..f16d002f 100644 --- a/runq.c +++ b/runq.c @@ -650,11 +650,35 @@ float* forward(Transformer* transformer, int token, int pos) { matmul(s->k, &s->xq, w->wk + l, dim, kv_dim); matmul(s->v, &s->xq, w->wv + l, dim, kv_dim); + // Llama 3.1 RoPE Scaling + const float scale_factor = 8; + const float low_freq_factor = 1; + const float high_freq_factor = 4; + const float old_context_len = 8192; + + const float low_freq_wavelen = old_context_len / low_freq_factor; + const float high_freq_wavelen = old_context_len / high_freq_factor; + // END Llama 3.1 RoPE Scaling + // RoPE relative positional encoding: complex-valued rotate q and k in each head for (int i = 0; i < dim; i+=2) { int head_dim = i % head_size; // L2E Addition float freq = 1.0f / powf(rope_tf, head_dim / (float)head_size); + +// Llama 3.1 + // Check if M_PI is defined? Sorry haven't coded in C in ages lol + float wavelen = 2.0f * M_PI / freq; + if (wavelen < high_freq_wavelen) { freq = freq; } + else if (wavelen > low_freq_wavelen) { freq = freq / scale_factor; } + else { + float smooth = (old_context_len / wavelen - low_freq_factor) / \ + (high_freq_factor - low_freq_factor); + + freq = (1.0 - smooth) * freq / scale_factor + smooth * freq; + } +// END Llama 3.1 + // END L2E Addition float val = pos * freq; float fcr = cosf(val);