@@ -3849,7 +3849,43 @@ def set_gguf_parameters(self):
38493849 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
38503850 # process the experts separately
38513851 name = name .replace ("language_model." , "" ) # InternVL
3852- if name .startswith ("mlp" ) or name .startswith ("vision_model" ) or name .startswith ("model.vision_tower" ) or name .startswith ("model.multi_modal_projector" ):
3852+
3853+ # handle aggregated expert tensors
3854+ # GGUF stores dimensions reversed from PyTorch, so:
3855+ # PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
3856+ # Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp)
3857+ # Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
3858+ if name .endswith ("mlp.experts.down_proj" ) or name .endswith ("mlp.experts.down_proj.weight" ):
3859+ mapped = f"{ name } .weight" if not name .endswith (".weight" ) else name
3860+ # Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
3861+ # Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
3862+ # Need PyTorch: (128, 2048, 768) [reversed of GGML]
3863+ # So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
3864+ permuted = data_torch .permute (0 , 2 , 1 ).contiguous ()
3865+ return [(self .map_tensor_name (mapped ), permuted )]
3866+
3867+ if name .endswith ("mlp.experts.gate_up_proj" ) or name .endswith ("mlp.experts.gate_up_proj.weight" ):
3868+ if data_torch .ndim < 3 or data_torch .shape [- 1 ] % 2 != 0 :
3869+ raise ValueError (f"Unexpected gate_up_proj shape for { name } : { tuple (data_torch .shape )} " )
3870+ split_dim = data_torch .shape [- 1 ] // 2
3871+ gate = data_torch [..., :split_dim ].contiguous ()
3872+ up = data_torch [..., split_dim :].contiguous ()
3873+ # Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
3874+ # Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
3875+ # Need PyTorch: (128, 768, 2048) [reversed of GGML]
3876+ # So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
3877+ base_name = name .removesuffix (".weight" )
3878+ base = base_name .rsplit ('.' , 1 )[0 ]
3879+ mapped_gate = f"{ base } .gate_proj.weight"
3880+ mapped_up = f"{ base } .up_proj.weight"
3881+ perm_gate = gate .permute (0 , 2 , 1 ).contiguous ()
3882+ perm_up = up .permute (0 , 2 , 1 ).contiguous ()
3883+ return [
3884+ (self .map_tensor_name (mapped_gate ), perm_gate ),
3885+ (self .map_tensor_name (mapped_up ), perm_up ),
3886+ ]
3887+
3888+ if name .startswith ("mlp" ) or name .startswith ("vision_model" ) or name .startswith ("model.vision_tower" ) or name .startswith ("model.multi_modal_projector" ) or name .startswith ("model.visual" ):
38533889 # skip visual tensors
38543890 return []
38553891 if name .find ("experts" ) != - 1 :
@@ -3997,6 +4033,201 @@ def set_vocab(self):
39974033 super ().set_vocab ()
39984034
39994035
4036+ @ModelBase .register ("Qwen3VLForConditionalGeneration" , "Qwen3VLMoeForConditionalGeneration" )
4037+ class Qwen3VLVisionModel (MmprojModel ):
4038+ def __init__ (self , * args , ** kwargs ):
4039+ super ().__init__ (* args , ** kwargs )
4040+ assert self .hparams_vision is not None
4041+ # Compute image_size if not present
4042+ if "image_size" not in self .hparams_vision :
4043+ # For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
4044+ num_pos = self .hparams_vision .get ("num_position_embeddings" , 2304 )
4045+ patch_size = self .hparams_vision .get ("patch_size" , 16 )
4046+ # num_position_embeddings = (image_size / patch_size) ** 2
4047+ # So image_size = sqrt(num_position_embeddings) * patch_size
4048+ image_size = int (num_pos ** 0.5 * patch_size )
4049+ self .hparams_vision ["image_size" ] = image_size
4050+
4051+ # Rename config values for compatibility
4052+ self .hparams_vision ["num_attention_heads" ] = self .hparams_vision .get ("num_heads" )
4053+ self .hparams_vision ["num_hidden_layers" ] = self .hparams_vision .get ("depth" )
4054+
4055+ self .deepstack_layers : list [int ] = list (self .hparams_vision .get ("deepstack_visual_indexes" , []))
4056+
4057+ def set_gguf_parameters (self ):
4058+ super ().set_gguf_parameters ()
4059+ self .gguf_writer .add_clip_projector_type (gguf .VisionProjectorType .QWEN3VL )
4060+ self .gguf_writer .add_vision_use_gelu (True )
4061+
4062+ if self .hparams_vision is not None :
4063+ merge_size = self .hparams_vision .get ("spatial_merge_size" )
4064+ if merge_size is not None :
4065+ self .gguf_writer .add_vision_spatial_merge_size (int (merge_size ))
4066+
4067+ # Use text config's rms_norm_eps for vision attention layernorm eps
4068+ rms_norm_eps = self .global_config .get ("text_config" , {}).get ("rms_norm_eps" , 1e-6 )
4069+ self .gguf_writer .add_vision_attention_layernorm_eps (rms_norm_eps )
4070+
4071+ if self .deepstack_layers :
4072+ self .gguf_writer .add_vision_deepstack_layers (self .deepstack_layers )
4073+
4074+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4075+ # Skip text model tensors - they go in the text model file
4076+ if name .startswith ("model.language_model." ) or name .startswith ("lm_head." ):
4077+ return []
4078+
4079+ if name .startswith ("model.visual." ):
4080+ name = name .replace ("model.visual." , "visual." , 1 )
4081+
4082+ if name .startswith ("visual.deepstack_merger_list." ):
4083+ prefix , rest = name .split ("." , maxsplit = 3 )[2 :]
4084+ idx = int (prefix )
4085+ target = rest
4086+
4087+ tensor_type : gguf .MODEL_TENSOR
4088+ if target .startswith ("norm." ):
4089+ tensor_type = gguf .MODEL_TENSOR .V_DS_NORM
4090+ suffix = target .split ("." , 1 )[1 ]
4091+ elif target .startswith ("linear_fc1." ):
4092+ tensor_type = gguf .MODEL_TENSOR .V_DS_FC1
4093+ suffix = target .split ("." , 1 )[1 ]
4094+ elif target .startswith ("linear_fc2." ):
4095+ tensor_type = gguf .MODEL_TENSOR .V_DS_FC2
4096+ suffix = target .split ("." , 1 )[1 ]
4097+ else :
4098+ raise ValueError (f"Unexpected deepstack tensor: { name } " )
4099+
4100+ new_name = self .format_tensor_name (tensor_type , idx , suffix = f".{ suffix } " )
4101+ return [(new_name , data_torch )]
4102+
4103+ if name .startswith ("visual.merger." ):
4104+ suffix = name .split ("." , 2 )[2 ]
4105+ if suffix .startswith ("linear_fc" ):
4106+ fc_idx_str , tail = suffix .split ("." , 1 )
4107+ fc_num = int (fc_idx_str .replace ("linear_fc" , "" ))
4108+ # Qwen3VL has linear_fc1 and linear_fc2
4109+ # Map to indices 0 and 2 (matching Qwen2VL which uses indices 0 and 2)
4110+ if fc_num == 1 :
4111+ fc_idx = 0
4112+ elif fc_num == 2 :
4113+ fc_idx = 2
4114+ else :
4115+ raise ValueError (f"unexpected fc index { fc_num } in { name } " )
4116+ new_name = self .format_tensor_name (gguf .MODEL_TENSOR .V_MMPROJ , fc_idx , suffix = f".{ tail } " )
4117+ elif suffix .startswith ("norm." ):
4118+ new_name = self .format_tensor_name (gguf .MODEL_TENSOR .V_POST_NORM , suffix = f".{ suffix .split ('.' , 1 )[1 ]} " )
4119+ else :
4120+ raise ValueError (f"Unexpected merger tensor: { name } " )
4121+ return [(new_name , data_torch )]
4122+
4123+ if name == "visual.patch_embed.proj.weight" :
4124+ # split Conv3D into Conv2Ds along temporal dimension
4125+ c1 , c2 , kt , _ , _ = data_torch .shape
4126+ del c1 , c2
4127+ if kt != 2 :
4128+ raise ValueError ("Current implementation only supports temporal_patch_size of 2" )
4129+ return [
4130+ (gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_PATCH ] + ".weight" , data_torch [:, :, 0 , ...]),
4131+ (gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_PATCH ] + ".weight.1" , data_torch [:, :, 1 , ...]),
4132+ ]
4133+
4134+ if name == "visual.patch_embed.proj.bias" :
4135+ # Include the bias - it's used by the C++ code
4136+ return [(gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_PATCH ] + ".bias" , data_torch )]
4137+
4138+ if name .startswith ("visual." ):
4139+ if ".qkv." in name :
4140+ if data_torch .ndim == 2 :
4141+ c3 , _ = data_torch .shape
4142+ else :
4143+ c3 = data_torch .shape [0 ]
4144+ if c3 % 3 != 0 :
4145+ raise ValueError (f"Unexpected QKV shape for { name } : { data_torch .shape } " )
4146+ c = c3 // 3
4147+ wq = data_torch [:c ]
4148+ wk = data_torch [c : c * 2 ]
4149+ wv = data_torch [c * 2 :]
4150+ base = name .replace ("qkv" , "{placeholder}" )
4151+ return [
4152+ (self .map_tensor_name (base .format (placeholder = "q" )), wq ),
4153+ (self .map_tensor_name (base .format (placeholder = "k" )), wk ),
4154+ (self .map_tensor_name (base .format (placeholder = "v" )), wv ),
4155+ ]
4156+
4157+ return [(self .map_tensor_name (name ), data_torch )]
4158+
4159+ # Fall back to parent class for other tensors
4160+ return super ().modify_tensors (data_torch , name , bid )
4161+
4162+
4163+ @ModelBase .register ("Qwen3VLForConditionalGeneration" )
4164+ class Qwen3VLTextModel (Qwen3Model ):
4165+ model_arch = gguf .MODEL_ARCH .QWEN3VL
4166+
4167+ def set_gguf_parameters (self ):
4168+ super ().set_gguf_parameters ()
4169+
4170+ # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4171+ text_config = self .hparams .get ("text_config" , {})
4172+ # rope_scaling is deprecated in V5, use rope_parameters instead
4173+ rope_scaling = text_config .get ("rope_scaling" ) or text_config .get ("rope_parameters" ) or {}
4174+
4175+ if rope_scaling .get ("mrope_section" ):
4176+ # mrope_section contains [time, height, width] dimensions
4177+ mrope_section = rope_scaling ["mrope_section" ]
4178+ # Pad to 4 dimensions [time, height, width, extra]
4179+ while len (mrope_section ) < 4 :
4180+ mrope_section .append (0 )
4181+ self .gguf_writer .add_rope_dimension_sections (mrope_section [:4 ])
4182+
4183+ logger .info (f"MRoPE sections: { mrope_section [:4 ]} " )
4184+
4185+ vision_config = self .hparams .get ("vision_config" , {})
4186+ deepstack_layer_num = len (vision_config .get ("deepstack_visual_indexes" , []))
4187+ self .gguf_writer .add_num_deepstack_layers (deepstack_layer_num )
4188+
4189+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4190+ # Skip vision tensors - they go in the mmproj file
4191+ if name .startswith ("model.visual." ):
4192+ return []
4193+
4194+ return super ().modify_tensors (data_torch , name , bid )
4195+
4196+
4197+ @ModelBase .register ("Qwen3VLMoeForConditionalGeneration" )
4198+ class Qwen3VLMoeTextModel (Qwen3MoeModel ):
4199+ model_arch = gguf .MODEL_ARCH .QWEN3VLMOE
4200+
4201+ def set_gguf_parameters (self ):
4202+ super ().set_gguf_parameters ()
4203+
4204+ # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4205+ text_config = self .hparams .get ("text_config" , {})
4206+ # rope_scaling is deprecated in V5, use rope_parameters instead
4207+ rope_scaling = text_config .get ("rope_scaling" ) or text_config .get ("rope_parameters" ) or {}
4208+
4209+ if rope_scaling .get ("mrope_section" ):
4210+ # mrope_section contains [time, height, width] dimensions
4211+ mrope_section = rope_scaling ["mrope_section" ]
4212+ # Pad to 4 dimensions [time, height, width, extra]
4213+ while len (mrope_section ) < 4 :
4214+ mrope_section .append (0 )
4215+ self .gguf_writer .add_rope_dimension_sections (mrope_section [:4 ])
4216+
4217+ logger .info (f"MRoPE sections: { mrope_section [:4 ]} " )
4218+
4219+ vision_config = self .hparams .get ("vision_config" , {})
4220+ deepstack_layer_num = len (vision_config .get ("deepstack_visual_indexes" , []))
4221+ self .gguf_writer .add_num_deepstack_layers (deepstack_layer_num )
4222+
4223+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4224+ # Skip vision tensors - they go in the mmproj file
4225+ if name .startswith ("model.visual." ):
4226+ return []
4227+
4228+ return super ().modify_tensors (data_torch , name , bid )
4229+
4230+
40004231@ModelBase .register ("GPT2LMHeadModel" )
40014232class GPT2Model (TextModel ):
40024233 model_arch = gguf .MODEL_ARCH .GPT2
0 commit comments