|
| 1 | +import asyncio |
| 2 | + |
| 3 | +import io |
| 4 | +import os |
| 5 | +from typing import List |
| 6 | +import uuid |
| 7 | +from loguru import logger |
| 8 | +import shortuuid |
| 9 | +from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase |
| 10 | +from gpt_server.model_worker.utils import pil_to_base64 |
| 11 | +from gpt_server.utils import STATIC_DIR |
| 12 | +import torch |
| 13 | +from diffusers import AutoencoderKLWan, WanPipeline |
| 14 | +from diffusers.utils import export_to_video |
| 15 | + |
| 16 | +root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) |
| 17 | + |
| 18 | + |
| 19 | +class WanWorker(ModelWorkerBase): |
| 20 | + def __init__( |
| 21 | + self, |
| 22 | + controller_addr: str, |
| 23 | + worker_addr: str, |
| 24 | + worker_id: str, |
| 25 | + model_path: str, |
| 26 | + model_names: List[str], |
| 27 | + limit_worker_concurrency: int, |
| 28 | + conv_template: str = None, # type: ignore |
| 29 | + ): |
| 30 | + super().__init__( |
| 31 | + controller_addr, |
| 32 | + worker_addr, |
| 33 | + worker_id, |
| 34 | + model_path, |
| 35 | + model_names, |
| 36 | + limit_worker_concurrency, |
| 37 | + conv_template, |
| 38 | + model_type="image", |
| 39 | + ) |
| 40 | + backend = os.environ["backend"] |
| 41 | + self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| 42 | + vae = AutoencoderKLWan.from_pretrained( |
| 43 | + model_path, subfolder="vae", torch_dtype=torch.float32 |
| 44 | + ) |
| 45 | + self.pipe = WanPipeline.from_pretrained( |
| 46 | + model_path, vae=vae, torch_dtype=torch.bfloat16 |
| 47 | + ).to(self.device) |
| 48 | + logger.warning(f"模型:{model_names[0]}") |
| 49 | + |
| 50 | + async def get_image_output(self, params): |
| 51 | + prompt = params["prompt"] |
| 52 | + negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" |
| 53 | + output = self.pipe( |
| 54 | + prompt=prompt, |
| 55 | + negative_prompt=negative_prompt, |
| 56 | + height=480, |
| 57 | + width=832, |
| 58 | + num_frames=81, |
| 59 | + guidance_scale=5.0, |
| 60 | + ).frames[0] |
| 61 | + |
| 62 | + # 生成唯一文件名(避免冲突) |
| 63 | + file_name = str(uuid.uuid4()) + ".mp4" |
| 64 | + save_path = STATIC_DIR / file_name |
| 65 | + export_to_video(output, save_path, fps=15) |
| 66 | + WORKER_PORT = os.environ["WORKER_PORT"] |
| 67 | + WORKER_HOST = os.environ["WORKER_HOST"] |
| 68 | + url = f"http://{WORKER_HOST}:{WORKER_PORT}/static/{file_name}" |
| 69 | + result = { |
| 70 | + "created": shortuuid.random(), |
| 71 | + "data": [{"url": url}], |
| 72 | + "usage": { |
| 73 | + "total_tokens": 0, |
| 74 | + "input_tokens": 0, |
| 75 | + "output_tokens": 0, |
| 76 | + "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, |
| 77 | + }, |
| 78 | + } |
| 79 | + return result |
| 80 | + |
| 81 | + |
| 82 | +if __name__ == "__main__": |
| 83 | + WanWorker.run() |
0 commit comments