Skip to content

Commit 4ff0203

Browse files
authored
Minor fixes for Mixtral (#2015)
1 parent b5f882c commit 4ff0203

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

docs/source/models/supported_models.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ Alongside each architecture, we include some popular models that use it.
5050
* - :code:`MistralForCausalLM`
5151
- Mistral, Mistral-Instruct
5252
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
53+
* - :code:`MixtralForCausalLM`
54+
- Mixtral-8x7B, Mixtral-8x7B-Instruct
55+
- :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.
5356
* - :code:`MPTForCausalLM`
5457
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
5558
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.

vllm/model_executor/models/mixtral.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# See the License for the specific language governing permissions and
2222
# limitations under the License.
2323
"""Inference-only Mixtral model."""
24-
from typing import List, Optional, Tuple, Union
24+
from typing import List, Optional, Tuple
2525

2626
import numpy as np
2727

@@ -453,10 +453,6 @@ def __init__(
453453
assert linear_method is None
454454
self.padding_idx = config.pad_token_id
455455
self.vocab_size = config.vocab_size
456-
self.tok_embeddings: Union[nn.Embedding, None] = None
457-
self.layers: nn.ModuleList = None
458-
self.output: Union[nn.Linear, None] = None
459-
self.sampler: Union[Sampler, None] = None
460456
self.tok_embeddings = VocabParallelEmbedding(
461457
config.vocab_size,
462458
config.hidden_size,
@@ -492,14 +488,14 @@ def forward(
492488
input_metadata,
493489
cache_event,
494490
)
491+
hidden_states = self.norm(hidden_states)
495492
return hidden_states
496493

497494
def sample(
498495
self,
499496
hidden_states: Optional[torch.Tensor],
500497
sampling_metadata: SamplingMetadata,
501498
) -> SamplerOutput:
502-
hidden_states = self.norm(hidden_states)
503499
next_tokens = self.sampler(self.output.weight, hidden_states,
504500
sampling_metadata)
505501
return next_tokens

0 commit comments

Comments
 (0)