Skip to content

Commit 02d6caf

Browse files
committed
Fix bugs
Signed-off-by: bash000000 <m2588953@outlook.com>
1 parent ee3c94b commit 02d6caf

File tree

2 files changed

+25
-27
lines changed

2 files changed

+25
-27
lines changed

vllm_omni/entrypoints/openai/image_api_utils.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
import base64
1212
import io
13-
import json
14-
from typing import Any
1513

1614
import PIL.Image
1715

@@ -65,25 +63,3 @@ def encode_image_base64(image: PIL.Image.Image) -> str:
6563
image.save(buffer, format="PNG")
6664
buffer.seek(0)
6765
return base64.b64encode(buffer.read()).decode("utf-8")
68-
69-
70-
def apply_stage_default_sampling_params(
71-
default_params_json: str | None,
72-
sampling_params: Any,
73-
stage_key: str,
74-
) -> None:
75-
"""
76-
Update a stage's sampling parameters with vLLM-Omni defaults.
77-
78-
Args:
79-
default_params_json: JSON string of stage-keyed default parameters
80-
sampling_params: The sampling parameters object to update
81-
stage_key: The stage ID/key in the pipeline
82-
"""
83-
if default_params_json is not None:
84-
default_params_dict = json.loads(default_params_json)
85-
if stage_key in default_params_dict:
86-
stage_defaults = default_params_dict[stage_key]
87-
for param_name, param_value in stage_defaults.items():
88-
if hasattr(sampling_params, param_name):
89-
setattr(sampling_params, param_name, param_value)

vllm_omni/entrypoints/openai/serving_image.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import base64
4+
import json
45
import random
56
import time
67
from http import HTTPStatus
@@ -17,7 +18,6 @@
1718

1819
from vllm_omni.entrypoints.async_omni import AsyncOmni
1920
from vllm_omni.entrypoints.openai.image_api_utils import (
20-
apply_stage_default_sampling_params,
2121
encode_image_base64,
2222
parse_size,
2323
)
@@ -76,6 +76,28 @@ def for_diffusion(
7676
stage_configs=stage_configs,
7777
)
7878

79+
@staticmethod
80+
def apply_stage_default_sampling_params(
81+
default_params_json: str | None,
82+
sampling_params: Any,
83+
stage_key: str,
84+
) -> None:
85+
"""
86+
Update a stage's sampling parameters with vLLM-Omni defaults.
87+
88+
Args:
89+
default_params_json: JSON string of stage-keyed default parameters
90+
sampling_params: The sampling parameters object to update
91+
stage_key: The stage ID/key in the pipeline
92+
"""
93+
if default_params_json is not None:
94+
default_params_dict = json.loads(default_params_json)
95+
if stage_key in default_params_dict:
96+
stage_defaults = default_params_dict[stage_key]
97+
for param_name, param_value in stage_defaults.items():
98+
if hasattr(sampling_params, param_name):
99+
setattr(sampling_params, param_name, param_value)
100+
79101
async def _generate_with_async_omni(
80102
self,
81103
gen_params: OmniDiffusionSamplingParams,
@@ -260,7 +282,7 @@ async def generate_image(
260282
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
261283
detail="No diffusion stage configured for image generation.",
262284
)
263-
apply_stage_default_sampling_params(
285+
self.apply_stage_default_sampling_params(
264286
default_sample_param,
265287
gen_params,
266288
str(diffusion_stage_ids[0]),
@@ -371,7 +393,7 @@ async def edit_images(
371393
detail="No diffusion stage configured for image generation.",
372394
)
373395
diffusion_stage_id = diffusion_stage_ids[0]
374-
apply_stage_default_sampling_params(
396+
self.apply_stage_default_sampling_params(
375397
default_sample_param,
376398
gen_params,
377399
str(diffusion_stage_id),

0 commit comments

Comments
 (0)