88import logging
99import skvideo .io
1010from rife .RIFE_HDv3 import Model
11-
11+ from huggingface_hub import hf_hub_download , snapshot_download
1212logger = logging .getLogger (__name__ )
13+
1314device = "cuda" if torch .cuda .is_available () else "cpu"
1415
1516
@@ -18,8 +19,8 @@ def pad_image(img, scale):
1819 tmp = max (32 , int (32 / scale ))
1920 ph = ((h - 1 ) // tmp + 1 ) * tmp
2021 pw = ((w - 1 ) // tmp + 1 ) * tmp
21- padding = (0 , 0 , pw - w , ph - h )
22- return F .pad (img , padding )
22+ padding = (0 , pw - w , 0 , ph - h )
23+ return F .pad (img , padding ), padding
2324
2425
2526def make_inference (model , I0 , I1 , upscale_amount , n ):
@@ -36,30 +37,56 @@ def make_inference(model, I0, I1, upscale_amount, n):
3637
3738@torch .inference_mode ()
3839def ssim_interpolation_rife (model , samples , exp = 1 , upscale_amount = 1 , output_device = "cpu" ):
39-
40+ print (f"samples dtype:{ samples .dtype } " )
41+ print (f"samples shape:{ samples .shape } " )
4042 output = []
43+ pbar = utils .ProgressBar (samples .shape [0 ], desc = "RIFE inference" )
4144 # [f, c, h, w]
4245 for b in range (samples .shape [0 ]):
4346 frame = samples [b : b + 1 ]
4447 _ , _ , h , w = frame .shape
48+
4549 I0 = samples [b : b + 1 ]
4650 I1 = samples [b + 1 : b + 2 ] if b + 2 < samples .shape [0 ] else samples [- 1 :]
47- I1 = pad_image (I1 , upscale_amount )
51+
52+ I0 , padding = pad_image (I0 , upscale_amount )
53+ I0 = I0 .to (torch .float )
54+ I1 , _ = pad_image (I1 , upscale_amount )
55+ I1 = I1 .to (torch .float )
56+
4857 # [c, h, w]
4958 I0_small = F .interpolate (I0 , (32 , 32 ), mode = "bilinear" , align_corners = False )
5059 I1_small = F .interpolate (I1 , (32 , 32 ), mode = "bilinear" , align_corners = False )
5160
5261 ssim = ssim_matlab (I0_small [:, :3 ], I1_small [:, :3 ])
5362
5463 if ssim > 0.996 :
55- I1 = I0
56- I1 = pad_image (I1 , upscale_amount )
64+ I1 = samples [b : b + 1 ]
65+ # print(f'upscale_amount:{upscale_amount}')
66+ # print(f'ssim:{upscale_amount}')
67+ # print(f'I0 shape:{I0.shape}')
68+ # print(f'I1 shape:{I1.shape}')
69+ I1 , padding = pad_image (I1 , upscale_amount )
70+ # print(f'I0 shape:{I0.shape}')
71+ # print(f'I1 shape:{I1.shape}')
5772 I1 = make_inference (model , I0 , I1 , upscale_amount , 1 )
58-
59- I1_small = F .interpolate (I1 [0 ], (32 , 32 ), mode = "bilinear" , align_corners = False )
60- ssim = ssim_matlab (I0_small [:, :3 ], I1_small [:, :3 ])
61- frame = I1 [0 ]
73+
74+ # print(f'I0 shape:{I0.shape}')
75+ # print(f'I1[0] shape:{I1[0].shape}')
6276 I1 = I1 [0 ]
77+
78+ # print(f'I1[0] unpadded shape:{I1.shape}')
79+ I1_small = F .interpolate (I1 , (32 , 32 ), mode = "bilinear" , align_corners = False )
80+ ssim = ssim_matlab (I0_small [:, :3 ], I1_small [:, :3 ])
81+ if padding [3 ] > 0 and padding [1 ] > 0 :
82+
83+ frame = I1 [:, :, : - padding [3 ],:- padding [1 ]]
84+ elif padding [3 ] > 0 :
85+ frame = I1 [:, :, : - padding [3 ],:]
86+ elif padding [1 ] > 0 :
87+ frame = I1 [:, :, :,:- padding [1 ]]
88+ else :
89+ frame = I1
6390
6491 tmp_output = []
6592 if ssim < 0.2 :
@@ -69,10 +96,17 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
6996 else :
7097 tmp_output = make_inference (model , I0 , I1 , upscale_amount , 2 ** exp - 1 ) if exp else []
7198
72- frame = pad_image (frame , upscale_amount )
73- tmp_output = [frame ] + tmp_output
74- for i , frame in enumerate (tmp_output ):
75- output .append (frame .to (output_device ))
99+ frame , _ = pad_image (frame , upscale_amount )
100+ # print(f'frame shape:{frame.shape}')
101+
102+ frame = F .interpolate (frame , size = (h , w ))
103+ output .append (frame .to (output_device ))
104+ for i , tmp_frame in enumerate (tmp_output ):
105+
106+ # tmp_frame, _ = pad_image(tmp_frame, upscale_amount)
107+ tmp_frame = F .interpolate (tmp_frame , size = (h , w ))
108+ output .append (tmp_frame .to (output_device ))
109+ pbar .update (1 )
76110 return output
77111
78112
@@ -94,14 +128,26 @@ def frame_generator(video_capture):
94128
95129
96130def rife_inference_with_path (model , video_path ):
131+ # Open the video file
97132 video_capture = cv2 .VideoCapture (video_path )
98- tot_frame = video_capture .get (cv2 .CAP_PROP_FRAME_COUNT )
133+ fps = video_capture .get (cv2 .CAP_PROP_FPS ) # Get the frames per second
134+ tot_frame = int (video_capture .get (cv2 .CAP_PROP_FRAME_COUNT )) # Total frames in the video
99135 pt_frame_data = []
100136 pt_frame = skvideo .io .vreader (video_path )
101- for frame in pt_frame :
137+ # Cyclic reading of the video frames
138+ while video_capture .isOpened ():
139+ ret , frame = video_capture .read ()
140+
141+ if not ret :
142+ break
143+
144+ # BGR to RGB
145+ frame_rgb = frame [..., ::- 1 ]
146+ frame_rgb = frame_rgb .copy ()
147+ tensor = torch .from_numpy (frame_rgb ).float ().to ("cpu" , non_blocking = True ).float () / 255.0
102148 pt_frame_data .append (
103- torch . from_numpy ( np . transpose ( frame , ( 2 , 0 , 1 ))). to ( "cpu" , non_blocking = True ). float () / 255.0
104- )
149+ tensor . permute ( 2 , 0 , 1 )
150+ ) # to [c, h, w,]
105151
106152 pt_frame = torch .from_numpy (np .stack (pt_frame_data ))
107153 pt_frame = pt_frame .to (device )
@@ -122,8 +168,17 @@ def rife_inference_with_latents(model, latents):
122168 for i in range (latents .size (0 )):
123169 # [f, c, w, h]
124170 latent = latents [i ]
171+
125172 frames = ssim_interpolation_rife (model , latent )
126173 pt_image = torch .stack ([frames [i ].squeeze (0 ) for i in range (len (frames ))]) # (to [f, c, w, h])
127174 rife_results .append (pt_image )
128175
129176 return torch .stack (rife_results )
177+
178+
179+ # if __name__ == "__main__":
180+ # snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
181+ # model = load_rife_model("model_rife")
182+
183+ # video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/output/20241003_130720.mp4")
184+ # print(video_path)
0 commit comments