feat(model): add embed_sparse task for BGE-M3 server-side sparse aggr…#35001
feat(model): add embed_sparse task for BGE-M3 server-side sparse aggr…#35001joeqzzuo wants to merge 3 commits intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
The pull request introduces a new embed_sparse pooling task for the BgeM3EmbeddingModel, enabling server-side sparse vector aggregation. This is a significant improvement, streamlining the process of generating vocabulary-sized sparse vectors directly usable by vector databases. The changes include adding a BgeM3SparsePooler class, updating the BgeM3EmbeddingModel to incorporate this new pooler, and modifying vllm/tasks.py to include embed_sparse as a PoolingTask. New integration tests have also been added to verify the functionality and correctness of the embed_sparse task. The code is generally well-structured and follows existing patterns within the codebase. I've identified a few areas for improvement related to type hinting and potential clarity in the BgeM3SparsePooler initialization.
|
|
||
| import itertools | ||
| from collections.abc import Iterable | ||
| from collections.abc import Iterable, Set |
There was a problem hiding this comment.
The Set import is used in BgeM3SparsePooler's get_supported_tasks method. It's good practice to import Set directly from typing rather than collections.abc for better compatibility and type checking across Python versions, especially when dealing with type hints.
| from collections.abc import Iterable, Set | |
| from collections.abc import Iterable | |
| from typing import Set |
| self, | ||
| sparse_linear: nn.Module, | ||
| vocab_size: int, | ||
| special_token_ids: list[int], |
There was a problem hiding this comment.
The special_token_ids parameter is typed as list[int]. While this is technically correct, Set[int] might be more appropriate here given that the order of special tokens doesn't matter and checking for membership (in self.special_token_ids) would be more efficient with a set. This also clearly communicates the intent that special_token_ids is a collection of unique IDs.
| special_token_ids: list[int], | |
| special_token_ids: Set[int], |
| if self.special_token_ids: | ||
| sparse_vec[self.special_token_ids] = 0.0 |
There was a problem hiding this comment.
The special_token_ids attribute is initialized as a list but used in a context where Set operations would be more efficient (checking if self.special_token_ids). Converting it to a set during initialization would improve lookup performance, especially if the list of special tokens grows large.
| if self.special_token_ids: | |
| sparse_vec[self.special_token_ids] = 0.0 | |
| if self.special_token_ids: | |
| sparse_vec[list(self.special_token_ids)] = 0.0 |
| self.bos_token_id, | ||
| self.eos_token_id, | ||
| self.pad_token_id, | ||
| ] if tid is not None and tid >= 0 |
There was a problem hiding this comment.
The special_token_ids list comprehension currently filters out None values. However, getattr(hf_config, "pad_token_id", 1) ensures pad_token_id is always an int. If bos_token_id or eos_token_id can truly be None, then the type hint for special_token_ids in BgeM3SparsePooler should reflect list[int | None] or the filtering logic should be more explicit about handling None if tid is not None is intended to handle more than just the pad_token_id default. Given the BgeM3SparsePooler expects list[int], it's safer to ensure all elements are ints before passing them.
| ] if tid is not None and tid >= 0 | |
| ] if tid is not None and tid >= 0 and isinstance(tid, int) | |
| ] |
|
Hi @joeqzzuo, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
ec3d4cb to
b615d75
Compare
…egation
Purpose
Add
embed_sparsepooling task forBgeM3EmbeddingModelto enable server-side sparse vector aggregation.Currently, BGE-M3 sparse retrieval requires a cumbersome 2-step client workflow:
/tokenizeto get token IDs/poolingwithtask=token_classifyto get per-position scores, then manually aggregate viascatter_reduceon the client sideThis PR adds a
BgeM3SparsePoolerthat performsscatter_reduce(index=input_ids, reduce="amax")aggregation server-side, producing vocabulary-sized sparse vectors directly usable by vector databases (Qdrant, Milvus, Vespa, etc.) — in a single API call.This follows the same pattern as
SPLADESparsePoolerinbert.py, adapted for BGE-M3's architecture wheresparse_linearmaps hidden states to a single scalar per position (rather than SPLADE's MLM head which maps to vocab-sized logits).Related issues: #13609 #15384 #18469
Test Plan
pytest tests/models/language/pooling/test_bge_m3.py -v -k "embed_sparse"Three new integration tests:
test_bge_m3_embed_sparse_matches_token_classify— verifiesembed_sparseoutput matchestoken_classify+ client-side aggregationtest_bge_m3_embed_sparse_lexical_scores— verifies lexical matching scores match reference valuestest_bge_m3_embed_sparse_corner_case— verifies short input ("Hi") produces correct sparse outputTest Result
Unit tests (standalone, validating core
scatter_reducelogic):Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.