Skip to content

Commit 61244cf

Browse files
committed
[Bugfix] Fix Dense module loading for sentence-transformers embedding models v2
Signed-off-by: FFFfff1FFFfff <[email protected]>
1 parent 68b254d commit 61244cf

File tree

5 files changed

+255
-72
lines changed

5 files changed

+255
-72
lines changed

requirements/test.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,6 @@ setuptools==77.0.3
968968
# via
969969
# lightning-utilities
970970
# pytablewriter
971-
# torch
972971
# triton
973972
shapely==2.1.1
974973
# via

vllm/model_executor/layers/pooler.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def for_encode(
7272
pooler_config: PoolerConfig,
7373
*,
7474
default_pooling_type: PoolingType = PoolingType.ALL,
75+
projector: Optional[nn.Module] = None,
7576
):
7677
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
7778
task="encode",
@@ -82,21 +83,22 @@ def for_encode(
8283
if resolved_config.pooling_type == PoolingType.STEP:
8384
return StepPooler()
8485

85-
return SimplePooler.from_config(resolved_config)
86+
return SimplePooler.from_config(resolved_config, projector=projector)
8687

8788
@staticmethod
8889
def for_embed(
8990
pooler_config: PoolerConfig,
9091
*,
9192
default_pooling_type: PoolingType = PoolingType.LAST,
93+
projector: Optional[nn.Module] = None,
9294
):
9395
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
9496
task="embed",
9597
pooler_config=pooler_config,
9698
pooling_type=default_pooling_type,
9799
)
98100

99-
return SimplePooler.from_config(resolved_config)
101+
return SimplePooler.from_config(resolved_config, projector=projector)
100102

101103
@staticmethod
102104
def for_classify(
@@ -470,12 +472,32 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
470472

471473
class EmbeddingPoolerHead(PoolerHead):
472474

473-
def __init__(self) -> None:
475+
def __init__(self, projector: Optional[nn.Module] = None) -> None:
474476
super().__init__(activation=PoolerNormalize())
477+
self.projector = projector
478+
self._device_set = False
475479

476480
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
477481
pooling_metadata: PoolingMetadata):
478482

483+
if self.projector is not None:
484+
ref = pooled_data[0] if isinstance(pooled_data,
485+
list) else pooled_data
486+
487+
if not self._device_set:
488+
self.projector.to(device=ref.device, dtype=torch.float32)
489+
self._device_set = True
490+
491+
def _proj(x: torch.Tensor) -> torch.Tensor:
492+
y = self.projector(x.to(torch.float32))
493+
return y.to(x.dtype)
494+
495+
if isinstance(pooled_data, torch.Tensor):
496+
pooled_data = _proj(pooled_data)
497+
else:
498+
pooled_data = [_proj(t) for t in pooled_data]
499+
# else: keep as is
500+
479501
pooling_params = get_pooling_params(pooling_metadata)
480502

481503
# for matryoshka representation
@@ -546,10 +568,11 @@ class SimplePooler(Pooler):
546568
def from_config(
547569
cls,
548570
pooler_config: ResolvedPoolingConfig,
571+
projector: Optional[nn.Module] = None,
549572
) -> "SimplePooler":
550573
pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
551574
if pooler_config.task == "embed":
552-
head = EmbeddingPoolerHead()
575+
head = EmbeddingPoolerHead(projector=projector)
553576
elif pooler_config.task == "encode":
554577
head = RewardPoolerHead()
555578
else:

vllm/model_executor/models/adapters.py

Lines changed: 142 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import io
45
from collections.abc import Iterable
56
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
67

@@ -23,6 +24,25 @@
2324
"LMHeadModel",
2425
]
2526

27+
# Note: projector uses standard nn.Linear to avoid sharding edge-cases
28+
29+
30+
def st_activation(name: Optional[str]) -> nn.Module:
31+
m = (name or "").lower()
32+
if m in ("gelu", "gelu_new"):
33+
return nn.GELU()
34+
if m == "relu":
35+
return nn.ReLU()
36+
if m == "tanh":
37+
return nn.Tanh()
38+
if m == "sigmoid":
39+
return nn.Sigmoid()
40+
if m == "swish":
41+
return nn.SiLU()
42+
if m == "identity":
43+
return nn.Identity()
44+
return nn.Identity()
45+
2646

2747
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
2848
model_name = orig_model_name
@@ -99,38 +119,137 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
99119
return ModelForPooling # type: ignore
100120

101121

102-
def as_embedding_model(cls: _T) -> _T:
122+
def _load_st_projector(vllm_config: "VllmConfig") -> Optional[nn.Module]:
123+
"""Load Sentence Transformers projector from modules.json
124+
and Dense folders.
103125
"""
104-
Subclass an existing vLLM model to support embeddings.
126+
from vllm.transformers_utils.config import (file_or_path_exists,
127+
get_hf_file_bytes,
128+
get_hf_file_to_dict)
129+
130+
model_path = vllm_config.model_config.model
131+
revision = vllm_config.model_config.revision
132+
133+
# Check if modules.json exists and contains Dense modules
134+
if not file_or_path_exists(model_path, "modules.json", revision):
135+
return None
136+
137+
modules = get_hf_file_to_dict("modules.json", model_path, revision)
138+
if not isinstance(modules, list):
139+
return None
140+
141+
dense_entries = [
142+
m for m in modules
143+
if m.get("type") == "sentence_transformers.models.Dense"
144+
]
145+
if not dense_entries:
146+
return None
147+
148+
# Get dtype and quant config
149+
raw_dtype = getattr(vllm_config.model_config, "dtype", None)
150+
if isinstance(raw_dtype, str):
151+
desired_dtype = getattr(torch, raw_dtype, torch.float32)
152+
elif isinstance(raw_dtype, torch.dtype):
153+
desired_dtype = raw_dtype
154+
else:
155+
desired_dtype = torch.float32
156+
157+
def _load_config_json(path: str) -> Optional[dict]:
158+
"""Load config.json from a Dense folder."""
159+
try:
160+
return get_hf_file_to_dict(path, model_path, revision)
161+
except Exception:
162+
return None
163+
164+
def _load_dense_weights(
165+
folder: str,
166+
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
167+
"""Load weights from a Dense folder, trying safetensors first."""
168+
# Try safetensors first
169+
try:
170+
b = get_hf_file_bytes(f"{folder}/model.safetensors", model_path,
171+
revision)
172+
if b is not None:
173+
from safetensors.torch import load as st_load
174+
sd = st_load(io.BytesIO(b))
175+
w = (sd.get("linear.weight") or sd.get("dense.weight")
176+
or sd.get("weight"))
177+
bias = (sd.get("linear.bias") or sd.get("dense.bias")
178+
or sd.get("bias"))
179+
return w, bias
180+
except Exception:
181+
pass
182+
183+
# Fallback to pytorch_model.bin
184+
try:
185+
b = get_hf_file_bytes(f"{folder}/pytorch_model.bin", model_path,
186+
revision)
187+
if b is not None:
188+
sd = torch.load(io.BytesIO(b), map_location="cpu")
189+
w = (sd.get("linear.weight") or sd.get("dense.weight")
190+
or sd.get("weight"))
191+
bias = (sd.get("linear.bias") or sd.get("dense.bias")
192+
or sd.get("bias"))
193+
return w, bias
194+
except Exception:
195+
pass
196+
return None, None
197+
198+
# Build projector layers
199+
layers: list[nn.Module] = []
200+
for i, entry in enumerate(dense_entries):
201+
folder = entry.get("path")
202+
if not folder:
203+
continue
204+
205+
cfg = _load_config_json(f"{folder}/config.json")
206+
if not cfg:
207+
continue
208+
209+
in_features = cfg.get("in_features")
210+
out_features = cfg.get("out_features")
211+
if in_features is None or out_features is None:
212+
continue
213+
214+
use_bias = cfg.get("bias", True)
215+
activation = st_activation(cfg.get("activation_function"))
216+
217+
# Create a simple nn.Linear for projector to avoid sharding edge-cases
218+
linear = nn.Linear(in_features, out_features, bias=use_bias)
219+
linear = linear.to(dtype=desired_dtype)
220+
221+
# Load weights
222+
weight, bias = _load_dense_weights(folder)
223+
if weight is not None:
224+
with torch.no_grad():
225+
# weight is expected in [out_features, in_features]
226+
linear.weight.copy_(weight.to(dtype=linear.weight.dtype))
227+
if use_bias and bias is not None and linear.bias is not None:
228+
linear.bias.copy_(bias.to(dtype=linear.bias.dtype))
229+
230+
layers.append(linear)
231+
layers.append(activation)
232+
233+
if not layers:
234+
return None
235+
return nn.Sequential(*layers)
105236

106-
By default, the embeddings of the whole prompt are extracted from the
107-
normalized hidden state corresponding to the last token.
108237

109-
Note:
110-
We assume that no extra layers are added to the original model;
111-
please implement your own model if this is not the case.
112-
"""
113-
# Avoid modifying existing embedding models
114-
if is_pooling_model(cls):
115-
return cls
116-
117-
# Lazy import
118-
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
238+
def as_embedding_model(cls: _T) -> _T:
239+
"""Convert a model class to support embedding tasks."""
119240

120241
class ModelForEmbedding(_create_pooling_model_cls(cls)):
121242

122243
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
123-
pooler_config = vllm_config.model_config.pooler_config
124-
assert pooler_config is not None
244+
from vllm.model_executor.layers.pooler import Pooler
125245

126-
self.pooler = DispatchPooler(
127-
{
128-
"encode": Pooler.for_encode(pooler_config),
129-
"embed": Pooler.for_embed(pooler_config),
130-
}, )
246+
# Load ST projector if available
247+
projector = _load_st_projector(vllm_config)
131248

132-
ModelForEmbedding.__name__ = \
133-
_get_pooling_model_name(cls.__name__, "ForEmbedding")
249+
# Use existing pooler_config instead of creating new one
250+
pooler_config = vllm_config.model_config.pooler_config
251+
assert pooler_config is not None
252+
self.pooler = Pooler.for_embed(pooler_config, projector=projector)
134253

135254
return ModelForEmbedding # type: ignore
136255

@@ -229,7 +348,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
229348
# ForSequenceClassification model.
230349
return seq_cls_model_loader(self, weights)
231350

232-
233351
ModelForSequenceClassification.__name__ = \
234352
_get_pooling_model_name(cls.__name__, "ForSequenceClassification")
235353

vllm/model_executor/models/bert.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from vllm.model_executor.layers.quantization import QuantizationConfig
2525
from vllm.model_executor.layers.vocab_parallel_embedding import (
2626
VocabParallelEmbedding)
27+
from vllm.model_executor.models.adapters import _load_st_projector
2728
from vllm.model_executor.pooling_metadata import PoolingMetadata
2829
from vllm.sequence import IntermediateTensors
2930
from vllm.tasks import PoolingTask
@@ -457,6 +458,9 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
457458
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
458459
super().__init__()
459460

461+
# Save vllm_config for projector loading
462+
self.vllm_config = vllm_config
463+
460464
pooler_config = vllm_config.model_config.pooler_config
461465
assert pooler_config is not None
462466

@@ -497,13 +501,19 @@ def _build_model(self,
497501
embedding_class=BertEmbedding)
498502

499503
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
504+
# Load projector using the stored vllm_config
505+
projector = None
506+
if hasattr(self, 'vllm_config'):
507+
projector = _load_st_projector(self.vllm_config)
508+
500509
return DispatchPooler({
501510
"encode":
502511
Pooler.for_encode(pooler_config),
503512
"embed":
504513
Pooler.for_embed(
505514
pooler_config,
506515
default_pooling_type=PoolingType.CLS,
516+
projector=projector,
507517
),
508518
})
509519

0 commit comments

Comments
 (0)