|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 |
|
| 4 | +import io |
4 | 5 | from collections.abc import Iterable
|
5 | 6 | from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
6 | 7 |
|
|
23 | 24 | "LMHeadModel",
|
24 | 25 | ]
|
25 | 26 |
|
| 27 | +# Note: projector uses standard nn.Linear to avoid sharding edge-cases |
| 28 | + |
| 29 | + |
| 30 | +def st_activation(name: Optional[str]) -> nn.Module: |
| 31 | + m = (name or "").lower() |
| 32 | + if m in ("gelu", "gelu_new"): |
| 33 | + return nn.GELU() |
| 34 | + if m == "relu": |
| 35 | + return nn.ReLU() |
| 36 | + if m == "tanh": |
| 37 | + return nn.Tanh() |
| 38 | + if m == "sigmoid": |
| 39 | + return nn.Sigmoid() |
| 40 | + if m == "swish": |
| 41 | + return nn.SiLU() |
| 42 | + if m == "identity": |
| 43 | + return nn.Identity() |
| 44 | + return nn.Identity() |
| 45 | + |
26 | 46 |
|
27 | 47 | def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
|
28 | 48 | model_name = orig_model_name
|
@@ -99,38 +119,137 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
99 | 119 | return ModelForPooling # type: ignore
|
100 | 120 |
|
101 | 121 |
|
102 |
| -def as_embedding_model(cls: _T) -> _T: |
| 122 | +def _load_st_projector(vllm_config: "VllmConfig") -> Optional[nn.Module]: |
| 123 | + """Load Sentence Transformers projector from modules.json |
| 124 | + and Dense folders. |
103 | 125 | """
|
104 |
| - Subclass an existing vLLM model to support embeddings. |
| 126 | + from vllm.transformers_utils.config import (file_or_path_exists, |
| 127 | + get_hf_file_bytes, |
| 128 | + get_hf_file_to_dict) |
| 129 | + |
| 130 | + model_path = vllm_config.model_config.model |
| 131 | + revision = vllm_config.model_config.revision |
| 132 | + |
| 133 | + # Check if modules.json exists and contains Dense modules |
| 134 | + if not file_or_path_exists(model_path, "modules.json", revision): |
| 135 | + return None |
| 136 | + |
| 137 | + modules = get_hf_file_to_dict("modules.json", model_path, revision) |
| 138 | + if not isinstance(modules, list): |
| 139 | + return None |
| 140 | + |
| 141 | + dense_entries = [ |
| 142 | + m for m in modules |
| 143 | + if m.get("type") == "sentence_transformers.models.Dense" |
| 144 | + ] |
| 145 | + if not dense_entries: |
| 146 | + return None |
| 147 | + |
| 148 | + # Get dtype and quant config |
| 149 | + raw_dtype = getattr(vllm_config.model_config, "dtype", None) |
| 150 | + if isinstance(raw_dtype, str): |
| 151 | + desired_dtype = getattr(torch, raw_dtype, torch.float32) |
| 152 | + elif isinstance(raw_dtype, torch.dtype): |
| 153 | + desired_dtype = raw_dtype |
| 154 | + else: |
| 155 | + desired_dtype = torch.float32 |
| 156 | + |
| 157 | + def _load_config_json(path: str) -> Optional[dict]: |
| 158 | + """Load config.json from a Dense folder.""" |
| 159 | + try: |
| 160 | + return get_hf_file_to_dict(path, model_path, revision) |
| 161 | + except Exception: |
| 162 | + return None |
| 163 | + |
| 164 | + def _load_dense_weights( |
| 165 | + folder: str, |
| 166 | + ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: |
| 167 | + """Load weights from a Dense folder, trying safetensors first.""" |
| 168 | + # Try safetensors first |
| 169 | + try: |
| 170 | + b = get_hf_file_bytes(f"{folder}/model.safetensors", model_path, |
| 171 | + revision) |
| 172 | + if b is not None: |
| 173 | + from safetensors.torch import load as st_load |
| 174 | + sd = st_load(io.BytesIO(b)) |
| 175 | + w = (sd.get("linear.weight") or sd.get("dense.weight") |
| 176 | + or sd.get("weight")) |
| 177 | + bias = (sd.get("linear.bias") or sd.get("dense.bias") |
| 178 | + or sd.get("bias")) |
| 179 | + return w, bias |
| 180 | + except Exception: |
| 181 | + pass |
| 182 | + |
| 183 | + # Fallback to pytorch_model.bin |
| 184 | + try: |
| 185 | + b = get_hf_file_bytes(f"{folder}/pytorch_model.bin", model_path, |
| 186 | + revision) |
| 187 | + if b is not None: |
| 188 | + sd = torch.load(io.BytesIO(b), map_location="cpu") |
| 189 | + w = (sd.get("linear.weight") or sd.get("dense.weight") |
| 190 | + or sd.get("weight")) |
| 191 | + bias = (sd.get("linear.bias") or sd.get("dense.bias") |
| 192 | + or sd.get("bias")) |
| 193 | + return w, bias |
| 194 | + except Exception: |
| 195 | + pass |
| 196 | + return None, None |
| 197 | + |
| 198 | + # Build projector layers |
| 199 | + layers: list[nn.Module] = [] |
| 200 | + for i, entry in enumerate(dense_entries): |
| 201 | + folder = entry.get("path") |
| 202 | + if not folder: |
| 203 | + continue |
| 204 | + |
| 205 | + cfg = _load_config_json(f"{folder}/config.json") |
| 206 | + if not cfg: |
| 207 | + continue |
| 208 | + |
| 209 | + in_features = cfg.get("in_features") |
| 210 | + out_features = cfg.get("out_features") |
| 211 | + if in_features is None or out_features is None: |
| 212 | + continue |
| 213 | + |
| 214 | + use_bias = cfg.get("bias", True) |
| 215 | + activation = st_activation(cfg.get("activation_function")) |
| 216 | + |
| 217 | + # Create a simple nn.Linear for projector to avoid sharding edge-cases |
| 218 | + linear = nn.Linear(in_features, out_features, bias=use_bias) |
| 219 | + linear = linear.to(dtype=desired_dtype) |
| 220 | + |
| 221 | + # Load weights |
| 222 | + weight, bias = _load_dense_weights(folder) |
| 223 | + if weight is not None: |
| 224 | + with torch.no_grad(): |
| 225 | + # weight is expected in [out_features, in_features] |
| 226 | + linear.weight.copy_(weight.to(dtype=linear.weight.dtype)) |
| 227 | + if use_bias and bias is not None and linear.bias is not None: |
| 228 | + linear.bias.copy_(bias.to(dtype=linear.bias.dtype)) |
| 229 | + |
| 230 | + layers.append(linear) |
| 231 | + layers.append(activation) |
| 232 | + |
| 233 | + if not layers: |
| 234 | + return None |
| 235 | + return nn.Sequential(*layers) |
105 | 236 |
|
106 |
| - By default, the embeddings of the whole prompt are extracted from the |
107 |
| - normalized hidden state corresponding to the last token. |
108 | 237 |
|
109 |
| - Note: |
110 |
| - We assume that no extra layers are added to the original model; |
111 |
| - please implement your own model if this is not the case. |
112 |
| - """ |
113 |
| - # Avoid modifying existing embedding models |
114 |
| - if is_pooling_model(cls): |
115 |
| - return cls |
116 |
| - |
117 |
| - # Lazy import |
118 |
| - from vllm.model_executor.layers.pooler import DispatchPooler, Pooler |
| 238 | +def as_embedding_model(cls: _T) -> _T: |
| 239 | + """Convert a model class to support embedding tasks.""" |
119 | 240 |
|
120 | 241 | class ModelForEmbedding(_create_pooling_model_cls(cls)):
|
121 | 242 |
|
122 | 243 | def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
123 |
| - pooler_config = vllm_config.model_config.pooler_config |
124 |
| - assert pooler_config is not None |
| 244 | + from vllm.model_executor.layers.pooler import Pooler |
125 | 245 |
|
126 |
| - self.pooler = DispatchPooler( |
127 |
| - { |
128 |
| - "encode": Pooler.for_encode(pooler_config), |
129 |
| - "embed": Pooler.for_embed(pooler_config), |
130 |
| - }, ) |
| 246 | + # Load ST projector if available |
| 247 | + projector = _load_st_projector(vllm_config) |
131 | 248 |
|
132 |
| - ModelForEmbedding.__name__ = \ |
133 |
| - _get_pooling_model_name(cls.__name__, "ForEmbedding") |
| 249 | + # Use existing pooler_config instead of creating new one |
| 250 | + pooler_config = vllm_config.model_config.pooler_config |
| 251 | + assert pooler_config is not None |
| 252 | + self.pooler = Pooler.for_embed(pooler_config, projector=projector) |
134 | 253 |
|
135 | 254 | return ModelForEmbedding # type: ignore
|
136 | 255 |
|
@@ -229,7 +348,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
229 | 348 | # ForSequenceClassification model.
|
230 | 349 | return seq_cls_model_loader(self, weights)
|
231 | 350 |
|
232 |
| - |
233 | 351 | ModelForSequenceClassification.__name__ = \
|
234 | 352 | _get_pooling_model_name(cls.__name__, "ForSequenceClassification")
|
235 | 353 |
|
|
0 commit comments