@@ -214,6 +214,7 @@ enum llm_arch {
214214 LLM_ARCH_GEMMA,
215215 LLM_ARCH_STARCODER2,
216216 LLM_ARCH_MAMBA,
217+ LLM_ARCH_COMMAND_R,
217218 LLM_ARCH_UNKNOWN,
218219};
219220
@@ -243,6 +244,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
243244 { LLM_ARCH_GEMMA, "gemma" },
244245 { LLM_ARCH_STARCODER2, "starcoder2" },
245246 { LLM_ARCH_MAMBA, "mamba" },
247+ { LLM_ARCH_COMMAND_R, "command-r" },
246248 { LLM_ARCH_UNKNOWN, "(unknown)" },
247249};
248250
@@ -268,6 +270,7 @@ enum llm_kv {
268270 LLM_KV_EXPERT_COUNT,
269271 LLM_KV_EXPERT_USED_COUNT,
270272 LLM_KV_POOLING_TYPE,
273+ LLM_KV_LOGIT_SCALE,
271274
272275 LLM_KV_ATTENTION_HEAD_COUNT,
273276 LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -332,6 +335,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
332335 { LLM_KV_EXPERT_COUNT, "%s.expert_count" },
333336 { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
334337 { LLM_KV_POOLING_TYPE , "%s.pooling_type" },
338+ { LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
335339
336340 { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
337341 { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -838,6 +842,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
838842 { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
839843 },
840844 },
845+ {
846+ LLM_ARCH_COMMAND_R,
847+ {
848+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
849+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
850+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
851+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
852+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
853+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
854+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
855+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
856+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
857+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
858+ },
859+ },
841860 {
842861 LLM_ARCH_UNKNOWN,
843862 {
@@ -1597,6 +1616,7 @@ enum e_model {
15971616 MODEL_20B,
15981617 MODEL_30B,
15991618 MODEL_34B,
1619+ MODEL_35B,
16001620 MODEL_40B,
16011621 MODEL_65B,
16021622 MODEL_70B,
@@ -1643,6 +1663,7 @@ struct llama_hparams {
16431663
16441664 float f_clamp_kqv = 0.0f;
16451665 float f_max_alibi_bias = 0.0f;
1666+ float f_logit_scale = 0.0f;
16461667
16471668 bool causal_attn = true;
16481669 bool need_kq_pos = false;
@@ -3231,6 +3252,7 @@ static const char * llama_model_type_name(e_model type) {
32313252 case MODEL_20B: return "20B";
32323253 case MODEL_30B: return "30B";
32333254 case MODEL_34B: return "34B";
3255+ case MODEL_35B: return "35B";
32343256 case MODEL_40B: return "40B";
32353257 case MODEL_65B: return "65B";
32363258 case MODEL_70B: return "70B";
@@ -3623,6 +3645,15 @@ static void llm_load_hparams(
36233645 default: model.type = e_model::MODEL_UNKNOWN;
36243646 }
36253647 } break;
3648+ case LLM_ARCH_COMMAND_R:
3649+ {
3650+ ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
3651+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
3652+ switch (hparams.n_layer) {
3653+ case 40: model.type = e_model::MODEL_35B; break;
3654+ default: model.type = e_model::MODEL_UNKNOWN;
3655+ }
3656+ } break;
36263657 default: (void)0;
36273658 }
36283659
@@ -3944,6 +3975,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
39443975 LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
39453976 LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
39463977 LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
3978+ LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
39473979 LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
39483980 LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
39493981 LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
@@ -4918,6 +4950,37 @@ static bool llm_load_tensors(
49184950 layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
49194951 }
49204952 } break;
4953+ case LLM_ARCH_COMMAND_R:
4954+ {
4955+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4956+
4957+ // output
4958+ {
4959+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
4960+ // init output from the input tok embed
4961+ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4962+ ml.n_created--; // artificial tensor
4963+ ml.size_data += ggml_nbytes(model.output);
4964+ }
4965+
4966+ for (int i = 0; i < n_layer; ++i) {
4967+ ggml_context * ctx_layer = ctx_for_layer(i);
4968+ ggml_context * ctx_split = ctx_for_layer_split(i);
4969+
4970+ auto & layer = model.layers[i];
4971+
4972+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
4973+
4974+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
4975+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
4976+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
4977+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
4978+
4979+ layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
4980+ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
4981+ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
4982+ }
4983+ } break;
49214984 default:
49224985 throw std::runtime_error("unknown architecture");
49234986 }
@@ -8315,6 +8378,121 @@ struct llm_build_context {
83158378
83168379 return gf;
83178380 }
8381+
8382+ struct ggml_cgraph * build_command_r() {
8383+
8384+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
8385+
8386+ const int64_t n_embd_head = hparams.n_embd_head_v;
8387+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
8388+ const float f_logit_scale = hparams.f_logit_scale;
8389+
8390+ struct ggml_tensor * cur;
8391+ struct ggml_tensor * inpL;
8392+
8393+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
8394+
8395+ // inp_pos - contains the positions
8396+ struct ggml_tensor * inp_pos = build_inp_pos();
8397+
8398+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
8399+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
8400+
8401+ for (int il = 0; il < n_layer; ++il) {
8402+
8403+ // norm
8404+ cur = llm_build_norm(ctx0, inpL, hparams,
8405+ model.layers[il].attn_norm, NULL,
8406+ LLM_NORM, cb, il);
8407+ cb(cur, "attn_norm", il);
8408+ struct ggml_tensor * ffn_inp = cur;
8409+
8410+ // self-attention
8411+ {
8412+ // compute Q and K and RoPE them
8413+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
8414+ cb(Qcur, "Qcur", il);
8415+ if (model.layers[il].bq) {
8416+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
8417+ cb(Qcur, "Qcur", il);
8418+ }
8419+
8420+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
8421+ cb(Kcur, "Kcur", il);
8422+ if (model.layers[il].bk) {
8423+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
8424+ cb(Kcur, "Kcur", il);
8425+ }
8426+
8427+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
8428+ cb(Vcur, "Vcur", il);
8429+ if (model.layers[il].bv) {
8430+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
8431+ cb(Vcur, "Vcur", il);
8432+ }
8433+
8434+ Qcur = ggml_rope_custom(
8435+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
8436+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
8437+ ext_factor, attn_factor, beta_fast, beta_slow
8438+ );
8439+ cb(Qcur, "Qcur", il);
8440+
8441+ Kcur = ggml_rope_custom(
8442+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
8443+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
8444+ ext_factor, attn_factor, beta_fast, beta_slow
8445+ );
8446+ cb(Kcur, "Kcur", il);
8447+
8448+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
8449+ model.layers[il].wo, model.layers[il].bo,
8450+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
8451+ }
8452+
8453+ struct ggml_tensor * attn_out = cur;
8454+
8455+ // feed-forward network
8456+ {
8457+ cur = llm_build_ffn(ctx0, ffn_inp,
8458+ model.layers[il].ffn_up, NULL,
8459+ model.layers[il].ffn_gate, NULL,
8460+ model.layers[il].ffn_down, NULL,
8461+ NULL,
8462+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
8463+ cb(cur, "ffn_out", il);
8464+ }
8465+
8466+ // add together residual + FFN + self-attention
8467+ cur = ggml_add(ctx0, cur, inpL);
8468+ cur = ggml_add(ctx0, cur, attn_out);
8469+ cb(cur, "l_out", il);
8470+
8471+ // input for next layer
8472+ inpL = cur;
8473+ }
8474+
8475+ cur = inpL;
8476+
8477+ cur = llm_build_norm(ctx0, cur, hparams,
8478+ model.output_norm, NULL,
8479+ LLM_NORM, cb, -1);
8480+ cb(cur, "result_norm", -1);
8481+
8482+ // lm_head
8483+ cur = ggml_mul_mat(ctx0, model.output, cur);
8484+
8485+ if (f_logit_scale) {
8486+ cur = ggml_scale(ctx0, cur, f_logit_scale);
8487+ }
8488+
8489+ cb(cur, "result_output", -1);
8490+
8491+ ggml_build_forward_expand(gf, cur);
8492+
8493+ return gf;
8494+
8495+ }
83188496};
83198497
83208498static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -8497,6 +8675,10 @@ static struct ggml_cgraph * llama_build_graph(
84978675 {
84988676 result = llm.build_mamba();
84998677 } break;
8678+ case LLM_ARCH_COMMAND_R:
8679+ {
8680+ result = llm.build_command_r();
8681+ } break;
85008682 default:
85018683 GGML_ASSERT(false);
85028684 }
@@ -13147,6 +13329,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1314713329 case LLM_ARCH_ORION:
1314813330 case LLM_ARCH_INTERNLM2:
1314913331 case LLM_ARCH_MINICPM:
13332+ case LLM_ARCH_COMMAND_R:
1315013333 return LLAMA_ROPE_TYPE_NORM;
1315113334
1315213335 // the pairs of head values are offset by n_rot/2
0 commit comments