Skip to content

Commit 7378a97

Browse files
committed
增强图片生成,支持 base64 和 url
1 parent 5fcbaa9 commit 7378a97

File tree

4 files changed

+52
-18
lines changed

4 files changed

+52
-18
lines changed

gpt_server/model_worker/flux.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import io
44
import os
55
from typing import List
6+
import uuid
67
from loguru import logger
78
import shortuuid
89
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
910
from gpt_server.model_worker.utils import pil_to_base64
1011
import torch
1112
from diffusers import FluxPipeline
13+
from gpt_server.utils import STATIC_DIR
1214

1315
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
1416

@@ -44,6 +46,7 @@ def __init__(
4446

4547
async def get_image_output(self, params):
4648
prompt = params["prompt"]
49+
response_format = params.get("response_format", "b64_json")
4750
image = self.pipe(
4851
prompt,
4952
height=1024,
@@ -53,17 +56,39 @@ async def get_image_output(self, params):
5356
max_sequence_length=512,
5457
generator=torch.Generator(self.device).manual_seed(0),
5558
).images[0]
56-
base64 = pil_to_base64(pil_img=image)
57-
result = {
58-
"created": shortuuid.random(),
59-
"data": [{"b64_json": base64}],
60-
"usage": {
61-
"total_tokens": 0,
62-
"input_tokens": 0,
63-
"output_tokens": 0,
64-
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
65-
},
66-
}
59+
result = {}
60+
if response_format == "b64_json":
61+
# Convert PIL image to base64
62+
base64 = pil_to_base64(pil_img=image)
63+
result = {
64+
"created": shortuuid.random(),
65+
"data": [{"b64_json": base64}],
66+
"usage": {
67+
"total_tokens": 0,
68+
"input_tokens": 0,
69+
"output_tokens": 0,
70+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
71+
},
72+
}
73+
return result
74+
elif response_format == "url":
75+
# 生成唯一文件名(避免冲突)
76+
file_name = str(uuid.uuid4()) + ".png"
77+
save_path = STATIC_DIR / file_name
78+
image.save(save_path, format="PNG")
79+
WORKER_PORT = os.environ["WORKER_PORT"]
80+
WORKER_HOST = os.environ["WORKER_HOST"]
81+
url = f"http://{WORKER_HOST}:{WORKER_PORT}/static/{file_name}"
82+
result = {
83+
"created": shortuuid.random(),
84+
"data": [{"url": url}],
85+
"usage": {
86+
"total_tokens": 0,
87+
"input_tokens": 0,
88+
"output_tokens": 0,
89+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
90+
},
91+
}
6792
return result
6893

6994

gpt_server/openai_api_protocol/custom_api_protocol.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@ class ImagesGenRequest(BaseModel):
2121
default="png",
2222
description="png, jpeg, or webp",
2323
)
24-
model_type: Literal["t2v", "t2i"] = Field(
25-
default="t2i",
26-
description="t2v: 文生视频 t2i: 文生图",
24+
# model_type: Literal["t2v", "t2i"] = Field(
25+
# default="t2i",
26+
# description="t2v: 文生视频 t2i: 文生图",
27+
# )
28+
response_format: Literal["url", "b64_json"] = Field(
29+
default="url",
30+
description="生成图像时返回的格式。必须为“ur”或“b64_json”之一。URL仅在图像生成后60分钟内有效。",
2731
)
2832

2933

gpt_server/serving/openai_api_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ async def speech(request: ImagesGenRequest):
736736
"model": request.model,
737737
"prompt": request.prompt,
738738
"output_format": request.output_format,
739-
"model_type": request.model_type,
739+
"response_format": request.response_format,
740740
}
741741
result = await get_images_gen(payload=payload)
742742
return result

tests/test_image_gen.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22
from openai import OpenAI
33

44
client = OpenAI(api_key="EMPTY", base_url="http://localhost:8082/v1")
5-
6-
img = client.images.generate(model="flux", prompt="A red pig")
7-
5+
# 两种响应方式
6+
## response_format = "url" 默认为 url
7+
img = client.images.generate(model="flux", prompt="A red pig", response_format="url")
8+
print(img.data[0])
9+
## response_format = "b64_json"
10+
img = client.images.generate(
11+
model="flux", prompt="A red pig", response_format="b64_json"
12+
)
813
image_bytes = base64.b64decode(img.data[0].b64_json)
914
with open("output.png", "wb") as f:
1015
f.write(image_bytes)

0 commit comments

Comments
 (0)