Skip to content

Commit 97e0907

Browse files
committed
loading LM
testing Vision model loading
1 parent 2aab52e commit 97e0907

File tree

6 files changed

+119
-40
lines changed

6 files changed

+119
-40
lines changed

convert_hf_to_gguf.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,12 +1494,9 @@ def __init__(self, *args, **kwargs):
14941494
# FIXME: DeepseekOCRVisionModel specific hack
14951495
if self.block_count is None:
14961496
if isinstance(self, DeepseekOCRVisionModel):
1497-
print(self.hparams)
14981497
clip_block_count = self.hparams['layers']
14991498
if clip_block_count is not None:
15001499
self.block_count = clip_block_count
1501-
if sam_block_count is not None:
1502-
self.block_count = sam_block_count if self.block_count is None else self.block_count + sam_block_count
15031500
if self.block_count is None:
15041501
raise KeyError(f"could not find block count using any of: {self.n_block_keys}")
15051502
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
@@ -5793,16 +5790,16 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
57935790

57945791
@ModelBase.register("DeepseekOCRForCausalLM")
57955792
class DeepseekOCRVisionModel(MmprojModel):
5796-
def __init__(self, *args, **kwargs):
5793+
def __init__(self, *args, **kwargs):
57975794
super().__init__(*args, **kwargs)
5798-
5795+
57995796
proc_fname = self.dir_model / "processor_config.json"
5800-
5797+
58015798
if proc_fname.is_file():
58025799
with open(proc_fname, "r") as f:
58035800
self.preprocessor_config = json.load(f)
5804-
5805-
5801+
5802+
58065803
def set_gguf_parameters(self):
58075804
super().set_gguf_parameters()
58085805
hparams = self.hparams
@@ -5860,7 +5857,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
58605857
return [(self.map_tensor_name(name, try_suffixes=("",)), data_torch)]
58615858

58625859
return [(self.map_tensor_name(name), data_torch)]
5863-
5860+
58645861

58655862
@ModelBase.register("Gemma3nForConditionalGeneration")
58665863
class Gemma3NModel(Gemma3Model):
@@ -7095,9 +7092,14 @@ def set_vocab(self):
70957092
raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!")
70967093

70977094
def set_gguf_parameters(self):
7095+
is_ocr = (self.hparams["num_hidden_layers"] == 12)
70987096

7099-
# note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group)
7100-
self.hparams["num_key_value_heads"] = 1
7097+
if is_ocr:
7098+
self.hparams['rope_theta'] = self.hparams.get('rope_theta', 10000.0)
7099+
self.hparams['rms_norm_eps'] = self.hparams.get('rms_norm_eps', 1e-6)
7100+
else:
7101+
# note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group)
7102+
self.hparams["num_key_value_heads"] = 1
71017103

71027104
super().set_gguf_parameters()
71037105
hparams = self.hparams
@@ -7110,13 +7112,16 @@ def set_gguf_parameters(self):
71107112
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
71117113
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
71127114
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
7113-
self.gguf_writer.add_kv_lora_rank(kv_lora_rank)
7115+
if "kv_lora_rank" in hparams and hparams["kv_lora_rank"] is not None:
7116+
self.gguf_writer.add_kv_lora_rank(kv_lora_rank)
71147117

71157118
# note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
7116-
self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"])
7117-
self.gguf_writer.add_value_length(kv_lora_rank)
7118-
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
7119-
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
7119+
if not is_ocr:
7120+
self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"])
7121+
self.gguf_writer.add_value_length(kv_lora_rank)
7122+
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
7123+
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
7124+
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
71207125

71217126
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
71227127
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
@@ -7131,8 +7136,6 @@ def set_gguf_parameters(self):
71317136
else:
71327137
raise ValueError(f"Unsupported scoring_func value: {scoring_func}")
71337138

7134-
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
7135-
71367139
rope_scaling = self.hparams.get("rope_scaling") or {}
71377140
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
71387141
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)

src/llama-arch.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,6 +1446,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
14461446
{ LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" },
14471447
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
14481448
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1449+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1450+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
14491451
{ LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" },
14501452
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
14511453
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },

src/llama-model.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,12 +1562,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15621562
case LLM_ARCH_DEEPSEEK2:
15631563
{
15641564
bool is_lite = (hparams.n_layer == 27);
1565+
bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos);
1566+
15651567
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
15661568
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
1567-
if (!is_lite) {
1569+
if (!is_lite && !is_ocr) {
15681570
ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
15691571
}
1570-
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
1572+
if (!is_ocr) {
1573+
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
1574+
}
15711575
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false);
15721576
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false);
15731577
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
@@ -1583,6 +1587,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15831587
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
15841588

15851589
switch (hparams.n_layer) {
1590+
case 12: type = LLM_TYPE_3B; break;
15861591
case 27: type = LLM_TYPE_16B; break;
15871592
case 60: type = LLM_TYPE_236B; break;
15881593
case 61: type = LLM_TYPE_671B; break;
@@ -4550,6 +4555,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
45504555
case LLM_ARCH_DEEPSEEK2:
45514556
{
45524557
const bool is_lite = (hparams.n_layer == 27);
4558+
const bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos);
45534559

45544560
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
45554561

@@ -4575,6 +4581,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
45754581
for (int i = 0; i < n_layer; ++i) {
45764582
auto & layer = layers[i];
45774583

4584+
if (is_ocr) {
4585+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4586+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
4587+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0);
4588+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0);
4589+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
4590+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4591+
4592+
if (i < (int) hparams.n_layer_dense_lead) {
4593+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
4594+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4595+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4596+
}
4597+
else {
4598+
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
4599+
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
4600+
// MoE branch
4601+
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
4602+
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
4603+
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
4604+
// Shared expert branch
4605+
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
4606+
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0);
4607+
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
4608+
}
4609+
4610+
continue;
4611+
}
4612+
45784613
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
45794614
if (!is_lite) {
45804615
layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);

src/models/deepseek2.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) :
66
llm_graph_context(params) {
77
bool is_lite = (hparams.n_layer == 27);
8+
bool is_ocr = (model.name.find("ocr") != std::string::npos || model.name.find("OCR") != std::string::npos);
89

910
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
1011

@@ -44,7 +45,36 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
4445
cb(cur, "attn_norm", il);
4546

4647
// self_attention
47-
{
48+
if (is_ocr) {
49+
const int n_embed_head = hparams.n_embd / hparams.n_head();
50+
GGML_ASSERT(n_embed_head == n_embd_head_k && n_embed_head == n_embd_head_v);
51+
52+
ggml_tensor * Qcur = NULL;
53+
ggml_tensor * Kcur = NULL;
54+
ggml_tensor * Vcur = NULL;
55+
56+
Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
57+
Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
58+
Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
59+
cb(Qcur, "q", il);
60+
cb(Kcur, "k", il);
61+
cb(Vcur, "v", il);
62+
63+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embed_head, n_head, n_tokens);
64+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embed_head, n_head, n_tokens);
65+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embed_head, n_head, n_tokens);
66+
67+
GGML_ASSERT(fabs(freq_base - 10000.0) < 1e-4);
68+
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_embed_head, rope_type, 0, freq_base, 1, 0, 1, 0, 0);
69+
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_embed_head, rope_type, 0, freq_base, 1, 0, 1, 0, 0);
70+
cb(Qcur, "q_pe", il);
71+
cb(Kcur, "k_pe", il);
72+
73+
cur = build_attn(inp_attn,
74+
model.layers[il].wo, NULL,
75+
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
76+
}
77+
else {
4878
ggml_tensor * q = NULL;
4979
if (!is_lite) {
5080
q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);

tools/mtmd/clip-impl.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,18 @@
130130
#define TN_TOK_EOI "v.eoi"
131131

132132
// deepseek-ocr
133-
#define TN_SAM_POS_EMBD "sam.pos_embd"
134-
#define TN_SAM_PATCH_EMBD "sam.patch_embd.%s"
135-
#define TN_SAM_PRE_NORM "sam.blk.%d.pre_ln.%s"
136-
#define TN_SAM_POST_NORM "sam.blk.%d.post_ln"
137-
#define TN_SAM_ATTN_POS_H "sam.blk.%d.attn.pos_h"
138-
#define TN_SAM_ATTN_POS_W "sam.blk.%d.attn.pos_w"
139-
#define TN_SAM_ATTN_QKV "sam.blk.%d.attn.qkv.%s"
140-
#define TN_SAM_ATTN_OUT "sam.blk.%d.attn.out.%s"
141-
#define TN_SAM_FFN_UP "sam.blk.%d.mlp.lin1.%s"
142-
#define TN_SAM_FFN_DOWN "sam.blk.%d.mlp.lin2.%s"
143-
#define TN_SAM_NECK "sam.neck.%d.%s"
144-
#define TN_SAM_NET "sam.net_%d.%s"
133+
#define TN_SAM_POS_EMBD "v.sam.pos_embd"
134+
#define TN_SAM_PATCH_EMBD "v.sam.patch_embd.%s"
135+
#define TN_SAM_PRE_NORM "v.sam.blk.%d.pre_ln.%s"
136+
#define TN_SAM_POST_NORM "v.sam.blk.%d.post_ln"
137+
#define TN_SAM_ATTN_POS_H "v.sam.blk.%d.attn.pos_h"
138+
#define TN_SAM_ATTN_POS_W "v.sam.blk.%d.attn.pos_w"
139+
#define TN_SAM_ATTN_QKV "v.sam.blk.%d.attn.qkv.%s"
140+
#define TN_SAM_ATTN_OUT "v.sam.blk.%d.attn.out.%s"
141+
#define TN_SAM_FFN_UP "v.sam.blk.%d.mlp.lin1.%s"
142+
#define TN_SAM_FFN_DOWN "v.sam.blk.%d.mlp.lin2.%s"
143+
#define TN_SAM_NECK "v.sam.neck.%d.%s"
144+
#define TN_SAM_NET "v.sam.net_%d.%s"
145145

146146
// align x to upper multiple of n
147147
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
@@ -170,7 +170,7 @@ enum projector_type {
170170
PROJECTOR_TYPE_LIGHTONOCR,
171171
PROJECTOR_TYPE_COGVLM,
172172
PROJECTOR_TYPE_JANUS_PRO,
173-
PROJECTOR_TYPE_DEEPSEEK_OCR,
173+
PROJECTOR_TYPE_DEEPSEEKOCR,
174174
PROJECTOR_TYPE_UNKNOWN,
175175
};
176176

@@ -197,7 +197,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
197197
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
198198
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
199199
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
200-
{ PROJECTOR_TYPE_DEEPSEEK_OCR,"deepseek_orc"},
200+
{ PROJECTOR_TYPE_DEEPSEEKOCR,"deepseekocr"},
201201
};
202202

203203
static projector_type clip_projector_type_from_string(const std::string & str) {

tools/mtmd/clip.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -682,8 +682,8 @@ struct clip_graph {
682682

683683
const int enc_n_patches = enc_image_size / enc_patch_size; // 64
684684

685-
ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_image_size, enc_n_embd);
686-
ggml_tensor * cur = ggml_add(ctx0, inpL, model.position_embeddings);
685+
ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_n_patches, enc_n_embd);
686+
ggml_tensor * cur = ggml_add(ctx0, inpL, model.pos_embed);
687687

688688
// loop over layers
689689
for (int il = 0; il < _depth; il++) {
@@ -842,7 +842,7 @@ struct clip_graph {
842842
ggml_tensor * inp_raw = build_inp_raw();
843843

844844

845-
ggml_tensor * global_features_1 = build_sam_enc(inp_raw);
845+
ggml_tensor * global_features_1 = build_sam_enc(inp_raw, std::max(img.nx, img.ny));
846846

847847
ggml_tensor * global_features_2 = build_dp_ocr_clip(inp_raw, global_features_1);
848848

@@ -2862,6 +2862,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
28622862
{
28632863
res = graph.build_cogvlm();
28642864
} break;
2865+
case PROJECTOR_TYPE_DEEPSEEKOCR:
2866+
{
2867+
res = graph.build_deepseek_ocr();
2868+
} break;
28652869
default:
28662870
{
28672871
res = graph.build_llava();
@@ -3187,6 +3191,11 @@ struct clip_model_loader {
31873191
hparams.ffn_op = FFN_GELU_ERF;
31883192
log_ffn_op = "gelu_erf"; // temporary solution for logging
31893193
} break;
3194+
case PROJECTOR_TYPE_DEEPSEEKOCR:
3195+
{
3196+
hparams.set_limit_image_tokens(8, 1024);
3197+
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
3198+
} break;
31903199
default:
31913200
break;
31923201
}
@@ -3574,7 +3583,7 @@ struct clip_model_loader {
35743583
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
35753584
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
35763585
} break;
3577-
case PROJECTOR_TYPE_DEEPSEEK_OCR:
3586+
case PROJECTOR_TYPE_DEEPSEEKOCR:
35783587
{
35793588
model.pos_embed = get_tensor(TN_SAM_POS_EMBD);
35803589
model.patch_embed_proj_w = get_tensor(string_format(TN_SAM_PATCH_EMBD, "weight"));
@@ -4830,7 +4839,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
48304839
}
48314840
}
48324841
} break;
4833-
case PROJECTOR_TYPE_DEEPSEEK_OCR:
4842+
case PROJECTOR_TYPE_DEEPSEEKOCR:
48344843
{
48354844
// configurable, or read from params
48364845
const int min_num = 2;

0 commit comments

Comments
 (0)