Skip to content

Commit ba4f826

Browse files
authored
[BugFix] Fix weight loading for Mixtral with TP (#2208)
1 parent de60a3f commit ba4f826

File tree

1 file changed

+5
-26
lines changed

1 file changed

+5
-26
lines changed

vllm/model_executor/models/mixtral.py

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
from vllm.model_executor.sampling_metadata import SamplingMetadata
5050
from vllm.model_executor.weight_utils import (default_weight_loader,
5151
hf_model_weights_iterator)
52-
from vllm.model_executor.utils import set_weight_attrs
5352
from vllm.sequence import SamplerOutput
5453

5554
KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -94,30 +93,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
9493
return current_hidden_states
9594

9695

97-
class DummyModule(nn.Module):
98-
99-
def __init__(self) -> None:
100-
super().__init__()
101-
102-
self.w1 = nn.Linear(0, 0, bias=False)
103-
self.w2 = nn.Linear(0, 0, bias=False)
104-
self.w3 = nn.Linear(0, 0, bias=False)
105-
106-
set_weight_attrs(self.w1.weight,
107-
{"weight_loader": self.dummy_weight_loader})
108-
set_weight_attrs(self.w2.weight,
109-
{"weight_loader": self.dummy_weight_loader})
110-
set_weight_attrs(self.w3.weight,
111-
{"weight_loader": self.dummy_weight_loader})
112-
113-
def forward(self, *args, **kwargs) -> None:
114-
raise NotImplementedError()
115-
116-
def dummy_weight_loader(self, *args, **kwargs) -> None: # pylint: disable=unused-argument
117-
# Noop
118-
return
119-
120-
12196
class MixtralMoE(nn.Module):
12297

12398
def __init__(
@@ -147,7 +122,7 @@ def __init__(
147122
config.hidden_size,
148123
config.intermediate_size,
149124
linear_method=linear_method)
150-
if idx in self.expert_indicies else DummyModule()
125+
if idx in self.expert_indicies else None
151126
for idx in range(self.num_total_experts)
152127
])
153128
self.gate = ReplicatedLinear(config.hidden_size,
@@ -427,6 +402,10 @@ def load_weights(self,
427402
# Skip loading extra bias for GPTQ models.
428403
if name.endswith(".bias") and name not in params_dict:
429404
continue
405+
# Skip experts that are not assigned to this worker.
406+
if ("block_sparse_moe.experts." in name
407+
and name not in params_dict):
408+
continue
430409
param = params_dict[name]
431410
weight_loader = getattr(param, "weight_loader",
432411
default_weight_loader)

0 commit comments

Comments
 (0)