@@ -134,12 +134,12 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
134134
135135
136136def convert_transformer (
137- ckpt_path : str ,
138- num_layers : int ,
139- num_attention_heads : int ,
140- use_rotary_positional_embeddings : bool ,
141- i2v : bool ,
142- dtype : torch .dtype ,
137+ ckpt_path : str ,
138+ num_layers : int ,
139+ num_attention_heads : int ,
140+ use_rotary_positional_embeddings : bool ,
141+ i2v : bool ,
142+ dtype : torch .dtype ,
143143):
144144 PREFIX_KEY = "model.diffusion_model."
145145
@@ -153,7 +153,7 @@ def convert_transformer(
153153 ).to (dtype = dtype )
154154
155155 for key in list (original_state_dict .keys ()):
156- new_key = key [len (PREFIX_KEY ) :]
156+ new_key = key [len (PREFIX_KEY ):]
157157 for replace_key , rename_key in TRANSFORMER_KEYS_RENAME_DICT .items ():
158158 new_key = new_key .replace (replace_key , rename_key )
159159 update_state_dict_inplace (original_state_dict , key , new_key )
@@ -241,7 +241,7 @@ def get_args():
241241 if args .vae_ckpt_path is not None :
242242 vae = convert_vae (args .vae_ckpt_path , args .scaling_factor , dtype )
243243
244- text_encoder_id = "google /t5-v1_1-xxl"
244+ text_encoder_id = "/share/official_pretrains/hf_home/ /t5-v1_1-xxl"
245245 tokenizer = T5Tokenizer .from_pretrained (text_encoder_id , model_max_length = TOKENIZER_MAX_LENGTH )
246246 text_encoder = T5EncoderModel .from_pretrained (text_encoder_id , cache_dir = args .text_encoder_cache_dir )
247247 # Apparently, the conversion does not work anymore without this :shrug:
@@ -283,4 +283,7 @@ def get_args():
283283 # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
284284 # for users to specify variant when the default is not fp32 and they want to run with the correct default (which
285285 # is either fp16/bf16 here).
286- pipe .save_pretrained (args .output_path , safe_serialization = True , push_to_hub = args .push_to_hub )
286+
287+ # This is necessary This is necessary for users with insufficient memory,
288+ # such as those using Colab and notebooks, as it can save some memory used for model loading.
289+ pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" , push_to_hub = args .push_to_hub )
0 commit comments