forked from Eco-Sphere/cache-dit
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_wan_2.2.py
More file actions
158 lines (139 loc) · 4.42 KB
/
run_wan_2.2.py
File metadata and controls
158 lines (139 loc) · 4.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import sys
sys.path.append("..")
import time
import torch
import diffusers
from diffusers import WanPipeline, AutoencoderKLWan, WanTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.schedulers.scheduling_unipc_multistep import (
UniPCMultistepScheduler,
)
from utils import get_args, GiB, strify, cachify
import cache_dit
args = get_args()
print(args)
height, width = 480, 832
pipe = WanPipeline.from_pretrained(
os.environ.get(
"WAN_2_2_DIR",
"Wan-AI/Wan2.2-T2V-A14B-Diffusers",
),
torch_dtype=torch.bfloat16,
# https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement
device_map=(
"balanced" if (torch.cuda.device_count() > 1 and GiB() <= 48) else None
),
)
# flow shift should be 3.0 for 480p images, 5.0 for 720p images
if hasattr(pipe, "scheduler") and pipe.scheduler is not None:
# Use the UniPCMultistepScheduler with the specified flow shift
flow_shift = 3.0 if height == 480 else 5.0
pipe.scheduler = UniPCMultistepScheduler.from_config(
pipe.scheduler.config,
flow_shift=flow_shift,
)
if args.cache:
from cache_dit import (
ForwardPattern,
BlockAdapter,
ParamsModifier,
DBCacheConfig,
)
cachify(
args,
BlockAdapter(
pipe=pipe,
transformer=[
pipe.transformer,
pipe.transformer_2,
],
blocks=[
pipe.transformer.blocks,
pipe.transformer_2.blocks,
],
forward_pattern=[
ForwardPattern.Pattern_2,
ForwardPattern.Pattern_2,
],
params_modifiers=[
# high-noise transformer only have 30% steps
ParamsModifier(
cache_config=DBCacheConfig().reset(
max_warmup_steps=4,
max_cached_steps=8,
),
),
ParamsModifier(
cache_config=DBCacheConfig().reset(
max_warmup_steps=2,
max_cached_steps=20,
),
),
],
has_separate_cfg=True,
),
)
# Wan currently requires installing diffusers from source
assert isinstance(pipe.vae, AutoencoderKLWan) # enable type check for IDE
if diffusers.__version__ >= "0.34.0":
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
else:
print(
"Wan pipeline requires diffusers version >= 0.34.0 "
"for vae tiling and slicing, please install diffusers "
"from source."
)
assert isinstance(pipe.transformer, WanTransformer3DModel)
assert isinstance(pipe.transformer_2, WanTransformer3DModel)
if args.quantize:
assert isinstance(args.quantize_type, str)
if args.quantize_type.endswith("wo"): # weight only
pipe.transformer = cache_dit.quantize(
pipe.transformer,
quant_type=args.quantize_type,
)
# We only apply activation quantization (default: FP8 DQ)
# for low-noise transformer to avoid non-trivial precision
# downgrade.
pipe.transformer_2 = cache_dit.quantize(
pipe.transformer_2,
quant_type=args.quantize_type,
)
if args.compile or args.quantize:
cache_dit.set_compile_configs()
pipe.transformer.compile_repeated_blocks(fullgraph=True)
pipe.transformer_2.compile_repeated_blocks(fullgraph=True)
# warmup
video = pipe(
prompt=(
"An astronaut dancing vigorously on the moon with earth "
"flying past in the background, hyperrealistic"
),
height=height,
width=width,
num_frames=81,
num_inference_steps=50,
generator=torch.Generator("cpu").manual_seed(0),
).frames[0]
start = time.time()
video = pipe(
prompt=(
"An astronaut dancing vigorously on the moon with earth "
"flying past in the background, hyperrealistic"
),
negative_prompt="",
height=height,
width=width,
num_frames=81,
num_inference_steps=50,
generator=torch.Generator("cpu").manual_seed(0),
).frames[0]
end = time.time()
cache_dit.summary(pipe, details=True)
time_cost = end - start
save_path = f"wan2.2.{strify(args, pipe)}.mp4"
print(f"Time cost: {time_cost:.2f}s")
print(f"Saving video to {save_path}")
export_to_video(video, save_path, fps=16)