Skip to content

Commit 97234be

Browse files
[Misc] Manage HTTP connections in one place (#6600)
1 parent c051bfe commit 97234be

File tree

7 files changed

+215
-85
lines changed

7 files changed

+215
-85
lines changed

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vllm import LLM, SamplingParams
1717
from vllm.assets.image import ImageAsset
1818
from vllm.config import TokenizerPoolConfig
19+
from vllm.connections import global_http_connection
1920
from vllm.distributed import (destroy_distributed_environment,
2021
destroy_model_parallel)
2122
from vllm.inputs import TextPrompt
@@ -74,6 +75,13 @@ def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
7475
"""Singleton instance of :class:`_ImageAssets`."""
7576

7677

78+
@pytest.fixture(autouse=True)
79+
def init_test_http_connection():
80+
# pytest_asyncio may use a different event loop per test
81+
# so we need to make sure the async client is created anew
82+
global_http_connection.reuse_client = False
83+
84+
7785
def cleanup():
7886
destroy_model_parallel()
7987
destroy_distributed_environment()

tests/entrypoints/openai/test_vision.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
import openai
44
import pytest
5-
import pytest_asyncio
65

7-
from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
6+
from vllm.multimodal.utils import encode_image_base64, fetch_image
87

98
from ...utils import VLLM_PATH, RemoteOpenAIServer
109

@@ -42,11 +41,10 @@ def client(server):
4241
return server.get_async_client()
4342

4443

45-
@pytest_asyncio.fixture(scope="session")
46-
async def base64_encoded_image() -> Dict[str, str]:
44+
@pytest.fixture(scope="session")
45+
def base64_encoded_image() -> Dict[str, str]:
4746
return {
48-
image_url:
49-
encode_image_base64(await ImageFetchAiohttp.fetch_image(image_url))
47+
image_url: encode_image_base64(fetch_image(image_url))
5048
for image_url in TEST_IMAGE_URLS
5149
}
5250

tests/multimodal/test_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88
from PIL import Image
99

10-
from vllm.multimodal.utils import ImageFetchAiohttp, fetch_image
10+
from vllm.multimodal.utils import async_fetch_image, fetch_image
1111

1212
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
1313
TEST_IMAGE_URLS = [
@@ -37,15 +37,15 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
3737
return (np.asarray(a) == np.asarray(b.convert(a.mode))).all()
3838

3939

40-
@pytest.mark.asyncio(scope="module")
40+
@pytest.mark.asyncio
4141
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
4242
async def test_fetch_image_http(image_url: str):
4343
image_sync = fetch_image(image_url)
44-
image_async = await ImageFetchAiohttp.fetch_image(image_url)
44+
image_async = await async_fetch_image(image_url)
4545
assert _image_equals(image_sync, image_async)
4646

4747

48-
@pytest.mark.asyncio(scope="module")
48+
@pytest.mark.asyncio
4949
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
5050
@pytest.mark.parametrize("suffix", get_supported_suffixes())
5151
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
@@ -78,5 +78,5 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
7878
else:
7979
pass # Lossy format; only check that image can be opened
8080

81-
data_image_async = await ImageFetchAiohttp.fetch_image(data_url)
81+
data_image_async = await async_fetch_image(data_url)
8282
assert _image_equals(data_image_sync, data_image_async)

vllm/assets/image.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
import shutil
21
from dataclasses import dataclass
32
from functools import lru_cache
43
from typing import Literal
54

6-
import requests
75
from PIL import Image
86

7+
from vllm.connections import global_http_connection
8+
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
9+
910
from .base import get_cache_dir
1011

1112

@@ -22,11 +23,9 @@ def get_air_example_data_2_asset(filename: str) -> Image.Image:
2223
if not image_path.exists():
2324
base_url = "https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava"
2425

25-
with requests.get(f"{base_url}/{filename}", stream=True) as response:
26-
response.raise_for_status()
27-
28-
with image_path.open("wb") as f:
29-
shutil.copyfileobj(response.raw, f)
26+
global_http_connection.download_file(f"{base_url}/{filename}",
27+
image_path,
28+
timeout=VLLM_IMAGE_FETCH_TIMEOUT)
3029

3130
return Image.open(image_path)
3231

vllm/connections.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from pathlib import Path
2+
from typing import Mapping, Optional
3+
from urllib.parse import urlparse
4+
5+
import aiohttp
6+
import requests
7+
8+
from vllm.version import __version__ as VLLM_VERSION
9+
10+
11+
class HTTPConnection:
12+
"""Helper class to send HTTP requests."""
13+
14+
def __init__(self, *, reuse_client: bool = True) -> None:
15+
super().__init__()
16+
17+
self.reuse_client = reuse_client
18+
19+
self._sync_client: Optional[requests.Session] = None
20+
self._async_client: Optional[aiohttp.ClientSession] = None
21+
22+
def get_sync_client(self) -> requests.Session:
23+
if self._sync_client is None or not self.reuse_client:
24+
self._sync_client = requests.Session()
25+
26+
return self._sync_client
27+
28+
# NOTE: We intentionally use an async function even though it is not
29+
# required, so that the client is only accessible inside async event loop
30+
async def get_async_client(self) -> aiohttp.ClientSession:
31+
if self._async_client is None or not self.reuse_client:
32+
self._async_client = aiohttp.ClientSession()
33+
34+
return self._async_client
35+
36+
def _validate_http_url(self, url: str):
37+
parsed_url = urlparse(url)
38+
39+
if parsed_url.scheme not in ("http", "https"):
40+
raise ValueError("Invalid HTTP URL: A valid HTTP URL "
41+
"must have scheme 'http' or 'https'.")
42+
43+
def _headers(self, **extras: str) -> Mapping[str, str]:
44+
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}
45+
46+
def get_response(
47+
self,
48+
url: str,
49+
*,
50+
stream: bool = False,
51+
timeout: Optional[float] = None,
52+
extra_headers: Optional[Mapping[str, str]] = None,
53+
):
54+
self._validate_http_url(url)
55+
56+
client = self.get_sync_client()
57+
extra_headers = extra_headers or {}
58+
59+
return client.get(url,
60+
headers=self._headers(**extra_headers),
61+
stream=stream,
62+
timeout=timeout)
63+
64+
async def get_async_response(
65+
self,
66+
url: str,
67+
*,
68+
timeout: Optional[float] = None,
69+
extra_headers: Optional[Mapping[str, str]] = None,
70+
):
71+
self._validate_http_url(url)
72+
73+
client = await self.get_async_client()
74+
extra_headers = extra_headers or {}
75+
76+
return client.get(url,
77+
headers=self._headers(**extra_headers),
78+
timeout=timeout)
79+
80+
def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes:
81+
with self.get_response(url, timeout=timeout) as r:
82+
r.raise_for_status()
83+
84+
return r.content
85+
86+
async def async_get_bytes(
87+
self,
88+
url: str,
89+
*,
90+
timeout: Optional[float] = None,
91+
) -> bytes:
92+
async with await self.get_async_response(url, timeout=timeout) as r:
93+
r.raise_for_status()
94+
95+
return await r.read()
96+
97+
def get_text(self, url: str, *, timeout: Optional[float] = None) -> str:
98+
with self.get_response(url, timeout=timeout) as r:
99+
r.raise_for_status()
100+
101+
return r.text
102+
103+
async def async_get_text(
104+
self,
105+
url: str,
106+
*,
107+
timeout: Optional[float] = None,
108+
) -> str:
109+
async with await self.get_async_response(url, timeout=timeout) as r:
110+
r.raise_for_status()
111+
112+
return await r.text()
113+
114+
def get_json(self, url: str, *, timeout: Optional[float] = None) -> str:
115+
with self.get_response(url, timeout=timeout) as r:
116+
r.raise_for_status()
117+
118+
return r.json()
119+
120+
async def async_get_json(
121+
self,
122+
url: str,
123+
*,
124+
timeout: Optional[float] = None,
125+
) -> str:
126+
async with await self.get_async_response(url, timeout=timeout) as r:
127+
r.raise_for_status()
128+
129+
return await r.json()
130+
131+
def download_file(
132+
self,
133+
url: str,
134+
save_path: Path,
135+
*,
136+
timeout: Optional[float] = None,
137+
chunk_size: int = 128,
138+
) -> Path:
139+
with self.get_response(url, timeout=timeout) as r:
140+
r.raise_for_status()
141+
142+
with save_path.open("wb") as f:
143+
for chunk in r.iter_content(chunk_size):
144+
f.write(chunk)
145+
146+
return save_path
147+
148+
async def async_download_file(
149+
self,
150+
url: str,
151+
save_path: Path,
152+
*,
153+
timeout: Optional[float] = None,
154+
chunk_size: int = 128,
155+
) -> Path:
156+
async with await self.get_async_response(url, timeout=timeout) as r:
157+
r.raise_for_status()
158+
159+
with save_path.open("wb") as f:
160+
async for chunk in r.content.iter_chunked(chunk_size):
161+
f.write(chunk)
162+
163+
return save_path
164+
165+
166+
global_http_connection = HTTPConnection()
167+
"""The global :class:`HTTPConnection` instance used by vLLM."""

vllm/multimodal/utils.py

Lines changed: 22 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,12 @@
11
import base64
22
from io import BytesIO
3-
from typing import Optional, Union
4-
from urllib.parse import urlparse
3+
from typing import Union
54

6-
import aiohttp
7-
import requests
85
from PIL import Image
96

7+
from vllm.connections import global_http_connection
108
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
119
from vllm.multimodal.base import MultiModalDataDict
12-
from vllm.version import __version__ as VLLM_VERSION
13-
14-
15-
def _validate_remote_url(url: str, *, name: str):
16-
parsed_url = urlparse(url)
17-
if parsed_url.scheme not in ["http", "https"]:
18-
raise ValueError(f"Invalid '{name}': A valid '{name}' "
19-
"must have scheme 'http' or 'https'.")
20-
21-
22-
def _get_request_headers():
23-
return {"User-Agent": f"vLLM/{VLLM_VERSION}"}
2410

2511

2612
def _load_image_from_bytes(b: bytes):
@@ -42,13 +28,8 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
4228
By default, the image is converted into RGB format.
4329
"""
4430
if image_url.startswith('http'):
45-
_validate_remote_url(image_url, name="image_url")
46-
47-
headers = _get_request_headers()
48-
49-
with requests.get(url=image_url, headers=headers) as response:
50-
response.raise_for_status()
51-
image_raw = response.content
31+
image_raw = global_http_connection.get_bytes(
32+
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
5233
image = _load_image_from_bytes(image_raw)
5334

5435
elif image_url.startswith('data:image'):
@@ -60,55 +41,30 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
6041
return image.convert(image_mode)
6142

6243

63-
class ImageFetchAiohttp:
64-
aiohttp_client: Optional[aiohttp.ClientSession] = None
65-
66-
@classmethod
67-
def get_aiohttp_client(cls) -> aiohttp.ClientSession:
68-
if cls.aiohttp_client is None:
69-
timeout = aiohttp.ClientTimeout(total=VLLM_IMAGE_FETCH_TIMEOUT)
70-
connector = aiohttp.TCPConnector()
71-
cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout,
72-
connector=connector)
73-
74-
return cls.aiohttp_client
75-
76-
@classmethod
77-
async def fetch_image(
78-
cls,
79-
image_url: str,
80-
*,
81-
image_mode: str = "RGB",
82-
) -> Image.Image:
83-
"""
84-
Asynchronously load a PIL image from a HTTP or base64 data URL.
85-
86-
By default, the image is converted into RGB format.
87-
"""
88-
89-
if image_url.startswith('http'):
90-
_validate_remote_url(image_url, name="image_url")
91-
92-
client = cls.get_aiohttp_client()
93-
headers = _get_request_headers()
44+
async def async_fetch_image(image_url: str,
45+
*,
46+
image_mode: str = "RGB") -> Image.Image:
47+
"""
48+
Asynchronously load a PIL image from a HTTP or base64 data URL.
9449
95-
async with client.get(url=image_url, headers=headers) as response:
96-
response.raise_for_status()
97-
image_raw = await response.read()
98-
image = _load_image_from_bytes(image_raw)
50+
By default, the image is converted into RGB format.
51+
"""
52+
if image_url.startswith('http'):
53+
image_raw = await global_http_connection.async_get_bytes(
54+
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
55+
image = _load_image_from_bytes(image_raw)
9956

100-
elif image_url.startswith('data:image'):
101-
image = _load_image_from_data_url(image_url)
102-
else:
103-
raise ValueError(
104-
"Invalid 'image_url': A valid 'image_url' must start "
105-
"with either 'data:image' or 'http'.")
57+
elif image_url.startswith('data:image'):
58+
image = _load_image_from_data_url(image_url)
59+
else:
60+
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
61+
"with either 'data:image' or 'http'.")
10662

107-
return image.convert(image_mode)
63+
return image.convert(image_mode)
10864

10965

11066
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
111-
image = await ImageFetchAiohttp.fetch_image(image_url)
67+
image = await async_fetch_image(image_url)
11268
return {"image": image}
11369

11470

0 commit comments

Comments
 (0)