Skip to content

Commit 6766c4e

Browse files
committed
refactor(lfm2): extract LFM2 layer handling into dedicated function
- Move all LFM2-specific logic into _extract_lfm2_layer() function - Use outer if/else dispatch by model_type instead of per-layer checks - Standard model loop is now completely free of LFM2 conditionals - Also fixes o_proj/mlp being outside self_attn/cross_attn guards in main
1 parent 9d09a2c commit 6766c4e

File tree

1 file changed

+103
-98
lines changed

1 file changed

+103
-98
lines changed

unsloth_zoo/vllm_utils.py

Lines changed: 103 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -871,44 +871,31 @@ def get_vllm_state_dict(
871871

872872
def _extract_short_conv_layer(short_conv, conv_prefix, state_dict, quant_state_dict, get_state_dict):
873873
"""
874-
Extracts LFM2 hybrid short convolution layers (non-attention layers).
875-
876-
vLLM ShortConv: in_proj (MergedColumnParallelLinear, splits into B, C, x),
877-
out_proj (RowParallelLinear),
878-
conv (ColumnParallelLinear, weight unsqueezed for conv1d)
879-
880-
HF Lfm2ShortConv: in_proj (nn.Linear, hidden -> 3*hidden),
881-
out_proj (nn.Linear, hidden -> hidden),
882-
conv (nn.Conv1d, depthwise with groups=hidden_size)
874+
Extracts LFM2 short convolution layer weights from vLLM to HF format.
875+
in_proj is extracted directly from base layer since get_state_dict
876+
incorrectly handles MergedColumnParallelLinear's 3-way split.
883877
"""
884-
# in_proj: vLLM uses MergedColumnParallelLinear with 3 output parts (B, C, x).
885-
# HF stores as a single nn.Linear(hidden, 3*hidden). We must extract the full
886-
# merged weight directly instead of going through get_state_dict which may
887-
# incorrectly slice or return only one shard.
878+
# in_proj: extract full merged weight directly
888879
in_proj = getattr(short_conv.in_proj, "base_layer", short_conv.in_proj)
889880
in_proj_weight = in_proj.weight
890881
in_proj_weight.requires_grad_(False)
891882
state_dict[f"{conv_prefix}.in_proj.weight"] = in_proj_weight
892883
quant_state_dict[f"{conv_prefix}.in_proj.weight"] = in_proj_weight
893-
# Handle in_proj bias if present
894884
in_proj_bias = getattr(in_proj, "bias", None)
895885
if in_proj_bias is not None:
896886
in_proj_bias.requires_grad_(False)
897887
state_dict[f"{conv_prefix}.in_proj.bias"] = in_proj_bias
898888
quant_state_dict[f"{conv_prefix}.in_proj.bias"] = in_proj_bias
899889

900-
# out_proj: direct mapping
890+
# out_proj
901891
get_state_dict(f"{conv_prefix}.out_proj", 0, state_dict, short_conv.out_proj)
902892

903-
# conv: vLLM stores as ColumnParallelLinear with weight shape (out, 1, kernel)
904-
# HF expects nn.Conv1d weight with same shape (hidden, 1, kernel) for depthwise conv
893+
# conv (nn.Conv1d weight)
905894
conv_module = short_conv.conv
906895
conv_weight = getattr(conv_module, "base_layer", conv_module).weight
907896
conv_weight.requires_grad_(False)
908897
state_dict[f"{conv_prefix}.conv.weight"] = conv_weight
909898
quant_state_dict[f"{conv_prefix}.conv.weight"] = conv_weight
910-
911-
# Handle conv bias if present
912899
conv_bias = getattr(conv_module, "bias", None)
913900
if conv_bias is None and hasattr(conv_module, "base_layer"):
914901
conv_bias = getattr(conv_module.base_layer, "bias", None)
@@ -919,6 +906,50 @@ def _extract_short_conv_layer(short_conv, conv_prefix, state_dict, quant_state_d
919906
pass
920907

921908

909+
def _extract_lfm2_layer(layer, kk, prefix, state_dict, quant_state_dict, get_state_dict):
910+
"""Extracts all components of a single LFM2 hybrid layer."""
911+
layer_prefix = f"{prefix}.layers.{kk}"
912+
913+
# Attention or short_conv
914+
if hasattr(layer, "self_attn"):
915+
attn_prefix = f"{layer_prefix}.self_attn"
916+
qkv_proj = layer.self_attn.qkv_proj
917+
get_state_dict(f"{attn_prefix}.q_proj", 0, state_dict, qkv_proj)
918+
get_state_dict(f"{attn_prefix}.k_proj", 1, state_dict, qkv_proj)
919+
get_state_dict(f"{attn_prefix}.v_proj", 2, state_dict, qkv_proj)
920+
get_state_dict(f"{attn_prefix}.out_proj", 0, state_dict, layer.self_attn.out_proj)
921+
elif hasattr(layer, "short_conv"):
922+
_extract_short_conv_layer(layer.short_conv, f"{layer_prefix}.conv",
923+
state_dict, quant_state_dict, get_state_dict)
924+
925+
# Feed-forward (w1/w3 merged in vLLM, w2 separate)
926+
if hasattr(layer, "feed_forward"):
927+
ff_prefix = f"{layer_prefix}.feed_forward"
928+
w1_proj = layer.feed_forward.w1
929+
get_state_dict(f"{ff_prefix}.w1", 0, state_dict, w1_proj)
930+
get_state_dict(f"{ff_prefix}.w3", 1, state_dict, w1_proj)
931+
get_state_dict(f"{ff_prefix}.w2", 0, state_dict, layer.feed_forward.w2)
932+
933+
# Layer norms
934+
lfm2_norms = [
935+
("operator_norm", f"{layer_prefix}.operator_norm"),
936+
("ffn_norm", f"{layer_prefix}.ffn_norm"),
937+
("self_attn.q_layernorm", f"{layer_prefix}.self_attn.q_layernorm"),
938+
("self_attn.k_layernorm", f"{layer_prefix}.self_attn.k_layernorm"),
939+
]
940+
for attr_path, norm_key in lfm2_norms:
941+
obj = layer
942+
for part in attr_path.split("."):
943+
obj = getattr(obj, part, None)
944+
if obj is None:
945+
break
946+
if obj is not None and hasattr(obj, "weight"):
947+
w = obj.weight.data
948+
state_dict[f"{norm_key}.weight"] = w
949+
quant_state_dict[f"{norm_key}.weight"] = w
950+
pass
951+
952+
922953
def _get_vllm_state_dict(llm, return_state_dict = False, config = None, is_vision_model = False):
923954
# All Unsloth Zoo code licensed under LGPLv3
924955
# Unmerges vLLM modules and returns HF equivalent state_dict
@@ -1139,57 +1170,45 @@ def _is_fused_module(name: str) -> bool:
11391170
packed = packed_modules_mapping.get(name)
11401171
return isinstance(packed, (list, tuple)) and len(packed) == 1 and packed[0] == name
11411172

1142-
# All layers
11431173
skipped_layernorms = []
1144-
for kk in range(len(vllm_text_model.layers)):
1145-
layer = vllm_text_model.layers[kk]
1146-
if hasattr(layer, "self_attn"):
1147-
prefix = f"{vllm_text_model_prefix}.layers.{kk}.self_attn"
1148-
qkv_proj = layer.self_attn.qkv_proj
1149-
1150-
use_fused_qkv = _is_fused_module("qkv_proj")
1151-
if use_fused_qkv:
1152-
# For some model types like phi3 vllm will expect fused qkv (e.g. Phi3, Phi3.5-mini-instruct, Phi4-mini-instruct)
1153-
# so we should not split them here otherwise there will be a size mismatch when activating the adapter
1154-
# see https://github.com/vllm-project/vllm/blob/9b693d023cf595e60b5346fdeeb41cf2a6eda838/vllm/model_executor/models/phi3.py
1155-
get_state_dict(f"{prefix}.qkv_proj", 0, state_dict, qkv_proj, slice_weights=False)
1156-
else:
1157-
get_state_dict(f"{prefix}.q_proj", 0, state_dict, qkv_proj)
1158-
get_state_dict(f"{prefix}.k_proj", 1, state_dict, qkv_proj)
1159-
get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj)
1160-
1161-
# Extract o_proj or out_proj depending on model architecture
1162-
# LFM2 uses out_proj, most other models use o_proj
1163-
if hasattr(layer.self_attn, "o_proj"):
1174+
if model_type == "lfm2":
1175+
for kk in range(len(vllm_text_model.layers)):
1176+
layer = vllm_text_model.layers[kk]
1177+
_extract_lfm2_layer(layer, kk, vllm_text_model_prefix,
1178+
state_dict, quant_state_dict, get_state_dict)
1179+
pass
1180+
else:
1181+
for kk in range(len(vllm_text_model.layers)):
1182+
layer = vllm_text_model.layers[kk]
1183+
if hasattr(layer, "self_attn"):
1184+
prefix = f"{vllm_text_model_prefix}.layers.{kk}.self_attn"
1185+
qkv_proj = layer.self_attn.qkv_proj
11641186
o_proj = layer.self_attn.o_proj
1165-
get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj)
1166-
elif hasattr(layer.self_attn, "out_proj"):
1167-
out_proj = layer.self_attn.out_proj
1168-
get_state_dict(f"{prefix}.out_proj", 0, state_dict, out_proj)
1169-
1170-
elif hasattr(layer, "cross_attn"):
1171-
prefix = f"{vllm_text_model_prefix}.layers.{kk}.cross_attn"
1172-
qkv_proj = layer.cross_attn.qkv_proj
1173-
o_proj = layer.cross_attn.o_proj
1174-
name = re.sub(r"\.(\d+)\.", r"[\1].", prefix.replace('model.language_model','language_model.model', 1) + ".qkv_proj")
1175-
cross_attn_layer = eval(f'vllm_internals.{name}')
1176-
q_proj = cross_attn_layer.proj['q_proj_decoder']
1177-
kv_proj = cross_attn_layer.proj['kv_proj_encoder']
1178-
get_state_dict(f"{prefix}.q_proj", 0, state_dict, q_proj)
1179-
get_state_dict(f"{prefix}.k_proj", 1, state_dict, kv_proj)
1180-
get_state_dict(f"{prefix}.v_proj", 2, state_dict, kv_proj)
1181-
get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj)
11821187

1183-
elif hasattr(layer, "short_conv"):
1184-
# LFM2 hybrid short convolution layers (non-attention layers)
1185-
conv_prefix = f"{vllm_text_model_prefix}.layers.{kk}.conv"
1186-
_extract_short_conv_layer(layer.short_conv, conv_prefix, state_dict, quant_state_dict, get_state_dict)
1187-
pass
1188+
use_fused_qkv = _is_fused_module("qkv_proj")
1189+
if use_fused_qkv:
1190+
# For some model types like phi3 vllm will expect fused qkv (e.g. Phi3, Phi3.5-mini-instruct, Phi4-mini-instruct)
1191+
# so we should not split them here otherwise there will be a size mismatch when activating the adapter
1192+
# see https://github.com/vllm-project/vllm/blob/9b693d023cf595e60b5346fdeeb41cf2a6eda838/vllm/model_executor/models/phi3.py
1193+
get_state_dict(f"{prefix}.qkv_proj", 0, state_dict, qkv_proj, slice_weights=False)
1194+
else:
1195+
get_state_dict(f"{prefix}.q_proj", 0, state_dict, qkv_proj)
1196+
get_state_dict(f"{prefix}.k_proj", 1, state_dict, qkv_proj)
1197+
get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj)
1198+
elif hasattr(layer, "cross_attn"):
1199+
prefix = f"{vllm_text_model_prefix}.layers.{kk}.cross_attn"
1200+
qkv_proj = layer.cross_attn.qkv_proj
1201+
o_proj = layer.cross_attn.o_proj
1202+
name = re.sub(r"\.(\d+)\.", r"[\1].", prefix.replace('model.language_model','language_model.model', 1) + ".qkv_proj")
1203+
cross_attn_layer = eval(f'vllm_internals.{name}')
1204+
q_proj = cross_attn_layer.proj['q_proj_decoder']
1205+
kv_proj = cross_attn_layer.proj['kv_proj_encoder']
1206+
get_state_dict(f"{prefix}.q_proj", 0, state_dict, q_proj)
1207+
get_state_dict(f"{prefix}.k_proj", 1, state_dict, kv_proj)
1208+
get_state_dict(f"{prefix}.v_proj", 2, state_dict, kv_proj)
1209+
1210+
get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj)
11881211

1189-
# MLP / Feed Forward extraction
1190-
# LFM2 uses feed_forward with w1 (gate), w2 (down), w3 (up) — SwiGLU style
1191-
# Standard models use mlp with gate_up_proj (merged), down_proj
1192-
if hasattr(layer, "mlp"):
11931212
proj = layer.mlp.gate_up_proj
11941213
use_fused_gate_up = _is_fused_module("gate_up_proj")
11951214
if use_fused_gate_up:
@@ -1200,33 +1219,23 @@ def _is_fused_module(name: str) -> bool:
12001219
else:
12011220
get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.gate_proj", 0, state_dict, proj)
12021221
get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.up_proj", 1, state_dict, proj)
1222+
12031223
proj = layer.mlp.down_proj
12041224
get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.down_proj", 0, state_dict, proj)
12051225

1206-
elif hasattr(layer, "feed_forward"):
1207-
# LFM2 uses feed_forward with w1 (gate), w3 (up), w2 (down)
1208-
# In vLLM, w1 and w3 are merged into a single MergedColumnParallelLinear
1209-
ff_prefix = f"{vllm_text_model_prefix}.layers.{kk}.feed_forward"
1210-
w1_proj = layer.feed_forward.w1 # MergedColumnParallelLinear (w1 at index 0, w3 at index 1)
1211-
get_state_dict(f"{ff_prefix}.w1", 0, state_dict, w1_proj)
1212-
get_state_dict(f"{ff_prefix}.w3", 1, state_dict, w1_proj)
1213-
1214-
w2_proj = layer.feed_forward.w2 # RowParallelLinear
1215-
get_state_dict(f"{ff_prefix}.w2", 0, state_dict, w2_proj)
1216-
pass
1217-
1218-
# Use layernorms from the layer configuration
1219-
layernorm_names = [name.format(kk=kk) for name in layer_config['layernorms']]
1226+
# Use layernorms from the layer configuration
1227+
layernorm_names = [name.format(kk=kk) for name in layer_config['layernorms']]
12201228

1221-
for layernorm_name in layernorm_names:
1222-
vllm_name = layernorm_name.replace(f".{kk}.", f"[{kk}].").replace(vllm_text_model_prefix, "vllm_text_model")
1223-
try:
1224-
layernorm = eval(vllm_name).state_dict()["weight"]
1225-
layernorm_name = f"{layernorm_name}.weight"
1226-
state_dict[layernorm_name] = layernorm
1227-
quant_state_dict[layernorm_name] = state_dict[layernorm_name]
1228-
except Exception as e:
1229-
skipped_layernorms.append(layernorm_name.split(".")[-1])
1229+
for layernorm_name in layernorm_names:
1230+
vllm_name = layernorm_name.replace(f".{kk}.", f"[{kk}].").replace(vllm_text_model_prefix, "vllm_text_model")
1231+
try:
1232+
layernorm = eval(vllm_name).state_dict()["weight"]
1233+
layernorm_name = f"{layernorm_name}.weight"
1234+
state_dict[layernorm_name] = layernorm
1235+
quant_state_dict[layernorm_name] = state_dict[layernorm_name]
1236+
except Exception as e:
1237+
skipped_layernorms.append(layernorm_name.split(".")[-1])
1238+
pass
12301239
pass
12311240
pass
12321241

@@ -1237,18 +1246,14 @@ def _is_fused_module(name: str) -> bool:
12371246
if is_vision_model:
12381247
# Handle vision-specific layers using dedicated functions
12391248
extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_state_dict)
1240-
# Norm
1241-
# For Gemma3 and similar multimodal models, norm should be under model.norm
1242-
# For standard models, also under model.norm
1243-
# LFM2 uses embedding_norm instead of norm
1244-
if hasattr(vllm_text_model, "norm"):
1245-
norm_prefix = f"{vllm_text_model_prefix}.norm.weight"
1246-
state_dict[norm_prefix] = vllm_text_model.norm.weight.data
1247-
quant_state_dict[norm_prefix] = state_dict[norm_prefix]
1248-
elif hasattr(vllm_text_model, "embedding_norm"):
1249+
if model_type == "lfm2":
12491250
norm_prefix = f"{vllm_text_model_prefix}.embedding_norm.weight"
12501251
state_dict[norm_prefix] = vllm_text_model.embedding_norm.weight.data
12511252
quant_state_dict[norm_prefix] = state_dict[norm_prefix]
1253+
elif hasattr(vllm_text_model, "norm"):
1254+
norm_prefix = f"{vllm_text_model_prefix}.norm.weight"
1255+
state_dict[norm_prefix] = vllm_text_model.norm.weight.data
1256+
quant_state_dict[norm_prefix] = state_dict[norm_prefix]
12521257

12531258
# LM Head - Use get_state_dict for consistency
12541259
if not getattr(text_config, "tie_word_embeddings", False):

0 commit comments

Comments
 (0)