Skip to content

Commit fe9a2a6

Browse files
committed
fix: handle LFM2/Mamba hybrid layers in _get_vllm_state_dict for fast_inference
Fix UnboundLocalError when using fast_inference=True with LFM2/Mamba hybrid models (e.g. LiquidAI/LFM2.5-1.2B-Thinking). The crash occurred because _get_vllm_state_dict only handled self_attn and cross_attn layer types, leaving the prefix variable unset for short conv layers. Changes: - Add short_conv branch to extract conv in_proj, out_proj, and conv weights - Move o_proj extraction inside attention branches (LFM2 uses out_proj) - Add feed_forward.w1/w2/w3 MLP handling alongside standard mlp path - Handle embedding_norm (LFM2) in addition to norm for final model norm - Add LFM2 layer templates to get_model_layer_config - Handle embedding_norm in set_additional_modules for HF reconstruction Fixes #4073
1 parent b6dbba1 commit fe9a2a6

File tree

2 files changed

+112
-22
lines changed

2 files changed

+112
-22
lines changed

unsloth_zoo/empty_model.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -442,10 +442,17 @@ def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False):
442442
set_embedding(new_model.model.visual.pos_embed, 'model.visual.pos_embed.weight', None, requires_grad=False)
443443

444444
# Norm
445+
# LFM2 uses embedding_norm instead of norm
445446
norm_key = f"{language_model_prefix}.norm.weight"
446-
norm = quant_state_dict[norm_key]
447-
norm = torch.nn.Parameter(norm, requires_grad = False)
448-
language_model.norm.weight = norm
447+
embedding_norm_key = f"{language_model_prefix}.embedding_norm.weight"
448+
if norm_key in quant_state_dict:
449+
norm = quant_state_dict[norm_key]
450+
norm = torch.nn.Parameter(norm, requires_grad = False)
451+
language_model.norm.weight = norm
452+
elif embedding_norm_key in quant_state_dict:
453+
norm = quant_state_dict[embedding_norm_key]
454+
norm = torch.nn.Parameter(norm, requires_grad = False)
455+
language_model.embedding_norm.weight = norm
449456

450457
# LM Head. Do note that for some models, like Mistral3ForConditionalGeneration,
451458
# there can be mismatch in the value of tie_word_embeddings between config and text_config
@@ -539,6 +546,17 @@ def get_model_layer_config(return_non_layered=True):
539546
"model.layers.{kk}.mlp.up_proj",
540547
"model.layers.{kk}.mlp.gate_up_proj", # for extracting from vLLM (phi3 architecture)
541548
"model.layers.{kk}.mlp.down_proj",
549+
550+
# LFM2 hybrid model layers (attention + short convolution)
551+
"model.layers.{kk}.self_attn.out_proj", # LFM2 attention uses out_proj instead of o_proj
552+
"model.layers.{kk}.self_attn.q_layernorm",
553+
"model.layers.{kk}.self_attn.k_layernorm",
554+
"model.layers.{kk}.conv.in_proj",
555+
"model.layers.{kk}.conv.out_proj",
556+
"model.layers.{kk}.conv.conv",
557+
"model.layers.{kk}.feed_forward.w1",
558+
"model.layers.{kk}.feed_forward.w3",
559+
"model.layers.{kk}.feed_forward.w2",
542560
},
543561
'layernorms': {
544562
"model.language_model.layers.{kk}.input_layernorm",
@@ -555,6 +573,13 @@ def get_model_layer_config(return_non_layered=True):
555573
"model.layers.{kk}.post_feedforward_layernorm",
556574
"model.layers.{kk}.self_attn.q_norm",
557575
"model.layers.{kk}.self_attn.k_norm",
576+
577+
# LFM2 hybrid model norms
578+
"model.layers.{kk}.operator_norm", # pre-block norm (replaces input_layernorm)
579+
"model.layers.{kk}.ffn_norm", # post-block norm (replaces post_attention_layernorm)
580+
"model.layers.{kk}.self_attn.q_layernorm", # QK norm inside attention
581+
"model.layers.{kk}.self_attn.k_layernorm", # QK norm inside attention
582+
558583
"model.visual.blocks.{kk}.norm1",
559584
"model.visual.blocks.{kk}.norm2",
560585
"model.vision_tower.vision_model.encoder.layers.{kk}.post_layernorm",

unsloth_zoo/vllm_utils.py

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,7 +1095,6 @@ def _is_fused_module(name: str) -> bool:
10951095
if hasattr(layer, "self_attn"):
10961096
prefix = f"{vllm_text_model_prefix}.layers.{kk}.self_attn"
10971097
qkv_proj = layer.self_attn.qkv_proj
1098-
o_proj = layer.self_attn.o_proj
10991098

11001099
use_fused_qkv = _is_fused_module("qkv_proj")
11011100
if use_fused_qkv:
@@ -1107,6 +1106,16 @@ def _is_fused_module(name: str) -> bool:
11071106
get_state_dict(f"{prefix}.q_proj", 0, state_dict, qkv_proj)
11081107
get_state_dict(f"{prefix}.k_proj", 1, state_dict, qkv_proj)
11091108
get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj)
1109+
1110+
# Extract o_proj or out_proj depending on model architecture
1111+
# LFM2 uses out_proj, most other models use o_proj
1112+
if hasattr(layer.self_attn, "o_proj"):
1113+
o_proj = layer.self_attn.o_proj
1114+
get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj)
1115+
elif hasattr(layer.self_attn, "out_proj"):
1116+
out_proj = layer.self_attn.out_proj
1117+
get_state_dict(f"{prefix}.out_proj", 0, state_dict, out_proj)
1118+
11101119
elif hasattr(layer, "cross_attn"):
11111120
prefix = f"{vllm_text_model_prefix}.layers.{kk}.cross_attn"
11121121
qkv_proj = layer.cross_attn.qkv_proj
@@ -1118,22 +1127,72 @@ def _is_fused_module(name: str) -> bool:
11181127
get_state_dict(f"{prefix}.q_proj", 0, state_dict, q_proj)
11191128
get_state_dict(f"{prefix}.k_proj", 1, state_dict, kv_proj)
11201129
get_state_dict(f"{prefix}.v_proj", 2, state_dict, kv_proj)
1130+
get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj)
1131+
1132+
elif hasattr(layer, "short_conv"):
1133+
# LFM2 hybrid short convolution layers (non-attention layers)
1134+
# vLLM ShortConv: in_proj (MergedColumnParallelLinear, splits into B, C, x),
1135+
# out_proj (RowParallelLinear),
1136+
# conv (ColumnParallelLinear, weight unsqueezed for conv1d)
1137+
# HF Lfm2ShortConv: in_proj (nn.Linear, hidden -> 3*hidden),
1138+
# out_proj (nn.Linear, hidden -> hidden),
1139+
# conv (nn.Conv1d, depthwise with groups=hidden_size)
1140+
conv_prefix = f"{vllm_text_model_prefix}.layers.{kk}.conv"
1141+
short_conv = layer.short_conv
1142+
1143+
# in_proj: vLLM splits into 3 shards via MergedColumnParallelLinear,
1144+
# but HF stores as a single nn.Linear(hidden, 3*hidden)
1145+
get_state_dict(f"{conv_prefix}.in_proj", 0, state_dict, short_conv.in_proj, slice_weights=False)
1146+
1147+
# out_proj: direct mapping
1148+
get_state_dict(f"{conv_prefix}.out_proj", 0, state_dict, short_conv.out_proj)
1149+
1150+
# conv: vLLM stores as ColumnParallelLinear with weight shape (out, 1, kernel)
1151+
# HF expects nn.Conv1d weight with same shape (hidden, 1, kernel) for depthwise conv
1152+
conv_module = short_conv.conv
1153+
conv_weight = getattr(conv_module, "base_layer", conv_module).weight
1154+
conv_weight.requires_grad_(False)
1155+
state_dict[f"{conv_prefix}.conv.weight"] = conv_weight
1156+
quant_state_dict[f"{conv_prefix}.conv.weight"] = conv_weight
1157+
# Handle conv bias if present
1158+
conv_bias = getattr(conv_module, "bias", None)
1159+
if conv_bias is None:
1160+
conv_bias = getattr(getattr(conv_module, "base_layer", conv_module), "bias", None)
1161+
if conv_bias is not None:
1162+
conv_bias.requires_grad_(False)
1163+
state_dict[f"{conv_prefix}.conv.bias"] = conv_bias
1164+
quant_state_dict[f"{conv_prefix}.conv.bias"] = conv_bias
1165+
pass
11211166

1122-
get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj)
1123-
1124-
proj = layer.mlp.gate_up_proj
1125-
use_fused_gate_up = _is_fused_module("gate_up_proj")
1126-
if use_fused_gate_up:
1127-
# For some model types like phi3 vllm will expect fused gate_up_proj (e.g. Phi3, Phi3.5-mini-instruct, Phi4-mini-instruct)
1128-
# so we should not split them here otherwise there will be a size mismatch when activating the adapter
1129-
# see https://github.com/vllm-project/vllm/blob/9b693d023cf595e60b5346fdeeb41cf2a6eda838/vllm/model_executor/models/phi3.py
1130-
get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.gate_up_proj", 0, state_dict, proj, slice_weights=False)
1131-
else:
1132-
get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.gate_proj", 0, state_dict, proj)
1133-
get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.up_proj", 1, state_dict, proj)
1134-
1135-
proj = layer.mlp.down_proj
1136-
get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.down_proj", 0, state_dict, proj)
1167+
# MLP / Feed Forward extraction
1168+
# LFM2 uses feed_forward with w1 (gate), w2 (down), w3 (up) — SwiGLU style
1169+
# Standard models use mlp with gate_up_proj (merged), down_proj
1170+
if hasattr(layer, "mlp"):
1171+
proj = layer.mlp.gate_up_proj
1172+
use_fused_gate_up = _is_fused_module("gate_up_proj")
1173+
if use_fused_gate_up:
1174+
# For some model types like phi3 vllm will expect fused gate_up_proj (e.g. Phi3, Phi3.5-mini-instruct, Phi4-mini-instruct)
1175+
# so we should not split them here otherwise there will be a size mismatch when activating the adapter
1176+
# see https://github.com/vllm-project/vllm/blob/9b693d023cf595e60b5346fdeeb41cf2a6eda838/vllm/model_executor/models/phi3.py
1177+
get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.gate_up_proj", 0, state_dict, proj, slice_weights=False)
1178+
else:
1179+
get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.gate_proj", 0, state_dict, proj)
1180+
get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.up_proj", 1, state_dict, proj)
1181+
1182+
proj = layer.mlp.down_proj
1183+
get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.down_proj", 0, state_dict, proj)
1184+
1185+
elif hasattr(layer, "feed_forward"):
1186+
# LFM2 uses feed_forward with w1 (gate), w3 (up), w2 (down)
1187+
# In vLLM, w1 and w3 are merged into a single MergedColumnParallelLinear
1188+
ff_prefix = f"{vllm_text_model_prefix}.layers.{kk}.feed_forward"
1189+
w1_proj = layer.feed_forward.w1 # MergedColumnParallelLinear (w1 at index 0, w3 at index 1)
1190+
get_state_dict(f"{ff_prefix}.w1", 0, state_dict, w1_proj)
1191+
get_state_dict(f"{ff_prefix}.w3", 1, state_dict, w1_proj)
1192+
1193+
w2_proj = layer.feed_forward.w2 # RowParallelLinear
1194+
get_state_dict(f"{ff_prefix}.w2", 0, state_dict, w2_proj)
1195+
pass
11371196

11381197
# Use layernorms from the layer configuration
11391198
layernorm_names = [name.format(kk=kk) for name in layer_config['layernorms']]
@@ -1160,9 +1219,15 @@ def _is_fused_module(name: str) -> bool:
11601219
# Norm
11611220
# For Gemma3 and similar multimodal models, norm should be under model.norm
11621221
# For standard models, also under model.norm
1163-
norm_prefix = f"{vllm_text_model_prefix}.norm.weight"
1164-
state_dict[norm_prefix] = vllm_text_model.norm.weight.data
1165-
quant_state_dict[norm_prefix] = state_dict[norm_prefix]
1222+
# LFM2 uses embedding_norm instead of norm
1223+
if hasattr(vllm_text_model, "norm"):
1224+
norm_prefix = f"{vllm_text_model_prefix}.norm.weight"
1225+
state_dict[norm_prefix] = vllm_text_model.norm.weight.data
1226+
quant_state_dict[norm_prefix] = state_dict[norm_prefix]
1227+
elif hasattr(vllm_text_model, "embedding_norm"):
1228+
norm_prefix = f"{vllm_text_model_prefix}.embedding_norm.weight"
1229+
state_dict[norm_prefix] = vllm_text_model.embedding_norm.weight.data
1230+
quant_state_dict[norm_prefix] = state_dict[norm_prefix]
11661231

11671232
# LM Head - Use get_state_dict for consistency
11681233
if not getattr(text_config, "tie_word_embeddings", False):

0 commit comments

Comments
 (0)