Skip to content

Commit a8dfb1f

Browse files
committed
add diffusers 版本升级为0.5.0
1 parent 337fa4e commit a8dfb1f

File tree

9 files changed

+192
-31
lines changed

9 files changed

+192
-31
lines changed

gpt_server/model_worker/base/base_model_worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,6 @@ def transcription(self, params):
189189

190190
def generate_voice_stream(self, params):
191191
raise NotImplementedError
192+
193+
def get_image_output(self, params):
194+
raise NotImplementedError

gpt_server/model_worker/base/model_worker_base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def __init__(
5555
multimodal: bool = False,
5656
):
5757
is_vision = False
58-
if model_type != "asr" and model_type != "tts":
58+
if model_type in ["image"]:
59+
pass
60+
elif model_type not in ["asr", "tts"]:
5961
try:
6062
self.model_config = AutoConfig.from_pretrained(
6163
model_path, trust_remote_code=True
@@ -406,6 +408,16 @@ async def api_get_embeddings(request: Request):
406408
return JSONResponse(content=embedding)
407409

408410

411+
@app.post("/worker_get_image_output")
412+
async def api_get_embeddings(request: Request):
413+
params = await request.json()
414+
await acquire_worker_semaphore()
415+
logger.debug(f"params {params}")
416+
embedding = await worker.get_image_output(params)
417+
release_worker_semaphore()
418+
return JSONResponse(content=embedding)
419+
420+
409421
@app.post("/worker_get_classify")
410422
async def api_get_classify(request: Request):
411423
params = await request.json()

gpt_server/model_worker/flux.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import asyncio
2+
3+
import io
4+
import os
5+
from typing import List
6+
from loguru import logger
7+
import shortuuid
8+
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
9+
from gpt_server.model_worker.utils import pil_to_base64
10+
import torch
11+
from diffusers import FluxPipeline
12+
13+
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
14+
15+
16+
class FluxWorker(ModelWorkerBase):
17+
def __init__(
18+
self,
19+
controller_addr: str,
20+
worker_addr: str,
21+
worker_id: str,
22+
model_path: str,
23+
model_names: List[str],
24+
limit_worker_concurrency: int,
25+
conv_template: str = None, # type: ignore
26+
):
27+
super().__init__(
28+
controller_addr,
29+
worker_addr,
30+
worker_id,
31+
model_path,
32+
model_names,
33+
limit_worker_concurrency,
34+
conv_template,
35+
model_type="image",
36+
)
37+
backend = os.environ["backend"]
38+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
39+
self.pipe = FluxPipeline.from_pretrained(
40+
model_path, torch_dtype=torch.bfloat16
41+
).to(self.device)
42+
43+
logger.warning(f"模型:{model_names[0]}")
44+
45+
async def get_image_output(self, params):
46+
prompt = params["prompt"]
47+
image = self.pipe(
48+
prompt,
49+
height=1024,
50+
width=1024,
51+
guidance_scale=3.5,
52+
num_inference_steps=50,
53+
max_sequence_length=512,
54+
generator=torch.Generator(self.device).manual_seed(0),
55+
).images[0]
56+
base64 = pil_to_base64(pil_img=image)
57+
result = {
58+
"created": shortuuid.random(),
59+
"data": [{"b64_json": base64}],
60+
"usage": {
61+
"total_tokens": 0,
62+
"input_tokens": 0,
63+
"output_tokens": 0,
64+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
65+
},
66+
}
67+
return result
68+
69+
70+
if __name__ == "__main__":
71+
FluxWorker.run()

gpt_server/model_worker/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44
import base64
55
import io
66

7+
from PIL.Image import Image
8+
9+
10+
# 转换为Base64
11+
def pil_to_base64(pil_img: Image, format: str = "PNG"):
12+
buffered = io.BytesIO()
13+
pil_img.save(buffered, format=format) # 明确指定PNG格式
14+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
15+
716

817
def extract_base64(data_url: str):
918
"""从Data URL中提取纯Base64数据"""

gpt_server/openai_api_protocol/custom_api_protocol.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
from pydantic import Field, BaseModel
1515

1616

17+
class ImagesGenRequest(BaseModel):
18+
prompt: str
19+
model: str
20+
output_format: str # png, jpeg, or webp
21+
22+
1723
# copy from https://github.com/remsky/Kokoro-FastAPI/blob/master/api/src/routers/openai_compatible.py
1824
class OpenAISpeechRequest(BaseModel):
1925
model: str = Field(

gpt_server/serving/openai_api_server.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,35 @@ async def generate_completion(payload: Dict[str, Any], worker_addr: str):
713713
ModerationsRequest,
714714
SpeechRequest,
715715
OpenAISpeechRequest,
716+
ImagesGenRequest,
716717
)
718+
719+
720+
async def get_images_gen(payload: Dict[str, Any]):
721+
model_name = payload["model"]
722+
worker_addr = get_worker_address(model_name)
723+
724+
transcription = await fetch_remote(
725+
worker_addr + "/worker_get_image_output", payload
726+
)
727+
return json.loads(transcription)
728+
729+
730+
@app.post("/v1/images/generations", dependencies=[Depends(check_api_key)])
731+
async def speech(request: ImagesGenRequest):
732+
error_check_ret = check_model(request)
733+
if error_check_ret is not None:
734+
return error_check_ret
735+
736+
payload = {
737+
"model": request.model,
738+
"prompt": request.prompt,
739+
"output_format": request.output_format,
740+
}
741+
result = await get_images_gen(payload=payload)
742+
return result
743+
744+
717745
import edge_tts
718746
import uuid
719747

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "gpt_server"
3-
version = "0.4.7"
3+
version = "0.5.0"
44
description = "gpt_server是一个用于生产级部署LLMs或Embedding的开源框架。"
55
readme = "README.md"
66
license = { text = "Apache 2.0" }
@@ -28,6 +28,7 @@ dependencies = [
2828
"sglang[all]>=0.4.6.post5",
2929
"flashinfer-python",
3030
"flashtts>=0.1.7",
31+
"diffusers>=0.33.1",
3132
]
3233

3334
[tool.uv]

0 commit comments

Comments
 (0)