3
3
from typing import TYPE_CHECKING
4
4
5
5
from vllm .logger import init_logger
6
- from vllm .multimodal import MULTIMODAL_REGISTRY
6
+ from vllm .multimodal import MultiModalRegistry
7
7
from vllm .v1 .request import Request
8
8
9
9
if TYPE_CHECKING :
@@ -67,13 +67,15 @@ def get_freed_ids(self) -> list[tuple[str, int]]:
67
67
def compute_encoder_budget (
68
68
model_config : "ModelConfig" ,
69
69
scheduler_config : "SchedulerConfig" ,
70
+ mm_registry : MultiModalRegistry ,
70
71
) -> tuple [int , int ]:
71
72
"""Compute the encoder cache budget based on the model and scheduler
72
73
configurations.
73
74
74
75
Args:
75
76
model_config: Model configuration.
76
77
scheduler_config: Scheduler configuration.
78
+ mm_registry: Provides information about the token cost.
77
79
78
80
Returns:
79
81
- Compute budget for encoder execution, in unit of number of tokens
@@ -89,21 +91,27 @@ def compute_encoder_budget(
89
91
(
90
92
encoder_compute_budget ,
91
93
encoder_cache_size ,
92
- ) = _compute_encoder_budget_multimodal (model_config , scheduler_config )
94
+ ) = _compute_encoder_budget_multimodal (
95
+ model_config ,
96
+ scheduler_config ,
97
+ mm_registry ,
98
+ )
93
99
94
100
return encoder_compute_budget , encoder_cache_size
95
101
96
102
97
103
def _compute_encoder_budget_multimodal (
98
104
model_config : "ModelConfig" ,
99
105
scheduler_config : "SchedulerConfig" ,
106
+ mm_registry : MultiModalRegistry ,
100
107
) -> tuple [int , int ]:
101
108
"""Compute the encoder cache budget based on the model and scheduler
102
109
configurations for a multimodal model.
103
110
104
111
Args:
105
112
model_config: Model configuration.
106
113
scheduler_config: Scheduler configuration.
114
+ mm_registry: Provides information about the token cost.
107
115
108
116
Returns:
109
117
- Compute budget for encoder execution, in unit of number of tokens
@@ -112,8 +120,8 @@ def _compute_encoder_budget_multimodal(
112
120
in the input sequence.
113
121
"""
114
122
115
- max_tokens_by_modality_dict = MULTIMODAL_REGISTRY . get_max_tokens_per_item_by_nonzero_modality ( # noqa: E501
116
- model_config )
123
+ max_tokens_by_modality_dict = mm_registry \
124
+ . get_max_tokens_per_item_by_nonzero_modality ( model_config )
117
125
118
126
if not max_tokens_by_modality_dict :
119
127
logger .warning (
0 commit comments