Skip to content

Commit 0c35ba2

Browse files
authored
Merge pull request #6 from qubvel/feature/train
Train utils
2 parents 9c79ffa + 38b7da6 commit 0c35ba2

File tree

10 files changed

+282
-15
lines changed

10 files changed

+282
-15
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
torchvision==0.2.2
2-
pretrainedmodels==0.7.4
2+
pretrainedmodels==0.7.4
3+
torchnet==0.0.4

segmentation_models_pytorch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
from .fpn import FPN
44
from .pspnet import PSPNet
55

6-
from . import encoders
6+
from . import encoders
7+
from . import utils

segmentation_models_pytorch/base/encoder_decoder.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,7 @@ def __init__(self, encoder, decoder, activation):
1818
self.activation = nn.Sigmoid()
1919
else:
2020
raise ValueError('Activation should be "sigmoid" or "softmax"')
21-
22-
if encoder.pretrained:
23-
self.set_preprocessing_params(
24-
input_size=encoder.input_size,
25-
input_space=encoder.input_space,
26-
input_range=encoder.input_range,
27-
mean=encoder.mean,
28-
std=encoder.std,
29-
)
30-
21+
3122
def forward(self, x):
3223
x = self.encoder(x)
3324
x = self.decoder(x)

segmentation_models_pytorch/encoders/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def get_preprocessing_fn(encoder_name, pretrained='imagenet'):
4545
mean = settings[pretrained].get('mean')
4646
std = settings[pretrained].get('std')
4747

48-
def _preprocess_input(x):
49-
return preprocess_input(x, mean=mean, std=std, input_space=input_space, input_range=input_range)
48+
def _preprocess_input(x, **kwargs):
49+
return preprocess_input(x, mean=mean, std=std, input_space=input_space, input_range=input_range, **kwargs)
5050

5151
return _preprocess_input

segmentation_models_pytorch/encoders/_preprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

33

4-
def preprocess_input(x, mean=None, std=None, input_space='RGB', input_range=None):
4+
def preprocess_input(x, mean=None, std=None, input_space='RGB', input_range=None, **kwargs):
55

66
if input_space == 'BGR':
77
x = x[..., ::-1].copy()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from . import train
2+
from . import losses
3+
from . import metrics
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
3+
4+
def iou(pr, gt, eps=1e-7, threshold=None, activation='sigmoid'):
5+
"""
6+
Source:
7+
https://github.com/catalyst-team/catalyst/
8+
Args:
9+
pr (torch.Tensor): A list of predicted elements
10+
gt (torch.Tensor): A list of elements that are to be predicted
11+
eps (float): epsilon to avoid zero division
12+
threshold: threshold for outputs binarization
13+
Returns:
14+
float: IoU (Jaccard) score
15+
"""
16+
17+
if activation is None or activation == "none":
18+
activation_fn = lambda x: x
19+
elif activation == "sigmoid":
20+
activation_fn = torch.nn.Sigmoid()
21+
elif activation == "softmax2d":
22+
activation_fn = torch.nn.Softmax2d()
23+
else:
24+
raise NotImplementedError(
25+
"Activation implemented for sigmoid and softmax2d"
26+
)
27+
28+
pr = activation_fn(pr)
29+
30+
if threshold is not None:
31+
pr = (pr > threshold).float()
32+
33+
intersection = torch.sum(gt * pr)
34+
union = torch.sum(gt) + torch.sum(pr) - intersection + eps
35+
return (intersection + eps) / union
36+
37+
jaccard = iou
38+
39+
40+
def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, activation='sigmoid'):
41+
"""
42+
Args:
43+
pr (torch.Tensor): A list of predicted elements
44+
gt (torch.Tensor): A list of elements that are to be predicted
45+
eps (float): epsilon to avoid zero division
46+
threshold: threshold for outputs binarization
47+
Returns:
48+
float: IoU (Jaccard) score
49+
"""
50+
51+
if activation is None or activation == "none":
52+
activation_fn = lambda x: x
53+
elif activation == "sigmoid":
54+
activation_fn = torch.nn.Sigmoid()
55+
elif activation == "softmax2d":
56+
activation_fn = torch.nn.Softmax2d()
57+
else:
58+
raise NotImplementedError(
59+
"Activation implemented for sigmoid and softmax2d"
60+
)
61+
62+
pr = activation_fn(pr)
63+
64+
if threshold is not None:
65+
pr = (pr > threshold).float()
66+
67+
68+
tp = torch.sum(gt * pr)
69+
fp = torch.sum(pr) - tp
70+
fn = torch.sum(gt) - tp
71+
72+
score = ((1 + beta ** 2) * tp + eps) \
73+
/ ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps)
74+
75+
return score
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch.nn as nn
2+
from . import functions as F
3+
4+
5+
class JaccardLoss(nn.Module):
6+
__name__ = 'jaccard_loss'
7+
8+
def __init__(self, eps=1e-7, activation='sigmoid'):
9+
super().__init__()
10+
self.activation = activation
11+
self.eps = eps
12+
13+
def forward(self, y_pr, y_gt):
14+
return 1 - F.jaccard(y_pr, y_gt, eps=self.eps, threshold=None, activation=self.activation)
15+
16+
17+
class DiceLoss(nn.Module):
18+
__name__ = 'dice_loss'
19+
20+
def __init__(self, eps=1e-7, activation='sigmoid'):
21+
super().__init__()
22+
self.activation = activation
23+
self.eps = eps
24+
25+
def forward(self, y_pr, y_gt):
26+
return 1 - F.f_score(y_pr, y_gt, beta=1., eps=self.eps, threshold=None, activation=self.activation)
27+
28+
29+
class BCEJaccardLoss(JaccardLoss):
30+
__name__ = 'bce_jaccard_loss'
31+
32+
def __init__(self, eps=1e-7, activation='sigmoid'):
33+
super().__init__(eps, activation)
34+
self.bce = nn.BCEWithLogitsLoss(reduction='mean')
35+
36+
def forward(self, y_pr, y_gt):
37+
jaccard = super().forward(y_pr, y_gt)
38+
bce = self.bce(y_pr, y_gt)
39+
return jaccard + bce
40+
41+
42+
class BCEDiceLoss(DiceLoss):
43+
__name__ = 'bce_dice_loss'
44+
45+
def __init__(self, eps=1e-7, activation='sigmoid'):
46+
super().__init__(eps, activation)
47+
self.bce = nn.BCEWithLogitsLoss(reduction='mean')
48+
49+
def forward(self, y_pr, y_gt):
50+
dice = super().forward(y_pr, y_gt)
51+
bce = self.bce(y_pr, y_gt)
52+
return dice + bce
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch.nn as nn
2+
from . import functions as F
3+
4+
5+
class IoUMetric(nn.Module):
6+
7+
__name__ = 'iou'
8+
9+
def __init__(self, eps=1e-7, threshold=0.5, activation='sigmoid'):
10+
super().__init__()
11+
self.activation = activation
12+
self.eps = eps
13+
self.threshold = threshold
14+
15+
def forward(self, y_pr, y_gt):
16+
return F.iou(y_pr, y_gt, self.eps, self.threshold, self.activation)
17+
18+
19+
class FscoreMetric(nn.Module):
20+
21+
__name__ = 'f-score'
22+
23+
def __init__(self, beta=1, eps=1e-7, threshold=0.5, activation='sigmoid'):
24+
super().__init__()
25+
self.activation = activation
26+
self.eps = eps
27+
self.threshold = threshold
28+
self.beta = beta
29+
30+
def forward(self, y_pr, y_gt):
31+
return F.f_score(y_pr, y_gt, self.beta, self.eps, self.threshold, self.activation)
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import sys
2+
import torch
3+
from tqdm import tqdm as tqdm
4+
from torchnet.meter import AverageValueMeter
5+
6+
7+
class Epoch:
8+
9+
def __init__(self, model, loss, metrics, stage_name, device='cpu', verbose=True):
10+
self.model = model
11+
self.loss = loss
12+
self.metrics = metrics
13+
self.stage_name = stage_name
14+
self.verbose = verbose
15+
self.device = device
16+
17+
self._to_device()
18+
19+
def _to_device(self):
20+
self.model.to(self.device)
21+
self.loss.to(self.device)
22+
for metric in self.metrics:
23+
metric.to(self.device)
24+
25+
def _format_logs(self, logs):
26+
str_logs = ['{} - {:.4}'.format(k, v) for k, v in logs.items()]
27+
s = ', '.join(str_logs)
28+
return s
29+
30+
def batch_update(self, x, y):
31+
raise NotImplementedError
32+
33+
def on_epoch_start(self):
34+
pass
35+
36+
def run(self, dataloader):
37+
38+
self.on_epoch_start()
39+
40+
logs = {}
41+
loss_meter = AverageValueMeter()
42+
metrics_meters = {metric.__name__: AverageValueMeter() for metric in self.metrics}
43+
44+
with tqdm(dataloader, desc=self.stage_name, file=sys.stdout, disable=not (self.verbose)) as iterator:
45+
for x, y in iterator:
46+
x, y = x.to(self.device), y.to(self.device)
47+
loss, y_pred = self.batch_update(x, y)
48+
49+
# update loss logs
50+
loss_value = loss.cpu().detach().numpy()
51+
loss_meter.add(loss_value)
52+
loss_logs = {self.loss.__name__: loss_meter.mean}
53+
logs.update(loss_logs)
54+
55+
# update metrics logs
56+
for metric_fn in self.metrics:
57+
metric_value = metric_fn(y_pred, y).cpu().detach().numpy()
58+
metrics_meters[metric_fn.__name__].add(metric_value)
59+
metrics_logs = {k: v.mean for k, v in metrics_meters.items()}
60+
logs.update(metrics_logs)
61+
62+
if self.verbose:
63+
s = self._format_logs(logs)
64+
iterator.set_postfix_str(s)
65+
66+
return logs
67+
68+
69+
class TrainEpoch(Epoch):
70+
71+
def __init__(self, model, loss, metrics, optimizer, device='cpu', verbose=True):
72+
super().__init__(
73+
model=model,
74+
loss=loss,
75+
metrics=metrics,
76+
stage_name='train',
77+
device=device,
78+
verbose=verbose,
79+
)
80+
self.optimizer = optimizer
81+
82+
def on_epoch_start(self):
83+
self.model.train()
84+
85+
def batch_update(self, x, y):
86+
self.optimizer.zero_grad()
87+
prediction = self.model.forward(x)
88+
loss = self.loss(prediction, y)
89+
loss.backward()
90+
self.optimizer.step()
91+
return loss, prediction
92+
93+
94+
class ValidEpoch(Epoch):
95+
96+
def __init__(self, model, loss, metrics, device='cpu', verbose=True):
97+
super().__init__(
98+
model=model,
99+
loss=loss,
100+
metrics=metrics,
101+
stage_name='valid',
102+
device=device,
103+
verbose=verbose,
104+
)
105+
106+
def on_epoch_start(self):
107+
self.model.eval()
108+
109+
def batch_update(self, x, y):
110+
with torch.no_grad():
111+
prediction = self.model.forward(x)
112+
loss = self.loss(prediction, y)
113+
return loss, prediction

0 commit comments

Comments
 (0)