|
25 | 25 | from transformers import BatchFeature
|
26 | 26 |
|
27 | 27 | from vllm.config import VllmConfig
|
28 |
| -from vllm.model_executor.layers.pooler import (AllPool, PoolerHead, |
29 |
| - PoolerIdentity, SimplePooler) |
| 28 | +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler |
30 | 29 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
31 | 30 | from vllm.model_executor.models.interfaces import (
|
32 |
| - IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput) |
| 31 | + IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput, |
| 32 | + default_pooling_type) |
33 | 33 | from vllm.model_executor.models.utils import AutoWeightsLoader
|
34 | 34 | from vllm.multimodal import MULTIMODAL_REGISTRY
|
35 | 35 | from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
@@ -142,6 +142,7 @@ def apply(
|
142 | 142 | )
|
143 | 143 |
|
144 | 144 |
|
| 145 | +@default_pooling_type("All") |
145 | 146 | @MULTIMODAL_REGISTRY.register_processor(
|
146 | 147 | PrithviGeoSpatialMAEMultiModalProcessor,
|
147 | 148 | info=PrithviGeoSpatialMAEProcessingInfo,
|
@@ -198,7 +199,11 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
198 | 199 | "Only SemanticSegmentationTask is supported for now "
|
199 | 200 | "by PrithviGeospatialMAE.")
|
200 | 201 |
|
201 |
| - self.pooler = SimplePooler(AllPool(), PoolerHead(PoolerIdentity())) |
| 202 | + pooler_config = vllm_config.model_config.pooler_config |
| 203 | + assert pooler_config is not None |
| 204 | + |
| 205 | + self.pooler = DispatchPooler( |
| 206 | + {"encode": Pooler.for_encode(pooler_config)}, ) |
202 | 207 |
|
203 | 208 | def _parse_and_validate_multimodal_data(
|
204 | 209 | self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
0 commit comments