Skip to content

Commit 4296e37

Browse files
committed
优化图片编辑
1 parent b926572 commit 4296e37

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

gpt_server/model_worker/qwen_image_edit.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22

3-
import io
43
import os
54
from typing import List
65
import uuid
@@ -14,7 +13,7 @@
1413
)
1514
from gpt_server.utils import STATIC_DIR
1615
import torch
17-
from diffusers import QwenImageEditPipeline
16+
from diffusers import QwenImageEditPlusPipeline
1817

1918
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
2019

@@ -40,9 +39,8 @@ def __init__(
4039
conv_template,
4140
model_type="image",
4241
)
43-
backend = os.environ["backend"]
4442
self.device = "cuda" if torch.cuda.is_available() else "cpu"
45-
self.pipe = QwenImageEditPipeline.from_pretrained(model_path)
43+
self.pipe = QwenImageEditPlusPipeline.from_pretrained(model_path)
4644
self.pipe.to(torch.bfloat16)
4745
self.pipe.to(self.device)
4846
self.pipe.set_progress_bar_config(disable=None)
@@ -51,16 +49,18 @@ def __init__(
5149
async def get_image_output(self, params):
5250
prompt = params["prompt"]
5351
response_format = params.get("response_format", "b64_json")
54-
bytes_io = await load_base64_or_url(params["image"])
55-
image = bytesio2image(bytes_io)
52+
image: list = params["image"]
53+
image = [bytesio2image(await load_base64_or_url(img)) for img in image]
54+
# bytes_io = await load_base64_or_url(params["image"])
55+
# image = bytesio2image(bytes_io)
5656
inputs = {
5757
"image": image,
5858
"prompt": prompt,
5959
"negative_prompt": None,
6060
"generator": torch.manual_seed(0),
6161
"true_cfg_scale": 4.0,
6262
"negative_prompt": " ",
63-
"num_inference_steps": 50,
63+
"num_inference_steps": 40,
6464
}
6565
with torch.inference_mode():
6666
output = await asyncio.to_thread(self.pipe, **inputs)

gpt_server/serving/openai_api_server.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,9 @@ async def get_images_edits(payload: Dict[str, Any]):
12291229
@app.post("/v1/images/edits", dependencies=[Depends(check_api_key)])
12301230
async def images_edits(
12311231
model: str = Form(...),
1232-
image: UploadFile = File(media_type="application/octet-stream"),
1232+
image: Union[UploadFile, List[UploadFile]] = File(
1233+
..., media_type="application/octet-stream"
1234+
),
12331235
prompt: Optional[Union[str, List[str]]] = Form(None),
12341236
# negative_prompt: Optional[Union[str, List[str]]] = Form(None),
12351237
response_format: Optional[str] = Form("url"),
@@ -1240,10 +1242,14 @@ async def images_edits(
12401242
error_check_ret = check_model(model)
12411243
if error_check_ret is not None:
12421244
return error_check_ret
1245+
images = None
1246+
if not isinstance(image, list): # 单
1247+
images = [image]
1248+
else:
1249+
images = image
1250+
image = [base64.b64encode(await img.read()).decode("utf-8") for img in images]
12431251
payload = {
1244-
"image": base64.b64encode(await image.read()).decode(
1245-
"utf-8"
1246-
), # bytes → Base64 字符串,
1252+
"image": image, # bytes → Base64 字符串,
12471253
"model": model,
12481254
"prompt": prompt,
12491255
"output_format": output_format,

0 commit comments

Comments
 (0)