Skip to content

Commit c8c7b62

Browse files
update diffusers code
1 parent a8205b5 commit c8c7b62

File tree

3 files changed

+49
-62
lines changed

3 files changed

+49
-62
lines changed

inference/cli_demo.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
The script supports different types of video generation, including text-to-video (t2v), image-to-video (i2v),
44
and video-to-video (v2v), depending on the input data and different weight.
55
6-
- text-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b
7-
- video-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b
8-
- image-to-video: THUDM/CogVideoX-5b-I2V
6+
- text-to-video: THUDM/CogVideoX-5b, THUDM/CogVideoX-2b or THUDM/CogVideoX1.5-5b
7+
- video-to-video: THUDM/CogVideoX-5b, THUDM/CogVideoX-2b or THUDM/CogVideoX1.5-5b
8+
- image-to-video: THUDM/CogVideoX-5b-I2V or THUDM/CogVideoX1.5-5b-I2V
99
1010
Running the Script:
1111
To run the script, use the following command with appropriate arguments:
1212
1313
```bash
14-
$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX-5b --generate_type "t2v"
14+
$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX1.5-5b --generate_type "t2v"
1515
```
1616
1717
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
@@ -23,7 +23,6 @@
2323
import torch
2424
from diffusers import (
2525
CogVideoXPipeline,
26-
CogVideoXDDIMScheduler,
2726
CogVideoXDPMScheduler,
2827
CogVideoXImageToVideoPipeline,
2928
CogVideoXVideoToVideoPipeline,
@@ -37,6 +36,7 @@ def generate_video(
3736
model_path: str,
3837
lora_path: str = None,
3938
lora_rank: int = 128,
39+
num_frames=81,
4040
output_path: str = "./output.mp4",
4141
image_or_video_path: str = "",
4242
num_inference_steps: int = 50,
@@ -45,6 +45,7 @@ def generate_video(
4545
dtype: torch.dtype = torch.bfloat16,
4646
generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
4747
seed: int = 42,
48+
fps: int = 8,
4849
):
4950
"""
5051
Generates a video based on the given prompt and saves it to the specified path.
@@ -56,11 +57,13 @@ def generate_video(
5657
- lora_rank (int): The rank of the LoRA weights.
5758
- output_path (str): The path where the generated video will be saved.
5859
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
60+
- num_frames (int): Number of frames to generate.
5961
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
6062
- num_videos_per_prompt (int): Number of videos to generate per prompt.
6163
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
6264
- generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v').·
6365
- seed (int): The seed for reproducibility.
66+
- fps (int): The frames per second for the generated video.
6467
"""
6568

6669
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
@@ -109,11 +112,11 @@ def generate_video(
109112
if generate_type == "i2v":
110113
video_generate = pipe(
111114
prompt=prompt,
112-
image=image, # The path of the image to be used as the background of the video
115+
image=image, # The path of the image, the resolution of video will be the same as the image for CogVideoX1.5-5B-I2V, otherwise it will be 720 * 480
113116
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
114117
num_inference_steps=num_inference_steps, # Number of inference steps
115-
num_frames=49, # Number of frames to generate,changed to 49 for diffusers version `0.30.3` and after.
116-
use_dynamic_cfg=True, # This id used for DPM Sechduler, for DDIM scheduler, it should be False
118+
num_frames=num_frames, # Number of frames to generate
119+
use_dynamic_cfg=True, # This id used for DPM scheduler, for DDIM scheduler, it should be False
117120
guidance_scale=guidance_scale,
118121
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
119122
).frames[0]
@@ -122,7 +125,7 @@ def generate_video(
122125
prompt=prompt,
123126
num_videos_per_prompt=num_videos_per_prompt,
124127
num_inference_steps=num_inference_steps,
125-
num_frames=49,
128+
num_frames=num_frames,
126129
use_dynamic_cfg=True,
127130
guidance_scale=guidance_scale,
128131
generator=torch.Generator().manual_seed(seed),
@@ -133,13 +136,12 @@ def generate_video(
133136
video=video, # The path of the video to be used as the background of the video
134137
num_videos_per_prompt=num_videos_per_prompt,
135138
num_inference_steps=num_inference_steps,
136-
# num_frames=49,
139+
num_frames=num_frames,
137140
use_dynamic_cfg=True,
138141
guidance_scale=guidance_scale,
139142
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
140143
).frames[0]
141-
# 5. Export the generated frames to a video file. fps must be 8 for original video.
142-
export_to_video(video_generate, output_path, fps=8)
144+
export_to_video(video_generate, output_path, fps=fps)
143145

144146

145147
if __name__ == "__main__":
@@ -152,24 +154,18 @@ def generate_video(
152154
help="The path of the image to be used as the background of the video",
153155
)
154156
parser.add_argument(
155-
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
157+
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model use"
156158
)
157159
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
158160
parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights")
159-
parser.add_argument(
160-
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
161-
)
161+
parser.add_argument("--output_path", type=str, default="./output.mp4", help="The path save generated video")
162162
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
163-
parser.add_argument(
164-
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
165-
)
163+
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
164+
parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process")
165+
parser.add_argument("--fps", type=int, default=16, help="Number of steps for the inference process")
166166
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
167-
parser.add_argument(
168-
"--generate_type", type=str, default="t2v", help="The type of video generation (e.g., 't2v', 'i2v', 'v2v')"
169-
)
170-
parser.add_argument(
171-
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
172-
)
167+
parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation")
168+
parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation")
173169
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
174170

175171
args = parser.parse_args()
@@ -180,11 +176,13 @@ def generate_video(
180176
lora_path=args.lora_path,
181177
lora_rank=args.lora_rank,
182178
output_path=args.output_path,
179+
num_frames=args.num_frames,
183180
image_or_video_path=args.image_or_video_path,
184181
num_inference_steps=args.num_inference_steps,
185182
guidance_scale=args.guidance_scale,
186183
num_videos_per_prompt=args.num_videos_per_prompt,
187184
dtype=dtype,
188185
generate_type=args.generate_type,
189186
seed=args.seed,
187+
fps=args.fps,
190188
)

inference/cli_demo_quantization.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
44
Note:
55
6-
Must install the `torchao`,`torch`,`diffusers`,`accelerate` library FROM SOURCE to use the quantization feature.
6+
Must install the `torchao`,`torch` library FROM SOURCE to use the quantization feature.
77
Only NVIDIA GPUs like H100 or higher are supported om FP-8 quantization.
88
99
ALL quantization schemes must use with NVIDIA GPUs.
@@ -51,6 +51,9 @@ def generate_video(
5151
num_videos_per_prompt: int = 1,
5252
quantization_scheme: str = "fp8",
5353
dtype: torch.dtype = torch.bfloat16,
54+
num_frames: int = 81,
55+
fps: int = 8,
56+
seed: int = 42,
5457
):
5558
"""
5659
Generates a video based on the given prompt and saves it to the specified path.
@@ -65,7 +68,6 @@ def generate_video(
6568
- quantization_scheme (str): The quantization scheme to use ('int8', 'fp8').
6669
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
6770
"""
68-
6971
text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
7072
text_encoder = quantize_model(part=text_encoder, quantization_scheme=quantization_scheme)
7173
transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
@@ -80,54 +82,38 @@ def generate_video(
8082
torch_dtype=dtype,
8183
)
8284
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
83-
84-
# Using with compile will run faster. First time infer will cost ~30min to compile.
85-
# pipe.transformer.to(memory_format=torch.channels_last)
86-
87-
# for FP8 should remove pipe.enable_model_cpu_offload()
8885
pipe.enable_model_cpu_offload()
89-
90-
# This is not for FP8 and INT8 and should remove this line
91-
# pipe.enable_sequential_cpu_offload()
9286
pipe.vae.enable_slicing()
9387
pipe.vae.enable_tiling()
88+
9489
video = pipe(
9590
prompt=prompt,
9691
num_videos_per_prompt=num_videos_per_prompt,
9792
num_inference_steps=num_inference_steps,
98-
num_frames=49,
93+
num_frames=num_frames,
9994
use_dynamic_cfg=True,
10095
guidance_scale=guidance_scale,
101-
generator=torch.Generator(device="cuda").manual_seed(42),
96+
generator=torch.Generator(device="cuda").manual_seed(seed),
10297
).frames[0]
10398

104-
export_to_video(video, output_path, fps=8)
99+
export_to_video(video, output_path, fps=fps)
105100

106101

107102
if __name__ == "__main__":
108103
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
109104
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
105+
parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model")
106+
parser.add_argument("--output_path", type=str, default="./output.mp4", help="Path to save generated video")
107+
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
108+
parser.add_argument("--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale")
109+
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Videos to generate per prompt")
110+
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type (e.g., 'float16', 'bfloat16')")
110111
parser.add_argument(
111-
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
112-
)
113-
parser.add_argument(
114-
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
115-
)
116-
parser.add_argument(
117-
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
118-
)
119-
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
120-
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
121-
parser.add_argument(
122-
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16', 'bfloat16')"
123-
)
124-
parser.add_argument(
125-
"--quantization_scheme",
126-
type=str,
127-
default="bf16",
128-
choices=["int8", "fp8"],
129-
help="The quantization scheme to use (int8, fp8)",
112+
"--quantization_scheme", type=str, default="fp8", choices=["int8", "fp8"], help="Quantization scheme"
130113
)
114+
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in the video")
115+
parser.add_argument("--fps", type=int, default=16, help="Frames per second for output video")
116+
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
131117

132118
args = parser.parse_args()
133119
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
@@ -140,4 +126,7 @@ def generate_video(
140126
num_videos_per_prompt=args.num_videos_per_prompt,
141127
quantization_scheme=args.quantization_scheme,
142128
dtype=dtype,
129+
num_frames=args.num_frames,
130+
fps=args.fps,
131+
seed=args.seed,
143132
)

requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
diffusers>=0.31.0
2-
accelerate>=1.0.1
3-
transformers>=4.46.1
2+
accelerate>=1.1.1
3+
transformers>=4.46.2
44
numpy==1.26.0
55
torch>=2.5.0
66
torchvision>=0.20.0
77
sentencepiece>=0.2.0
88
SwissArmyTransformer>=0.4.12
9-
gradio>=5.4.0
9+
gradio>=5.5.0
1010
imageio>=2.35.1
1111
imageio-ffmpeg>=0.5.1
12-
openai>=1.53.0
12+
openai>=1.54.0
1313
moviepy>=1.0.3
14-
scikit-video>=1.1.11
14+
scikit-video>=1.1.11

0 commit comments

Comments
 (0)