Skip to content

Commit e4b8356

Browse files
committed
Basic train utils
1 parent 4d506a8 commit e4b8356

File tree

4 files changed

+106
-25
lines changed

4 files changed

+106
-25
lines changed
Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import torch
22

33

4-
def iou(outputs, targets, eps=1e-7, threshold=None, activation='sigmoid'):
4+
def iou(pr, gt, eps=1e-7, threshold=None, activation='sigmoid'):
55
"""
66
Args:
7-
outputs (torch.Tensor): A list of predicted elements
8-
targets (torch.Tensor): A list of elements that are to be predicted
7+
pr (torch.Tensor): A list of predicted elements
8+
gt (torch.Tensor): A list of elements that are to be predicted
99
eps (float): epsilon to avoid zero division
1010
threshold: threshold for outputs binarization
1111
Returns:
@@ -20,16 +20,54 @@ def iou(outputs, targets, eps=1e-7, threshold=None, activation='sigmoid'):
2020
activation_fn = torch.nn.Softmax2d()
2121
else:
2222
raise NotImplementedError(
23-
"Dice is only implemented for sigmoid and softmax2d"
23+
"Activation implemented for sigmoid and softmax2d"
2424
)
2525

26-
outputs = activation_fn(outputs)
26+
pr = activation_fn(pr)
2727

2828
if threshold is not None:
29-
outputs = (outputs > threshold).float()
29+
pr = (pr > threshold).float()
3030

31-
intersection = torch.sum(targets * outputs)
32-
union = torch.sum(targets) + torch.sum(outputs) - intersection + eps
31+
intersection = torch.sum(gt * pr)
32+
union = torch.sum(gt) + torch.sum(pr) - intersection + eps
3333
return (intersection + eps) / union
3434

35-
jaccard = iou
35+
jaccard = iou
36+
37+
38+
def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, activation='sigmoid'):
39+
"""
40+
Args:
41+
pr (torch.Tensor): A list of predicted elements
42+
gt (torch.Tensor): A list of elements that are to be predicted
43+
eps (float): epsilon to avoid zero division
44+
threshold: threshold for outputs binarization
45+
Returns:
46+
float: IoU (Jaccard) score
47+
"""
48+
49+
if activation is None or activation == "none":
50+
activation_fn = lambda x: x
51+
elif activation == "sigmoid":
52+
activation_fn = torch.nn.Sigmoid()
53+
elif activation == "softmax2d":
54+
activation_fn = torch.nn.Softmax2d()
55+
else:
56+
raise NotImplementedError(
57+
"Activation implemented for sigmoid and softmax2d"
58+
)
59+
60+
pr = activation_fn(pr)
61+
62+
if threshold is not None:
63+
pr = (pr > threshold).float()
64+
65+
66+
tp = torch.sum(gt * pr)
67+
fp = torch.sum(pr) - tp
68+
fn = torch.sum(gt) - tp
69+
70+
score = ((1 + beta ** 2) * tp + eps) \
71+
/ ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps)
72+
73+
return score

segmentation_models_pytorch/utils/losses.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44

55
class JaccardLoss(nn.Module):
6-
76
__name__ = 'jaccard_loss'
87

98
def __init__(self, eps=1e-7, activation='sigmoid'):
@@ -12,11 +11,22 @@ def __init__(self, eps=1e-7, activation='sigmoid'):
1211
self.eps = eps
1312

1413
def forward(self, y_pr, y_gt):
15-
return F.iou(y_pr, y_gt, self.eps, threshold=None, activation=self.activation)
14+
return 1 - F.jaccard(y_pr, y_gt, eps=self.eps, threshold=None, activation=self.activation)
1615

1716

18-
class BCEJaccardLoss(JaccardLoss):
17+
class DiceLoss(nn.Module):
18+
__name__ = 'dice_loss'
19+
20+
def __init__(self, eps=1e-7, activation='sigmoid'):
21+
super().__init__()
22+
self.activation = activation
23+
self.eps = eps
24+
25+
def forward(self, y_pr, y_gt):
26+
return 1 - F.f_score(y_pr, y_gt, beta=1., eps=self.eps, threshold=None, activation=self.activation)
27+
1928

29+
class BCEJaccardLoss(JaccardLoss):
2030
__name__ = 'bce_jaccard_loss'
2131

2232
def __init__(self, eps=1e-7, activation='sigmoid'):
@@ -26,4 +36,17 @@ def __init__(self, eps=1e-7, activation='sigmoid'):
2636
def forward(self, y_pr, y_gt):
2737
jaccard = super().forward(y_pr, y_gt)
2838
bce = self.bce(y_pr, y_gt)
29-
return jaccard + bce
39+
return jaccard + bce
40+
41+
42+
class BCEDiceLoss(DiceLoss):
43+
__name__ = 'bce_dice_loss'
44+
45+
def __init__(self, eps=1e-7, activation='sigmoid'):
46+
super().__init__(eps, activation)
47+
self.bce = nn.BCEWithLogitsLoss(reduction='mean')
48+
49+
def forward(self, y_pr, y_gt):
50+
dice = super().forward(y_pr, y_gt)
51+
bce = self.bce(y_pr, y_gt)
52+
return dice + bce

segmentation_models_pytorch/utils/metrics.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,18 @@ def __init__(self, eps=1e-7, threshold=0.5, activation='sigmoid'):
1414

1515
def forward(self, y_pr, y_gt):
1616
return F.iou(y_pr, y_gt, self.eps, self.threshold, self.activation)
17+
18+
19+
class FscoreMetric(nn.Module):
20+
21+
__name__ = 'f-score'
22+
23+
def __init__(self, beta=1, eps=1e-7, threshold=0.5, activation='sigmoid'):
24+
super().__init__()
25+
self.activation = activation
26+
self.eps = eps
27+
self.threshold = threshold
28+
self.beta = beta
29+
30+
def forward(self, y_pr, y_gt):
31+
return F.f_score(y_pr, y_gt, self.beta, self.eps, self.threshold, self.activation)

segmentation_models_pytorch/utils/train.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,16 @@ def __init__(self, model, loss, metrics, stage_name, device='cpu', verbose=True)
1414
self.verbose = verbose
1515
self.device = device
1616

17+
self._to_device()
18+
19+
def _to_device(self):
20+
self.model.to(self.device)
21+
self.loss.to(self.device)
22+
for metric in self.metrics:
23+
metric.to(self.device)
24+
1725
def _format_logs(self, logs):
18-
str_logs = ['{}_{} - {:.4}'.format(self.stage_name, k, v) for k, v in logs.items()]
26+
str_logs = ['{} - {:.4}'.format(k, v) for k, v in logs.items()]
1927
s = ', '.join(str_logs)
2028
return s
2129

@@ -25,33 +33,30 @@ def batch_update(self, x, y):
2533
def on_epoch_start(self):
2634
pass
2735

28-
def run(self, dataloder):
36+
def run(self, dataloader):
2937

3038
self.on_epoch_start()
3139

3240
logs = {}
3341
loss_meter = AverageValueMeter()
34-
metrics_meters = {m.__name__: AverageValueMeter() for m in self.metrics}
42+
metrics_meters = {metric.__name__: AverageValueMeter() for metric in self.metrics}
3543

36-
with tqdm(dataloder, desc=self.stage_name, file=sys.stdout, disable=not (self.verbose)) as iterator:
44+
with tqdm(dataloader, desc=self.stage_name, file=sys.stdout, disable=not (self.verbose)) as iterator:
3745
for x, y in iterator:
3846
x, y = x.to(self.device), y.to(self.device)
3947
loss, y_pred = self.batch_update(x, y)
4048

41-
# update loss meter
49+
# update loss logs
4250
loss_value = loss.cpu().detach().numpy()
4351
loss_meter.add(loss_value)
52+
loss_logs = {self.loss.__name__: loss_meter.mean}
53+
logs.update(loss_logs)
4454

45-
# update metrics meters
55+
# update metrics logs
4656
for metric_fn in self.metrics:
4757
metric_value = metric_fn(y_pred, y).cpu().detach().numpy()
4858
metrics_meters[metric_fn.__name__].add(metric_value)
49-
50-
# create_logs
51-
loss_logs = {'loss': loss_meter.mean}
5259
metrics_logs = {k: v.mean for k, v in metrics_meters.items()}
53-
54-
logs.update(loss_logs)
5560
logs.update(metrics_logs)
5661

5762
if self.verbose:
@@ -105,4 +110,4 @@ def batch_update(self, x, y):
105110
with torch.no_grad():
106111
prediction = self.model.forward(x)
107112
loss = self.loss(prediction, y)
108-
return loss, prediction
113+
return loss, prediction

0 commit comments

Comments
 (0)