|
21 | 21 | # See the License for the specific language governing permissions and
|
22 | 22 | # limitations under the License.
|
23 | 23 | """Inference-only Mixtral model."""
|
24 |
| -from typing import List, Optional, Tuple, Union |
| 24 | +from typing import List, Optional, Tuple |
25 | 25 |
|
26 | 26 | import numpy as np
|
27 | 27 |
|
@@ -453,10 +453,6 @@ def __init__(
|
453 | 453 | assert linear_method is None
|
454 | 454 | self.padding_idx = config.pad_token_id
|
455 | 455 | self.vocab_size = config.vocab_size
|
456 |
| - self.tok_embeddings: Union[nn.Embedding, None] = None |
457 |
| - self.layers: nn.ModuleList = None |
458 |
| - self.output: Union[nn.Linear, None] = None |
459 |
| - self.sampler: Union[Sampler, None] = None |
460 | 456 | self.tok_embeddings = VocabParallelEmbedding(
|
461 | 457 | config.vocab_size,
|
462 | 458 | config.hidden_size,
|
@@ -492,14 +488,14 @@ def forward(
|
492 | 488 | input_metadata,
|
493 | 489 | cache_event,
|
494 | 490 | )
|
| 491 | + hidden_states = self.norm(hidden_states) |
495 | 492 | return hidden_states
|
496 | 493 |
|
497 | 494 | def sample(
|
498 | 495 | self,
|
499 | 496 | hidden_states: Optional[torch.Tensor],
|
500 | 497 | sampling_metadata: SamplingMetadata,
|
501 | 498 | ) -> SamplerOutput:
|
502 |
| - hidden_states = self.norm(hidden_states) |
503 | 499 | next_tokens = self.sampler(self.output.weight, hidden_states,
|
504 | 500 | sampling_metadata)
|
505 | 501 | return next_tokens
|
|
0 commit comments