forked from hse-aml/intro-to-dl
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathkeras_utils.py
More file actions
58 lines (47 loc) · 1.89 KB
/
keras_utils.py
File metadata and controls
58 lines (47 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import keras
import tqdm
from collections import defaultdict
import numpy as np
from keras.models import save_model
class TqdmProgressCallback(keras.callbacks.Callback):
def on_train_begin(self, logs=None):
self.epochs = self.params['epochs']
def on_epoch_begin(self, epoch, logs=None):
print('Epoch %d/%d' % (epoch + 1, self.epochs))
if "steps" in self.params:
self.use_steps = True
self.target = self.params['steps']
else:
self.use_steps = False
self.target = self.params['samples']
self.prog_bar = tqdm.tqdm_notebook(total=self.target)
self.log_values_by_metric = defaultdict(list)
def _set_prog_bar_desc(self, logs):
for k in self.params['metrics']:
if k in logs:
self.log_values_by_metric[k].append(logs[k])
desc = "; ".join("{0}: {1:.4f}".format(k, np.mean(values)) for k, values in self.log_values_by_metric.items())
self.prog_bar.set_description(desc)
def on_batch_end(self, batch, logs=None):
logs = logs or {}
if self.use_steps:
self.prog_bar.update(1)
else:
batch_size = logs.get('size', 0)
self.prog_bar.update(batch_size)
self._set_prog_bar_desc(logs)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self._set_prog_bar_desc(logs)
self.prog_bar.update(1) # workaround to show description
self.prog_bar.close()
class ModelSaveCallback(keras.callbacks.Callback):
def __init__(self, file_name):
super(ModelSaveCallback, self).__init__()
self.file_name = file_name
def on_epoch_end(self, epoch, logs=None):
model_filename = self.file_name.format(epoch)
save_model(self.model, model_filename)
print("Model saved in {}".format(model_filename))