Skip to content

Commit 1405f0c

Browse files
[Misc] Factor out common _apply_feature_select_strategy (#26003)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 84d5734 commit 1405f0c

File tree

4 files changed

+40
-39
lines changed

4 files changed

+40
-39
lines changed

vllm/model_executor/models/llava.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from .siglip import SiglipVisionModel
4242
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
4343
init_vllm_registered_model, maybe_prefix)
44-
from .vision import get_vision_encoder_info
44+
from .vision import get_num_selected_vision_tokens, get_vision_encoder_info
4545

4646

4747
class LlavaImagePixelInputs(TensorSchema):
@@ -147,19 +147,6 @@ def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
147147
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
148148
return {"image": None}
149149

150-
def _apply_feature_select_strategy(
151-
self,
152-
strategy: str,
153-
encoder_num_image_tokens: int,
154-
) -> int:
155-
if strategy == "default":
156-
return encoder_num_image_tokens - 1
157-
if strategy == "full":
158-
return encoder_num_image_tokens
159-
160-
msg = f"Unexpected feature select strategy: {strategy!r}"
161-
raise NotImplementedError(msg)
162-
163150
def get_num_image_tokens(
164151
self,
165152
*,
@@ -169,12 +156,12 @@ def get_num_image_tokens(
169156
hf_config = self.get_hf_config()
170157
vision_encoder_info = self.get_vision_encoder_info()
171158

172-
return self._apply_feature_select_strategy(
173-
hf_config.vision_feature_select_strategy,
159+
return get_num_selected_vision_tokens(
174160
vision_encoder_info.get_num_image_tokens(
175161
image_width=image_width,
176162
image_height=image_height,
177163
),
164+
hf_config.vision_feature_select_strategy,
178165
)
179166

180167
def get_image_size_with_most_features(self) -> ImageSize:

vllm/model_executor/models/llava_next.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .siglip import SiglipVisionModel
2828
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
2929
init_vllm_registered_model, maybe_prefix)
30+
from .vision import get_num_selected_vision_tokens
3031

3132

3233
class LlavaNextImagePixelInputs(TensorSchema):
@@ -95,12 +96,12 @@ def get_num_image_tokens(
9596
hf_config = self.get_hf_config()
9697
vision_encoder_info = self.get_vision_encoder_info()
9798

98-
base_feature_size = self._apply_feature_select_strategy(
99-
hf_config.vision_feature_select_strategy,
99+
base_feature_size = get_num_selected_vision_tokens(
100100
vision_encoder_info.get_num_image_tokens(
101101
image_width=image_width,
102102
image_height=image_height,
103103
),
104+
hf_config.vision_feature_select_strategy,
104105
)
105106

106107
num_patch_height, num_patch_width = get_anyres_image_grid_shape(

vllm/model_executor/models/tarsier.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
from .siglip import SiglipVisionModel
4141
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
4242
maybe_prefix)
43-
from .vision import VisionEncoderInfo, get_vision_encoder_info
43+
from .vision import (VisionEncoderInfo, get_num_selected_vision_tokens,
44+
get_vision_encoder_info)
4445

4546

4647
class TarsierImagePixelInputs(TensorSchema):
@@ -201,18 +202,6 @@ def get_hf_processor(self, **kwargs: object) -> TarsierProcessor:
201202
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
202203
return {"image": None}
203204

204-
def _apply_feature_select_strategy(
205-
self,
206-
strategy: str,
207-
encoder_num_image_tokens: int,
208-
) -> int:
209-
if strategy == "default":
210-
return encoder_num_image_tokens - 1
211-
if strategy == "full":
212-
return encoder_num_image_tokens
213-
msg = f"Unexpected feature select strategy: {strategy!r}"
214-
raise NotImplementedError(msg)
215-
216205
def get_num_image_tokens(
217206
self,
218207
*,
@@ -221,21 +210,21 @@ def get_num_image_tokens(
221210
) -> int:
222211
hf_config = self.get_hf_config()
223212
vision_encoder_info = self.get_vision_encoder_info()
224-
num_projected_patches = self._apply_feature_select_strategy(
225-
hf_config.vision_feature_select_strategy,
213+
num_projected_patches = get_num_selected_vision_tokens(
226214
vision_encoder_info.get_num_image_tokens(
227215
image_width=image_width,
228216
image_height=image_height,
229217
),
218+
hf_config.vision_feature_select_strategy,
230219
)
231220
if num_projected_patches <= 0:
232221
default_size = self.get_image_size_with_most_features()
233-
num_projected_patches_default = self._apply_feature_select_strategy(
234-
hf_config.vision_feature_select_strategy,
222+
num_projected_patches_default = get_num_selected_vision_tokens(
235223
vision_encoder_info.get_num_image_tokens(
236224
image_width=default_size.width,
237225
image_height=default_size.height,
238226
),
227+
hf_config.vision_feature_select_strategy,
239228
)
240229
if num_projected_patches_default <= 0:
241230
raise ValueError(

vllm/model_executor/models/vision.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import torch
1111
from transformers import PretrainedConfig
12-
from typing_extensions import assert_never
1312

1413
from vllm.distributed import (get_tensor_model_parallel_rank,
1514
get_tensor_model_parallel_world_size,
@@ -22,9 +21,13 @@
2221
_C = TypeVar("_C", bound=PretrainedConfig)
2322

2423

24+
class _RootConfig(Protocol[_C]):
25+
vision_config: _C
26+
27+
2528
class VisionEncoderInfo(ABC, Generic[_C]):
2629

27-
def __init__(self, hf_config: _C) -> None:
30+
def __init__(self, hf_config: _RootConfig[_C]) -> None:
2831
super().__init__()
2932

3033
self.hf_config = hf_config
@@ -95,7 +98,7 @@ def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend:
9598

9699

97100
def _get_vision_feature_selector(
98-
strategy: VisionFeatureSelectStrategy,
101+
strategy: Union[VisionFeatureSelectStrategy, str],
99102
) -> Callable[[torch.Tensor], torch.Tensor]:
100103
if callable(strategy):
101104
return strategy
@@ -111,7 +114,28 @@ def _get_vision_feature_selector(
111114
if strategy == "full":
112115
return lambda feats: feats
113116

114-
assert_never(strategy)
117+
raise ValueError(f"Unexpected feature select strategy: {strategy!r}")
118+
119+
120+
def get_num_selected_vision_tokens(
121+
num_vision_tokens: int,
122+
strategy: Union[VisionFeatureSelectStrategy, str],
123+
) -> int:
124+
if callable(strategy):
125+
dummy_features = torch.empty(1, num_vision_tokens, 64) # [B, L, D]
126+
dummy_selected_features = strategy(dummy_features)
127+
return dummy_selected_features.shape[1]
128+
129+
if strategy == "class":
130+
return 1
131+
132+
if strategy == "default":
133+
return num_vision_tokens - 1
134+
135+
if strategy == "full":
136+
return num_vision_tokens
137+
138+
raise ValueError(f"Unexpected feature select strategy: {strategy!r}")
115139

116140

117141
def resolve_visual_encoder_outputs(

0 commit comments

Comments
 (0)