Skip to content

Commit 3228797

Browse files
committed
Initial train example
1 parent 071e18d commit 3228797

File tree

6 files changed

+193
-1
lines changed

6 files changed

+193
-1
lines changed

segmentation_models_pytorch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
from .fpn import FPN
44
from .pspnet import PSPNet
55

6-
from . import encoders
6+
from . import encoders
7+
from . import utils
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from . import train
2+
from . import losses
3+
from . import metrics
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
3+
4+
def iou(outputs, targets, eps=1e-7, threshold=None, activation='sigmoid'):
5+
"""
6+
Args:
7+
outputs (torch.Tensor): A list of predicted elements
8+
targets (torch.Tensor): A list of elements that are to be predicted
9+
eps (float): epsilon to avoid zero division
10+
threshold: threshold for outputs binarization
11+
Returns:
12+
float: IoU (Jaccard) score
13+
"""
14+
15+
if activation is None or activation == "none":
16+
activation_fn = lambda x: x
17+
elif activation == "sigmoid":
18+
activation_fn = torch.nn.Sigmoid()
19+
elif activation == "softmax2d":
20+
activation_fn = torch.nn.Softmax2d()
21+
else:
22+
raise NotImplementedError(
23+
"Dice is only implemented for sigmoid and softmax2d"
24+
)
25+
26+
outputs = activation_fn(outputs)
27+
28+
if threshold is not None:
29+
outputs = (outputs > threshold).float()
30+
31+
intersection = torch.sum(targets * outputs)
32+
union = torch.sum(targets) + torch.sum(outputs) - intersection + eps
33+
return (intersection + eps) / union
34+
35+
jaccard = iou
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch.nn as nn
2+
from . import functions as F
3+
4+
5+
class JaccardLoss(nn.Module):
6+
7+
__name__ = 'jaccard_loss'
8+
9+
def __init__(self, eps=1e-7, activation='sigmoid'):
10+
super().__init__()
11+
self.activation = activation
12+
self.eps = eps
13+
14+
def forward(self, y_pr, y_gt):
15+
return F.iou(y_pr, y_gt, self.eps, threshold=None, activation=self.activation)
16+
17+
18+
class BCEJaccardLoss(JaccardLoss):
19+
20+
__name__ = 'bce_jaccard_loss'
21+
22+
def __init__(self, eps=1e-7, activation='sigmoid'):
23+
super().__init__(eps, activation)
24+
self.bce = nn.BCEWithLogitsLoss(reduction='mean')
25+
26+
def forward(self, y_pr, y_gt):
27+
jaccard = super().forward(y_pr, y_gt)
28+
bce = self.bce(y_pr, y_gt)
29+
return jaccard + bce
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch.nn as nn
2+
from . import functions as F
3+
4+
5+
class IoUMetric(nn.Module):
6+
7+
__name__ = 'iou'
8+
9+
def __init__(self, eps=1e-7, threshold=0.5, activation='sigmoid'):
10+
super().__init__()
11+
self.activation = activation
12+
self.eps = eps
13+
self.threshold = threshold
14+
15+
def forward(self, y_pr, y_gt):
16+
return F.iou(y_pr, y_gt, self.eps, self.threshold, self.activation)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import sys
2+
import torch
3+
from tqdm import tqdm as tqdm
4+
from torchnet.meter import AverageValueMeter
5+
6+
7+
class Epoch:
8+
9+
def __init__(self, model, loss, metrics, stage_name, device='cpu', verbose=True):
10+
self.model = model
11+
self.loss = loss
12+
self.metrics = metrics
13+
self.stage_name = stage_name
14+
self.verbose = verbose
15+
self.device = device
16+
17+
def _format_logs(self, logs):
18+
str_logs = ['{}_{} - {:.4}'.format(self.stage_name, k, v) for k, v in logs.items()]
19+
s = ', '.join(str_logs)
20+
return s
21+
22+
def batch_update(self, x, y):
23+
raise NotImplementedError
24+
25+
def on_epoch_start(self):
26+
pass
27+
28+
def run(self, dataloder):
29+
30+
self.on_epoch_start()
31+
32+
logs = {}
33+
loss_meter = AverageValueMeter()
34+
metrics_meters = {m.__name__: AverageValueMeter() for m in self.metrics}
35+
36+
with tqdm(dataloder, desc=self.stage_name, file=sys.stdout, disable=not (self.verbose)) as iterator:
37+
for x, y in iterator:
38+
x, y = x.to(self.device), y.to(self.device)
39+
loss, y_pred = self.batch_update(x, y)
40+
41+
# update loss meter
42+
loss_value = loss.cpu().detach().numpy()
43+
loss_meter.add(loss_value)
44+
45+
# update metrics meters
46+
for metric_fn in self.metrics:
47+
metric_value = metric_fn(y_pred, y).cpu().detach().numpy()
48+
metrics_meters[metric_fn.__name__].add(metric_value)
49+
50+
# create_logs
51+
loss_logs = {'loss': loss_meter.mean}
52+
metrics_logs = {k: v.mean for k, v in metrics_meters.items()}
53+
54+
logs.update(loss_logs)
55+
logs.update(metrics_logs)
56+
57+
if self.verbose:
58+
s = self._format_logs(logs)
59+
iterator.set_postfix_str(s)
60+
61+
return logs
62+
63+
64+
class TrainEpoch(Epoch):
65+
66+
def __init__(self, model, loss, metrics, optimizer, device='cpu', verbose=True):
67+
super().__init__(
68+
model=model,
69+
loss=loss,
70+
metrics=metrics,
71+
stage_name='train',
72+
device=device,
73+
verbose=verbose,
74+
)
75+
self.optimizer = optimizer
76+
77+
def on_epoch_start(self):
78+
self.model.train()
79+
80+
def batch_update(self, x, y):
81+
self.optimizer.zero_grad()
82+
prediction = self.model.forward(x)
83+
loss = self.loss(prediction, y)
84+
loss.backward()
85+
self.optimizer.step()
86+
return loss, prediction
87+
88+
89+
class ValidEpoch(Epoch):
90+
91+
def __init__(self, model, loss, metrics, device='cpu', verbose=True):
92+
super().__init__(
93+
model=model,
94+
loss=loss,
95+
metrics=metrics,
96+
stage_name='valid',
97+
device=device,
98+
verbose=verbose,
99+
)
100+
101+
def on_epoch_start(self):
102+
self.model.eval()
103+
104+
def batch_update(self, x, y):
105+
with torch.no_grad():
106+
prediction = self.model.forward(x)
107+
loss = self.loss(prediction, y)
108+
return loss, prediction

0 commit comments

Comments
 (0)