Skip to content

Commit 14a5d90

Browse files
authored
[Model] NemotronH Support (#22349)
Signed-off-by: Daniel Afrimi <[email protected]>
1 parent 951b038 commit 14a5d90

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

vllm/model_executor/models/nemotron_h.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,32 @@ class NemotronHMLP(nn.Module):
6464
def __init__(
6565
self,
6666
config: NemotronHConfig,
67+
layer_idx: int,
6768
quant_config: Optional[QuantizationConfig] = None,
6869
bias: bool = False,
6970
prefix: str = "",
7071
) -> None:
7172
super().__init__()
73+
74+
hybrid_override_pattern = config.hybrid_override_pattern
75+
mlp_index = hybrid_override_pattern[:layer_idx + 1].count("-") - 1
76+
if isinstance(config.intermediate_size, list):
77+
if len(config.intermediate_size) == 1:
78+
intermediate_size = config.intermediate_size[0]
79+
else:
80+
intermediate_size = config.intermediate_size[mlp_index]
81+
else:
82+
intermediate_size = config.intermediate_size
83+
7284
self.up_proj = ColumnParallelLinear(
7385
input_size=config.hidden_size,
74-
output_size=config.intermediate_size,
86+
output_size=intermediate_size,
7587
bias=bias,
7688
quant_config=quant_config,
7789
prefix=f"{prefix}.up_proj",
7890
)
7991
self.down_proj = RowParallelLinear(
80-
input_size=config.intermediate_size,
92+
input_size=intermediate_size,
8193
output_size=config.hidden_size,
8294
bias=bias,
8395
quant_config=quant_config,
@@ -110,6 +122,7 @@ def __init__(
110122
quant_config=quant_config,
111123
bias=config.mlp_bias,
112124
prefix=f"{prefix}.mixer",
125+
layer_idx=layer_idx,
113126
)
114127

115128
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -146,7 +159,7 @@ def __init__(
146159
hidden_size=config.hidden_size,
147160
ssm_state_size=config.ssm_state_size,
148161
conv_kernel_size=config.conv_kernel,
149-
intermediate_size=config.expand * config.hidden_size,
162+
intermediate_size=config.mamba_num_heads * config.mamba_head_dim,
150163
use_conv_bias=config.use_conv_bias,
151164
use_bias=config.use_bias,
152165
n_groups=config.n_groups,
@@ -205,7 +218,10 @@ def __init__(
205218
# the KV heads across multiple tensor parallel GPUs.
206219
assert tp_size % self.total_num_kv_heads == 0
207220
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
208-
self.head_dim = config.hidden_size // self.total_num_heads
221+
if hasattr(config, "head_dim") and config.head_dim is not None:
222+
self.head_dim = config.head_dim
223+
else:
224+
self.head_dim = config.hidden_size // self.total_num_heads
209225
self.q_size = self.num_heads * self.head_dim
210226
self.kv_size = self.num_kv_heads * self.head_dim
211227
self.scaling = self.head_dim**-0.5
@@ -481,7 +497,7 @@ def get_mamba_state_shape_from_config(
481497
"""
482498
parallel_config = vllm_config.parallel_config
483499
hf_config = vllm_config.model_config.hf_config
484-
intermediate_size = hf_config.expand * hf_config.hidden_size
500+
intermediate_size = hf_config.mamba_num_heads * hf_config.mamba_head_dim
485501

486502
return MambaStateShapeCalculator.mamba2_state_shape(
487503
intermediate_size=intermediate_size,

vllm/transformers_utils/configs/nemotron_h.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def __init__(
151151
num_hidden_layers=52,
152152
hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
153153
num_attention_heads=32,
154-
attention_head_dim=128,
154+
head_dim=128,
155155
num_key_value_heads=8, # nemo: num_query_groups
156156
mlp_hidden_act="relu2",
157157
attention_bias=False,
@@ -194,7 +194,7 @@ def __init__(
194194
self.num_hidden_layers = num_hidden_layers
195195
self.hybrid_override_pattern = hybrid_override_pattern
196196
self.num_attention_heads = num_attention_heads
197-
self.attention_head_dim = attention_head_dim
197+
self.head_dim = head_dim
198198
self.sliding_window = sliding_window
199199
self.max_position_embeddings = max_position_embeddings
200200
self.attention_dropout = attention_dropout

0 commit comments

Comments
 (0)