forked from shentong-hbu/TriNerSeg
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict3d.py
More file actions
93 lines (76 loc) · 2.89 KB
/
predict3d.py
File metadata and controls
93 lines (76 loc) · 2.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn.functional as F
import numpy as np
import os
import glob
from tqdm import tqdm
import SimpleITK as sitk
import torchio
import time
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
DATABASE = 'TRINE/'
#
args = {
'root' : '/',
'test_path': './dataset/' + DATABASE + 'test/',
'pred_path': 'assets/' + 'SegResults/',
}
if not os.path.exists(args['pred_path']):
os.makedirs(args['pred_path'])
def load_3dV2():
test_images = sorted(glob.glob(os.path.join(args['test_path'], "images", "*.nii.gz")))
test_labels = sorted(glob.glob(os.path.join(args['test_path'], "labels", "*.nii.gz")))
return test_images, test_labels
def load_net():
net = torch.load('./checkpoint/FinerRes2CSNet.pth')
print(net)
return net
def save_prediction(pred, filename='', spacing=None, origin=None, direction=None):
pred = torch.argmax(pred, dim=1)
save_path = args['pred_path'] + 'pred/'
if not os.path.exists(save_path):
os.makedirs(save_path)
print("Make dirs success!")
mask = pred.data.cpu().numpy()
mask = mask / np.max(mask)
mask = (mask * 255).astype(np.uint8)
mask = mask.squeeze(0) # for CE Loss
mask = np.transpose(mask, axes=(2, 1, 0))
mask = sitk.GetImageFromArray(mask)
if spacing is not None:
mask.SetSpacing(spacing)
mask.SetOrigin(origin)
mask.SetDirection(direction)
sitk.WriteImage(mask, os.path.join(save_path + filename + ".mha"))
def predict():
net = load_net()
images, labels = load_3dV2()
print(len(images))
with torch.no_grad():
net.eval()
fps = []
for i in tqdm(range(len(images))):
name_list = images[i].split('/')
index = name_list[-1][:-7]
transform = torchio.RescaleIntensity()
subject = torchio.Subject(
image=torchio.ScalarImage(images[i]),
label=torchio.LabelMap(labels[i]),
)
transformed = transform(subject)
image = transformed['image'][torchio.DATA].numpy().astype(np.float32)
label = transformed['label'][torchio.DATA].numpy().astype(np.int64).squeeze(0)
# select a reference volume to obtain the spacing, origin, and direction
config = sitk.ReadImage(images[i])
spacing = config.GetSpacing()
origin = config.GetOrigin()
direction = config.GetDirection()
# if cuda
image = torch.from_numpy(np.ascontiguousarray(image)).unsqueeze(0)
# image = torch.from_numpy(np.ascontiguousarray(image)).unsqueeze(0).unsqueeze(0)
image = image.cuda()
coarse, output = net(image)
# save_prediction(output, affine=affine, filename=index + '_pred')
save_prediction(output, filename=index + '_pred', spacing=spacing, origin=origin, direction=direction)
if __name__ == '__main__':
predict()