forked from isl-org/MiDaS
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
108 lines (86 loc) · 2.8 KB
/
run.py
File metadata and controls
108 lines (86 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Compute depth maps for images in the input folder.
"""
import os
import glob
import torch
import utils
import cv2
from torchvision.transforms import Compose
from models.midas_net import MidasNet
from models.transforms import Resize, NormalizeImage, PrepareForNet
import time
def run(input_path, output_path, model_path):
"""Run MonoDepthNN to compute depth maps.
Args:
input_path (str): path to input folder
output_path (str): path to output folder
model_path (str): path to saved model
"""
print("initialize")
# select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: %s" % device)
# load network
model = MidasNet(model_path, non_negative=True)
transform = Compose(
[
Resize(
384,
384,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="upper_bound",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
]
)
model.to(device)
model.eval()
# get input
img_names = glob.glob(os.path.join(input_path, "*"))
num_images = len(img_names)
# create output folder
os.makedirs(output_path, exist_ok=True)
print("start processing")
for ind, img_name in enumerate(img_names):
start = time.time()
print(" processing {} ({}/{})".format(img_name, ind + 1, num_images))
# input
img = utils.read_image(img_name)
img_input = transform({"image": img})["image"]
# compute
with torch.no_grad():
sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
prediction = model.forward(sample)
prediction = (
torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img.shape[:2],
mode="bicubic",
align_corners=False,
)
.squeeze()
.cpu()
.numpy()
)
print(f"{img_name} took {start - time.time()} s")
# output
filename = os.path.join(
output_path, os.path.splitext(os.path.basename(img_name))[0]
)
utils.write_depth(filename, prediction, bits=2)
print("finished")
if __name__ == "__main__":
# set paths
INPUT_PATH = "input"
OUTPUT_PATH = "output"
# MODEL_PATH = "model.pt"
MODEL_PATH = "model.pt"
# set torch options
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# compute depth maps
run(INPUT_PATH, OUTPUT_PATH, MODEL_PATH)