-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathstep2_score_instances_with_SAM.py
More file actions
118 lines (83 loc) · 4.1 KB
/
step2_score_instances_with_SAM.py
File metadata and controls
118 lines (83 loc) · 4.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright (C) 2024 * Ltd. All rights reserved.
# author: Sanghyun Jo <shjo.april@gmail.com>
import cmapy
import numpy as np
import sanghyunjo as shjo
from torch.nn import functional as F
from core import evaluators
from core.sam2 import SAM2
if __name__ == '__main__':
args = shjo.Parser(
{
'image': './data/MoNuSeg/train/image/',
'pred': './submissions/Ours+PSM@CRF(G=2)/train_instance/',
'mask': './submissions/Ours+PSM@CRF(G=2)/train_sam/',
'sam': './weights/sam2.1_hiera_l.pt',
'threshold': -1.,
'strategy': 'mean',
}
)
model_sam = SAM2(args.sam)
mask_dir = shjo.makedir(args.mask)
semantic_dir = shjo.makedir(args.mask.replace('_sam', '_sam2semantic'))
score_dir = shjo.makedir(args.mask.replace('_sam', '_sam2score'))
colors = shjo.get_colors() # RGB color
colors[0] = [0, 0, 0] # background
colors[1] = [32, 167, 132] # foreground (e.g., cell)
# colors = colors[:, ::-1]
for pred_path in shjo.progress(sorted(shjo.listdir(args.pred + '*.png'))):
pred_path = pred_path.replace('\\', '/')
mask_name = shjo.basename(pred_path)
image_id = mask_name.replace('.png', '')
pred_mask = shjo.imread(args.pred + image_id + '.png')
pred_mask = pred_mask.astype(np.int64)
pred_mask = pred_mask[:, :, 0] * 256 + pred_mask[:, :, 1]
cv_image = shjo.imread(shjo.listdir(args.image + image_id + '.*')[0])
model_sam.set_image(cv_image)
scores = []
masks = []
for index in shjo.progress(np.unique(pred_mask), image_id):
if index == 0: continue
pred_ins_mask = (pred_mask == index).astype(np.uint8)
cy, cx = map(lambda x: int(np.average(x)), np.where(pred_ins_mask > 0))
sam_ins_mask = model_sam.predict(point_coords=[(cx, cy)], point_labels=[1])
IoU = evaluators.get_IoU(pred_ins_mask, sam_ins_mask)
scores.append(IoU)
masks.append(sam_ins_mask > 0)
scores = np.asarray(scores)
masks = np.asarray(masks)
sorted_indices = np.argsort(scores) # [::-1]
sorted_masks = masks[sorted_indices]
sorted_scores = scores[sorted_indices]
if args.threshold == -1:
if args.strategy == 'mean':
threshold = scores.mean()
else:
threshold = min(scores.mean() + scores.std(), 0.9)
else:
threshold = args.threshold
sorted_masks = sorted_masks[sorted_scores > threshold]
instance_mask = np.zeros_like(pred_mask).astype(np.int64)
instance_score = np.zeros_like(pred_mask).astype(np.float32)
for instance_id, (mask, score) in enumerate(zip(sorted_masks, sorted_scores)):
instance_mask[mask > 0] = instance_id + 1
instance_score[mask > 0] = score
# shjo.write_image(score_dir + image_id + '.png', (instance_score * 255).astype(np.uint8), cmapy.cmap('seismic', rgb_order=True)[:, 0, :])
shjo.imwrite(score_dir + image_id + '.png', shjo.colorize((instance_score * 255).astype(np.uint8)))
h, w = instance_mask.shape
reliable_mask = np.zeros((h, w, 3), dtype=np.uint8)
reliable_mask[pred_mask > 0, :] = 255
for index in np.unique(instance_mask):
if index == 0: continue # ignore background
reliable_mask[instance_mask == index, 0] = index // 256 # B channel
reliable_mask[instance_mask == index, 1] = index % 256 # G channel
reliable_mask[instance_mask == index, 2] = 0 # R channel
shjo.imwrite(args.mask + image_id + '.png', reliable_mask)
# instance to semantic
pred_mask = reliable_mask.astype(np.int64)
ignore_mask = pred_mask.sum(axis=2) == (255*3)
pred_mask[ignore_mask, :] = 0
pred_mask = pred_mask[:, :, 0] * 256 + pred_mask[:, :, 1]
pred_mask = (pred_mask > 0).astype(np.uint8)
pred_mask[ignore_mask] = 255
shjo.imwrite(semantic_dir + image_id + '.png', pred_mask, colors)