Skip to content

Commit 7d52e4e

Browse files
Merge pull request #1101 from terrastackai/cp_fixes_vllm_plugins
Various fixes to vLLM plugins
2 parents 56acae5 + 907a19a commit 7d52e4e

File tree

5 files changed

+52
-13
lines changed

5 files changed

+52
-13
lines changed

integrationtests/vLLM/test_segmentation_io_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_serving_segmentation_plugin(get_server, model_name, input_name):
7777
# This is just in case the test ends up with a GPU of less memory than an A100-80GB.
7878
# Just to avoid OOMing in the CI
7979
"--max-num-seqs",
80-
"8",
80+
"32",
8181
"--io-processor-plugin",
8282
io_processor_plugin,
8383
"--model-impl",

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ geobenchv2 = [
141141

142142
vllm = [
143143
"geobenchv2==0.9",
144-
"vllm>=0.12,<=0.14.0",
144+
"vllm>=0.12,!=0.15.*",
145145
]
146146

147147
vllm_test = [

terratorch/vllm/plugins/segmentation/segmentation_io_processor.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from einops import rearrange
2020
import logging
2121
from terratorch.vllm.plugins import generate_datamodule
22+
from terratorch.vllm.utils import check_vllm_version
2223
import uuid
2324
import warnings
2425
from vllm.config import VllmConfig
@@ -325,7 +326,7 @@ def pre_process(
325326
# Just run the async function froma. synchronous context.
326327
# Since we are already in the vLLM server event loop we use that one.
327328
loop = asyncio.get_event_loop()
328-
loop.run_until_complete(self.pre_process_async(prompt, request_id, **kwargs))
329+
return loop.run_until_complete(self.pre_process_async(prompt, request_id, **kwargs))
329330

330331

331332
async def pre_process_async(
@@ -414,16 +415,23 @@ async def pre_process_async(
414415
window["image"] = window["image"][None, :, :, :]
415416
window = self.datamodule.aug(window)["image"]
416417

417-
prompt = {
418-
"prompt_token_ids": [1],
419-
"multi_modal_data": {
420-
"pixel_values": window.to(torch.float16)[0],
421-
}
418+
multi_modal_data = {
419+
"pixel_values": window.to(torch.float16)[0],
422420
}
423-
424421
# not all models use location coordinates, so we don't bother sending them to vLLM if not needed
425422
if "location_coords" in self.model_config["input"]["data"]:
426-
prompt["multi_modal_data"]["location_coords"] = location_coords
423+
multi_modal_data["location_coords"] = location_coords
424+
425+
# after v0.14.0 vLLM has changed the input structure for multimodal data
426+
if check_vllm_version("0.14.0", ">"):
427+
multi_modal_data = {
428+
"image": multi_modal_data
429+
}
430+
431+
prompt = {
432+
"prompt_token_ids": [1],
433+
"multi_modal_data": multi_modal_data
434+
}
427435

428436
prompts.append(prompt)
429437

terratorch/vllm/plugins/segmentation/terramind_segmentation_io_processor.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from terratorch.tasks.tiled_inference import generate_tiled_inference_output, prepare_tiled_inference_input
2323
from terratorch.vllm.plugins import generate_datamodule
2424
from terratorch.cli_tools import write_tiff
25+
from terratorch.vllm.utils import check_vllm_version
2526
from .utils import download_file_async, get_filename_from_url, path_or_tmpdir, to_base64_tiff
2627

2728
from .types import PluginConfig, RequestData, RequestOutput, TiledInferenceParameters
@@ -146,7 +147,7 @@ def pre_process(
146147
# Just run the async function froma. synchronous context.
147148
# Since we are already in the vLLM server event loop we use that one.
148149
loop = asyncio.get_event_loop()
149-
loop.run_until_complete(self.pre_process_async(prompt, request_id, **kwargs))
150+
return loop.run_until_complete(self.pre_process_async(prompt, request_id, **kwargs))
150151

151152

152153
async def pre_process_async(
@@ -193,10 +194,17 @@ async def pre_process_async(
193194
for tile in prompt_data:
194195
reshaped_tile = tensor_reshape_fn(tile.input_data)
195196
# TODO: Check if there's a better way of getting the data in the correct data type ouf of the box.
196-
vllm_input = {mod: tensor.to(torch.float16) for mod, tensor in reshaped_tile.items()}
197+
multi_modal_data = {mod: tensor.to(torch.float16) for mod, tensor in reshaped_tile.items()}
198+
199+
# after v0.14.0 vLLM has changed the input structure for multimodal data
200+
if check_vllm_version("0.14.0", ">"):
201+
multi_modal_data = {
202+
"image": multi_modal_data
203+
}
204+
197205
prompt = {
198206
"prompt_token_ids": [1],
199-
"multi_modal_data": vllm_input
207+
"multi_modal_data": multi_modal_data
200208
}
201209

202210
prompts.append(prompt)

terratorch/vllm/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,29 @@
99
from typing import List,Dict
1010
from enum import Enum
1111

12+
from packaging import version
13+
from vllm import __version__ as vllm_version
14+
15+
16+
def check_vllm_version(target_version: str, comparison: str):
17+
current_version = version.parse(vllm_version)
18+
target = version.parse(target_version)
19+
20+
if comparison == "==":
21+
return current_version == target
22+
elif comparison == "!=":
23+
return current_version != target
24+
elif comparison == "<":
25+
return current_version < target
26+
elif comparison == "<=":
27+
return current_version <= target
28+
elif comparison == ">":
29+
return current_version > target
30+
elif comparison == ">=":
31+
return current_version >= target
32+
else:
33+
raise ValueError(f"Invalid comparison operator: {comparison}")
34+
1235
class InputTypeEnum(str, Enum):
1336
tensor= 'torch.Tensor'
1437

0 commit comments

Comments
 (0)