99import skvideo .io
1010from rife .RIFE_HDv3 import Model
1111from huggingface_hub import hf_hub_download , snapshot_download
12-
1312logger = logging .getLogger (__name__ )
1413
1514device = "cuda" if torch .cuda .is_available () else "cpu"
@@ -20,8 +19,7 @@ def pad_image(img, scale):
2019 tmp = max (32 , int (32 / scale ))
2120 ph = ((h - 1 ) // tmp + 1 ) * tmp
2221 pw = ((w - 1 ) // tmp + 1 ) * tmp
23- padding = (0 , pw - w , 0 , ph - h )
24-
22+ padding = (0 , pw - w , 0 , ph - h )
2523 return F .pad (img , padding ), padding
2624
2725
@@ -47,15 +45,15 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
4745 for b in range (samples .shape [0 ]):
4846 frame = samples [b : b + 1 ]
4947 _ , _ , h , w = frame .shape
50-
48+
5149 I0 = samples [b : b + 1 ]
5250 I1 = samples [b + 1 : b + 2 ] if b + 2 < samples .shape [0 ] else samples [- 1 :]
53-
51+
5452 I0 , padding = pad_image (I0 , upscale_amount )
5553 I0 = I0 .to (torch .float )
5654 I1 , _ = pad_image (I1 , upscale_amount )
5755 I1 = I1 .to (torch .float )
58-
56+
5957 # [c, h, w]
6058 I0_small = F .interpolate (I0 , (32 , 32 ), mode = "bilinear" , align_corners = False )
6159 I1_small = F .interpolate (I1 , (32 , 32 ), mode = "bilinear" , align_corners = False )
@@ -72,15 +70,23 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
7270 # print(f'I0 shape:{I0.shape}')
7371 # print(f'I1 shape:{I1.shape}')
7472 I1 = make_inference (model , I0 , I1 , upscale_amount , 1 )
75-
73+
7674 # print(f'I0 shape:{I0.shape}')
77- # print(f'I1[0] shape:{I1[0].shape}')
75+ # print(f'I1[0] shape:{I1[0].shape}')
7876 I1 = I1 [0 ]
79-
80- # print(f'I1[0] unpadded shape:{I1.shape}')
77+
78+ # print(f'I1[0] unpadded shape:{I1.shape}')
8179 I1_small = F .interpolate (I1 , (32 , 32 ), mode = "bilinear" , align_corners = False )
8280 ssim = ssim_matlab (I0_small [:, :3 ], I1_small [:, :3 ])
83- frame = I1 [padding [0 ] :, padding [2 ] :, : - padding [3 ], padding [1 ] :]
81+ if padding [3 ] > 0 and padding [1 ] > 0 :
82+
83+ frame = I1 [:, :, : - padding [3 ],:- padding [1 ]]
84+ elif padding [3 ] > 0 :
85+ frame = I1 [:, :, : - padding [3 ],:]
86+ elif padding [1 ] > 0 :
87+ frame = I1 [:, :, :,:- padding [1 ]]
88+ else :
89+ frame = I1
8490
8591 tmp_output = []
8692 if ssim < 0.2 :
@@ -95,7 +101,8 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
95101
96102 frame = F .interpolate (frame , size = (h , w ))
97103 output .append (frame .to (output_device ))
98- for i , tmp_frame in enumerate (tmp_output ):
104+ for i , tmp_frame in enumerate (tmp_output ):
105+
99106 # tmp_frame, _ = pad_image(tmp_frame, upscale_amount)
100107 tmp_frame = F .interpolate (tmp_frame , size = (h , w ))
101108 output .append (tmp_frame .to (output_device ))
@@ -138,7 +145,9 @@ def rife_inference_with_path(model, video_path):
138145 frame_rgb = frame [..., ::- 1 ]
139146 frame_rgb = frame_rgb .copy ()
140147 tensor = torch .from_numpy (frame_rgb ).float ().to ("cpu" , non_blocking = True ).float () / 255.0
141- pt_frame_data .append (tensor .permute (2 , 0 , 1 )) # to [c, h, w,]
148+ pt_frame_data .append (
149+ tensor .permute (2 , 0 , 1 )
150+ ) # to [c, h, w,]
142151
143152 pt_frame = torch .from_numpy (np .stack (pt_frame_data ))
144153 pt_frame = pt_frame .to (device )
@@ -167,9 +176,9 @@ def rife_inference_with_latents(model, latents):
167176 return torch .stack (rife_results )
168177
169178
170- if __name__ == "__main__" :
171- snapshot_download (repo_id = "AlexWortega/RIFE" , local_dir = "model_rife" )
172- model = load_rife_model ("model_rife" )
173-
174- video_path = rife_inference_with_path (model , "/mnt/ceph/develop/jiawei/CogVideo/output/chunk_3710_1 .mp4" )
175- print (video_path )
179+ # if __name__ == "__main__":
180+ # snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
181+ # model = load_rife_model("model_rife")
182+
183+ # video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/output/20241003_130720 .mp4")
184+ # print(video_path)
0 commit comments