@@ -871,44 +871,31 @@ def get_vllm_state_dict(
871871
872872def _extract_short_conv_layer (short_conv , conv_prefix , state_dict , quant_state_dict , get_state_dict ):
873873 """
874- Extracts LFM2 hybrid short convolution layers (non-attention layers).
875-
876- vLLM ShortConv: in_proj (MergedColumnParallelLinear, splits into B, C, x),
877- out_proj (RowParallelLinear),
878- conv (ColumnParallelLinear, weight unsqueezed for conv1d)
879-
880- HF Lfm2ShortConv: in_proj (nn.Linear, hidden -> 3*hidden),
881- out_proj (nn.Linear, hidden -> hidden),
882- conv (nn.Conv1d, depthwise with groups=hidden_size)
874+ Extracts LFM2 short convolution layer weights from vLLM to HF format.
875+ in_proj is extracted directly from base layer since get_state_dict
876+ incorrectly handles MergedColumnParallelLinear's 3-way split.
883877 """
884- # in_proj: vLLM uses MergedColumnParallelLinear with 3 output parts (B, C, x).
885- # HF stores as a single nn.Linear(hidden, 3*hidden). We must extract the full
886- # merged weight directly instead of going through get_state_dict which may
887- # incorrectly slice or return only one shard.
878+ # in_proj: extract full merged weight directly
888879 in_proj = getattr (short_conv .in_proj , "base_layer" , short_conv .in_proj )
889880 in_proj_weight = in_proj .weight
890881 in_proj_weight .requires_grad_ (False )
891882 state_dict [f"{ conv_prefix } .in_proj.weight" ] = in_proj_weight
892883 quant_state_dict [f"{ conv_prefix } .in_proj.weight" ] = in_proj_weight
893- # Handle in_proj bias if present
894884 in_proj_bias = getattr (in_proj , "bias" , None )
895885 if in_proj_bias is not None :
896886 in_proj_bias .requires_grad_ (False )
897887 state_dict [f"{ conv_prefix } .in_proj.bias" ] = in_proj_bias
898888 quant_state_dict [f"{ conv_prefix } .in_proj.bias" ] = in_proj_bias
899889
900- # out_proj: direct mapping
890+ # out_proj
901891 get_state_dict (f"{ conv_prefix } .out_proj" , 0 , state_dict , short_conv .out_proj )
902892
903- # conv: vLLM stores as ColumnParallelLinear with weight shape (out, 1, kernel)
904- # HF expects nn.Conv1d weight with same shape (hidden, 1, kernel) for depthwise conv
893+ # conv (nn.Conv1d weight)
905894 conv_module = short_conv .conv
906895 conv_weight = getattr (conv_module , "base_layer" , conv_module ).weight
907896 conv_weight .requires_grad_ (False )
908897 state_dict [f"{ conv_prefix } .conv.weight" ] = conv_weight
909898 quant_state_dict [f"{ conv_prefix } .conv.weight" ] = conv_weight
910-
911- # Handle conv bias if present
912899 conv_bias = getattr (conv_module , "bias" , None )
913900 if conv_bias is None and hasattr (conv_module , "base_layer" ):
914901 conv_bias = getattr (conv_module .base_layer , "bias" , None )
@@ -919,6 +906,50 @@ def _extract_short_conv_layer(short_conv, conv_prefix, state_dict, quant_state_d
919906 pass
920907
921908
909+ def _extract_lfm2_layer (layer , kk , prefix , state_dict , quant_state_dict , get_state_dict ):
910+ """Extracts all components of a single LFM2 hybrid layer."""
911+ layer_prefix = f"{ prefix } .layers.{ kk } "
912+
913+ # Attention or short_conv
914+ if hasattr (layer , "self_attn" ):
915+ attn_prefix = f"{ layer_prefix } .self_attn"
916+ qkv_proj = layer .self_attn .qkv_proj
917+ get_state_dict (f"{ attn_prefix } .q_proj" , 0 , state_dict , qkv_proj )
918+ get_state_dict (f"{ attn_prefix } .k_proj" , 1 , state_dict , qkv_proj )
919+ get_state_dict (f"{ attn_prefix } .v_proj" , 2 , state_dict , qkv_proj )
920+ get_state_dict (f"{ attn_prefix } .out_proj" , 0 , state_dict , layer .self_attn .out_proj )
921+ elif hasattr (layer , "short_conv" ):
922+ _extract_short_conv_layer (layer .short_conv , f"{ layer_prefix } .conv" ,
923+ state_dict , quant_state_dict , get_state_dict )
924+
925+ # Feed-forward (w1/w3 merged in vLLM, w2 separate)
926+ if hasattr (layer , "feed_forward" ):
927+ ff_prefix = f"{ layer_prefix } .feed_forward"
928+ w1_proj = layer .feed_forward .w1
929+ get_state_dict (f"{ ff_prefix } .w1" , 0 , state_dict , w1_proj )
930+ get_state_dict (f"{ ff_prefix } .w3" , 1 , state_dict , w1_proj )
931+ get_state_dict (f"{ ff_prefix } .w2" , 0 , state_dict , layer .feed_forward .w2 )
932+
933+ # Layer norms
934+ lfm2_norms = [
935+ ("operator_norm" , f"{ layer_prefix } .operator_norm" ),
936+ ("ffn_norm" , f"{ layer_prefix } .ffn_norm" ),
937+ ("self_attn.q_layernorm" , f"{ layer_prefix } .self_attn.q_layernorm" ),
938+ ("self_attn.k_layernorm" , f"{ layer_prefix } .self_attn.k_layernorm" ),
939+ ]
940+ for attr_path , norm_key in lfm2_norms :
941+ obj = layer
942+ for part in attr_path .split ("." ):
943+ obj = getattr (obj , part , None )
944+ if obj is None :
945+ break
946+ if obj is not None and hasattr (obj , "weight" ):
947+ w = obj .weight .data
948+ state_dict [f"{ norm_key } .weight" ] = w
949+ quant_state_dict [f"{ norm_key } .weight" ] = w
950+ pass
951+
952+
922953def _get_vllm_state_dict (llm , return_state_dict = False , config = None , is_vision_model = False ):
923954 # All Unsloth Zoo code licensed under LGPLv3
924955 # Unmerges vLLM modules and returns HF equivalent state_dict
@@ -1139,57 +1170,45 @@ def _is_fused_module(name: str) -> bool:
11391170 packed = packed_modules_mapping .get (name )
11401171 return isinstance (packed , (list , tuple )) and len (packed ) == 1 and packed [0 ] == name
11411172
1142- # All layers
11431173 skipped_layernorms = []
1144- for kk in range (len (vllm_text_model .layers )):
1145- layer = vllm_text_model .layers [kk ]
1146- if hasattr (layer , "self_attn" ):
1147- prefix = f"{ vllm_text_model_prefix } .layers.{ kk } .self_attn"
1148- qkv_proj = layer .self_attn .qkv_proj
1149-
1150- use_fused_qkv = _is_fused_module ("qkv_proj" )
1151- if use_fused_qkv :
1152- # For some model types like phi3 vllm will expect fused qkv (e.g. Phi3, Phi3.5-mini-instruct, Phi4-mini-instruct)
1153- # so we should not split them here otherwise there will be a size mismatch when activating the adapter
1154- # see https://github.com/vllm-project/vllm/blob/9b693d023cf595e60b5346fdeeb41cf2a6eda838/vllm/model_executor/models/phi3.py
1155- get_state_dict (f"{ prefix } .qkv_proj" , 0 , state_dict , qkv_proj , slice_weights = False )
1156- else :
1157- get_state_dict (f"{ prefix } .q_proj" , 0 , state_dict , qkv_proj )
1158- get_state_dict (f"{ prefix } .k_proj" , 1 , state_dict , qkv_proj )
1159- get_state_dict (f"{ prefix } .v_proj" , 2 , state_dict , qkv_proj )
1160-
1161- # Extract o_proj or out_proj depending on model architecture
1162- # LFM2 uses out_proj, most other models use o_proj
1163- if hasattr (layer .self_attn , "o_proj" ):
1174+ if model_type == "lfm2" :
1175+ for kk in range (len (vllm_text_model .layers )):
1176+ layer = vllm_text_model .layers [kk ]
1177+ _extract_lfm2_layer (layer , kk , vllm_text_model_prefix ,
1178+ state_dict , quant_state_dict , get_state_dict )
1179+ pass
1180+ else :
1181+ for kk in range (len (vllm_text_model .layers )):
1182+ layer = vllm_text_model .layers [kk ]
1183+ if hasattr (layer , "self_attn" ):
1184+ prefix = f"{ vllm_text_model_prefix } .layers.{ kk } .self_attn"
1185+ qkv_proj = layer .self_attn .qkv_proj
11641186 o_proj = layer .self_attn .o_proj
1165- get_state_dict (f"{ prefix } .o_proj" , 0 , state_dict , o_proj )
1166- elif hasattr (layer .self_attn , "out_proj" ):
1167- out_proj = layer .self_attn .out_proj
1168- get_state_dict (f"{ prefix } .out_proj" , 0 , state_dict , out_proj )
1169-
1170- elif hasattr (layer , "cross_attn" ):
1171- prefix = f"{ vllm_text_model_prefix } .layers.{ kk } .cross_attn"
1172- qkv_proj = layer .cross_attn .qkv_proj
1173- o_proj = layer .cross_attn .o_proj
1174- name = re .sub (r"\.(\d+)\." , r"[\1]." , prefix .replace ('model.language_model' ,'language_model.model' , 1 ) + ".qkv_proj" )
1175- cross_attn_layer = eval (f'vllm_internals.{ name } ' )
1176- q_proj = cross_attn_layer .proj ['q_proj_decoder' ]
1177- kv_proj = cross_attn_layer .proj ['kv_proj_encoder' ]
1178- get_state_dict (f"{ prefix } .q_proj" , 0 , state_dict , q_proj )
1179- get_state_dict (f"{ prefix } .k_proj" , 1 , state_dict , kv_proj )
1180- get_state_dict (f"{ prefix } .v_proj" , 2 , state_dict , kv_proj )
1181- get_state_dict (f"{ prefix } .o_proj" , 0 , state_dict , o_proj )
11821187
1183- elif hasattr (layer , "short_conv" ):
1184- # LFM2 hybrid short convolution layers (non-attention layers)
1185- conv_prefix = f"{ vllm_text_model_prefix } .layers.{ kk } .conv"
1186- _extract_short_conv_layer (layer .short_conv , conv_prefix , state_dict , quant_state_dict , get_state_dict )
1187- pass
1188+ use_fused_qkv = _is_fused_module ("qkv_proj" )
1189+ if use_fused_qkv :
1190+ # For some model types like phi3 vllm will expect fused qkv (e.g. Phi3, Phi3.5-mini-instruct, Phi4-mini-instruct)
1191+ # so we should not split them here otherwise there will be a size mismatch when activating the adapter
1192+ # see https://github.com/vllm-project/vllm/blob/9b693d023cf595e60b5346fdeeb41cf2a6eda838/vllm/model_executor/models/phi3.py
1193+ get_state_dict (f"{ prefix } .qkv_proj" , 0 , state_dict , qkv_proj , slice_weights = False )
1194+ else :
1195+ get_state_dict (f"{ prefix } .q_proj" , 0 , state_dict , qkv_proj )
1196+ get_state_dict (f"{ prefix } .k_proj" , 1 , state_dict , qkv_proj )
1197+ get_state_dict (f"{ prefix } .v_proj" , 2 , state_dict , qkv_proj )
1198+ elif hasattr (layer , "cross_attn" ):
1199+ prefix = f"{ vllm_text_model_prefix } .layers.{ kk } .cross_attn"
1200+ qkv_proj = layer .cross_attn .qkv_proj
1201+ o_proj = layer .cross_attn .o_proj
1202+ name = re .sub (r"\.(\d+)\." , r"[\1]." , prefix .replace ('model.language_model' ,'language_model.model' , 1 ) + ".qkv_proj" )
1203+ cross_attn_layer = eval (f'vllm_internals.{ name } ' )
1204+ q_proj = cross_attn_layer .proj ['q_proj_decoder' ]
1205+ kv_proj = cross_attn_layer .proj ['kv_proj_encoder' ]
1206+ get_state_dict (f"{ prefix } .q_proj" , 0 , state_dict , q_proj )
1207+ get_state_dict (f"{ prefix } .k_proj" , 1 , state_dict , kv_proj )
1208+ get_state_dict (f"{ prefix } .v_proj" , 2 , state_dict , kv_proj )
1209+
1210+ get_state_dict (f"{ prefix } .o_proj" , 0 , state_dict , o_proj )
11881211
1189- # MLP / Feed Forward extraction
1190- # LFM2 uses feed_forward with w1 (gate), w2 (down), w3 (up) — SwiGLU style
1191- # Standard models use mlp with gate_up_proj (merged), down_proj
1192- if hasattr (layer , "mlp" ):
11931212 proj = layer .mlp .gate_up_proj
11941213 use_fused_gate_up = _is_fused_module ("gate_up_proj" )
11951214 if use_fused_gate_up :
@@ -1200,33 +1219,23 @@ def _is_fused_module(name: str) -> bool:
12001219 else :
12011220 get_state_dict (f"{ vllm_text_model_prefix } .layers.{ kk } .mlp.gate_proj" , 0 , state_dict , proj )
12021221 get_state_dict (f"{ vllm_text_model_prefix } .layers.{ kk } .mlp.up_proj" , 1 , state_dict , proj )
1222+
12031223 proj = layer .mlp .down_proj
12041224 get_state_dict (f"{ vllm_text_model_prefix } .layers.{ kk } .mlp.down_proj" , 0 , state_dict , proj )
12051225
1206- elif hasattr (layer , "feed_forward" ):
1207- # LFM2 uses feed_forward with w1 (gate), w3 (up), w2 (down)
1208- # In vLLM, w1 and w3 are merged into a single MergedColumnParallelLinear
1209- ff_prefix = f"{ vllm_text_model_prefix } .layers.{ kk } .feed_forward"
1210- w1_proj = layer .feed_forward .w1 # MergedColumnParallelLinear (w1 at index 0, w3 at index 1)
1211- get_state_dict (f"{ ff_prefix } .w1" , 0 , state_dict , w1_proj )
1212- get_state_dict (f"{ ff_prefix } .w3" , 1 , state_dict , w1_proj )
1213-
1214- w2_proj = layer .feed_forward .w2 # RowParallelLinear
1215- get_state_dict (f"{ ff_prefix } .w2" , 0 , state_dict , w2_proj )
1216- pass
1217-
1218- # Use layernorms from the layer configuration
1219- layernorm_names = [name .format (kk = kk ) for name in layer_config ['layernorms' ]]
1226+ # Use layernorms from the layer configuration
1227+ layernorm_names = [name .format (kk = kk ) for name in layer_config ['layernorms' ]]
12201228
1221- for layernorm_name in layernorm_names :
1222- vllm_name = layernorm_name .replace (f".{ kk } ." , f"[{ kk } ]." ).replace (vllm_text_model_prefix , "vllm_text_model" )
1223- try :
1224- layernorm = eval (vllm_name ).state_dict ()["weight" ]
1225- layernorm_name = f"{ layernorm_name } .weight"
1226- state_dict [layernorm_name ] = layernorm
1227- quant_state_dict [layernorm_name ] = state_dict [layernorm_name ]
1228- except Exception as e :
1229- skipped_layernorms .append (layernorm_name .split ("." )[- 1 ])
1229+ for layernorm_name in layernorm_names :
1230+ vllm_name = layernorm_name .replace (f".{ kk } ." , f"[{ kk } ]." ).replace (vllm_text_model_prefix , "vllm_text_model" )
1231+ try :
1232+ layernorm = eval (vllm_name ).state_dict ()["weight" ]
1233+ layernorm_name = f"{ layernorm_name } .weight"
1234+ state_dict [layernorm_name ] = layernorm
1235+ quant_state_dict [layernorm_name ] = state_dict [layernorm_name ]
1236+ except Exception as e :
1237+ skipped_layernorms .append (layernorm_name .split ("." )[- 1 ])
1238+ pass
12301239 pass
12311240 pass
12321241
@@ -1237,18 +1246,14 @@ def _is_fused_module(name: str) -> bool:
12371246 if is_vision_model :
12381247 # Handle vision-specific layers using dedicated functions
12391248 extract_vision_layers (vllm_internals , state_dict , quant_state_dict , get_state_dict )
1240- # Norm
1241- # For Gemma3 and similar multimodal models, norm should be under model.norm
1242- # For standard models, also under model.norm
1243- # LFM2 uses embedding_norm instead of norm
1244- if hasattr (vllm_text_model , "norm" ):
1245- norm_prefix = f"{ vllm_text_model_prefix } .norm.weight"
1246- state_dict [norm_prefix ] = vllm_text_model .norm .weight .data
1247- quant_state_dict [norm_prefix ] = state_dict [norm_prefix ]
1248- elif hasattr (vllm_text_model , "embedding_norm" ):
1249+ if model_type == "lfm2" :
12491250 norm_prefix = f"{ vllm_text_model_prefix } .embedding_norm.weight"
12501251 state_dict [norm_prefix ] = vllm_text_model .embedding_norm .weight .data
12511252 quant_state_dict [norm_prefix ] = state_dict [norm_prefix ]
1253+ elif hasattr (vllm_text_model , "norm" ):
1254+ norm_prefix = f"{ vllm_text_model_prefix } .norm.weight"
1255+ state_dict [norm_prefix ] = vllm_text_model .norm .weight .data
1256+ quant_state_dict [norm_prefix ] = state_dict [norm_prefix ]
12521257
12531258 # LM Head - Use get_state_dict for consistency
12541259 if not getattr (text_config , "tie_word_embeddings" , False ):
0 commit comments