Skip to content

Commit be44450

Browse files
authored
[Fix][Spec Decode] Fix llama4 draft loading with different quantization (#27136)
Signed-off-by: linzebing <[email protected]>
1 parent f381cf2 commit be44450

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

vllm/model_executor/models/llama4_eagle.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,23 @@ def __init__(
6060
prefix=maybe_prefix(prefix, "embed_tokens"),
6161
)
6262

63-
self.layers = nn.ModuleList(
64-
[
65-
Llama4DecoderLayer(
66-
vllm_config=vllm_config,
67-
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
68-
config=self.config,
69-
)
70-
for i in range(self.config.num_hidden_layers)
71-
]
72-
)
63+
# Temporarily modify vllm_config.quant_config for draft model layers
64+
original_quant_config = vllm_config.quant_config
65+
vllm_config.quant_config = quant_config
66+
try:
67+
self.layers = nn.ModuleList(
68+
[
69+
Llama4DecoderLayer(
70+
vllm_config=vllm_config,
71+
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
72+
config=self.config,
73+
)
74+
for i in range(self.config.num_hidden_layers)
75+
]
76+
)
77+
finally:
78+
# Restore original quant_config
79+
vllm_config.quant_config = original_quant_config
7380
self.fc = torch.nn.Linear(
7481
self.config.hidden_size * 2, self.config.hidden_size, bias=False
7582
)

0 commit comments

Comments
 (0)