Skip to content

Commit c8ec68c

Browse files
Update convert_cogvideox_to_diffusers.py
1 parent ed8bda9 commit c8ec68c

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

scripts/convert_cogvideox_to_diffusers.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,12 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
134134

135135

136136
def convert_transformer(
137-
ckpt_path: str,
138-
num_layers: int,
139-
num_attention_heads: int,
140-
use_rotary_positional_embeddings: bool,
141-
i2v: bool,
142-
dtype: torch.dtype,
137+
ckpt_path: str,
138+
num_layers: int,
139+
num_attention_heads: int,
140+
use_rotary_positional_embeddings: bool,
141+
i2v: bool,
142+
dtype: torch.dtype,
143143
):
144144
PREFIX_KEY = "model.diffusion_model."
145145

@@ -153,7 +153,7 @@ def convert_transformer(
153153
).to(dtype=dtype)
154154

155155
for key in list(original_state_dict.keys()):
156-
new_key = key[len(PREFIX_KEY) :]
156+
new_key = key[len(PREFIX_KEY):]
157157
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
158158
new_key = new_key.replace(replace_key, rename_key)
159159
update_state_dict_inplace(original_state_dict, key, new_key)
@@ -241,7 +241,7 @@ def get_args():
241241
if args.vae_ckpt_path is not None:
242242
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
243243

244-
text_encoder_id = "google/t5-v1_1-xxl"
244+
text_encoder_id = "/share/official_pretrains/hf_home//t5-v1_1-xxl"
245245
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
246246
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
247247
# Apparently, the conversion does not work anymore without this :shrug:
@@ -283,4 +283,7 @@ def get_args():
283283
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
284284
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
285285
# is either fp16/bf16 here).
286-
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
286+
287+
# This is necessary This is necessary for users with insufficient memory,
288+
# such as those using Colab and notebooks, as it can save some memory used for model loading.
289+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)

0 commit comments

Comments
 (0)