88import logging
99import skvideo .io
1010from rife .RIFE_HDv3 import Model
11+ from huggingface_hub import hf_hub_download , snapshot_download
1112
1213logger = logging .getLogger (__name__ )
14+
1315device = "cuda" if torch .cuda .is_available () else "cpu"
1416
1517
@@ -18,8 +20,9 @@ def pad_image(img, scale):
1820 tmp = max (32 , int (32 / scale ))
1921 ph = ((h - 1 ) // tmp + 1 ) * tmp
2022 pw = ((w - 1 ) // tmp + 1 ) * tmp
21- padding = (0 , 0 , pw - w , ph - h )
22- return F .pad (img , padding )
23+ padding = (0 , pw - w , 0 , ph - h )
24+
25+ return F .pad (img , padding ), padding
2326
2427
2528def make_inference (model , I0 , I1 , upscale_amount , n ):
@@ -36,31 +39,49 @@ def make_inference(model, I0, I1, upscale_amount, n):
3639
3740@torch .inference_mode ()
3841def ssim_interpolation_rife (model , samples , exp = 1 , upscale_amount = 1 , output_device = "cpu" ):
39-
42+ print (f"samples dtype:{ samples .dtype } " )
43+ print (f"samples shape:{ samples .shape } " )
4044 output = []
45+ pbar = utils .ProgressBar (samples .shape [0 ], desc = "RIFE inference" )
4146 # [f, c, h, w]
4247 for b in range (samples .shape [0 ]):
4348 frame = samples [b : b + 1 ]
4449 _ , _ , h , w = frame .shape
50+
4551 I0 = samples [b : b + 1 ]
4652 I1 = samples [b + 1 : b + 2 ] if b + 2 < samples .shape [0 ] else samples [- 1 :]
47- I1 = pad_image (I1 , upscale_amount )
53+
54+ I0 , padding = pad_image (I0 , upscale_amount )
55+ I0 = I0 .to (torch .float )
56+ I1 , _ = pad_image (I1 , upscale_amount )
57+ I1 = I1 .to (torch .float )
58+
4859 # [c, h, w]
4960 I0_small = F .interpolate (I0 , (32 , 32 ), mode = "bilinear" , align_corners = False )
5061 I1_small = F .interpolate (I1 , (32 , 32 ), mode = "bilinear" , align_corners = False )
5162
5263 ssim = ssim_matlab (I0_small [:, :3 ], I1_small [:, :3 ])
5364
5465 if ssim > 0.996 :
55- I1 = I0
56- I1 = pad_image (I1 , upscale_amount )
66+ I1 = samples [b : b + 1 ]
67+ # print(f'upscale_amount:{upscale_amount}')
68+ # print(f'ssim:{upscale_amount}')
69+ # print(f'I0 shape:{I0.shape}')
70+ # print(f'I1 shape:{I1.shape}')
71+ I1 , padding = pad_image (I1 , upscale_amount )
72+ # print(f'I0 shape:{I0.shape}')
73+ # print(f'I1 shape:{I1.shape}')
5774 I1 = make_inference (model , I0 , I1 , upscale_amount , 1 )
5875
59- I1_small = F .interpolate (I1 [0 ], (32 , 32 ), mode = "bilinear" , align_corners = False )
60- ssim = ssim_matlab (I0_small [:, :3 ], I1_small [:, :3 ])
61- frame = I1 [0 ]
76+ # print(f'I0 shape:{I0.shape}')
77+ # print(f'I1[0] shape:{I1[0].shape}')
6278 I1 = I1 [0 ]
6379
80+ # print(f'I1[0] unpadded shape:{I1.shape}')
81+ I1_small = F .interpolate (I1 , (32 , 32 ), mode = "bilinear" , align_corners = False )
82+ ssim = ssim_matlab (I0_small [:, :3 ], I1_small [:, :3 ])
83+ frame = I1 [padding [0 ] :, padding [2 ] :, : - padding [3 ], padding [1 ] :]
84+
6485 tmp_output = []
6586 if ssim < 0.2 :
6687 for i in range ((2 ** exp ) - 1 ):
@@ -69,10 +90,16 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
6990 else :
7091 tmp_output = make_inference (model , I0 , I1 , upscale_amount , 2 ** exp - 1 ) if exp else []
7192
72- frame = pad_image (frame , upscale_amount )
73- tmp_output = [frame ] + tmp_output
74- for i , frame in enumerate (tmp_output ):
75- output .append (frame .to (output_device ))
93+ frame , _ = pad_image (frame , upscale_amount )
94+ # print(f'frame shape:{frame.shape}')
95+
96+ frame = F .interpolate (frame , size = (h , w ))
97+ output .append (frame .to (output_device ))
98+ for i , tmp_frame in enumerate (tmp_output ):
99+ # tmp_frame, _ = pad_image(tmp_frame, upscale_amount)
100+ tmp_frame = F .interpolate (tmp_frame , size = (h , w ))
101+ output .append (tmp_frame .to (output_device ))
102+ pbar .update (1 )
76103 return output
77104
78105
@@ -94,14 +121,24 @@ def frame_generator(video_capture):
94121
95122
96123def rife_inference_with_path (model , video_path ):
124+ # Open the video file
97125 video_capture = cv2 .VideoCapture (video_path )
98- tot_frame = video_capture .get (cv2 .CAP_PROP_FRAME_COUNT )
126+ fps = video_capture .get (cv2 .CAP_PROP_FPS ) # Get the frames per second
127+ tot_frame = int (video_capture .get (cv2 .CAP_PROP_FRAME_COUNT )) # Total frames in the video
99128 pt_frame_data = []
100129 pt_frame = skvideo .io .vreader (video_path )
101- for frame in pt_frame :
102- pt_frame_data .append (
103- torch .from_numpy (np .transpose (frame , (2 , 0 , 1 ))).to ("cpu" , non_blocking = True ).float () / 255.0
104- )
130+ # Cyclic reading of the video frames
131+ while video_capture .isOpened ():
132+ ret , frame = video_capture .read ()
133+
134+ if not ret :
135+ break
136+
137+ # BGR to RGB
138+ frame_rgb = frame [..., ::- 1 ]
139+ frame_rgb = frame_rgb .copy ()
140+ tensor = torch .from_numpy (frame_rgb ).float ().to ("cpu" , non_blocking = True ).float () / 255.0
141+ pt_frame_data .append (tensor .permute (2 , 0 , 1 )) # to [c, h, w,]
105142
106143 pt_frame = torch .from_numpy (np .stack (pt_frame_data ))
107144 pt_frame = pt_frame .to (device )
@@ -122,8 +159,17 @@ def rife_inference_with_latents(model, latents):
122159 for i in range (latents .size (0 )):
123160 # [f, c, w, h]
124161 latent = latents [i ]
162+
125163 frames = ssim_interpolation_rife (model , latent )
126164 pt_image = torch .stack ([frames [i ].squeeze (0 ) for i in range (len (frames ))]) # (to [f, c, w, h])
127165 rife_results .append (pt_image )
128166
129167 return torch .stack (rife_results )
168+
169+
170+ if __name__ == "__main__" :
171+ snapshot_download (repo_id = "AlexWortega/RIFE" , local_dir = "model_rife" )
172+ model = load_rife_model ("model_rife" )
173+
174+ video_path = rife_inference_with_path (model , "/mnt/ceph/develop/jiawei/CogVideo/output/chunk_3710_1.mp4" )
175+ print (video_path )
0 commit comments