Skip to content

Commit 7935c24

Browse files
Merge pull request #480 from PVSemk/fba-matting
FBA Matting
2 parents 2706dbe + 91c757e commit 7935c24

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+1148
-0
lines changed

FBAMatting/LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2020 Marco Forte
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

FBAMatting/README.md

Lines changed: 42 additions & 0 deletions

FBAMatting/dataloader.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import os
2+
3+
import cv2
4+
import numpy as np
5+
from torch.utils.data import Dataset
6+
7+
8+
class PredDataset(Dataset):
9+
""" Reads image and trimap pairs from folder.
10+
11+
"""
12+
13+
def __init__(self, img_dir, trimap_dir):
14+
self.img_dir, self.trimap_dir = img_dir, trimap_dir
15+
self.img_names = [
16+
x
17+
for x in os.listdir(self.img_dir)
18+
if os.path.isfile(os.path.join(self.img_dir, x))
19+
]
20+
21+
def __len__(self):
22+
return len(self.img_names)
23+
24+
def __getitem__(self, idx):
25+
img_name = self.img_names[idx]
26+
trimap_name = img_name[:-3] + "png"
27+
28+
image = read_image(os.path.join(self.img_dir, img_name))
29+
trimap = read_trimap(os.path.join(self.trimap_dir, trimap_name))
30+
pred_dict = {"image": image, "trimap": trimap, "name": img_name}
31+
32+
return pred_dict
33+
34+
35+
def read_image(name):
36+
return (cv2.imread(name) / 255.0)[:, :, ::-1]
37+
38+
39+
def read_trimap(name):
40+
trimap_im = cv2.imread(name, 0) / 255.0
41+
h, w = trimap_im.shape
42+
trimap = np.zeros((h, w, 2))
43+
trimap[trimap_im == 1, 1] = 1
44+
trimap[trimap_im == 0, 0] = 1
45+
return trimap

FBAMatting/demo.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import argparse
2+
import os
3+
4+
import cv2
5+
import numpy as np
6+
import torch
7+
from dataloader import PredDataset
8+
from networks.models import build_model
9+
from networks.transforms import (
10+
groupnorm_normalise_image,
11+
trimap_transform,
12+
)
13+
from tqdm import tqdm
14+
15+
16+
def np_to_torch(x):
17+
return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float()
18+
19+
20+
def scale_input(x: np.ndarray, scale: float, scale_type) -> np.ndarray:
21+
""" Scales inputs to multiple of 8. """
22+
h, w = x.shape[:2]
23+
h1 = int(np.ceil(scale * h / 8) * 8)
24+
w1 = int(np.ceil(scale * w / 8) * 8)
25+
x_scale = cv2.resize(x, (w1, h1), interpolation=scale_type)
26+
return x_scale
27+
28+
29+
def swap_bg(image, alpha):
30+
green_bg = np.zeros_like(image).astype(np.float32)
31+
green_bg[:, :, 1] = 255
32+
33+
alpha = alpha[:, :, np.newaxis]
34+
result = alpha * image.astype(np.float32) + (1 - alpha) * green_bg
35+
result = np.clip(result, 0, 255).astype(np.uint8)
36+
37+
return result
38+
39+
40+
def predict_fba_folder(model, args):
41+
save_dir = args.output_dir
42+
os.makedirs(save_dir, exist_ok=True)
43+
44+
dataset_test = PredDataset(args.image_dir, args.trimap_dir)
45+
46+
gen = iter(dataset_test)
47+
for item_dict in tqdm(gen):
48+
image_np = item_dict["image"]
49+
trimap_np = item_dict["trimap"]
50+
51+
fg, bg, alpha = pred(image_np, trimap_np, model, args)
52+
53+
cv2.imwrite(
54+
os.path.join(save_dir, item_dict["name"][:-4] + "_fg.png"),
55+
fg[:, :, ::-1] * 255,
56+
)
57+
cv2.imwrite(
58+
os.path.join(save_dir, item_dict["name"][:-4] + "_bg.png"),
59+
bg[:, :, ::-1] * 255,
60+
)
61+
cv2.imwrite(
62+
os.path.join(save_dir, item_dict["name"][:-4] + "_alpha.png"), alpha * 255,
63+
)
64+
65+
example_swap_bg = swap_bg(fg[:, :, ::-1] * 255, alpha)
66+
cv2.imwrite(
67+
os.path.join(save_dir, item_dict["name"][:-4] + "_swapped_bg.png"), example_swap_bg,
68+
)
69+
70+
71+
def pred(image_np: np.ndarray, trimap_np: np.ndarray, model, args) -> np.ndarray:
72+
""" Predict alpha, foreground and background.
73+
Parameters:
74+
image_np -- the image in rgb format between 0 and 1. Dimensions: (h, w, 3)
75+
trimap_np -- two channel trimap, first background then foreground. Dimensions: (h, w, 2)
76+
Returns:
77+
fg: foreground image in rgb format between 0 and 1. Dimensions: (h, w, 3)
78+
bg: background image in rgb format between 0 and 1. Dimensions: (h, w, 3)
79+
alpha: alpha matte image between 0 and 1. Dimensions: (h, w)
80+
"""
81+
h, w = trimap_np.shape[:2]
82+
83+
image_scale_np = scale_input(image_np, 1.0, cv2.INTER_LANCZOS4)
84+
trimap_scale_np = scale_input(trimap_np, 1.0, cv2.INTER_LANCZOS4)
85+
86+
with torch.no_grad():
87+
88+
image_torch = np_to_torch(image_scale_np).to(args.device)
89+
trimap_torch = np_to_torch(trimap_scale_np).to(args.device)
90+
91+
trimap_transformed_torch = np_to_torch(trimap_transform(trimap_scale_np)).to(
92+
args.device,
93+
)
94+
image_transformed_torch = groupnorm_normalise_image(
95+
image_torch.clone(), format="nchw",
96+
)
97+
98+
output = model(
99+
image_torch,
100+
trimap_torch,
101+
image_transformed_torch,
102+
trimap_transformed_torch,
103+
)
104+
105+
output = cv2.resize(
106+
output[0].cpu().numpy().transpose((1, 2, 0)), (w, h), cv2.INTER_LANCZOS4,
107+
)
108+
alpha = output[:, :, 0]
109+
fg = output[:, :, 1:4]
110+
bg = output[:, :, 4:7]
111+
112+
alpha[trimap_np[:, :, 0] == 1] = 0
113+
alpha[trimap_np[:, :, 1] == 1] = 1
114+
fg[alpha == 1] = image_np[alpha == 1]
115+
bg[alpha == 0] = image_np[alpha == 0]
116+
return fg, bg, alpha
117+
118+
119+
if __name__ == "__main__":
120+
121+
parser = argparse.ArgumentParser()
122+
# Model related arguments
123+
parser.add_argument("--encoder", default="resnet50_GN_WS", help="Encoder model")
124+
parser.add_argument("--decoder", default="fba_decoder", help="Decoder model")
125+
parser.add_argument("--weights", default="FBA.pth")
126+
parser.add_argument("--image_dir", default="./examples/images", help="")
127+
parser.add_argument(
128+
"--trimap_dir", default="./examples/trimaps", help="",
129+
)
130+
parser.add_argument("--output_dir", default="./examples/predictions", help="")
131+
parser.add_argument("--device", default="cpu", help="Device for inference on")
132+
133+
args = parser.parse_args()
134+
model = build_model(args).to(args.device)
135+
model.eval()
136+
predict_fba_folder(model, args)
500 KB
1.81 MB
1.49 MB
1.62 MB
1.01 MB
329 KB

0 commit comments

Comments
 (0)