|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import base64 |
| 4 | +import json |
4 | 5 | import random |
5 | 6 | import time |
6 | 7 | from http import HTTPStatus |
|
17 | 18 |
|
18 | 19 | from vllm_omni.entrypoints.async_omni import AsyncOmni |
19 | 20 | from vllm_omni.entrypoints.openai.image_api_utils import ( |
20 | | - apply_stage_default_sampling_params, |
21 | 21 | encode_image_base64, |
22 | 22 | parse_size, |
23 | 23 | ) |
@@ -76,6 +76,28 @@ def for_diffusion( |
76 | 76 | stage_configs=stage_configs, |
77 | 77 | ) |
78 | 78 |
|
| 79 | + @staticmethod |
| 80 | + def apply_stage_default_sampling_params( |
| 81 | + default_params_json: str | None, |
| 82 | + sampling_params: Any, |
| 83 | + stage_key: str, |
| 84 | + ) -> None: |
| 85 | + """ |
| 86 | + Update a stage's sampling parameters with vLLM-Omni defaults. |
| 87 | +
|
| 88 | + Args: |
| 89 | + default_params_json: JSON string of stage-keyed default parameters |
| 90 | + sampling_params: The sampling parameters object to update |
| 91 | + stage_key: The stage ID/key in the pipeline |
| 92 | + """ |
| 93 | + if default_params_json is not None: |
| 94 | + default_params_dict = json.loads(default_params_json) |
| 95 | + if stage_key in default_params_dict: |
| 96 | + stage_defaults = default_params_dict[stage_key] |
| 97 | + for param_name, param_value in stage_defaults.items(): |
| 98 | + if hasattr(sampling_params, param_name): |
| 99 | + setattr(sampling_params, param_name, param_value) |
| 100 | + |
79 | 101 | async def _generate_with_async_omni( |
80 | 102 | self, |
81 | 103 | gen_params: OmniDiffusionSamplingParams, |
@@ -260,7 +282,7 @@ async def generate_image( |
260 | 282 | status_code=HTTPStatus.INTERNAL_SERVER_ERROR, |
261 | 283 | detail="No diffusion stage configured for image generation.", |
262 | 284 | ) |
263 | | - apply_stage_default_sampling_params( |
| 285 | + self.apply_stage_default_sampling_params( |
264 | 286 | default_sample_param, |
265 | 287 | gen_params, |
266 | 288 | str(diffusion_stage_ids[0]), |
@@ -371,7 +393,7 @@ async def edit_images( |
371 | 393 | detail="No diffusion stage configured for image generation.", |
372 | 394 | ) |
373 | 395 | diffusion_stage_id = diffusion_stage_ids[0] |
374 | | - apply_stage_default_sampling_params( |
| 396 | + self.apply_stage_default_sampling_params( |
375 | 397 | default_sample_param, |
376 | 398 | gen_params, |
377 | 399 | str(diffusion_stage_id), |
|
0 commit comments