5
5
6
6
import torch
7
7
import torch .nn as nn
8
- from transformers import MptConfig
9
8
10
9
from vllm .model_executor .input_metadata import InputMetadata
11
10
from vllm .model_executor .layers .activation import get_act_fn
22
21
from vllm .model_executor .weight_utils import (default_weight_loader ,
23
22
hf_model_weights_iterator )
24
23
from vllm .sequence import SamplerOutput
24
+ from vllm .transformers_utils .configs .mpt import MPTConfig
25
25
26
26
KVCache = Tuple [torch .Tensor , torch .Tensor ]
27
27
@@ -39,21 +39,21 @@ def _get_alibi_slopes(
39
39
return slopes
40
40
41
41
42
- class MptAttention (nn .Module ):
42
+ class MPTAttention (nn .Module ):
43
43
44
44
def __init__ (
45
45
self ,
46
- config : MptConfig ,
46
+ config : MPTConfig ,
47
47
linear_method : Optional [LinearMethodBase ] = None ,
48
48
):
49
49
super ().__init__ ()
50
50
self .d_model = config .d_model
51
51
self .total_num_heads = config .n_heads
52
- self .clip_qkv = config .attn_config . clip_qkv
53
- self .qk_ln = config .attn_config . qk_ln
54
- self .alibi_bias_max = config .attn_config . alibi_bias_max
55
- assert not config .attn_config . prefix_lm
56
- assert config .attn_config . alibi
52
+ self .clip_qkv = config .attn_config [ " clip_qkv" ]
53
+ self .qk_ln = config .attn_config [ " qk_ln" ]
54
+ self .alibi_bias_max = config .attn_config [ " alibi_bias_max" ]
55
+ assert not config .attn_config [ " prefix_lm" ]
56
+ assert config .attn_config [ " alibi" ]
57
57
58
58
# pylint: disable=invalid-name
59
59
self .Wqkv = QKVParallelLinear (
@@ -113,11 +113,11 @@ def forward(
113
113
return output
114
114
115
115
116
- class MptMLP (nn .Module ):
116
+ class MPTMLP (nn .Module ):
117
117
118
118
def __init__ (
119
119
self ,
120
- config : MptConfig ,
120
+ config : MPTConfig ,
121
121
linear_method : Optional [LinearMethodBase ] = None ,
122
122
):
123
123
super ().__init__ ()
@@ -145,19 +145,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
145
145
return x
146
146
147
147
148
- class MptBlock (nn .Module ):
148
+ class MPTBlock (nn .Module ):
149
149
150
150
def __init__ (
151
151
self ,
152
- config : MptConfig ,
152
+ config : MPTConfig ,
153
153
linear_method : Optional [LinearMethodBase ] = None ,
154
154
):
155
155
super ().__init__ ()
156
156
hidden_size = config .d_model
157
157
self .norm_1 = nn .LayerNorm (hidden_size )
158
- self .attn = MptAttention (config , linear_method )
158
+ self .attn = MPTAttention (config , linear_method )
159
159
self .norm_2 = nn .LayerNorm (hidden_size )
160
- self .ffn = MptMLP (config , linear_method )
160
+ self .ffn = MPTMLP (config , linear_method )
161
161
162
162
def forward (
163
163
self ,
@@ -182,11 +182,11 @@ def forward(
182
182
return hidden_states
183
183
184
184
185
- class MptModel (nn .Module ):
185
+ class MPTModel (nn .Module ):
186
186
187
187
def __init__ (
188
188
self ,
189
- config : MptConfig ,
189
+ config : MPTConfig ,
190
190
linear_method : Optional [LinearMethodBase ] = None ,
191
191
):
192
192
super ().__init__ ()
@@ -198,7 +198,7 @@ def __init__(
198
198
config .d_model ,
199
199
)
200
200
self .blocks = nn .ModuleList (
201
- [MptBlock (config , linear_method ) for _ in range (config .n_layers )])
201
+ [MPTBlock (config , linear_method ) for _ in range (config .n_layers )])
202
202
self .norm_f = nn .LayerNorm (config .d_model )
203
203
if config .no_bias :
204
204
for module in self .modules ():
@@ -233,19 +233,19 @@ def forward(
233
233
return hidden_states
234
234
235
235
236
- class MptForCausalLM (nn .Module ):
236
+ class MPTForCausalLM (nn .Module ):
237
237
238
238
def __init__ (
239
239
self ,
240
- config : MptConfig ,
240
+ config : MPTConfig ,
241
241
linear_method : Optional [LinearMethodBase ] = None ,
242
242
):
243
243
super ().__init__ ()
244
244
self .config = config
245
245
assert config .tie_word_embeddings
246
246
self .linear_method = linear_method
247
247
248
- self .transformer = MptModel (config , linear_method )
248
+ self .transformer = MPTModel (config , linear_method )
249
249
self .lm_head_weight = self .transformer .wte .weight
250
250
self .sampler = Sampler (config .vocab_size )
251
251
0 commit comments