-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain.py
More file actions
72 lines (60 loc) · 2.81 KB
/
train.py
File metadata and controls
72 lines (60 loc) · 2.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import torch
from ignite.metrics import Loss
from challenge.dataset import get_loaders
from challenge.loss import BCEWithLogitsLoss
from challenge.metric import Accuracy, F1Score
from challenge.train import train_model
from challenge.utils import load_model
from commons.utils import init_logger, get_logger, remove_resource_limits
from commons.utils.model import unfreeze
from config import config
if __name__ == '__main__':
remove_resource_limits()
init_logger(f'{config.exp}_sz{config.image_size}_x{config.batch_size}', config.log_path, config.tensorboard_path)
logger = get_logger()
logger.info(f'PID: {os.getpid()}')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
default_loaders, loaders = get_loaders(path=config.data_path,
image_size=config.image_size,
n_splits=config.k_fold,
test_size=config.test_size,
batch_size=config.batch_size,
num_workers=config.num_workers,
external=config.external_data,
use_sampler=config.use_sampler)
metrics = {
'loss': Loss(BCEWithLogitsLoss()),
'acc': Accuracy(),
'F1': F1Score()
}
logger.info(f'Model: {config.model}')
logger.info(f'External data: {config.external_data}')
logger.info(f'K-fold: {config.k_fold}')
logger.info(f'Mixed precision: {config.mixed_precision}')
logger.info(f'Image size: {config.image_size}')
logger.info(f'Batch size: {config.batch_size}')
for fold in range(config.k_fold):
if config.k_fold > 1:
logger.info(f'Fold: {fold + 1}')
suffix = f'_f{fold + 1}'
else:
suffix = ''
if config.checkpoint:
checkpoint = config.checkpoint.format(fold + 1) if config.k_fold > 1 else config.checkpoint
logger.info(f'Use checkpoint: checkpoint_{checkpoint}.pth')
model = load_model(config.model, f'{config.model_path}/checkpoint_{checkpoint}.pth', config.mixed_precision)
unfreeze(model)
else:
model = load_model(config.model)
train_model(name=f'{config.exp}_sz{config.image_size}_x{config.batch_size}{suffix}',
model=model,
data_loaders=loaders[fold],
metrics=metrics,
device=device,
lr=config.lr,
num_epochs=config.num_epochs,
cycles_len=config.cycles_len,
lr_divs=config.lr_divs,
mixed_precision=config.mixed_precision,
backup_path=config.model_path)