Skip to content

Commit 5fcbaa9

Browse files
committed
添加部分 文生视频代码(未测试)
1 parent 396de4a commit 5fcbaa9

File tree

5 files changed

+115
-7
lines changed

5 files changed

+115
-7
lines changed

gpt_server/model_worker/base/model_worker_base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from abc import ABC, abstractmethod
66
from fastapi import BackgroundTasks, Request, FastAPI
77
from fastapi.responses import JSONResponse, StreamingResponse
8+
from fastapi.staticfiles import StaticFiles
89
from fastchat.utils import SEQUENCE_LENGTH_KEYS
910
from loguru import logger
1011
import os
@@ -16,11 +17,12 @@
1617
AutoConfig,
1718
)
1819
import uuid
19-
from gpt_server.utils import get_free_tcp_port
20+
from gpt_server.utils import get_free_tcp_port, STATIC_DIR, local_ip
2021
from gpt_server.model_worker.base.base_model_worker import BaseModelWorker
2122

2223
worker = None
2324
app = FastAPI()
25+
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
2426

2527

2628
def get_context_length_(config):
@@ -263,6 +265,8 @@ def run(cls):
263265
controller_address = args.controller_address
264266

265267
port = get_free_tcp_port()
268+
os.environ["WORKER_PORT"] = str(port)
269+
os.environ["WORKER_HOST"] = str(local_ip)
266270
worker_addr = f"http://{host}:{port}"
267271

268272
@app.on_event("startup")
@@ -409,9 +413,9 @@ async def api_get_embeddings(request: Request):
409413
params = await request.json()
410414
await acquire_worker_semaphore()
411415
logger.debug(f"params {params}")
412-
embedding = await worker.get_image_output(params)
416+
result = await worker.get_image_output(params)
413417
release_worker_semaphore()
414-
return JSONResponse(content=embedding)
418+
return JSONResponse(content=result)
415419

416420

417421
@app.post("/worker_get_classify")

gpt_server/model_worker/wan.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import asyncio
2+
3+
import io
4+
import os
5+
from typing import List
6+
import uuid
7+
from loguru import logger
8+
import shortuuid
9+
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
10+
from gpt_server.model_worker.utils import pil_to_base64
11+
from gpt_server.utils import STATIC_DIR
12+
import torch
13+
from diffusers import AutoencoderKLWan, WanPipeline
14+
from diffusers.utils import export_to_video
15+
16+
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
17+
18+
19+
class WanWorker(ModelWorkerBase):
20+
def __init__(
21+
self,
22+
controller_addr: str,
23+
worker_addr: str,
24+
worker_id: str,
25+
model_path: str,
26+
model_names: List[str],
27+
limit_worker_concurrency: int,
28+
conv_template: str = None, # type: ignore
29+
):
30+
super().__init__(
31+
controller_addr,
32+
worker_addr,
33+
worker_id,
34+
model_path,
35+
model_names,
36+
limit_worker_concurrency,
37+
conv_template,
38+
model_type="image",
39+
)
40+
backend = os.environ["backend"]
41+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
42+
vae = AutoencoderKLWan.from_pretrained(
43+
model_path, subfolder="vae", torch_dtype=torch.float32
44+
)
45+
self.pipe = WanPipeline.from_pretrained(
46+
model_path, vae=vae, torch_dtype=torch.bfloat16
47+
).to(self.device)
48+
logger.warning(f"模型:{model_names[0]}")
49+
50+
async def get_image_output(self, params):
51+
prompt = params["prompt"]
52+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
53+
output = self.pipe(
54+
prompt=prompt,
55+
negative_prompt=negative_prompt,
56+
height=480,
57+
width=832,
58+
num_frames=81,
59+
guidance_scale=5.0,
60+
).frames[0]
61+
62+
# 生成唯一文件名(避免冲突)
63+
file_name = str(uuid.uuid4()) + ".mp4"
64+
save_path = STATIC_DIR / file_name
65+
export_to_video(output, save_path, fps=15)
66+
WORKER_PORT = os.environ["WORKER_PORT"]
67+
WORKER_HOST = os.environ["WORKER_HOST"]
68+
url = f"http://{WORKER_HOST}:{WORKER_PORT}/static/{file_name}"
69+
result = {
70+
"created": shortuuid.random(),
71+
"data": [{"url": url}],
72+
"usage": {
73+
"total_tokens": 0,
74+
"input_tokens": 0,
75+
"output_tokens": 0,
76+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
77+
},
78+
}
79+
return result
80+
81+
82+
if __name__ == "__main__":
83+
WanWorker.run()

gpt_server/openai_api_protocol/custom_api_protocol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class ImagesGenRequest(BaseModel):
2121
default="png",
2222
description="png, jpeg, or webp",
2323
)
24+
model_type: Literal["t2v", "t2i"] = Field(
25+
default="t2i",
26+
description="t2v: 文生视频 t2i: 文生图",
27+
)
2428

2529

2630
# copy from https://github.com/remsky/Kokoro-FastAPI/blob/master/api/src/routers/openai_compatible.py

gpt_server/serving/openai_api_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,11 +732,11 @@ async def speech(request: ImagesGenRequest):
732732
error_check_ret = check_model(request)
733733
if error_check_ret is not None:
734734
return error_check_ret
735-
736735
payload = {
737736
"model": request.model,
738737
"prompt": request.prompt,
739738
"output_format": request.output_format,
739+
"model_type": request.model_type,
740740
}
741741
result = await get_images_gen(payload=payload)
742742
return result

gpt_server/utils.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,13 @@
1010
import psutil
1111
from rich import print
1212
import signal
13+
from pathlib import Path
1314

15+
ENV = os.environ
1416
logger.add("logs/gpt_server.log", rotation="100 MB", level="INFO")
17+
root_dir = Path(__file__).parent
18+
STATIC_DIR = root_dir / "static"
19+
os.makedirs(STATIC_DIR, exist_ok=True)
1520

1621

1722
def kill_child_processes(parent_pid, including_parent=False):
@@ -111,8 +116,7 @@ def start_api_server(config: dict):
111116

112117
def get_model_types():
113118
model_types = []
114-
root_dir = os.path.dirname(__file__)
115-
model_worker_path = os.path.join(root_dir, "model_worker")
119+
model_worker_path = root_dir / "model_worker"
116120
# 遍历目录及其子目录
117121
for root, dirs, files in os.walk(model_worker_path):
118122
for file in files:
@@ -352,6 +356,18 @@ def is_port_in_use(port):
352356
return True
353357

354358

359+
def get_physical_ip():
360+
import socket
361+
362+
local_ip = socket.gethostbyname(socket.getfqdn(socket.gethostname()))
363+
return local_ip
364+
365+
366+
try:
367+
local_ip = get_physical_ip()
368+
except Exception as e:
369+
local_ip = ENV.get("local_ip", "127.0.0.1")
370+
355371
model_type_mapping = {
356372
"yi": "yi",
357373
"qwen": "qwen",
@@ -374,12 +390,13 @@ def is_port_in_use(port):
374390
from lmdeploy.archs import get_model_arch
375391
from lmdeploy.cli.utils import get_chat_template
376392

393+
print(local_ip)
377394
ckpt = "/home/dev/model/Qwen/Qwen3-32B/" # internlm2
378395
chat_template = get_chat_template(ckpt)
379396
model_type = get_names_from_model(ckpt)
380397
arch = get_model_arch(ckpt)
398+
381399
print(chat_template)
382400
# print(arch)
383401
print(model_type)
384402
print(model_type[1] == "base")
385-
print()

0 commit comments

Comments
 (0)