@@ -1350,6 +1350,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
13501350 {
13511351 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
13521352
1353+ const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
1354+ if (found_swa && hparams.n_swa > 0) {
1355+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1356+ hparams.set_swa_pattern(4);
1357+ } else {
1358+ hparams.swa_type = LLAMA_SWA_TYPE_NONE;
1359+ }
1360+
13531361 switch (hparams.n_layer) {
13541362 case 16: type = LLM_TYPE_1B; break;
13551363 case 32: type = LLM_TYPE_7B; break;
@@ -12233,6 +12241,7 @@ struct llm_build_olmo : public llm_graph_context {
1223312241 }
1223412242};
1223512243
12244+ template <bool iswa>
1223612245struct llm_build_olmo2 : public llm_graph_context {
1223712246 llm_build_olmo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
1223812247 const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -12248,7 +12257,14 @@ struct llm_build_olmo2 : public llm_graph_context {
1224812257 // inp_pos - contains the positions
1224912258 ggml_tensor * inp_pos = build_inp_pos();
1225012259
12251- auto * inp_attn = build_attn_inp_kv();
12260+ using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
12261+ inp_attn_type * inp_attn = nullptr;
12262+
12263+ if constexpr (iswa) {
12264+ inp_attn = build_attn_inp_kv_iswa();
12265+ } else {
12266+ inp_attn = build_attn_inp_kv();
12267+ }
1225212268
1225312269 ggml_tensor * inp_out_ids = build_inp_out_ids();
1225412270
@@ -12281,17 +12297,36 @@ struct llm_build_olmo2 : public llm_graph_context {
1228112297 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
1228212298 Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
1228312299
12284- Qcur = ggml_rope_ext(
12300+ const bool is_swa = hparams.is_swa(il);
12301+
12302+ if (is_swa) {
12303+ // For sliding window layers, Olmo3 use regular rope with no yarn rope scaling.
12304+ // This is achieved here by setting freq_scale and attn_factor to 1.
12305+ // We also set ext_factor to 0 to avoid a few unnecessary computations.
12306+ Qcur = ggml_rope_ext(
12307+ ctx0, Qcur, inp_pos, nullptr,
12308+ n_rot, rope_type, n_ctx_orig, freq_base, 1.0,
12309+ 0.0, 1.0, beta_fast, beta_slow
12310+ );
12311+
12312+ Kcur = ggml_rope_ext(
12313+ ctx0, Kcur, inp_pos, nullptr,
12314+ n_rot, rope_type, n_ctx_orig, freq_base, 1.0,
12315+ 0.0, 1.0, beta_fast, beta_slow
12316+ );
12317+ } else {
12318+ Qcur = ggml_rope_ext(
1228512319 ctx0, Qcur, inp_pos, nullptr,
1228612320 n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1228712321 ext_factor, attn_factor, beta_fast, beta_slow
1228812322 );
1228912323
12290- Kcur = ggml_rope_ext(
12324+ Kcur = ggml_rope_ext(
1229112325 ctx0, Kcur, inp_pos, nullptr,
1229212326 n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1229312327 ext_factor, attn_factor, beta_fast, beta_slow
1229412328 );
12329+ }
1229512330
1229612331 cb(Qcur, "Qcur", il);
1229712332 cb(Kcur, "Kcur", il);
@@ -19131,7 +19166,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1913119166 } break;
1913219167 case LLM_ARCH_OLMO2:
1913319168 {
19134- llm = std::make_unique<llm_build_olmo2>(*this, params);
19169+ if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
19170+ llm = std::make_unique<llm_build_olmo2<true>>(*this, params);
19171+ } else {
19172+ llm = std::make_unique<llm_build_olmo2<false>>(*this, params);
19173+ }
1913519174 } break;
1913619175 case LLM_ARCH_OLMOE:
1913719176 {
0 commit comments