@@ -899,16 +899,13 @@ def shared_moe_coefficient_loader(param: torch.Tensor,
899
899
@support_torch_compile
900
900
class MiniMaxText01Model (nn .Module ):
901
901
902
- def __init__ (
903
- self ,
904
- config : MiniMaxConfig ,
905
- model_config : Optional [ModelConfig ] = None ,
906
- quant_config : Optional [QuantizationConfig ] = None ,
907
- cache_config : Optional [CacheConfig ] = None ,
908
- scheduler_config = None ,
909
- prefix : str = "" ,
910
- ) -> None :
902
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
911
903
super ().__init__ ()
904
+ config : MiniMaxConfig = vllm_config .model_config .hf_config
905
+ model_config = vllm_config .model_config
906
+ quant_config = vllm_config .quant_config
907
+ cache_config = vllm_config .cache_config
908
+ scheduler_config = vllm_config .scheduler_config
912
909
913
910
self .padding_idx = config .pad_token_id
914
911
self .vocab_size = config .vocab_size
@@ -1138,7 +1135,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
1138
1135
1139
1136
super ().__init__ ()
1140
1137
config = vllm_config .model_config .hf_config
1141
- quant_config = vllm_config .quant_config
1142
1138
lora_config = vllm_config .lora_config
1143
1139
self .config = config
1144
1140
self .lora_config = lora_config
@@ -1151,13 +1147,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
1151
1147
self .unpadded_vocab_size = self .config .vocab_size
1152
1148
if hasattr (vllm_config .model_config , "max_model_len" ):
1153
1149
self .config .max_model_len = vllm_config .model_config .max_model_len
1154
- self .model = MiniMaxText01Model (
1155
- self .config ,
1156
- model_config = vllm_config .model_config ,
1157
- cache_config = vllm_config .cache_config ,
1158
- quant_config = quant_config ,
1159
- scheduler_config = vllm_config .scheduler_config ,
1160
- prefix = maybe_prefix (prefix , "model" ))
1150
+ self .model = MiniMaxText01Model (vllm_config = vllm_config ,
1151
+ prefix = maybe_prefix (prefix , "model" ))
1161
1152
if get_pp_group ().is_last_rank :
1162
1153
self .lm_head = ParallelLMHead (
1163
1154
self .unpadded_vocab_size ,
0 commit comments