Skip to content

Commit f5b44c4

Browse files
committed
添加 qwen edits 暂未支持
1 parent 22dc397 commit f5b44c4

File tree

4 files changed

+257
-9
lines changed

4 files changed

+257
-9
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""暂时没有使用此代码"""
2+
3+
from typing import List, Dict, Optional, Any
4+
from multiprocessing import Process
5+
from sqlmodel import SQLModel, Field, create_engine, Session, select
6+
from datetime import datetime
7+
import json
8+
from uuid import uuid4
9+
10+
11+
# 数据库模型
12+
class ProcessRecord(SQLModel, table=True):
13+
id: int | None = Field(default=None, primary_key=True, description="主键ID")
14+
pid: int | None = Field(default=None, description="进程ID")
15+
args: str = Field(default="", description="进程参数")
16+
status: str = Field(
17+
default="created", description="进程状态"
18+
) # created, started, stopped
19+
created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
20+
started_at: Optional[datetime] = Field(default=None, description="启动时间")
21+
stopped_at: Optional[datetime] = Field(default=None, description="停止时间")
22+
23+
24+
class ProcessManager:
25+
def __init__(self, write_db: bool = False, db_url: str = "sqlite:///processes.db"):
26+
"""进程管理类
27+
28+
Parameters
29+
----------
30+
write_db : bool, optional
31+
是否将进程信息写入到数据库, by default False
32+
db_url : str, optional
33+
数据库的连接 url, by default "sqlite:///processes.db"
34+
"""
35+
self.processes: List[Dict[Process, dict]] | None = []
36+
self.write_db = write_db
37+
if self.write_db:
38+
self.engine = create_engine(db_url)
39+
# 创建表
40+
SQLModel.metadata.create_all(self.engine)
41+
42+
def add_process(
43+
self,
44+
target,
45+
args=(),
46+
):
47+
p = Process(target=target, args=args)
48+
process_id = uuid4().int & ((1 << 64) - 1)
49+
self.processes.append({p: {"args": args, "process_id": process_id}})
50+
if self.write_db:
51+
# 记录到数据库
52+
with Session(self.engine) as session:
53+
54+
process_record = ProcessRecord(
55+
id=process_id,
56+
pid=None,
57+
args=json.dumps(args, ensure_ascii=False),
58+
status="created",
59+
)
60+
session.add(process_record)
61+
session.commit()
62+
session.refresh(process_record)
63+
64+
def start_all(self):
65+
for process in self.processes:
66+
for _process, process_info in process.items():
67+
_process.start()
68+
process_info["pid"] = _process.pid
69+
if self.write_db:
70+
process_id = process_info["process_id"]
71+
# 更新数据库记录
72+
with Session(self.engine) as session:
73+
# 根据PID查找记录(这里简化处理,实际可能需要更好的标识)
74+
statement = select(ProcessRecord).where(
75+
ProcessRecord.id == process_id
76+
)
77+
result = session.exec(statement)
78+
process_record = result.first()
79+
if process_record:
80+
process_record.pid = _process.pid
81+
process_record.status = "started"
82+
process_record.started_at = datetime.now()
83+
session.add(process_record)
84+
session.commit()
85+
session.refresh(process_record)
86+
87+
def join_all(self):
88+
for process in self.processes:
89+
for _process, process_info in process.items():
90+
_process.join()
91+
if self.write_db:
92+
process_id = process_info["process_id"]
93+
# 更新数据库记录为完成状态
94+
with Session(self.engine) as session:
95+
statement = select(ProcessRecord).where(
96+
ProcessRecord.id == process_id
97+
)
98+
results = session.exec(statement)
99+
record = results.first()
100+
if record:
101+
record.status = "finished"
102+
record.finished_at = datetime.now()
103+
session.add(record)
104+
session.commit()
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import asyncio
2+
3+
import io
4+
import os
5+
from typing import List
6+
import uuid
7+
from loguru import logger
8+
import shortuuid
9+
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
10+
from gpt_server.model_worker.utils import (
11+
pil_to_base64,
12+
load_base64_or_url,
13+
bytesio2image,
14+
)
15+
from gpt_server.utils import STATIC_DIR
16+
import torch
17+
from diffusers import QwenImageEditPipeline
18+
19+
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
20+
21+
22+
class QwenImageEditWorker(ModelWorkerBase):
23+
def __init__(
24+
self,
25+
controller_addr: str,
26+
worker_addr: str,
27+
worker_id: str,
28+
model_path: str,
29+
model_names: List[str],
30+
limit_worker_concurrency: int,
31+
conv_template: str = None, # type: ignore
32+
):
33+
super().__init__(
34+
controller_addr,
35+
worker_addr,
36+
worker_id,
37+
model_path,
38+
model_names,
39+
limit_worker_concurrency,
40+
conv_template,
41+
model_type="image",
42+
)
43+
backend = os.environ["backend"]
44+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
45+
self.pipe = QwenImageEditPipeline.from_pretrained(model_path)
46+
self.pipe.to(torch.bfloat16)
47+
self.pipe.to(self.device)
48+
self.pipe.set_progress_bar_config(disable=None)
49+
logger.warning(f"模型:{model_names[0]}")
50+
51+
async def get_image_output(self, params):
52+
prompt = params["prompt"]
53+
response_format = params.get("response_format", "b64_json")
54+
bytes_io = await load_base64_or_url(params["image"])
55+
image = bytesio2image(bytes_io)
56+
inputs = {
57+
"image": image,
58+
"prompt": prompt,
59+
"negative_prompt": None,
60+
"generator": torch.manual_seed(0),
61+
"true_cfg_scale": 4.0,
62+
"negative_prompt": " ",
63+
"num_inference_steps": 50,
64+
}
65+
with torch.inference_mode():
66+
output = self.pipe(**inputs)
67+
image = output.images[0]
68+
69+
result = {}
70+
if response_format == "b64_json":
71+
# Convert PIL image to base64
72+
base64 = pil_to_base64(pil_img=image)
73+
result = {
74+
"created": shortuuid.random(),
75+
"data": [{"b64_json": base64}],
76+
"usage": {
77+
"total_tokens": 0,
78+
"input_tokens": 0,
79+
"output_tokens": 0,
80+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
81+
},
82+
}
83+
return result
84+
elif response_format == "url":
85+
# 生成唯一文件名(避免冲突)
86+
file_name = str(uuid.uuid4()) + ".png"
87+
save_path = STATIC_DIR / file_name
88+
image.save(save_path, format="PNG")
89+
WORKER_PORT = os.environ["WORKER_PORT"]
90+
WORKER_HOST = os.environ["WORKER_HOST"]
91+
url = f"http://{WORKER_HOST}:{WORKER_PORT}/static/{file_name}"
92+
result = {
93+
"created": shortuuid.random(),
94+
"data": [{"url": url}],
95+
"usage": {
96+
"total_tokens": 0,
97+
"input_tokens": 0,
98+
"output_tokens": 0,
99+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
100+
},
101+
}
102+
return result
103+
104+
105+
if __name__ == "__main__":
106+
QwenImageEditWorker.run()

gpt_server/model_worker/utils.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import base64
55
import io
66
import os
7-
from PIL.Image import Image
7+
from PIL import Image
88
import re
99

1010

@@ -14,38 +14,47 @@ def is_base64_image(data_string):
1414

1515

1616
# 转换为Base64
17-
def pil_to_base64(pil_img: Image, format: str = "PNG"):
17+
def pil_to_base64(pil_img: Image.Image, format: str = "PNG"):
1818
buffered = io.BytesIO()
1919
pil_img.save(buffered, format=format) # 明确指定PNG格式
2020
return base64.b64encode(buffered.getvalue()).decode("utf-8")
2121

2222

23-
def extract_base64(data_url: str):
23+
def _extract_base64(data_url: str):
2424
"""从Data URL中提取纯Base64数据"""
2525
return data_url.split(",", 1)[-1] # 从第一个逗号后分割
2626

2727

28-
async def get_bytes_from_url(url: str) -> bytes:
28+
async def _get_bytes_from_url(url: str) -> bytes:
2929
async with httpx.AsyncClient() as client:
3030
response = await client.get(url)
3131
if response.status_code != 200:
3232
raise HTTPException(status_code=400, detail="无法从指定 URL 下载数据")
3333
return response.content
3434

3535

36-
async def load_base64_or_url(base64_or_url):
36+
def bytesio2image(bytes_io: io.BytesIO) -> Image.Image:
37+
return Image.open(bytes_io)
38+
39+
40+
def bytes2image(bytes_: bytes) -> Image.Image:
41+
bytes_io = io.BytesIO(bytes_)
42+
return Image.open(bytes_io)
43+
44+
45+
async def load_base64_or_url(base64_or_url) -> io.BytesIO:
3746
# 根据 reference_audio 内容判断读取方式
3847
if base64_or_url.startswith("http://") or base64_or_url.startswith("https://"):
39-
audio_bytes = await get_bytes_from_url(base64_or_url)
48+
audio_bytes = await _get_bytes_from_url(base64_or_url)
4049
else:
4150
try:
4251
if "data:" in base64_or_url:
43-
base64_or_url = extract_base64(data_url=base64_or_url)
52+
base64_or_url = _extract_base64(data_url=base64_or_url)
4453
audio_bytes = base64.b64decode(base64_or_url)
4554
except Exception as e:
4655
logger.warning("无效的 base64 数据: " + str(e))
4756
raise HTTPException(status_code=400, detail="无效的 base64 数据: " + str(e))
48-
# 利用 BytesIO 包装字节数据,然后使用 soundfile 读取为 numpy 数组
57+
# 利用 BytesIO 包装字节数据
4958
try:
5059
bytes_io = io.BytesIO(audio_bytes)
5160
except Exception as e:

gpt_server/serving/openai_api_server.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,9 +714,37 @@ async def generate_completion(payload: Dict[str, Any], worker_addr: str):
714714
SpeechRequest,
715715
OpenAISpeechRequest,
716716
ImagesGenRequest,
717+
ImagesEditsRequest,
717718
)
718719

719720

721+
async def get_images_edits(payload: Dict[str, Any]):
722+
model_name = payload["model"]
723+
worker_addr = get_worker_address(model_name)
724+
725+
transcription = await fetch_remote(
726+
worker_addr + "/worker_get_image_output", payload
727+
)
728+
return json.loads(transcription)
729+
730+
731+
@app.post("/v1/images/edits", dependencies=[Depends(check_api_key)])
732+
async def images_edits(request: ImagesEditsRequest):
733+
"""图片编辑"""
734+
error_check_ret = check_model(request)
735+
if error_check_ret is not None:
736+
return error_check_ret
737+
payload = {
738+
"image": request.image,
739+
"model": request.model,
740+
"prompt": request.prompt,
741+
"output_format": request.output_format,
742+
"response_format": request.response_format,
743+
}
744+
result = await get_images_edits(payload=payload)
745+
return result
746+
747+
720748
async def get_images_gen(payload: Dict[str, Any]):
721749
model_name = payload["model"]
722750
worker_addr = get_worker_address(model_name)
@@ -728,7 +756,8 @@ async def get_images_gen(payload: Dict[str, Any]):
728756

729757

730758
@app.post("/v1/images/generations", dependencies=[Depends(check_api_key)])
731-
async def speech(request: ImagesGenRequest):
759+
async def images_generations(request: ImagesGenRequest):
760+
"""文生图"""
732761
error_check_ret = check_model(request)
733762
if error_check_ret is not None:
734763
return error_check_ret

0 commit comments

Comments
 (0)