Skip to content

Commit 6b501b2

Browse files
author
Sigrid Jin (Sionic AI)
committed
feat: add vision pooling support for jina embeddings v4
Signed-off-by: Sigrid Jin (Sionic AI) <[email protected]>
1 parent 5114a3c commit 6b501b2

File tree

4 files changed

+17
-102
lines changed

4 files changed

+17
-102
lines changed

examples/offline_inference/embed_jina_embeddings_v4.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
"""
4-
Example of offline inference with Jina Embeddings V4 multimodal model.
5-
6-
This example demonstrates:
7-
1. Text-only embeddings
8-
2. Image-only embeddings
9-
3. Cross-modal embeddings (text-to-image similarity)
10-
11-
The model supports both text and vision inputs through a unified architecture.
12-
"""
13-
143
import torch
154

165
from vllm import LLM

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3256,7 +3256,7 @@ def get_limit_per_prompt(self, modality: str) -> int:
32563256
@config
32573257
@dataclass
32583258
class PoolerConfig:
3259-
"""Configuration for the pooler."""
3259+
"""Controls the behavior of output pooling in pooling models."""
32603260

32613261
pooling_type: Optional[Literal["last", "all", "cls", "step", "mean",
32623262
"vision"]] = None

vllm/model_executor/layers/pooler.py

Lines changed: 15 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -625,56 +625,6 @@ def forward(
625625
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
626626

627627

628-
class VisionPooler(Pooler):
629-
630-
@classmethod
631-
def from_config(cls, model_config: ModelConfig) -> "VisionPooler":
632-
return cls(model_config)
633-
634-
def __init__(self, config: ModelConfig):
635-
super().__init__()
636-
self.config = config
637-
638-
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
639-
if task == "embed":
640-
return PoolingParams(pooling_type="vision",
641-
logits_processing_needs_token_ids=True)
642-
return None
643-
644-
def forward(
645-
self,
646-
hidden_states: torch.Tensor,
647-
pooling_metadata: PoolingMetadata,
648-
) -> PoolerOutput:
649-
assert isinstance(pooling_metadata, V1PoolingMetadata)
650-
651-
pooled_outputs = []
652-
for i in range(len(pooling_metadata.prompt_lens)):
653-
start_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
654-
hf_config.vision_start_token_id).nonzero()[-1].item()
655-
end_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
656-
hf_config.vision_end_token_id).nonzero()[-1].item()
657-
658-
seq_start = torch.cumsum(
659-
torch.tensor([0] + pooling_metadata.prompt_lens.tolist()),
660-
dim=0)[i]
661-
seq_len = pooling_metadata.prompt_lens[i]
662-
663-
output = torch.empty(self.config.hidden_size,
664-
device=hidden_states.device,
665-
dtype=hidden_states.dtype)
666-
667-
grid = lambda meta: (self.config.hidden_size, )
668-
mean_pool_with_position_kernel[grid](hidden_states, output,
669-
seq_start, seq_len,
670-
self.config.hidden_size,
671-
start_pos, end_pos + 1)
672-
673-
pooled_outputs.append(output)
674-
675-
return build_output(torch.stack(pooled_outputs))
676-
677-
678628
if HAS_TRITON:
679629

680630
@triton.jit
@@ -688,7 +638,6 @@ def mean_pool_with_position_kernel(
688638
pool_end,
689639
BLOCK_SIZE: tl.constexpr,
690640
):
691-
"""Triton kernel to perform mean pooling over a specified token range."""
692641
pid = tl.program_id(0)
693642

694643
if pid >= hidden_size:
@@ -817,10 +766,12 @@ def forward(
817766

818767
pooled_outputs = []
819768
for i in range(len(pooling_metadata.prompt_lens)):
820-
start_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
821-
hf_config.vision_start_token_id).nonzero()[-1].item()
822-
end_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
823-
hf_config.vision_end_token_id).nonzero()[-1].item()
769+
start_pos = (pooling_metadata.prompt_token_ids[i] ==
770+
self.config.hf_config.vision_start_token_id).
771+
nonzero()[-1].item()
772+
end_pos = (pooling_metadata.prompt_token_ids[i] ==
773+
self.config.hf_config.vision_end_token_id).
774+
nonzero()[-1].item()
824775

825776
seq_start = torch.cumsum(
826777
torch.tensor([0] + pooling_metadata.prompt_lens.tolist()),
@@ -832,41 +783,18 @@ def forward(
832783
dtype=hidden_states.dtype)
833784

834785
grid = lambda meta: (self.config.hidden_size, )
835-
mean_pool_with_position_kernel[grid](hidden_states, output,
836-
seq_start, seq_len,
837-
self.config.hidden_size,
838-
start_pos, end_pos + 1)
786+
if HAS_TRITON:
787+
mean_pool_with_position_kernel[grid](hidden_states, output,
788+
seq_start, seq_len,
789+
self.config.hidden_size,
790+
start_pos, end_pos + 1)
791+
else:
792+
# Fallback to PyTorch implementation if Triton is not available
793+
vision_tokens_range = hidden_states[seq_start + start_pos : seq_start + end_pos + 1]
794+
output = vision_tokens_range.mean(dim=0)
839795

840796
pooled_outputs.append(output)
841797

842798
return build_output(torch.stack(pooled_outputs))
843799

844800

845-
if HAS_TRITON:
846-
847-
@triton.jit
848-
def mean_pool_with_position_kernel(
849-
hidden_states_ptr,
850-
output_ptr,
851-
seq_start,
852-
seq_len,
853-
hidden_size,
854-
pool_start,
855-
pool_end,
856-
BLOCK_SIZE: tl.constexpr,
857-
):
858-
"""Triton kernel to perform mean pooling over a specified token range."""
859-
pid = tl.program_id(0)
860-
861-
if pid >= hidden_size:
862-
return
863-
864-
accumulator = 0.0
865-
for i in range(pool_start, pool_end):
866-
hidden_val = tl.load(hidden_states_ptr +
867-
(seq_start + i) * hidden_size + pid)
868-
accumulator += hidden_val
869-
870-
# Store mean pooled result
871-
result = accumulator / (pool_end - pool_start)
872-
tl.store(output_ptr + pid, result)

vllm/model_executor/models/jina_embeddings_v4.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from vllm.config import VllmConfig
88
from vllm.logger import init_logger
9-
from vllm.model_executor.layers.pooler import Pooler, PoolingTask
9+
from vllm.model_executor.layers.pooler import Pooler, PoolingTask, VisionPooler
1010
# yapf: disable
1111
from vllm.model_executor.pooling_metadata import (
1212
PoolingMetadata as V0PoolingMetadata)
@@ -32,8 +32,6 @@
3232

3333

3434
class JinaVLPooler(Pooler):
35-
"""Vision-aware pooler for Jina V4 with special vision token handling."""
36-
3735
def __init__(self, vllm_config: VllmConfig):
3836
super().__init__()
3937
self.vision_pooler = VisionPooler(vllm_config.model_config)

0 commit comments

Comments
 (0)