Skip to content

Commit ee87d4f

Browse files
committed
Integrate Zhipu AI
1 parent 3f1220b commit ee87d4f

File tree

3 files changed

+135
-28
lines changed

3 files changed

+135
-28
lines changed

api/zhipu.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import logging
2+
import time
3+
4+
from zhipuai import ZhipuAI
5+
6+
# 设置日志
7+
logging.basicConfig(level=logging.WARNING)
8+
logger = logging.getLogger(__name__)
9+
10+
# 初始化客户端(API Key 通过环境变量传入)
11+
client = ZhipuAI()
12+
13+
14+
def generate_zhipu(prompt, image_url=None, model="cogvideox-2", quality="speed", with_audio=False, size="1920x1080",
15+
fps=30):
16+
try:
17+
# 根据是否有图片选择文生视频或图生视频
18+
params = {
19+
"model": model,
20+
"prompt": prompt,
21+
"quality": quality,
22+
"with_audio": with_audio,
23+
}
24+
if model != "cogvideox-flash":
25+
params["size"] = size
26+
params["fps"] = fps
27+
if image_url:
28+
params["image_url"] = image_url
29+
30+
# 发起生成请求
31+
response = client.videos.generations(**params)
32+
task_id = response.id
33+
logger.debug(f"智谱AI任务创建成功: {task_id}")
34+
35+
# 轮询任务状态
36+
while True:
37+
result = client.videos.retrieve_videos_result(id=task_id)
38+
logger.debug(f"任务状态: {result.task_status}")
39+
if result.task_status in ["SUCCESS", "FAILED"]:
40+
break
41+
time.sleep(2)
42+
43+
# 打印完整结果以调试
44+
logger.debug(f"任务结果完整内容: {vars(result)}")
45+
46+
if result.task_status == "SUCCESS":
47+
# 处理 video_result
48+
video_result = result.video_result
49+
if isinstance(video_result, list) and video_result:
50+
# 如果是列表,取第一个视频 URL
51+
video_url = video_result[0].url
52+
elif hasattr(video_result, "url"):
53+
# 如果是单一对象,直接取 url 属性
54+
video_url = video_result.url
55+
else:
56+
return None, "任务成功但未找到视频 URL"
57+
58+
if video_url:
59+
return video_url, "任务完成"
60+
return None, "任务成功但未找到视频 URL"
61+
else:
62+
return None, f"任务失败: {result.task_status}"
63+
except Exception as e:
64+
logger.error(f"智谱AI生成失败: {str(e)}")
65+
return None, f"智谱AI生成失败: {str(e)}"

main.py

Lines changed: 70 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,64 +7,58 @@
77
from api.ark import generate_volcengine
88
from api.bailian import generate_aliyun, MODEL_MAPPING
99
from api.tebi import upload_file_to_tebi
10+
from api.zhipu import generate_zhipu
1011

11-
# 设置全局日志级别为 WARNING,减少详细输出
12+
# 设置全局日志级别为 WARNING
1213
logging.basicConfig(level=logging.WARNING)
1314
logging.getLogger("httpx").setLevel(logging.WARNING)
1415

1516

1617
# 检查图片是否符合火山引擎要求
1718
def validate_image(file_path):
1819
try:
19-
# 文件大小检查(小于 10MB)
2020
file_size = os.path.getsize(file_path)
21-
if file_size > 10 * 1024 * 1024: # 10MB = 10,485,760 字节
21+
if file_size > 10 * 1024 * 1024:
2222
return False, "图片文件大小超过 10MB"
2323

24-
# 打开图片并检查格式、尺寸
2524
with Image.open(file_path) as img:
26-
# 格式检查
2725
valid_formats = {"JPEG", "PNG", "WEBP", "BMP", "TIFF"}
2826
if img.format not in valid_formats:
2927
return False, f"图片格式不支持,仅支持 {', '.join(valid_formats)}"
30-
31-
# 宽高检查
3228
width, height = img.size
3329
aspect_ratio = width / height
3430
min_side = min(width, height)
3531
max_side = max(width, height)
36-
3732
if not (0.4 <= aspect_ratio <= 2.5):
3833
return False, "图片宽高比需在 2:5 到 5:2 之间 (0.4 - 2.5)"
3934
if min_side < 300:
4035
return False, "图片短边像素需大于等于 300px"
4136
if max_side > 6000:
4237
return False, "图片长边像素需小于等于 6000px"
43-
4438
return True, "图片验证通过"
4539
except Exception as e:
4640
return False, f"图片验证失败: {str(e)}"
4741

4842

4943
# 判断是否需要图片上传
50-
def is_image_required(platform, aliyun_model, ark_duration):
44+
def is_image_required(platform, aliyun_model, ark_duration, zhipu_model):
5145
if platform == "火山引擎":
52-
return ark_duration == 5 # 仅 5 秒支持文图生视频
46+
return ark_duration == 5
5347
elif platform == "阿里云百炼":
5448
return aliyun_model in ["通义万相-图生视频2.1-Turbo", "通义万相-图生视频2.1-Plus"]
49+
elif platform == "智谱AI":
50+
return zhipu_model == "CogVideoX-2 (图生视频)"
5551
return False
5652

5753

5854
# 生成视频逻辑
59-
def generate_video(platform, prompt, image_file, aliyun_model, ark_ratio, ark_duration, bailian_size):
55+
def generate_video(platform, prompt, image_file, aliyun_model, ark_ratio, ark_duration, bailian_size, zhipu_model,
56+
zhipu_quality, zhipu_audio, zhipu_size, zhipu_fps):
6057
image_url = None
61-
if is_image_required(platform, aliyun_model, ark_duration) and image_file:
62-
# 验证图片
58+
if is_image_required(platform, aliyun_model, ark_duration, zhipu_model) and image_file:
6359
is_valid, message = validate_image(image_file)
6460
if not is_valid:
6561
return None, message
66-
67-
# 上传到 Tebi
6862
image_url = upload_file_to_tebi(image_file)
6963
if not image_url:
7064
return None, "图片上传到 Tebi 失败"
@@ -73,15 +67,18 @@ def generate_video(platform, prompt, image_file, aliyun_model, ark_ratio, ark_du
7367
video_url, status = generate_volcengine(prompt, image_url, ark_ratio, ark_duration)
7468
elif platform == "阿里云百炼":
7569
video_url, status = generate_aliyun(prompt, image_url, aliyun_model, bailian_size)
70+
elif platform == "智谱AI":
71+
model = "cogvideox-2" if "CogVideoX-2" in zhipu_model else "cogvideox-flash"
72+
video_url, status = generate_zhipu(prompt, image_url, model, zhipu_quality, zhipu_audio, zhipu_size, zhipu_fps)
7673
else:
7774
return None, "请选择有效平台"
7875

7976
return video_url, status
8077

8178

8279
# 更新图片上传区域可见性
83-
def update_image_visibility(platform, aliyun_model, ark_duration):
84-
return gr.update(visible=is_image_required(platform, aliyun_model, ark_duration))
80+
def update_image_visibility(platform, aliyun_model, ark_duration, zhipu_model):
81+
return gr.update(visible=is_image_required(platform, aliyun_model, ark_duration, zhipu_model))
8582

8683

8784
# Gradio 界面
@@ -90,10 +87,9 @@ def update_image_visibility(platform, aliyun_model, ark_duration):
9087
gr.Markdown("输入提示词并选择平台生成视频,支持上传图片用于文图生视频。")
9188

9289
with gr.Row():
93-
# 左侧输入区
9490
with gr.Column(scale=1):
9591
platform = gr.Dropdown(
96-
choices=["火山引擎", "阿里云百炼"],
92+
choices=["火山引擎", "阿里云百炼", "智谱AI"],
9793
label="选择平台",
9894
value="火山引擎"
9995
)
@@ -103,6 +99,12 @@ def update_image_visibility(platform, aliyun_model, ark_duration):
10399
value="通义万相-文生视频2.1-Turbo",
104100
visible=False
105101
)
102+
zhipu_model = gr.Dropdown(
103+
choices=["CogVideoX-2 (文生视频)", "CogVideoX-2 (图生视频)", "CogVideoX-Flash"],
104+
label="智谱AI模型(仅智谱AI生效)",
105+
value="CogVideoX-2 (文生视频)",
106+
visible=False
107+
)
106108
prompt = gr.Textbox(label="提示词", placeholder="请输入生成视频的描述,例如:一只猫在草地上奔跑")
107109
image_file = gr.File(label="上传图片(用于文图生视频)", type="filepath")
108110

@@ -127,11 +129,30 @@ def update_image_visibility(platform, aliyun_model, ark_duration):
127129
value="1280*720"
128130
)
129131

132+
# 智谱AI参数
133+
with gr.Group(visible=False) as zhipu_params:
134+
zhipu_quality = gr.Dropdown(
135+
choices=["speed", "quality"],
136+
label="输出模式",
137+
value="speed"
138+
)
139+
zhipu_audio = gr.Checkbox(label="生成AI音效", value=False)
140+
zhipu_size = gr.Dropdown(
141+
choices=["720x480", "1024x1024", "1280x960", "960x1280", "1920x1080", "1080x1920", "2048x1080",
142+
"3840x2160"],
143+
label="分辨率",
144+
value="1920x1080"
145+
)
146+
zhipu_fps = gr.Dropdown(
147+
choices=[30, 60],
148+
label="帧率 (FPS)",
149+
value=30
150+
)
151+
130152
with gr.Row():
131153
submit_btn = gr.Button("生成视频")
132154
clear_btn = gr.Button("清除")
133155

134-
# 右侧输出区
135156
with gr.Column(scale=1):
136157
video_output = gr.Video(label="生成结果")
137158
status_output = gr.Textbox(label="状态")
@@ -141,36 +162,57 @@ def update_image_visibility(platform, aliyun_model, ark_duration):
141162
def update_visibility(platform):
142163
return (
143164
gr.update(visible=platform == "阿里云百炼"), # aliyun_model
165+
gr.update(visible=platform == "智谱AI"), # zhipu_model
144166
gr.update(visible=platform == "火山引擎"), # ark_params
145-
gr.update(visible=platform == "阿里云百炼") # bailian_params
167+
gr.update(visible=platform == "阿里云百炼"), # bailian_params
168+
gr.update(visible=platform == "智谱AI"), # zhipu_params
169+
gr.update(visible=platform == "智谱AI" and "cogvideox-flash" not in zhipu_model.value), # zhipu_quality
170+
gr.update(visible=platform == "智谱AI" and "cogvideox-flash" not in zhipu_model.value), # zhipu_size
171+
gr.update(visible=platform == "智谱AI" and "cogvideox-flash" not in zhipu_model.value) # zhipu_fps
146172
)
147173

148174

149175
platform.change(
150176
fn=update_visibility,
151177
inputs=platform,
152-
outputs=[aliyun_model, ark_params, bailian_params]
178+
outputs=[aliyun_model, zhipu_model, ark_params, bailian_params, zhipu_params, zhipu_quality, zhipu_size,
179+
zhipu_fps]
153180
)
154181
platform.change(
155182
fn=update_image_visibility,
156-
inputs=[platform, aliyun_model, ark_duration],
183+
inputs=[platform, aliyun_model, ark_duration, zhipu_model],
157184
outputs=image_file
158185
)
159186
aliyun_model.change(
160187
fn=update_image_visibility,
161-
inputs=[platform, aliyun_model, ark_duration],
188+
inputs=[platform, aliyun_model, ark_duration, zhipu_model],
162189
outputs=image_file
163190
)
164191
ark_duration.change(
165192
fn=update_image_visibility,
166-
inputs=[platform, aliyun_model, ark_duration],
193+
inputs=[platform, aliyun_model, ark_duration, zhipu_model],
167194
outputs=image_file
168195
)
196+
zhipu_model.change(
197+
fn=update_image_visibility,
198+
inputs=[platform, aliyun_model, ark_duration, zhipu_model],
199+
outputs=image_file
200+
)
201+
zhipu_model.change(
202+
fn=lambda zhipu_model: (
203+
gr.update(visible="cogvideox-flash" not in zhipu_model),
204+
gr.update(visible="cogvideox-flash" not in zhipu_model),
205+
gr.update(visible="cogvideox-flash" not in zhipu_model)
206+
),
207+
inputs=zhipu_model,
208+
outputs=[zhipu_quality, zhipu_size, zhipu_fps]
209+
)
169210

170211
# 提交生成
171212
submit_btn.click(
172213
fn=generate_video,
173-
inputs=[platform, prompt, image_file, aliyun_model, ark_ratio, ark_duration, bailian_size],
214+
inputs=[platform, prompt, image_file, aliyun_model, ark_ratio, ark_duration, bailian_size, zhipu_model,
215+
zhipu_quality, zhipu_audio, zhipu_size, zhipu_fps],
174216
outputs=[video_output, status_output]
175217
)
176218

@@ -181,4 +223,4 @@ def update_visibility(platform):
181223
outputs=[video_output, status_output, image_file, prompt]
182224
)
183225

184-
demo.launch(server_name="0.0.0.0", quiet=True)
226+
demo.launch(server_name="0.0.0.0")

requirements.txt

116 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)