diff --git a/requirements/docs.txt b/requirements/docs.txt index a24b9c7e924b..a47844ff55d8 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -30,4 +30,4 @@ torch transformers zmq uvloop -prometheus-client +prometheus-client \ No newline at end of file diff --git a/requirements/test.in b/requirements/test.in index 6652bfdfe66c..0429b71b1a15 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -52,4 +52,4 @@ runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 -terratorch==1.1rc2 # required for PrithviMAE test +terratorch==1.1rc2 # required for PrithviMAE test \ No newline at end of file diff --git a/tests/models/language/pooling/test_st_projector.py b/tests/models/language/pooling/test_st_projector.py new file mode 100644 index 000000000000..85604432e61f --- /dev/null +++ b/tests/models/language/pooling/test_st_projector.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import numpy as np +import pytest +from scipy.spatial.distance import cosine + +from ...utils import EmbedModelInfo + + +def _get_vllm_embeddings(vllm_runner, model_info: EmbedModelInfo, + test_texts: list[str]): + """Get embeddings from vLLM.""" + vllm_extra_kwargs: dict[str, Any] = {} + if model_info.architecture == "GteNewModel": + vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} + + with vllm_runner( + model_info.name, + runner="pooling", + max_model_len=None, + trust_remote_code=True, + **vllm_extra_kwargs, + ) as vllm_model: + embeddings = vllm_model.encode(test_texts) + + # Extract tensor/numpy data + data = [] + for emb in embeddings: + if hasattr(emb, "outputs"): + data.append(emb.outputs.data.cpu().numpy()) + else: + data.append(emb.cpu().numpy() if hasattr(emb, "cpu") else emb) + return np.array(data) + + +def _get_hf_embeddings(hf_runner, model_info: EmbedModelInfo, + test_texts: list[str]): + """Get embeddings from HuggingFace ST interface.""" + with hf_runner( + model_info.name, + is_sentence_transformer=True, + dtype="float32", + ) as hf_model: + embeddings = hf_model.encode(test_texts) + if hasattr(embeddings, "cpu"): + return embeddings.cpu().numpy() + return np.array(embeddings) + + +# ST models with projector (Dense) layers +ST_PROJECTOR_MODELS = [ + EmbedModelInfo( + "TencentBAC/Conan-embedding-v1", + architecture="BertModel", + enable_test=True, + ), +] + + +@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) +def test_st_projector_loading(vllm_runner, model_info: EmbedModelInfo) -> None: + """Ensure projector models load and output expected dim.""" + if not model_info.enable_test: + pytest.skip("Skipping test.") + + test_texts = ["This is a test sentence."] + embeddings_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts) + + actual_dim = embeddings_data.shape[-1] + expected_dim = 1792 + assert actual_dim == expected_dim, ( + f"Expected {expected_dim}, got {actual_dim}") + + +@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) +def test_compare_with_hf_dimensions(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + """Compare embedding dimensions between vLLM and HuggingFace.""" + if not model_info.enable_test: + pytest.skip("Skipping test.") + + test_texts = ["This is a test sentence for dimension comparison."] + + vllm_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts) + hf_data = _get_hf_embeddings(hf_runner, model_info, test_texts) + + vllm_dim = vllm_data.shape[-1] + hf_dim = hf_data.shape[-1] + + assert vllm_dim == hf_dim, ("Embedding dim mismatch: " + f"vLLM {vllm_dim} vs HF {hf_dim}") + print(f"✓ Embedding dimensions match: {vllm_dim}") + + +@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) +def test_embedding_numerical_similarity(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + """Numerical similarity between vLLM and HF embeddings.""" + if not model_info.enable_test: + pytest.skip("Skipping test.") + + test_texts = [ + "This is a test sentence for numerical comparison.", + "Another sentence to verify embedding quality.", + "机器学习是人工智能的一个重要分支。", # Chinese test + ] + + vllm_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts) + hf_data = _get_hf_embeddings(hf_runner, model_info, test_texts) + + assert vllm_data.shape == hf_data.shape, ( + "Shape mismatch: " + f"vLLM {vllm_data.shape} vs HF {hf_data.shape}") + + print(f"Embedding shape: {vllm_data.shape}") + print(f"Embedding dimension: {vllm_data.shape[-1]}") + + similarities = [] + for i, text in enumerate(test_texts): + vllm_emb = vllm_data[i] + hf_emb = hf_data[i] + + similarity = 1 - cosine(vllm_emb, hf_emb) + similarities.append(similarity) + + preview = text[:50] + ("..." if len(text) > 50 else "") + print(f"Text {i + 1}: '{preview}'") + print(f" Cosine similarity: {similarity:.6f}") + + min_similarity = 0.95 + assert similarity > min_similarity, ( + f"Text {i + 1} similarity too low: " + f"{similarity:.6f} < {min_similarity}\n" + f"vLLM norm: {np.linalg.norm(vllm_emb):.6f}, " + f"HF norm: {np.linalg.norm(hf_emb):.6f}") + + avg_similarity = np.mean(similarities) + print(f"\nAverage cosine similarity: {avg_similarity:.6f}") + + assert avg_similarity > 0.98, ( + f"Average similarity too low: {avg_similarity:.6f} < 0.98") + print("✓ All numerical similarity tests passed!") + + +@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) +def test_embedding_quality_checks(vllm_runner, + model_info: EmbedModelInfo) -> None: + """Basic quality checks: non-zero, non-constant, distinct.""" + if not model_info.enable_test: + pytest.skip("Skipping test.") + + test_texts = [ + "First test sentence.", + "Second different sentence.", + "Completely different content here.", + ] + + embeddings_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts) + + print(f"Embeddings shape: {embeddings_data.shape}") + + # Non-zero and non-constant + for i, emb in enumerate(embeddings_data): + norm = np.linalg.norm(emb) + print(f"Embedding {i + 1} L2 norm: {norm:.6f}") + assert norm > 1e-6, ( + f"Embedding {i + 1} too close to zero: norm={norm}") + + std = np.std(emb) + print(f"Embedding {i + 1} std: {std:.6f}") + assert std > 1e-6, ( + f"Embedding {i + 1} too close to constant: std={std}") + + # Different texts should differ + for i in range(len(embeddings_data)): + for j in range(i + 1, len(embeddings_data)): + sim = 1 - cosine(embeddings_data[i], embeddings_data[j]) + print(f"Similarity between text {i + 1} and {j + 1}: {sim:.6f}") + assert sim < 0.99, ("Embeddings too similar: " + f"{i + 1} vs {j + 1} -> {sim:.6f}") + + print("✓ All embedding quality checks passed!") diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index e2162e5cbf95..6beade18bd06 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from enum import IntEnum from itertools import groupby -from typing import Callable, Optional, TypeVar, Union +from typing import Callable, Optional, TypeVar, Union, cast import torch import torch.nn as nn @@ -67,33 +67,61 @@ class Pooler(nn.Module, ABC): """The interface required for all poolers used in pooling models in vLLM.""" @staticmethod - def for_encode(pooler_config: PoolerConfig): + def for_encode( + pooler_config: PoolerConfig, + *, + default_pooling_type: PoolingType = PoolingType.ALL, + ): if pooler_config.pooling_type == "STEP": return StepPooler() - resolved_config = ResolvedPoolingConfig(task="encode", - pooling_type=PoolingType.ALL) + # Use original logic: if pooler_config.pooling_type is None, use default + if pooler_config.pooling_type is not None: + resolved_config = ResolvedPoolingConfig.from_config( + task="encode", + pooler_config=pooler_config, + ) + else: + resolved_config = ResolvedPoolingConfig( + task="encode", pooling_type=default_pooling_type) return SimplePooler.from_config(resolved_config) @staticmethod - def for_embed(pooler_config: PoolerConfig): - resolved_config = ResolvedPoolingConfig.from_config( - task="embed", - pooler_config=pooler_config, - ) + def for_embed( + pooler_config: PoolerConfig, + *, + default_pooling_type: PoolingType = PoolingType.LAST, + projector: Optional[nn.Module] = None, + ): + # Use original logic: if pooler_config.pooling_type is None, use default + if pooler_config.pooling_type is not None: + resolved_config = ResolvedPoolingConfig.from_config( + task="embed", + pooler_config=pooler_config, + ) + else: + resolved_config = ResolvedPoolingConfig( + task="embed", pooling_type=default_pooling_type) - return SimplePooler.from_config(resolved_config) + return SimplePooler.from_config(resolved_config, projector=projector) @staticmethod def for_classify( pooler_config: PoolerConfig, classifier: Optional[ClassifierFn], + *, + default_pooling_type: PoolingType = PoolingType.LAST, ): - resolved_config = ResolvedPoolingConfig.from_config( - task="classify", - pooler_config=pooler_config, - ) + # Use original logic: if pooler_config.pooling_type is None, use default + if pooler_config.pooling_type is not None: + resolved_config = ResolvedPoolingConfig.from_config( + task="classify", + pooler_config=pooler_config, + ) + else: + resolved_config = ResolvedPoolingConfig( + task="classify", pooling_type=default_pooling_type) pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type) @@ -454,24 +482,89 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], class EmbeddingPoolerHead(PoolerHead): - def __init__(self) -> None: + def __init__(self, projector: Optional[nn.Module] = None) -> None: super().__init__(activation=PoolerNormalize()) + self.projector = projector + self._dim_checked = False + + def _ensure_projector_device_and_dtype(self, + ref_tensor: torch.Tensor) -> None: + """Ensure projector is on correct device with float32 dtype.""" + if self.projector is None: + return + + projector = cast(nn.Module, self.projector) + try: + proj_device = next(projector.parameters()).device + if proj_device != ref_tensor.device: + projector.to(device=ref_tensor.device, dtype=torch.float32) + # Ensure all parameters are float32 + for param in projector.parameters(): + param.data = param.data.to(torch.float32) + except StopIteration: + # Empty projector, skip device check + pass + + def _validate_projector_dimensions(self, ref_tensor: torch.Tensor) -> None: + """Validate projector input dimensions match pooled output.""" + if self.projector is None: + return + + projector = cast(nn.Module, self.projector) + first_linear = None + for module in projector.modules(): + if isinstance(module, nn.Linear): + first_linear = module + break + + if first_linear is not None: + expected_dim = first_linear.in_features + actual_dim = ref_tensor.shape[-1] + if expected_dim != actual_dim: + raise ValueError( + f"Dimension mismatch: Dense projector expects " + f"input dim {expected_dim}, but pooled output " + f"has dim {actual_dim}") def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): + # Step 1: Apply ST projector (e.g., 1024 → 1792) + if self.projector is not None: + projector = cast(nn.Module, self.projector) + ref = pooled_data[0] if isinstance(pooled_data, + list) else pooled_data + + # Ensure projector is on correct device with float32 dtype + self._ensure_projector_device_and_dtype(ref) + + # Check dimension compatibility on first run + if not self._dim_checked: + self._validate_projector_dimensions(ref) + self._dim_checked = True + + # Apply projection with fp32 computation for stability + def _proj(x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + y = projector(x.to(torch.float32)) + return y.to(orig_dtype) + + if isinstance(pooled_data, torch.Tensor): + pooled_data = _proj(pooled_data) + else: + pooled_data = [_proj(t) for t in pooled_data] + pooling_params = get_pooling_params(pooling_metadata) - # for matryoshka representation + # Step 2: Handle Matryoshka dimension truncation if specified dimensions_list = [ pooling_param.dimensions for pooling_param in pooling_params ] if any(d is not None for d in dimensions_list): - # change the output dimension assert len(pooled_data) == len(dimensions_list) if len(set(dimensions_list)) == 1 and not isinstance( pooled_data, list): - # if all dimensions are the same + # All dimensions are the same d = dimensions_list[0] pooled_data = pooled_data[..., :d] else: @@ -480,7 +573,7 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], for vecs, d in zip(pooled_data, dimensions_list) ] - # for normalize + # Step 3: Apply normalization flags = [p.normalize for p in pooling_params] if len(set(flags)) == 1: if flags[0]: @@ -530,10 +623,11 @@ class SimplePooler(Pooler): def from_config( cls, pooler_config: ResolvedPoolingConfig, + projector: Optional[nn.Module] = None, ) -> "SimplePooler": pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type) if pooler_config.task == "embed": - head = EmbeddingPoolerHead() + head = EmbeddingPoolerHead(projector=projector) elif pooler_config.task == "encode": head = RewardPoolerHead() else: diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 1dbe70f84a62..1ec6cce59e06 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast @@ -11,6 +12,8 @@ from .interfaces_base import VllmModelForPooling, is_pooling_model +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from vllm.config import VllmConfig @@ -24,6 +27,162 @@ ] +def _load_weights_from_state_dict(state_dict: dict, linear: nn.Linear, + use_bias: bool) -> bool: + """Load weights from a state dict into a linear layer.""" + # Try common weight key names + weight = None + bias = None + + for weight_key in ["linear.weight", "dense.weight", "weight"]: + if weight_key in state_dict: + weight = state_dict[weight_key] + break + + for bias_key in ["linear.bias", "dense.bias", "bias"]: + if bias_key in state_dict: + bias = state_dict[bias_key] + break + + if weight is None: + return False + + try: + with torch.no_grad(): + # Ensure weights are float32 for numerical stability + linear.weight.copy_(weight.to(torch.float32)) + if use_bias and bias is not None and linear.bias is not None: + linear.bias.copy_(bias.to(torch.float32)) + return True + except Exception: + return False + + +def _load_from_safetensors(folder: str, model_path: str, revision: str, + linear: nn.Linear, use_bias: bool) -> bool: + """Try loading weights from safetensors file.""" + from vllm.transformers_utils.config import get_hf_file_bytes + + try: + b = get_hf_file_bytes(f"{folder}/model.safetensors", model_path, + revision) + if b is not None: + import io + + from safetensors.torch import load as st_load + sd = st_load(io.BytesIO(b)) + return _load_weights_from_state_dict(sd, linear, use_bias) + except Exception: + pass + return False + + +def _load_from_pytorch_bin(folder: str, model_path: str, revision: str, + linear: nn.Linear, use_bias: bool) -> bool: + """Try loading weights from pytorch_model.bin file.""" + from vllm.transformers_utils.config import get_hf_file_bytes + + try: + b = get_hf_file_bytes(f"{folder}/pytorch_model.bin", model_path, + revision) + if b is not None: + import io + sd = torch.load(io.BytesIO(b), map_location="cpu") + return _load_weights_from_state_dict(sd, linear, use_bias) + except Exception: + pass + return False + + +def st_activation(name: Optional[str]) -> nn.Module: + m = (name or "").lower() + if m == "gelu": + return nn.GELU() + if m == "gelu_new": + return nn.GELU(approximate="tanh") + if m == "relu": + return nn.ReLU() + if m == "tanh": + return nn.Tanh() + if m == "sigmoid": + return nn.Sigmoid() + if m == "swish": + return nn.SiLU() + return nn.Identity() + + +def _load_st_projector(vllm_config: "VllmConfig") -> Optional[nn.Module]: + """Load Sentence-Transformers Dense projection layers.""" + from vllm.transformers_utils.config import get_hf_file_to_dict + + model_path = vllm_config.model_config.model + revision = vllm_config.model_config.revision + + # Read modules.json + modules = get_hf_file_to_dict("modules.json", model_path, revision) + + # Handle dict format (some ST variants) + if isinstance(modules, dict): + modules = modules.get("modules", []) + if not isinstance(modules, list): + return None + + # Filter Dense modules + dense_entries = [ + m for m in modules + if m.get("type") == "sentence_transformers.models.Dense" + ] + if not dense_entries: + return None + + # Build projection layer sequence + layers = [] + for entry in dense_entries: + folder = entry.get("path") + if not folder: + continue + + # Read config + cfg = get_hf_file_to_dict(f"{folder}/config.json", model_path, + revision) + if not cfg: + continue + + in_features = cfg.get("in_features") + out_features = cfg.get("out_features") + if in_features is None or out_features is None: + continue + + use_bias = cfg.get("bias", True) + activation = st_activation(cfg.get("activation_function")) + + # Create linear layer with float32 for numerical stability + linear = nn.Linear(in_features, out_features, bias=use_bias) + + # Try to load weights + weight_loaded = False + + # Try loading weights from safetensors first, then pytorch_model.bin + if not weight_loaded: + weight_loaded = _load_from_safetensors(folder, model_path, + revision, linear, use_bias) + + if not weight_loaded: + weight_loaded = _load_from_pytorch_bin(folder, model_path, + revision, linear, use_bias) + + layers.append(linear) + layers.append(activation) + + if not layers: + return None + + # Ensure the entire module uses float32 + projector = nn.Sequential(*layers) + projector = projector.to(dtype=torch.float32) + return projector + + def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: model_name = orig_model_name @@ -123,11 +282,15 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - { - "encode": Pooler.for_encode(pooler_config), - "embed": Pooler.for_embed(pooler_config), - }, ) + # Load ST projector for embed task only + projector = _load_st_projector(vllm_config) + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "embed": + Pooler.for_embed(pooler_config, projector=projector), + }) ModelForEmbedding.__name__ = \ _get_pooling_model_name(cls.__name__, "ForEmbedding") @@ -182,8 +345,8 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): assert pooler_config is not None pooling_type_str = pooler_config.pooling_type - assert pooling_type_str is not None - pooling_type = PoolingType[pooling_type_str] + pooling_type = (PoolingType.LAST if pooling_type_str is None else + PoolingType[pooling_type_str]) self.pooler = DispatchPooler({ "encode": diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 6638f06f9826..1d2f1a1c8bc7 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -28,8 +28,7 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask -from .interfaces import (SupportsCrossEncoding, SupportsQuant, - default_pooling_type) +from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -61,13 +60,21 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: + input_shape = input_ids.size() - token_type_ids = _decode_token_type_ids(input_ids) - + # Input embeddings. inputs_embeds = self.word_embeddings(input_ids) + + # Position embeddings. position_embeddings = self.position_embeddings(position_ids) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device) + token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings + position_embeddings @@ -328,7 +335,6 @@ def forward(self, hidden_states: torch.Tensor, @support_torch_compile -@default_pooling_type("CLS") class BertModel(nn.Module, SupportsQuant): is_pooling_model = True @@ -344,23 +350,25 @@ def __init__( ) -> None: super().__init__() - self.config = vllm_config.model_config.hf_config - self.embeddings = embedding_class(self.config) + config = vllm_config.model_config.hf_config + self.embeddings = embedding_class(config) self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder") def forward( self, input_ids: torch.Tensor, - positions: torch.Tensor, + position_ids: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embeddings(input_ids=input_ids, - position_ids=positions) + position_ids=position_ids, + token_type_ids=token_type_ids) return self.encoder(hidden_states) def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): @@ -403,7 +411,6 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -@default_pooling_type("ALL") class BertPoolingModel(BertModel): is_pooling_model = True @@ -434,7 +441,6 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -@default_pooling_type("CLS") class BertEmbeddingModel(nn.Module, SupportsQuant): """A model that uses Bert to provide embedding functionalities. @@ -456,17 +462,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = self._build_model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self.pooler = self._build_pooler(pooler_config) + self.pooler = self._build_pooler(pooler_config, vllm_config) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.model(input_ids=input_ids, - positions=positions, + position_ids=positions, + token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) @@ -488,61 +496,25 @@ def _build_model(self, prefix=prefix, embedding_class=BertEmbedding) - def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: + def _build_pooler(self, pooler_config: PoolerConfig, + vllm_config: VllmConfig) -> Pooler: + from .adapters import _load_st_projector + projector = _load_st_projector(vllm_config) + return DispatchPooler({ - "encode": Pooler.for_encode(pooler_config), - "embed": Pooler.for_embed(pooler_config), + "encode": + Pooler.for_encode(pooler_config), + "embed": + Pooler.for_embed( + pooler_config, + default_pooling_type=PoolingType.CLS, + projector=projector, + ), }) -# Here we encode the token type ids together with the input ids. -# Since we use int 32 for the input IDs and the vocabulary size -# is way lower than 2**31, there is room to encode additional -# bits. At the same time, for cross-encoder use cases, the -# token type ids are only 0 or 1, requiring only 1 bit. -# This means that we can store the token type ids in the 31st -# bit. We void the 32nd bit because that would produce a negative -# number, which could be used to signal other things. -# -# The reason for all of this is that all the tensors that are -# passed as input to the forward function of a module marked -# with @support_torch_compile have to be persistent. So to -# avoid adding more persistent tensors in the model runner, we -# encode more information in the same persistent tensor. -# -# Since the *ForClassification module is outside of the BertModel -# which is compiled, we can do the encoding here and then separate -# the information again in the Embedding layer. Since with bit masks -# we can do this entirely with torch operations and without branching, -# it works with torch compile. - -TOKEN_TYPE_SHIFT = 30 - - -def _encode_token_type_ids(input_ids: torch.Tensor, - token_type_ids: torch.Tensor) -> None: - # input_ids can be padded to the right - input_ids[:token_type_ids.shape[0]].bitwise_or_( - token_type_ids << TOKEN_TYPE_SHIFT) - - -def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: - - ids_mask = torch.ones(input_ids.shape, - dtype=torch.int32, - device=input_ids.device) << TOKEN_TYPE_SHIFT - tokens_mask = ids_mask.bitwise_not() - - token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT - - input_ids.bitwise_and_(tokens_mask) - - return token_type_ids - - -@default_pooling_type("CLS") -class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, - SupportsQuant): +class BertForSequenceClassification(nn.Module, SupportsV0Only, + SupportsCrossEncoding, SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -600,13 +572,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - if token_type_ids is not None: - assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) - assert input_ids is not None - _encode_token_type_ids(input_ids, token_type_ids) - return self.bert(input_ids=input_ids, - positions=positions, + position_ids=positions, inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) + intermediate_tensors=intermediate_tensors, + token_type_ids=token_type_ids) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 02ea0814ddef..b665a11f6e7d 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -925,3 +925,30 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: exc_info=e) return max_position_embeddings + + +def get_hf_file_bytes(file_name: str, + model: Union[str, Path], + revision: Optional[str] = 'main') -> Optional[bytes]: + file_path = try_get_local_file(model=model, + file_name=file_name, + revision=revision) + + if file_path is None: + try: + hf_hub_file = hf_hub_download(model, + file_name, + revision=revision, + token=_get_hf_token()) + file_path = Path(hf_hub_file) + except Exception: + return None + + if file_path is not None and file_path.is_file(): + try: + with open(file_path, 'rb') as file: + return file.read() + except Exception: + return None + + return None