|
49 | 49 | from vllm.model_executor.sampling_metadata import SamplingMetadata
|
50 | 50 | from vllm.model_executor.weight_utils import (default_weight_loader,
|
51 | 51 | hf_model_weights_iterator)
|
52 |
| -from vllm.model_executor.utils import set_weight_attrs |
53 | 52 | from vllm.sequence import SamplerOutput
|
54 | 53 |
|
55 | 54 | KVCache = Tuple[torch.Tensor, torch.Tensor]
|
@@ -94,30 +93,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
94 | 93 | return current_hidden_states
|
95 | 94 |
|
96 | 95 |
|
97 |
| -class DummyModule(nn.Module): |
98 |
| - |
99 |
| - def __init__(self) -> None: |
100 |
| - super().__init__() |
101 |
| - |
102 |
| - self.w1 = nn.Linear(0, 0, bias=False) |
103 |
| - self.w2 = nn.Linear(0, 0, bias=False) |
104 |
| - self.w3 = nn.Linear(0, 0, bias=False) |
105 |
| - |
106 |
| - set_weight_attrs(self.w1.weight, |
107 |
| - {"weight_loader": self.dummy_weight_loader}) |
108 |
| - set_weight_attrs(self.w2.weight, |
109 |
| - {"weight_loader": self.dummy_weight_loader}) |
110 |
| - set_weight_attrs(self.w3.weight, |
111 |
| - {"weight_loader": self.dummy_weight_loader}) |
112 |
| - |
113 |
| - def forward(self, *args, **kwargs) -> None: |
114 |
| - raise NotImplementedError() |
115 |
| - |
116 |
| - def dummy_weight_loader(self, *args, **kwargs) -> None: # pylint: disable=unused-argument |
117 |
| - # Noop |
118 |
| - return |
119 |
| - |
120 |
| - |
121 | 96 | class MixtralMoE(nn.Module):
|
122 | 97 |
|
123 | 98 | def __init__(
|
@@ -147,7 +122,7 @@ def __init__(
|
147 | 122 | config.hidden_size,
|
148 | 123 | config.intermediate_size,
|
149 | 124 | linear_method=linear_method)
|
150 |
| - if idx in self.expert_indicies else DummyModule() |
| 125 | + if idx in self.expert_indicies else None |
151 | 126 | for idx in range(self.num_total_experts)
|
152 | 127 | ])
|
153 | 128 | self.gate = ReplicatedLinear(config.hidden_size,
|
@@ -427,6 +402,10 @@ def load_weights(self,
|
427 | 402 | # Skip loading extra bias for GPTQ models.
|
428 | 403 | if name.endswith(".bias") and name not in params_dict:
|
429 | 404 | continue
|
| 405 | + # Skip experts that are not assigned to this worker. |
| 406 | + if ("block_sparse_moe.experts." in name |
| 407 | + and name not in params_dict): |
| 408 | + continue |
430 | 409 | param = params_dict[name]
|
431 | 410 | weight_loader = getattr(param, "weight_loader",
|
432 | 411 | default_weight_loader)
|
|
0 commit comments