@@ -135,14 +135,14 @@ def sampling_main(args, model_cls):
135135 sample_func = model .sample
136136 num_samples = [1 ]
137137 force_uc_zero_embeddings = ["txt" ]
138-
138+ T , C = args . sampling_num_frames , args . latent_channels
139139 with torch .no_grad ():
140140 for text , cnt in tqdm (data_iter ):
141141 if args .image2video :
142142 # use with input image shape
143- text , image_path = text .split ('@@' )
143+ text , image_path = text .split ("@@" )
144144 assert os .path .exists (image_path ), image_path
145- image = Image .open (image_path ).convert (' RGB' )
145+ image = Image .open (image_path ).convert (" RGB" )
146146 (img_W , img_H ) = image .size
147147
148148 def nearest_multiple_of_16 (n ):
@@ -163,7 +163,7 @@ def nearest_multiple_of_16(n):
163163 chained_trainsforms .append (TT .Resize (size = [int (H * 8 ), int (W * 8 )], interpolation = 1 ))
164164 chained_trainsforms .append (TT .ToTensor ())
165165 transform = TT .Compose (chained_trainsforms )
166- image = transform (image ).unsqueeze (0 ).to (' cuda' )
166+ image = transform (image ).unsqueeze (0 ).to (" cuda" )
167167 image = image * 2.0 - 1.0
168168 image = image .unsqueeze (2 ).to (torch .bfloat16 )
169169 image = model .encode_first_stage (image , None )
@@ -173,7 +173,7 @@ def nearest_multiple_of_16(n):
173173 image = torch .concat ([image , torch .zeros (pad_shape ).to (image .device ).to (image .dtype )], dim = 1 )
174174 else :
175175 image_size = args .sampling_image_size
176- T , H , W , C = args . sampling_num_frames , image_size [0 ], image_size [1 ], args . latent_channels
176+ H , W = image_size [0 ], image_size [1 ]
177177 F = 8 # 8x downsampled
178178 image = None
179179
@@ -183,11 +183,7 @@ def nearest_multiple_of_16(n):
183183 src = global_rank * mp_size
184184 torch .distributed .broadcast_object_list (text_cast , src = src , group = mpu .get_model_parallel_group ())
185185 text = text_cast [0 ]
186- value_dict = {
187- 'prompt' : text ,
188- 'negative_prompt' : '' ,
189- 'num_frames' : torch .tensor (T ).unsqueeze (0 )
190- }
186+ value_dict = {"prompt" : text , "negative_prompt" : "" , "num_frames" : torch .tensor (T ).unsqueeze (0 )}
191187
192188 batch , batch_uc = get_batch (
193189 get_unique_embedder_keys_from_conditioner (model .conditioner ), value_dict , num_samples
@@ -216,19 +212,15 @@ def nearest_multiple_of_16(n):
216212 for index in range (args .batch_size ):
217213 if args .image2video :
218214 samples_z = sample_func (
219- c ,
220- uc = uc ,
221- batch_size = 1 ,
222- shape = (T , C , H , W ),
223- ofs = torch .tensor ([2.0 ]).to ('cuda' )
215+ c , uc = uc , batch_size = 1 , shape = (T , C , H , W ), ofs = torch .tensor ([2.0 ]).to ("cuda" )
224216 )
225217 else :
226218 samples_z = sample_func (
227219 c ,
228220 uc = uc ,
229221 batch_size = 1 ,
230222 shape = (T , C , H // F , W // F ),
231- ).to (' cuda' )
223+ ).to (" cuda" )
232224
233225 samples_z = samples_z .permute (0 , 2 , 1 , 3 , 4 ).contiguous ()
234226 if args .only_save_latents :
@@ -250,11 +242,12 @@ def nearest_multiple_of_16(n):
250242 if mpu .get_model_parallel_rank () == 0 :
251243 save_video_as_grid_and_mp4 (samples , save_path , fps = args .sampling_fps )
252244
253- if __name__ == '__main__' :
254- if 'OMPI_COMM_WORLD_LOCAL_RANK' in os .environ :
255- os .environ ['LOCAL_RANK' ] = os .environ ['OMPI_COMM_WORLD_LOCAL_RANK' ]
256- os .environ ['WORLD_SIZE' ] = os .environ ['OMPI_COMM_WORLD_SIZE' ]
257- os .environ ['RANK' ] = os .environ ['OMPI_COMM_WORLD_RANK' ]
245+
246+ if __name__ == "__main__" :
247+ if "OMPI_COMM_WORLD_LOCAL_RANK" in os .environ :
248+ os .environ ["LOCAL_RANK" ] = os .environ ["OMPI_COMM_WORLD_LOCAL_RANK" ]
249+ os .environ ["WORLD_SIZE" ] = os .environ ["OMPI_COMM_WORLD_SIZE" ]
250+ os .environ ["RANK" ] = os .environ ["OMPI_COMM_WORLD_RANK" ]
258251 py_parser = argparse .ArgumentParser (add_help = False )
259252 known , args_list = py_parser .parse_known_args ()
260253
0 commit comments