@@ -64,20 +64,32 @@ class NemotronHMLP(nn.Module):
64
64
def __init__ (
65
65
self ,
66
66
config : NemotronHConfig ,
67
+ layer_idx : int ,
67
68
quant_config : Optional [QuantizationConfig ] = None ,
68
69
bias : bool = False ,
69
70
prefix : str = "" ,
70
71
) -> None :
71
72
super ().__init__ ()
73
+
74
+ hybrid_override_pattern = config .hybrid_override_pattern
75
+ mlp_index = hybrid_override_pattern [:layer_idx + 1 ].count ("-" ) - 1
76
+ if isinstance (config .intermediate_size , list ):
77
+ if len (config .intermediate_size ) == 1 :
78
+ intermediate_size = config .intermediate_size [0 ]
79
+ else :
80
+ intermediate_size = config .intermediate_size [mlp_index ]
81
+ else :
82
+ intermediate_size = config .intermediate_size
83
+
72
84
self .up_proj = ColumnParallelLinear (
73
85
input_size = config .hidden_size ,
74
- output_size = config . intermediate_size ,
86
+ output_size = intermediate_size ,
75
87
bias = bias ,
76
88
quant_config = quant_config ,
77
89
prefix = f"{ prefix } .up_proj" ,
78
90
)
79
91
self .down_proj = RowParallelLinear (
80
- input_size = config . intermediate_size ,
92
+ input_size = intermediate_size ,
81
93
output_size = config .hidden_size ,
82
94
bias = bias ,
83
95
quant_config = quant_config ,
@@ -110,6 +122,7 @@ def __init__(
110
122
quant_config = quant_config ,
111
123
bias = config .mlp_bias ,
112
124
prefix = f"{ prefix } .mixer" ,
125
+ layer_idx = layer_idx ,
113
126
)
114
127
115
128
self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
@@ -146,7 +159,7 @@ def __init__(
146
159
hidden_size = config .hidden_size ,
147
160
ssm_state_size = config .ssm_state_size ,
148
161
conv_kernel_size = config .conv_kernel ,
149
- intermediate_size = config .expand * config .hidden_size ,
162
+ intermediate_size = config .mamba_num_heads * config .mamba_head_dim ,
150
163
use_conv_bias = config .use_conv_bias ,
151
164
use_bias = config .use_bias ,
152
165
n_groups = config .n_groups ,
@@ -205,7 +218,10 @@ def __init__(
205
218
# the KV heads across multiple tensor parallel GPUs.
206
219
assert tp_size % self .total_num_kv_heads == 0
207
220
self .num_kv_heads = max (1 , self .total_num_kv_heads // tp_size )
208
- self .head_dim = config .hidden_size // self .total_num_heads
221
+ if hasattr (config , "head_dim" ) and config .head_dim is not None :
222
+ self .head_dim = config .head_dim
223
+ else :
224
+ self .head_dim = config .hidden_size // self .total_num_heads
209
225
self .q_size = self .num_heads * self .head_dim
210
226
self .kv_size = self .num_kv_heads * self .head_dim
211
227
self .scaling = self .head_dim ** - 0.5
@@ -481,7 +497,7 @@ def get_mamba_state_shape_from_config(
481
497
"""
482
498
parallel_config = vllm_config .parallel_config
483
499
hf_config = vllm_config .model_config .hf_config
484
- intermediate_size = hf_config .expand * hf_config .hidden_size
500
+ intermediate_size = hf_config .mamba_num_heads * hf_config .mamba_head_dim
485
501
486
502
return MambaStateShapeCalculator .mamba2_state_shape (
487
503
intermediate_size = intermediate_size ,
0 commit comments