Skip to content

Commit 17996f1

Browse files
update
1 parent 5e3e3aa commit 17996f1

File tree

5 files changed

+79
-26
lines changed

5 files changed

+79
-26
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ models we currently offer, along with their foundational information.
194194
</tr>
195195
<tr>
196196
<td style="text-align: center;">Inference Precision</td>
197-
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
197+
<td colspan="2" style="text-align: center;"><b>BF16 (Recommended)</b>, FP16, FP32, FP8*, INT8, Not supported: INT4</td>
198198
<td style="text-align: center;"><b>FP16*(Recommended)</b>, BF16, FP32, FP8*, INT8, Not supported: INT4</td>
199199
<td colspan="2" style="text-align: center;"><b>BF16 (Recommended)</b>, FP16, FP32, FP8*, INT8, Not supported: INT4</td>
200200
</tr>

README_ja.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ CogVideoXは、[清影](https://chatglm.cn/video?fr=osm_cogvideox) と同源の
186186
</tr>
187187
<tr>
188188
<td style="text-align: center;">推論精度</td>
189-
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
189+
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32,FP8*,INT8,INT4非対応</td>
190190
<td style="text-align: center;"><b>FP16*(推奨)</b>, BF16, FP32,FP8*,INT8,INT4非対応</td>
191191
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32,FP8*,INT8,INT4非対応</td>
192192
</tr>

README_zh.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
176176
</tr>
177177
<tr>
178178
<td style="text-align: center;">推理精度</td>
179-
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
179+
<td colspan="2" style="text-align: center;"><b>BF16(推荐)</b>, FP16, FP32,FP8*,INT8,不支持INT4</td>
180180
<td style="text-align: center;"><b>FP16*(推荐)</b>, BF16, FP32,FP8*,INT8,不支持INT4</td>
181181
<td colspan="2" style="text-align: center;"><b>BF16(推荐)</b>, FP16, FP32,FP8*,INT8,不支持INT4</td>
182182
</tr>

inference/cli_demo.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,13 @@ def generate_video(
103103
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
104104
# and enable to("cuda")
105105

106-
pipe.to("cuda")
107-
108-
# pipe.enable_sequential_cpu_offload()
109-
106+
# pipe.to("cuda")
107+
pipe.enable_sequential_cpu_offload()
110108
pipe.vae.enable_slicing()
111109
pipe.vae.enable_tiling()
112110

113111
# 4. Generate the video frames based on the prompt.
114112
# `num_frames` is the Number of frames to generate.
115-
# This is the default value for 6 seconds video and 8 fps and will plus 1 frame for the first frame and 49 frames.
116113
if generate_type == "i2v":
117114
video_generate = pipe(
118115
height=height,

tools/convert_weight_sat2hf.py

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
9292
"post_attn1_layernorm": "norm2.norm",
9393
"time_embed.0": "time_embedding.linear_1",
9494
"time_embed.2": "time_embedding.linear_2",
95+
"ofs_embed.0": "ofs_embedding.linear_1",
96+
"ofs_embed.2": "ofs_embedding.linear_2",
9597
"mixins.patch_embed": "patch_embed",
9698
"mixins.final_layer.norm_final": "norm_out.norm",
9799
"mixins.final_layer.linear": "proj_out",
@@ -146,12 +148,13 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
146148

147149

148150
def convert_transformer(
149-
ckpt_path: str,
150-
num_layers: int,
151-
num_attention_heads: int,
152-
use_rotary_positional_embeddings: bool,
153-
i2v: bool,
154-
dtype: torch.dtype,
151+
ckpt_path: str,
152+
num_layers: int,
153+
num_attention_heads: int,
154+
use_rotary_positional_embeddings: bool,
155+
i2v: bool,
156+
dtype: torch.dtype,
157+
init_kwargs: Dict[str, Any],
155158
):
156159
PREFIX_KEY = "model.diffusion_model."
157160

@@ -161,11 +164,13 @@ def convert_transformer(
161164
num_layers=num_layers,
162165
num_attention_heads=num_attention_heads,
163166
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
164-
use_learned_positional_embeddings=i2v,
167+
ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V
168+
use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
169+
**init_kwargs,
165170
).to(dtype=dtype)
166171

167172
for key in list(original_state_dict.keys()):
168-
new_key = key[len(PREFIX_KEY):]
173+
new_key = key[len(PREFIX_KEY) :]
169174
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
170175
new_key = new_key.replace(replace_key, rename_key)
171176
update_state_dict_inplace(original_state_dict, key, new_key)
@@ -175,13 +180,18 @@ def convert_transformer(
175180
if special_key not in key:
176181
continue
177182
handler_fn_inplace(key, original_state_dict)
183+
178184
transformer.load_state_dict(original_state_dict, strict=True)
179185
return transformer
180186

181187

182-
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
188+
def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype):
189+
init_kwargs = {"scaling_factor": scaling_factor}
190+
if version == "1.5":
191+
init_kwargs.update({"invert_scale_latents": True})
192+
183193
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
184-
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
194+
vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype)
185195

186196
for key in list(original_state_dict.keys()):
187197
new_key = key[:]
@@ -199,6 +209,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
199209
return vae
200210

201211

212+
def get_transformer_init_kwargs(version: str):
213+
if version == "1.0":
214+
vae_scale_factor_spatial = 8
215+
init_kwargs = {
216+
"patch_size": 2,
217+
"patch_size_t": None,
218+
"patch_bias": True,
219+
"sample_height": 480 // vae_scale_factor_spatial,
220+
"sample_width": 720 // vae_scale_factor_spatial,
221+
"sample_frames": 49,
222+
}
223+
224+
elif version == "1.5":
225+
vae_scale_factor_spatial = 8
226+
init_kwargs = {
227+
"patch_size": 2,
228+
"patch_size_t": 2,
229+
"patch_bias": False,
230+
"sample_height": 768 // vae_scale_factor_spatial,
231+
"sample_width": 1360 // vae_scale_factor_spatial,
232+
"sample_frames": 81,
233+
}
234+
else:
235+
raise ValueError("Unsupported version of CogVideoX.")
236+
237+
return init_kwargs
238+
239+
202240
def get_args():
203241
parser = argparse.ArgumentParser()
204242
parser.add_argument(
@@ -214,6 +252,12 @@ def get_args():
214252
parser.add_argument(
215253
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
216254
)
255+
parser.add_argument(
256+
"--typecast_text_encoder",
257+
action="store_true",
258+
default=False,
259+
help="Whether or not to apply fp16/bf16 precision to text_encoder",
260+
)
217261
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
218262
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
219263
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
@@ -226,7 +270,18 @@ def get_args():
226270
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
227271
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
228272
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
229-
parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
273+
parser.add_argument(
274+
"--i2v",
275+
action="store_true",
276+
default=False,
277+
help="Whether the model to be converted is the Image-to-Video version of CogVideoX.",
278+
)
279+
parser.add_argument(
280+
"--version",
281+
choices=["1.0", "1.5"],
282+
default="1.0",
283+
help="Which version of CogVideoX to use for initializing default modeling parameters.",
284+
)
230285
return parser.parse_args()
231286

232287

@@ -242,21 +297,27 @@ def get_args():
242297
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
243298

244299
if args.transformer_ckpt_path is not None:
300+
init_kwargs = get_transformer_init_kwargs(args.version)
245301
transformer = convert_transformer(
246302
args.transformer_ckpt_path,
247303
args.num_layers,
248304
args.num_attention_heads,
249305
args.use_rotary_positional_embeddings,
250306
args.i2v,
251307
dtype,
308+
init_kwargs,
252309
)
253310
if args.vae_ckpt_path is not None:
254-
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
311+
# Keep VAE in float32 for better quality
312+
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32)
255313

256-
text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl"
314+
text_encoder_id = "google/t5-v1_1-xxl"
257315
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
258316
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
259317

318+
if args.typecast_text_encoder:
319+
text_encoder = text_encoder.to(dtype=dtype)
320+
260321
# Apparently, the conversion does not work anymore without this :shrug:
261322
for param in text_encoder.parameters():
262323
param.data = param.data.contiguous()
@@ -288,11 +349,6 @@ def get_args():
288349
scheduler=scheduler,
289350
)
290351

291-
if args.fp16:
292-
pipe = pipe.to(dtype=torch.float16)
293-
if args.bf16:
294-
pipe = pipe.to(dtype=torch.bfloat16)
295-
296352
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
297353
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
298354
# is either fp16/bf16 here).

0 commit comments

Comments
 (0)