Skip to content

Commit 06d347b

Browse files
[Misc] Extend Diffusion Benchmark script to other backends (#875)
Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
1 parent af11b02 commit 06d347b

File tree

2 files changed

+233
-127
lines changed

2 files changed

+233
-127
lines changed

benchmarks/diffusion/backends.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import base64
2+
import mimetypes
3+
import os
4+
import time
5+
import uuid
6+
from dataclasses import dataclass, field
7+
from typing import Any
8+
9+
import aiohttp
10+
from tqdm import tqdm
11+
12+
13+
@dataclass
14+
class RequestFuncInput:
15+
prompt: str
16+
api_url: str
17+
model: str
18+
width: int | None = None
19+
height: int | None = None
20+
num_frames: int | None = None
21+
num_inference_steps: int | None = None
22+
seed: int | None = None
23+
fps: int | None = None
24+
timestamp: float | None = None
25+
slo_ms: float | None = None
26+
extra_body: dict[str, Any] = field(default_factory=dict)
27+
image_paths: list[str] | None = None
28+
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
29+
30+
31+
@dataclass
32+
class RequestFuncOutput:
33+
success: bool = False
34+
latency: float = 0.0
35+
error: str = ""
36+
start_time: float = 0.0
37+
response_body: dict[str, Any] = field(default_factory=dict)
38+
peak_memory_mb: float = 0.0
39+
slo_achieved: bool | None = None
40+
41+
42+
def _guess_mime_type(path: str) -> str:
43+
mime, _ = mimetypes.guess_type(path)
44+
return mime or "application/octet-stream"
45+
46+
47+
def _encode_image_as_data_url(path: str) -> str:
48+
with open(path, "rb") as f:
49+
encoded = base64.b64encode(f.read()).decode("utf-8")
50+
mime = _guess_mime_type(path)
51+
return f"data:{mime};base64,{encoded}"
52+
53+
54+
async def async_request_chat_completions(
55+
input: RequestFuncInput,
56+
session: aiohttp.ClientSession,
57+
pbar: tqdm | None = None,
58+
) -> RequestFuncOutput:
59+
output = RequestFuncOutput()
60+
output.start_time = time.perf_counter()
61+
62+
extra_body = dict(input.extra_body)
63+
if input.width and input.height:
64+
extra_body.setdefault("height", input.height)
65+
extra_body.setdefault("width", input.width)
66+
if input.num_frames:
67+
extra_body.setdefault("num_frames", input.num_frames)
68+
if input.num_inference_steps:
69+
extra_body.setdefault("num_inference_steps", input.num_inference_steps)
70+
if input.seed is not None:
71+
extra_body.setdefault("seed", input.seed)
72+
if input.fps:
73+
extra_body.setdefault("fps", input.fps)
74+
75+
if input.image_paths and len(input.image_paths) > 0:
76+
content = []
77+
if input.prompt:
78+
content.append({"type": "text", "text": input.prompt})
79+
for img_path in input.image_paths:
80+
if not os.path.exists(img_path):
81+
output.error = f"Image file not found: {img_path}"
82+
output.success = False
83+
if pbar:
84+
pbar.update(1)
85+
return output
86+
content.append(
87+
{
88+
"type": "image_url",
89+
"image_url": {"url": _encode_image_as_data_url(img_path)},
90+
}
91+
)
92+
messages = [{"role": "user", "content": content}]
93+
else:
94+
messages = [{"role": "user", "content": input.prompt}]
95+
96+
payload = {
97+
"model": input.model,
98+
"messages": messages,
99+
}
100+
if extra_body:
101+
payload["extra_body"] = extra_body
102+
103+
try:
104+
async with session.post(input.api_url, json=payload) as response:
105+
if response.status == 200:
106+
resp_json = await response.json()
107+
output.response_body = resp_json
108+
output.success = True
109+
if "peak_memory_mb" in resp_json:
110+
output.peak_memory_mb = resp_json["peak_memory_mb"]
111+
else:
112+
output.error = f"HTTP {response.status}: {await response.text()}"
113+
output.success = False
114+
except Exception as e:
115+
output.error = str(e)
116+
output.success = False
117+
118+
output.latency = time.perf_counter() - output.start_time
119+
120+
if output.success and input.slo_ms is not None:
121+
output.slo_achieved = (output.latency * 1000.0) <= float(input.slo_ms)
122+
123+
if pbar:
124+
pbar.update(1)
125+
return output
126+
127+
128+
async def async_request_openai_images(
129+
input: RequestFuncInput,
130+
session: aiohttp.ClientSession,
131+
pbar: tqdm | None = None,
132+
) -> RequestFuncOutput:
133+
"""
134+
Send request to OpenAI's /v1/images/generations endpoint.
135+
"""
136+
output = RequestFuncOutput()
137+
output.start_time = time.perf_counter()
138+
139+
# Build size string from width/height
140+
width = input.width or 1024
141+
height = input.height or 1024
142+
size = f"{width}x{height}"
143+
144+
payload: dict[str, Any] = {
145+
"model": input.model,
146+
"prompt": input.prompt,
147+
"n": 1,
148+
"size": size,
149+
"response_format": "b64_json",
150+
}
151+
152+
# Add optional parameters
153+
if input.seed is not None:
154+
payload["seed"] = input.seed
155+
if input.num_inference_steps is not None:
156+
payload["num_inference_steps"] = input.num_inference_steps
157+
158+
# Add any extra body parameters
159+
if input.extra_body:
160+
for key, value in input.extra_body.items():
161+
if key not in payload:
162+
payload[key] = value
163+
164+
headers = {
165+
"Content-Type": "application/json",
166+
"Authorization": "Bearer EMPTY",
167+
}
168+
169+
try:
170+
async with session.post(input.api_url, json=payload, headers=headers) as response:
171+
if response.status == 200:
172+
resp_json = await response.json()
173+
output.response_body = resp_json
174+
output.success = True
175+
# Check for usage/memory info if available
176+
if "usage" in resp_json and "peak_memory_mb" in resp_json.get("usage", {}):
177+
output.peak_memory_mb = resp_json["usage"]["peak_memory_mb"]
178+
else:
179+
output.error = f"HTTP {response.status}: {await response.text()}"
180+
output.success = False
181+
except Exception as e:
182+
output.error = str(e)
183+
output.success = False
184+
185+
output.latency = time.perf_counter() - output.start_time
186+
187+
if output.success and input.slo_ms is not None:
188+
output.slo_achieved = (output.latency * 1000.0) <= float(input.slo_ms)
189+
190+
if pbar:
191+
pbar.update(1)
192+
return output
193+
194+
195+
backends_function_mapping = {
196+
"vllm-omni": (async_request_chat_completions, "/v1/chat/completions"),
197+
"openai": (async_request_openai_images, "/v1/images/generations"),
198+
}

0 commit comments

Comments
 (0)