Skip to content

Commit 566e0cc

Browse files
author
jibxie
committed
Add vllm version check for compatibility
1 parent c85d972 commit 566e0cc

File tree

2 files changed

+41
-23
lines changed

2 files changed

+41
-23
lines changed

src/model.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from vllm.lora.request import LoRARequest
4343
from vllm.sampling_params import SamplingParams
4444
from vllm.utils import random_uuid
45+
from vllm.version import __version__ as _VLLM_VERSION
4546

4647
from utils.metrics import VllmStatLogger
4748

@@ -54,12 +55,6 @@ class TritonPythonModel:
5455
def auto_complete_config(auto_complete_model_config):
5556
inputs = [
5657
{"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]},
57-
{
58-
"name": "image",
59-
"data_type": "TYPE_STRING",
60-
"dims": [-1], # can be multiple images as separate elements
61-
"optional": True,
62-
},
6358
{
6459
"name": "stream",
6560
"data_type": "TYPE_BOOL",
@@ -79,6 +74,14 @@ def auto_complete_config(auto_complete_model_config):
7974
"optional": True,
8075
},
8176
]
77+
if _VLLM_VERSION >= "0.6.3.post1":
78+
inputs.append({
79+
"name": "image",
80+
"data_type": "TYPE_STRING",
81+
"dims": [-1], # can be multiple images as separate elements
82+
"optional": True,
83+
})
84+
8285
outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}]
8386

8487
# Store the model configuration as a dictionary.
@@ -394,22 +397,23 @@ async def generate(self, request):
394397
if isinstance(prompt, bytes):
395398
prompt = prompt.decode("utf-8")
396399

397-
image_input_tensor = pb_utils.get_input_tensor_by_name(
398-
request, "image"
399-
)
400-
if image_input_tensor:
401-
image_list = []
402-
for image_raw in image_input_tensor.as_numpy():
403-
image_data = base64.b64decode(image_raw.decode("utf-8"))
404-
image = Image.open(BytesIO(image_data)).convert("RGB")
405-
image_list.append(image)
406-
if len(image_list) > 0:
407-
prompt = {
408-
"prompt": prompt,
409-
"multi_modal_data": {
410-
"image": image_list
400+
if _VLLM_VERSION >= "0.6.3.post1":
401+
image_input_tensor = pb_utils.get_input_tensor_by_name(
402+
request, "image"
403+
)
404+
if image_input_tensor:
405+
image_list = []
406+
for image_raw in image_input_tensor.as_numpy():
407+
image_data = base64.b64decode(image_raw.decode("utf-8"))
408+
image = Image.open(BytesIO(image_data)).convert("RGB")
409+
image_list.append(image)
410+
if len(image_list) > 0:
411+
prompt = {
412+
"prompt": prompt,
413+
"multi_modal_data": {
414+
"image": image_list
415+
}
411416
}
412-
}
413417

414418
stream = pb_utils.get_input_tensor_by_name(request, "stream")
415419
if stream:

src/utils/metrics.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase
3333
from vllm.engine.metrics import Stats as VllmStats
3434
from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets
35-
35+
from vllm.version import __version__ as _VLLM_VERSION
3636

3737
class TritonMetrics:
3838
def __init__(self, labels: List[str], max_model_len: int):
@@ -76,6 +76,14 @@ def __init__(self, labels: List[str], max_model_len: int):
7676
description="Number of generation tokens processed.",
7777
kind=pb_utils.MetricFamily.HISTOGRAM,
7878
)
79+
# 'best_of' metric has been hidden since vllm 0.6.3
80+
# https://github.com/vllm-project/vllm/commit/cbc2ef55292b2af6ff742095c030e8425124c005
81+
if _VLLM_VERSION < "0.6.3":
82+
self.histogram_best_of_request_family = pb_utils.MetricFamily(
83+
name="vllm:request_params_best_of",
84+
description="Histogram of the best_of request parameter.",
85+
kind=pb_utils.MetricFamily.HISTOGRAM,
86+
)
7987
self.histogram_n_request_family = pb_utils.MetricFamily(
8088
name="vllm:request_params_n",
8189
description="Histogram of the n request parameter.",
@@ -154,6 +162,11 @@ def __init__(self, labels: List[str], max_model_len: int):
154162
buckets=build_1_2_5_buckets(max_model_len),
155163
)
156164
)
165+
if _VLLM_VERSION < "0.6.3":
166+
self.histogram_best_of_request = self.histogram_best_of_request_family.Metric(
167+
labels=labels,
168+
buckets=[1, 2, 5, 10, 20],
169+
)
157170
self.histogram_n_request = self.histogram_n_request_family.Metric(
158171
labels=labels,
159172
buckets=[1, 2, 5, 10, 20],
@@ -240,7 +253,8 @@ def log(self, stats: VllmStats) -> None:
240253
),
241254
(self.metrics.histogram_n_request, stats.n_requests),
242255
]
243-
256+
if _VLLM_VERSION < "0.6.3":
257+
histogram_metrics.append((self.metrics.histogram_best_of_request, stats.best_of_requests))
244258
for metric, data in counter_metrics:
245259
self._log_counter(metric, data)
246260
for metric, data in histogram_metrics:

0 commit comments

Comments
 (0)