Skip to content

Commit 60bad42

Browse files
jeejeeleeamd-xiaoyu12
authored andcommitted
[Model] Improve olmo and olmo2 (vllm-project#23228)
Signed-off-by: Jee Jee Li <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
1 parent 754e044 commit 60bad42

File tree

3 files changed

+36
-7
lines changed

3 files changed

+36
-7
lines changed

docs/models/supported_models.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,8 @@ th {
384384
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ |
385385
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ |
386386
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ |
387-
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ |
388-
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ |
387+
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
388+
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ |
389389
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ |
390390
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ |
391391
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ |

vllm/model_executor/models/olmo.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from vllm.model_executor.sampling_metadata import SamplingMetadata
4848
from vllm.sequence import IntermediateTensors
4949

50-
from .interfaces import SupportsPP
50+
from .interfaces import SupportsLoRA, SupportsPP
5151
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
5252
make_empty_intermediate_tensors_factory, make_layers,
5353
maybe_prefix)
@@ -91,6 +91,7 @@ def __init__(
9191
self.total_num_heads,
9292
bias=config.attention_bias,
9393
quant_config=quant_config,
94+
prefix=f"{prefix}.qkv_proj",
9495
)
9596

9697
# Rotary embeddings.
@@ -114,6 +115,7 @@ def __init__(
114115
self.hidden_size,
115116
bias=config.attention_bias,
116117
quant_config=quant_config,
118+
prefix=f"{prefix}.o_proj",
117119
)
118120

119121
def forward(
@@ -142,6 +144,7 @@ def __init__(
142144
self,
143145
config: OlmoConfig,
144146
quant_config: Optional[QuantizationConfig] = None,
147+
prefix: str = "",
145148
):
146149
super().__init__()
147150
self.config = config
@@ -154,6 +157,7 @@ def __init__(
154157
[self.intermediate_size] * 2,
155158
bias=False,
156159
quant_config=quant_config,
160+
prefix=f"{prefix}.gate_up_proj",
157161
)
158162

159163
# Activation function.
@@ -165,6 +169,7 @@ def __init__(
165169
self.hidden_size,
166170
bias=False,
167171
quant_config=quant_config,
172+
prefix=f"{prefix}.down_proj",
168173
)
169174

170175
def forward(
@@ -197,7 +202,7 @@ def __init__(self,
197202
prefix=f"{prefix}.self_attn")
198203

199204
# MLP block.
200-
self.mlp = OlmoMLP(config, quant_config)
205+
self.mlp = OlmoMLP(config, quant_config, prefix=f"{prefix}.mlp")
201206

202207
# LayerNorm
203208
self.input_layernorm = nn.LayerNorm(config.hidden_size,
@@ -326,10 +331,21 @@ def load_weights(self, weights: Iterable[tuple[str,
326331
return loaded_params
327332

328333

329-
class OlmoForCausalLM(nn.Module, SupportsPP):
334+
class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
330335
"""
331336
Extremely barebones HF model wrapper.
332337
"""
338+
packed_modules_mapping = {
339+
"qkv_proj": [
340+
"q_proj",
341+
"k_proj",
342+
"v_proj",
343+
],
344+
"gate_up_proj": [
345+
"gate_proj",
346+
"up_proj",
347+
],
348+
}
333349

334350
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
335351
super().__init__()

vllm/model_executor/models/olmo2.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from transformers import Olmo2Config
3434

3535
from vllm.attention import Attention
36+
from vllm.compilation.decorators import support_torch_compile
3637
from vllm.config import VllmConfig
3738
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
3839
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
@@ -48,7 +49,7 @@
4849
from vllm.model_executor.layers.vocab_parallel_embedding import (
4950
ParallelLMHead, VocabParallelEmbedding)
5051
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
51-
from vllm.model_executor.models.interfaces import SupportsPP
52+
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
5253
from vllm.model_executor.models.utils import (
5354
AutoWeightsLoader, is_pp_missing_parameter,
5455
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
@@ -253,6 +254,7 @@ def forward(
253254
return hidden_states
254255

255256

257+
@support_torch_compile
256258
class Olmo2Model(nn.Module):
257259

258260
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -354,10 +356,21 @@ def load_weights(self, weights: Iterable[tuple[str,
354356
return loaded_params
355357

356358

357-
class Olmo2ForCausalLM(nn.Module, SupportsPP):
359+
class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
358360
"""
359361
Extremely barebones HF model wrapper.
360362
"""
363+
packed_modules_mapping = {
364+
"qkv_proj": [
365+
"q_proj",
366+
"k_proj",
367+
"v_proj",
368+
],
369+
"gate_up_proj": [
370+
"gate_proj",
371+
"up_proj",
372+
],
373+
}
361374

362375
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
363376
super().__init__()

0 commit comments

Comments
 (0)