Skip to content

Commit 2e7ea27

Browse files
committed
Create OmniOpenAIServeImage Class To manage edit_images and generate_image
1 parent a42b748 commit 2e7ea27

File tree

8 files changed

+711
-625
lines changed

8 files changed

+711
-625
lines changed

vllm_omni/entrypoints/openai/api_server.py

Lines changed: 57 additions & 614 deletions
Large diffs are not rendered by default.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from vllm.entrypoints.openai.engine.protocol import (
2+
ModelCard,
3+
ModelList,
4+
ModelPermission,
5+
)
6+
from vllm.entrypoints.openai.models.protocol import BaseModelPath
7+
8+
9+
class DiffusionServingModels:
10+
"""Minimal OpenAIServingModels implementation for diffusion-only servers.
11+
12+
vLLM's /v1/models route expects `app.state.openai_serving_models` to expose
13+
`show_available_models()`. In pure diffusion mode we don't initialize the
14+
full OpenAIServingModels (it depends on LLM-specific processors), so we
15+
provide a lightweight fallback.
16+
"""
17+
18+
def __init__(self, base_model_paths: list[BaseModelPath]) -> None:
19+
self._base_model_paths = base_model_paths
20+
21+
async def show_available_models(self) -> ModelList:
22+
return ModelList(
23+
data=[
24+
ModelCard(
25+
id=base_model.name,
26+
root=base_model.model_path,
27+
permission=[ModelPermission()],
28+
)
29+
for base_model in self._base_model_paths
30+
]
31+
)

vllm_omni/entrypoints/openai/image_api_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
import base64
1212
import io
13+
import json
14+
from typing import Any
1315

1416
import PIL.Image
1517

@@ -63,3 +65,25 @@ def encode_image_base64(image: PIL.Image.Image) -> str:
6365
image.save(buffer, format="PNG")
6466
buffer.seek(0)
6567
return base64.b64encode(buffer.read()).decode("utf-8")
68+
69+
70+
def apply_stage_default_sampling_params(
71+
default_params_json: str | None,
72+
sampling_params: Any,
73+
stage_key: str,
74+
) -> None:
75+
"""
76+
Update a stage's sampling parameters with vLLM-Omni defaults.
77+
78+
Args:
79+
default_params_json: JSON string of stage-keyed default parameters
80+
sampling_params: The sampling parameters object to update
81+
stage_key: The stage ID/key in the pipeline
82+
"""
83+
if default_params_json is not None:
84+
default_params_dict = json.loads(default_params_json)
85+
if stage_key in default_params_dict:
86+
stage_defaults = default_params_dict[stage_key]
87+
for param_name, param_value in stage_defaults.items():
88+
if hasattr(sampling_params, param_name):
89+
setattr(sampling_params, param_name, param_value)

vllm_omni/entrypoints/openai/protocol/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from vllm_omni.entrypoints.openai.protocol.chat_completion import OmniChatCompletionStreamResponse
55
from vllm_omni.entrypoints.openai.protocol.images import (
66
ImageData,
7+
ImageEditRequest,
8+
ImageEditResponse,
79
ImageGenerationRequest,
810
ImageGenerationResponse,
911
ResponseFormat,
@@ -19,6 +21,8 @@
1921
"ImageData",
2022
"ImageGenerationRequest",
2123
"ImageGenerationResponse",
24+
"ImageEditRequest",
25+
"ImageEditResponse",
2226
"ResponseFormat",
2327
"VideoData",
2428
"VideoGenerationRequest",

vllm_omni/entrypoints/openai/protocol/images.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
for text-to-image generation, with vllm-omni specific extensions.
88
"""
99

10+
import json
1011
from enum import Enum
1112
from typing import Any
1213

@@ -127,5 +128,84 @@ class ImageGenerationResponse(BaseModel):
127128

128129
created: int = Field(..., description="Unix timestamp of when the generation completed")
129130
data: list[ImageData] = Field(..., description="Array of generated images")
130-
output_format: str = Field(None, description="The output format of the image generation")
131-
size: str = Field(None, description="The size of the image generated")
131+
132+
133+
class ImageEditResponse(BaseModel):
134+
"""
135+
OpenAI DALL-E compatible image generation response.
136+
137+
Returns generated images with metadata.
138+
"""
139+
140+
created: int = Field(..., description="Unix timestamp of when the generation completed")
141+
data: list[ImageData] = Field(..., description="Array of generated images")
142+
output_format: str = Field(..., description="The output format of the image generation")
143+
size: str = Field(..., description="The size of the image generated")
144+
145+
146+
class ImageEditRequest(BaseModel):
147+
prompt: str = Field(..., description="Text description of the desired image edit")
148+
model: str | None = Field(
149+
default=None,
150+
description="Model to use (optional, uses server's configured model if omitted)",
151+
)
152+
n: int = Field(default=1, ge=1, le=10, description="Number of images to generate")
153+
size: str | None = Field(
154+
default=None,
155+
description="Image dimensions in WIDTHxHEIGHT format (e.g., '1024x1024', uses model defaults if omitted)",
156+
)
157+
response_format: ResponseFormat = Field(default=ResponseFormat.B64_JSON, description="Format of the returned image")
158+
user: str | None = Field(default=None, description="User identifier for tracking")
159+
160+
# vllm-omni extensions for diffusion control
161+
negative_prompt: str | None = Field(default=None, description="Text describing what to avoid in the image")
162+
num_inference_steps: int | None = Field(
163+
default=None,
164+
ge=1,
165+
le=200,
166+
description="Number of diffusion sampling steps (uses model defaults if not specified)",
167+
)
168+
guidance_scale: float | None = Field(
169+
default=None,
170+
ge=0.0,
171+
le=20.0,
172+
description="Classifier-free guidance scale (uses model defaults if not specified)",
173+
)
174+
true_cfg_scale: float | None = Field(
175+
default=None,
176+
ge=0.0,
177+
le=20.0,
178+
description="True CFG scale (model-specific parameter, may be ignored if not supported)",
179+
)
180+
seed: int | None = Field(default=None, description="Random seed for reproducibility")
181+
generator_device: str | None = Field(
182+
default=None,
183+
description="Device for the seeded torch.Generator (e.g. 'cpu', 'cuda'). Defaults to the runner's device.",
184+
)
185+
lora: dict[str, Any] | None = Field(
186+
default=None,
187+
description=(
188+
"Optional LoRA adapter for this request. Expected shape: "
189+
"{name/path/scale/int_id}. Field names are flexible "
190+
"(e.g. name|lora_name|adapter, path|lora_path|local_path, "
191+
"scale|lora_scale, int_id|lora_int_id)."
192+
),
193+
)
194+
195+
@field_validator("lora")
196+
@classmethod
197+
def validate_lora(cls, v):
198+
"""Validate LoRA field - must be a dict if provided."""
199+
if isinstance(v, str):
200+
try:
201+
v_dict = json.loads(v)
202+
if isinstance(v_dict, dict):
203+
return v_dict
204+
else:
205+
raise ValueError("LoRA field must be a JSON object (dict)")
206+
except json.JSONDecodeError:
207+
raise ValueError("LoRA field must be a valid JSON string representing a dict")
208+
elif isinstance(v, dict) or v is None:
209+
return v
210+
else:
211+
raise ValueError("LoRA field must be either a dict or a JSON string representing a dict")

0 commit comments

Comments
 (0)