|
7 | 7 | for text-to-image generation, with vllm-omni specific extensions. |
8 | 8 | """ |
9 | 9 |
|
| 10 | +import json |
10 | 11 | from enum import Enum |
11 | 12 | from typing import Any |
12 | 13 |
|
@@ -127,5 +128,84 @@ class ImageGenerationResponse(BaseModel): |
127 | 128 |
|
128 | 129 | created: int = Field(..., description="Unix timestamp of when the generation completed") |
129 | 130 | data: list[ImageData] = Field(..., description="Array of generated images") |
130 | | - output_format: str = Field(None, description="The output format of the image generation") |
131 | | - size: str = Field(None, description="The size of the image generated") |
| 131 | + |
| 132 | + |
| 133 | +class ImageEditResponse(BaseModel): |
| 134 | + """ |
| 135 | + OpenAI DALL-E compatible image generation response. |
| 136 | +
|
| 137 | + Returns generated images with metadata. |
| 138 | + """ |
| 139 | + |
| 140 | + created: int = Field(..., description="Unix timestamp of when the generation completed") |
| 141 | + data: list[ImageData] = Field(..., description="Array of generated images") |
| 142 | + output_format: str = Field(..., description="The output format of the image generation") |
| 143 | + size: str = Field(..., description="The size of the image generated") |
| 144 | + |
| 145 | + |
| 146 | +class ImageEditRequest(BaseModel): |
| 147 | + prompt: str = Field(..., description="Text description of the desired image edit") |
| 148 | + model: str | None = Field( |
| 149 | + default=None, |
| 150 | + description="Model to use (optional, uses server's configured model if omitted)", |
| 151 | + ) |
| 152 | + n: int = Field(default=1, ge=1, le=10, description="Number of images to generate") |
| 153 | + size: str | None = Field( |
| 154 | + default=None, |
| 155 | + description="Image dimensions in WIDTHxHEIGHT format (e.g., '1024x1024', uses model defaults if omitted)", |
| 156 | + ) |
| 157 | + response_format: ResponseFormat = Field(default=ResponseFormat.B64_JSON, description="Format of the returned image") |
| 158 | + user: str | None = Field(default=None, description="User identifier for tracking") |
| 159 | + |
| 160 | + # vllm-omni extensions for diffusion control |
| 161 | + negative_prompt: str | None = Field(default=None, description="Text describing what to avoid in the image") |
| 162 | + num_inference_steps: int | None = Field( |
| 163 | + default=None, |
| 164 | + ge=1, |
| 165 | + le=200, |
| 166 | + description="Number of diffusion sampling steps (uses model defaults if not specified)", |
| 167 | + ) |
| 168 | + guidance_scale: float | None = Field( |
| 169 | + default=None, |
| 170 | + ge=0.0, |
| 171 | + le=20.0, |
| 172 | + description="Classifier-free guidance scale (uses model defaults if not specified)", |
| 173 | + ) |
| 174 | + true_cfg_scale: float | None = Field( |
| 175 | + default=None, |
| 176 | + ge=0.0, |
| 177 | + le=20.0, |
| 178 | + description="True CFG scale (model-specific parameter, may be ignored if not supported)", |
| 179 | + ) |
| 180 | + seed: int | None = Field(default=None, description="Random seed for reproducibility") |
| 181 | + generator_device: str | None = Field( |
| 182 | + default=None, |
| 183 | + description="Device for the seeded torch.Generator (e.g. 'cpu', 'cuda'). Defaults to the runner's device.", |
| 184 | + ) |
| 185 | + lora: dict[str, Any] | None = Field( |
| 186 | + default=None, |
| 187 | + description=( |
| 188 | + "Optional LoRA adapter for this request. Expected shape: " |
| 189 | + "{name/path/scale/int_id}. Field names are flexible " |
| 190 | + "(e.g. name|lora_name|adapter, path|lora_path|local_path, " |
| 191 | + "scale|lora_scale, int_id|lora_int_id)." |
| 192 | + ), |
| 193 | + ) |
| 194 | + |
| 195 | + @field_validator("lora") |
| 196 | + @classmethod |
| 197 | + def validate_lora(cls, v): |
| 198 | + """Validate LoRA field - must be a dict if provided.""" |
| 199 | + if isinstance(v, str): |
| 200 | + try: |
| 201 | + v_dict = json.loads(v) |
| 202 | + if isinstance(v_dict, dict): |
| 203 | + return v_dict |
| 204 | + else: |
| 205 | + raise ValueError("LoRA field must be a JSON object (dict)") |
| 206 | + except json.JSONDecodeError: |
| 207 | + raise ValueError("LoRA field must be a valid JSON string representing a dict") |
| 208 | + elif isinstance(v, dict) or v is None: |
| 209 | + return v |
| 210 | + else: |
| 211 | + raise ValueError("LoRA field must be either a dict or a JSON string representing a dict") |
0 commit comments