Skip to content

Commit 111756a

Browse files
Merge pull request #358 from glide-the/rile
bug fix rife
2 parents a9a5546 + f0098c0 commit 111756a

File tree

1 file changed

+74
-19
lines changed

1 file changed

+74
-19
lines changed

inference/gradio_composite_demo/rife_model.py

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
import logging
99
import skvideo.io
1010
from rife.RIFE_HDv3 import Model
11-
11+
from huggingface_hub import hf_hub_download, snapshot_download
1212
logger = logging.getLogger(__name__)
13+
1314
device = "cuda" if torch.cuda.is_available() else "cpu"
1415

1516

@@ -18,8 +19,8 @@ def pad_image(img, scale):
1819
tmp = max(32, int(32 / scale))
1920
ph = ((h - 1) // tmp + 1) * tmp
2021
pw = ((w - 1) // tmp + 1) * tmp
21-
padding = (0, 0, pw - w, ph - h)
22-
return F.pad(img, padding)
22+
padding = (0, pw - w, 0, ph - h)
23+
return F.pad(img, padding), padding
2324

2425

2526
def make_inference(model, I0, I1, upscale_amount, n):
@@ -36,30 +37,56 @@ def make_inference(model, I0, I1, upscale_amount, n):
3637

3738
@torch.inference_mode()
3839
def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_device="cpu"):
39-
40+
print(f"samples dtype:{samples.dtype}")
41+
print(f"samples shape:{samples.shape}")
4042
output = []
43+
pbar = utils.ProgressBar(samples.shape[0], desc="RIFE inference")
4144
# [f, c, h, w]
4245
for b in range(samples.shape[0]):
4346
frame = samples[b : b + 1]
4447
_, _, h, w = frame.shape
48+
4549
I0 = samples[b : b + 1]
4650
I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:]
47-
I1 = pad_image(I1, upscale_amount)
51+
52+
I0, padding = pad_image(I0, upscale_amount)
53+
I0 = I0.to(torch.float)
54+
I1, _ = pad_image(I1, upscale_amount)
55+
I1 = I1.to(torch.float)
56+
4857
# [c, h, w]
4958
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
5059
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
5160

5261
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
5362

5463
if ssim > 0.996:
55-
I1 = I0
56-
I1 = pad_image(I1, upscale_amount)
64+
I1 = samples[b : b + 1]
65+
# print(f'upscale_amount:{upscale_amount}')
66+
# print(f'ssim:{upscale_amount}')
67+
# print(f'I0 shape:{I0.shape}')
68+
# print(f'I1 shape:{I1.shape}')
69+
I1, padding = pad_image(I1, upscale_amount)
70+
# print(f'I0 shape:{I0.shape}')
71+
# print(f'I1 shape:{I1.shape}')
5772
I1 = make_inference(model, I0, I1, upscale_amount, 1)
58-
59-
I1_small = F.interpolate(I1[0], (32, 32), mode="bilinear", align_corners=False)
60-
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
61-
frame = I1[0]
73+
74+
# print(f'I0 shape:{I0.shape}')
75+
# print(f'I1[0] shape:{I1[0].shape}')
6276
I1 = I1[0]
77+
78+
# print(f'I1[0] unpadded shape:{I1.shape}')
79+
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
80+
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
81+
if padding[3] > 0 and padding[1] >0 :
82+
83+
frame = I1[:, :, : -padding[3],:-padding[1]]
84+
elif padding[3] > 0:
85+
frame = I1[:, :, : -padding[3],:]
86+
elif padding[1] >0:
87+
frame = I1[:, :, :,:-padding[1]]
88+
else:
89+
frame = I1
6390

6491
tmp_output = []
6592
if ssim < 0.2:
@@ -69,10 +96,17 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
6996
else:
7097
tmp_output = make_inference(model, I0, I1, upscale_amount, 2**exp - 1) if exp else []
7198

72-
frame = pad_image(frame, upscale_amount)
73-
tmp_output = [frame] + tmp_output
74-
for i, frame in enumerate(tmp_output):
75-
output.append(frame.to(output_device))
99+
frame, _ = pad_image(frame, upscale_amount)
100+
# print(f'frame shape:{frame.shape}')
101+
102+
frame = F.interpolate(frame, size=(h, w))
103+
output.append(frame.to(output_device))
104+
for i, tmp_frame in enumerate(tmp_output):
105+
106+
# tmp_frame, _ = pad_image(tmp_frame, upscale_amount)
107+
tmp_frame = F.interpolate(tmp_frame, size=(h, w))
108+
output.append(tmp_frame.to(output_device))
109+
pbar.update(1)
76110
return output
77111

78112

@@ -94,14 +128,26 @@ def frame_generator(video_capture):
94128

95129

96130
def rife_inference_with_path(model, video_path):
131+
# Open the video file
97132
video_capture = cv2.VideoCapture(video_path)
98-
tot_frame = video_capture.get(cv2.CAP_PROP_FRAME_COUNT)
133+
fps = video_capture.get(cv2.CAP_PROP_FPS) # Get the frames per second
134+
tot_frame = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) # Total frames in the video
99135
pt_frame_data = []
100136
pt_frame = skvideo.io.vreader(video_path)
101-
for frame in pt_frame:
137+
# Cyclic reading of the video frames
138+
while video_capture.isOpened():
139+
ret, frame = video_capture.read()
140+
141+
if not ret:
142+
break
143+
144+
# BGR to RGB
145+
frame_rgb = frame[..., ::-1]
146+
frame_rgb = frame_rgb.copy()
147+
tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0
102148
pt_frame_data.append(
103-
torch.from_numpy(np.transpose(frame, (2, 0, 1))).to("cpu", non_blocking=True).float() / 255.0
104-
)
149+
tensor.permute(2, 0, 1)
150+
) # to [c, h, w,]
105151

106152
pt_frame = torch.from_numpy(np.stack(pt_frame_data))
107153
pt_frame = pt_frame.to(device)
@@ -122,8 +168,17 @@ def rife_inference_with_latents(model, latents):
122168
for i in range(latents.size(0)):
123169
# [f, c, w, h]
124170
latent = latents[i]
171+
125172
frames = ssim_interpolation_rife(model, latent)
126173
pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h])
127174
rife_results.append(pt_image)
128175

129176
return torch.stack(rife_results)
177+
178+
179+
# if __name__ == "__main__":
180+
# snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
181+
# model = load_rife_model("model_rife")
182+
183+
# video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/output/20241003_130720.mp4")
184+
# print(video_path)

0 commit comments

Comments
 (0)