Skip to content

Commit 677540c

Browse files
wqerrewetwFMayranJJJYmmmThireusyairpatch
authored
Qwen3vl tmp (#13)
* simple fix proposal for Qwen2.5 VL's cache causal masking issues. This is just a quick and dirty demonstration. * replace the map with a vector for performance * fix compiler warning in llama_ubatch struct construction * adapting the previous fix to the syntax used by other fields of the ubatch * support qwen3vl series. Co-authored-by: Thireus ☠ <[email protected]> Co-authored-by: yairpatch <[email protected]> Co-authored-by: LETS-BEE <[email protected]> * bugfix: fix the arch check for qwen3vl-moe. --------- Co-authored-by: FMayran <[email protected]> Co-authored-by: JJJYmmm <[email protected]> Co-authored-by: Thireus ☠ <[email protected]> Co-authored-by: yairpatch <[email protected]> Co-authored-by: LETS-BEE <[email protected]>
1 parent 8a3b149 commit 677540c

File tree

20 files changed

+1127
-126
lines changed

20 files changed

+1127
-126
lines changed

convert_hf_to_gguf.py

Lines changed: 232 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3852,7 +3852,43 @@ def set_gguf_parameters(self):
38523852
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
38533853
# process the experts separately
38543854
name = name.replace("language_model.", "") # InternVL
3855-
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
3855+
3856+
# handle aggregated expert tensors
3857+
# GGUF stores dimensions reversed from PyTorch, so:
3858+
# PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
3859+
# Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp)
3860+
# Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
3861+
if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"):
3862+
mapped = f"{name}.weight" if not name.endswith(".weight") else name
3863+
# Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
3864+
# Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
3865+
# Need PyTorch: (128, 2048, 768) [reversed of GGML]
3866+
# So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
3867+
permuted = data_torch.permute(0, 2, 1).contiguous()
3868+
return [(self.map_tensor_name(mapped), permuted)]
3869+
3870+
if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"):
3871+
if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0:
3872+
raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}")
3873+
split_dim = data_torch.shape[-1] // 2
3874+
gate = data_torch[..., :split_dim].contiguous()
3875+
up = data_torch[..., split_dim:].contiguous()
3876+
# Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
3877+
# Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
3878+
# Need PyTorch: (128, 768, 2048) [reversed of GGML]
3879+
# So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
3880+
base_name = name.removesuffix(".weight")
3881+
base = base_name.rsplit('.', 1)[0]
3882+
mapped_gate = f"{base}.gate_proj.weight"
3883+
mapped_up = f"{base}.up_proj.weight"
3884+
perm_gate = gate.permute(0, 2, 1).contiguous()
3885+
perm_up = up.permute(0, 2, 1).contiguous()
3886+
return [
3887+
(self.map_tensor_name(mapped_gate), perm_gate),
3888+
(self.map_tensor_name(mapped_up), perm_up),
3889+
]
3890+
3891+
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"):
38563892
# skip visual tensors
38573893
return []
38583894
if name.find("experts") != -1:
@@ -4004,6 +4040,201 @@ def set_vocab(self):
40044040
super().set_vocab()
40054041

40064042

4043+
@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
4044+
class Qwen3VLVisionModel(MmprojModel):
4045+
def __init__(self, *args, **kwargs):
4046+
super().__init__(*args, **kwargs)
4047+
assert self.hparams_vision is not None
4048+
# Compute image_size if not present
4049+
if "image_size" not in self.hparams_vision:
4050+
# For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
4051+
num_pos = self.hparams_vision.get("num_position_embeddings", 2304)
4052+
patch_size = self.hparams_vision.get("patch_size", 16)
4053+
# num_position_embeddings = (image_size / patch_size) ** 2
4054+
# So image_size = sqrt(num_position_embeddings) * patch_size
4055+
image_size = int(num_pos**0.5 * patch_size)
4056+
self.hparams_vision["image_size"] = image_size
4057+
4058+
# Rename config values for compatibility
4059+
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
4060+
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")
4061+
4062+
self.deepstack_layers: list[int] = list(self.hparams_vision.get("deepstack_visual_indexes", []))
4063+
4064+
def set_gguf_parameters(self):
4065+
super().set_gguf_parameters()
4066+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
4067+
self.gguf_writer.add_vision_use_gelu(True)
4068+
4069+
if self.hparams_vision is not None:
4070+
merge_size = self.hparams_vision.get("spatial_merge_size")
4071+
if merge_size is not None:
4072+
self.gguf_writer.add_vision_spatial_merge_size(int(merge_size))
4073+
4074+
# Use text config's rms_norm_eps for vision attention layernorm eps
4075+
rms_norm_eps = self.global_config.get("text_config", {}).get("rms_norm_eps", 1e-6)
4076+
self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)
4077+
4078+
if self.deepstack_layers:
4079+
self.gguf_writer.add_vision_deepstack_layers(self.deepstack_layers)
4080+
4081+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4082+
# Skip text model tensors - they go in the text model file
4083+
if name.startswith("model.language_model.") or name.startswith("lm_head."):
4084+
return []
4085+
4086+
if name.startswith("model.visual."):
4087+
name = name.replace("model.visual.", "visual.", 1)
4088+
4089+
if name.startswith("visual.deepstack_merger_list."):
4090+
prefix, rest = name.split(".", maxsplit=3)[2:]
4091+
idx = int(prefix)
4092+
target = rest
4093+
4094+
tensor_type: gguf.MODEL_TENSOR
4095+
if target.startswith("norm."):
4096+
tensor_type = gguf.MODEL_TENSOR.V_DS_NORM
4097+
suffix = target.split(".", 1)[1]
4098+
elif target.startswith("linear_fc1."):
4099+
tensor_type = gguf.MODEL_TENSOR.V_DS_FC1
4100+
suffix = target.split(".", 1)[1]
4101+
elif target.startswith("linear_fc2."):
4102+
tensor_type = gguf.MODEL_TENSOR.V_DS_FC2
4103+
suffix = target.split(".", 1)[1]
4104+
else:
4105+
raise ValueError(f"Unexpected deepstack tensor: {name}")
4106+
4107+
new_name = self.format_tensor_name(tensor_type, idx, suffix=f".{suffix}")
4108+
return [(new_name, data_torch)]
4109+
4110+
if name.startswith("visual.merger."):
4111+
suffix = name.split(".", 2)[2]
4112+
if suffix.startswith("linear_fc"):
4113+
fc_idx_str, tail = suffix.split(".", 1)
4114+
fc_num = int(fc_idx_str.replace("linear_fc", ""))
4115+
# Qwen3VL has linear_fc1 and linear_fc2
4116+
# Map to indices 0 and 2 (matching Qwen2VL which uses indices 0 and 2)
4117+
if fc_num == 1:
4118+
fc_idx = 0
4119+
elif fc_num == 2:
4120+
fc_idx = 2
4121+
else:
4122+
raise ValueError(f"unexpected fc index {fc_num} in {name}")
4123+
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, fc_idx, suffix=f".{tail}")
4124+
elif suffix.startswith("norm."):
4125+
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_POST_NORM, suffix=f".{suffix.split('.', 1)[1]}")
4126+
else:
4127+
raise ValueError(f"Unexpected merger tensor: {name}")
4128+
return [(new_name, data_torch)]
4129+
4130+
if name == "visual.patch_embed.proj.weight":
4131+
# split Conv3D into Conv2Ds along temporal dimension
4132+
c1, c2, kt, _, _ = data_torch.shape
4133+
del c1, c2
4134+
if kt != 2:
4135+
raise ValueError("Current implementation only supports temporal_patch_size of 2")
4136+
return [
4137+
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", data_torch[:, :, 0, ...]),
4138+
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]),
4139+
]
4140+
4141+
if name == "visual.patch_embed.proj.bias":
4142+
# Include the bias - it's used by the C++ code
4143+
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch)]
4144+
4145+
if name.startswith("visual."):
4146+
if ".qkv." in name:
4147+
if data_torch.ndim == 2:
4148+
c3, _ = data_torch.shape
4149+
else:
4150+
c3 = data_torch.shape[0]
4151+
if c3 % 3 != 0:
4152+
raise ValueError(f"Unexpected QKV shape for {name}: {data_torch.shape}")
4153+
c = c3 // 3
4154+
wq = data_torch[:c]
4155+
wk = data_torch[c: c * 2]
4156+
wv = data_torch[c * 2:]
4157+
base = name.replace("qkv", "{placeholder}")
4158+
return [
4159+
(self.map_tensor_name(base.format(placeholder="q")), wq),
4160+
(self.map_tensor_name(base.format(placeholder="k")), wk),
4161+
(self.map_tensor_name(base.format(placeholder="v")), wv),
4162+
]
4163+
4164+
return [(self.map_tensor_name(name), data_torch)]
4165+
4166+
# Fall back to parent class for other tensors
4167+
return super().modify_tensors(data_torch, name, bid)
4168+
4169+
4170+
@ModelBase.register("Qwen3VLForConditionalGeneration")
4171+
class Qwen3VLTextModel(Qwen3Model):
4172+
model_arch = gguf.MODEL_ARCH.QWEN3VL
4173+
4174+
def set_gguf_parameters(self):
4175+
super().set_gguf_parameters()
4176+
4177+
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4178+
text_config = self.hparams.get("text_config", {})
4179+
# rope_scaling is deprecated in V5, use rope_parameters instead
4180+
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
4181+
4182+
if rope_scaling.get("mrope_section"):
4183+
# mrope_section contains [time, height, width] dimensions
4184+
mrope_section = rope_scaling["mrope_section"]
4185+
# Pad to 4 dimensions [time, height, width, extra]
4186+
while len(mrope_section) < 4:
4187+
mrope_section.append(0)
4188+
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
4189+
4190+
logger.info(f"MRoPE sections: {mrope_section[:4]}")
4191+
4192+
vision_config = self.hparams.get("vision_config", {})
4193+
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
4194+
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
4195+
4196+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4197+
# Skip vision tensors - they go in the mmproj file
4198+
if name.startswith("model.visual."):
4199+
return []
4200+
4201+
return super().modify_tensors(data_torch, name, bid)
4202+
4203+
4204+
@ModelBase.register("Qwen3VLMoeForConditionalGeneration")
4205+
class Qwen3VLMoeTextModel(Qwen3MoeModel):
4206+
model_arch = gguf.MODEL_ARCH.QWEN3VLMOE
4207+
4208+
def set_gguf_parameters(self):
4209+
super().set_gguf_parameters()
4210+
4211+
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4212+
text_config = self.hparams.get("text_config", {})
4213+
# rope_scaling is deprecated in V5, use rope_parameters instead
4214+
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
4215+
4216+
if rope_scaling.get("mrope_section"):
4217+
# mrope_section contains [time, height, width] dimensions
4218+
mrope_section = rope_scaling["mrope_section"]
4219+
# Pad to 4 dimensions [time, height, width, extra]
4220+
while len(mrope_section) < 4:
4221+
mrope_section.append(0)
4222+
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
4223+
4224+
logger.info(f"MRoPE sections: {mrope_section[:4]}")
4225+
4226+
vision_config = self.hparams.get("vision_config", {})
4227+
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
4228+
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
4229+
4230+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4231+
# Skip vision tensors - they go in the mmproj file
4232+
if name.startswith("model.visual."):
4233+
return []
4234+
4235+
return super().modify_tensors(data_torch, name, bid)
4236+
4237+
40074238
@ModelBase.register("GPT2LMHeadModel")
40084239
class GPT2Model(TextModel):
40094240
model_arch = gguf.MODEL_ARCH.GPT2

ggml/include/ggml.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@
242242
#define GGML_ROPE_TYPE_NEOX 2
243243
#define GGML_ROPE_TYPE_MROPE 8
244244
#define GGML_ROPE_TYPE_VISION 24
245+
#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000
245246

246247
#define GGML_MROPE_SECTIONS 4
247248

ggml/src/ggml-cpu/ops.cpp

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5474,7 +5474,7 @@ static void ggml_rope_cache_init(
54745474
}
54755475

54765476
static void ggml_mrope_cache_init(
5477-
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
5477+
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
54785478
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
54795479
float * cache, float sin_sign, float theta_scale) {
54805480
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
@@ -5509,14 +5509,24 @@ static void ggml_mrope_cache_init(
55095509
}
55105510

55115511
float theta = theta_t;
5512-
if (sector >= sections[0] && sector < sec_w) {
5513-
theta = theta_h;
5514-
}
5515-
else if (sector >= sec_w && sector < sec_w + sections[2]) {
5516-
theta = theta_w;
5517-
}
5518-
else if (sector >= sec_w + sections[2]) {
5519-
theta = theta_e;
5512+
if (is_imrope) { // qwen3vl apply interleaved mrope
5513+
if (sector % 3 == 1 && sector < 3 * sections[1]) {
5514+
theta = theta_h;
5515+
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
5516+
theta = theta_w;
5517+
} else {
5518+
theta = theta_e;
5519+
}
5520+
} else {
5521+
if (sector >= sections[0] && sector < sec_w) {
5522+
theta = theta_h;
5523+
}
5524+
else if (sector >= sec_w && sector < sec_w + sections[2]) {
5525+
theta = theta_w;
5526+
}
5527+
else if (sector >= sec_w + sections[2]) {
5528+
theta = theta_e;
5529+
}
55205530
}
55215531

55225532
rope_yarn(
@@ -5589,6 +5599,7 @@ static void ggml_compute_forward_rope_f32(
55895599

55905600
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
55915601
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
5602+
const bool is_imrope = mode & GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
55925603
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
55935604

55945605
if (is_mrope) {
@@ -5627,7 +5638,7 @@ static void ggml_compute_forward_rope_f32(
56275638
const int64_t p_w = pos[i2 + ne2 * 2];
56285639
const int64_t p_e = pos[i2 + ne2 * 3];
56295640
ggml_mrope_cache_init(
5630-
p_t, p_h, p_w, p_e, sections, is_vision,
5641+
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
56315642
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
56325643
}
56335644

@@ -5775,6 +5786,7 @@ static void ggml_compute_forward_rope_f16(
57755786

57765787
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
57775788
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5789+
const bool is_imrope = mode & GGML_ROPE_TYPE_IMROPE;
57785790
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
57795791

57805792
if (is_mrope) {
@@ -5813,7 +5825,7 @@ static void ggml_compute_forward_rope_f16(
58135825
const int64_t p_w = pos[i2 + ne2 * 2];
58145826
const int64_t p_e = pos[i2 + ne2 * 3];
58155827
ggml_mrope_cache_init(
5816-
p_t, p_h, p_w, p_e, sections, is_vision,
5828+
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
58175829
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
58185830
}
58195831

0 commit comments

Comments
 (0)