Skip to content

Commit 1fc77f9

Browse files
authored
Merge pull request #78 from JoelKronander/tensorboard_visualization_in_fit_function
Added support for automatic tensorboard visualization in fit()
2 parents 584f3c5 + dfd8731 commit 1fc77f9

File tree

1 file changed

+78
-2
lines changed

1 file changed

+78
-2
lines changed

tensorlayer/utils.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
#! /usr/bin/python
22
# -*- coding: utf8 -*-
33
import tensorflow as tf
4+
import tensorlayer as tl
45
from . import iterate
56
import numpy as np
67
import time
8+
import math
79

810

9-
def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_size=100, n_epoch=100, print_freq=5, X_val=None, y_val=None, eval_train=True):
11+
def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_size=100,
12+
n_epoch=100, print_freq=5, X_val=None, y_val=None, eval_train=True,
13+
tensorboard=False, tensorboard_epoch_freq=5, tensorboard_weight_histograms=True, tensorboard_graph_vis=True):
1014
"""Traing a given non time-series network by the given cost function, training data, batch_size, n_epoch etc.
1115
1216
Parameters
@@ -39,17 +43,69 @@ def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_
3943
the target of validation data
4044
eval_train : boolean
4145
if X_val and y_val are not None, it refects whether to evaluate the training data
42-
46+
tensorboard : boolean
47+
if True summary data will be stored to the log/ direcory for visualization with tensorboard.
48+
See also detailed tensorboard_X settings for specific configurations of features. (default False)
49+
Also runs tl.layers.initialize_global_variables(sess) internally in fit() to setup the summary nodes, see Note:
50+
tensorboard_epoch_freq : int
51+
how many epochs between storing tensorboard checkpoint for visualization to log/ directory (default 5)
52+
tensorboard_weight_histograms : boolean
53+
if True updates tensorboard data in the logs/ directory for visulaization
54+
of the weight histograms every tensorboard_epoch_freq epoch (default True)
55+
tensorboard_graph_vis : boolean
56+
if True stores the graph in the tensorboard summaries saved to log/ (default True)
4357
Examples
4458
--------
4559
>>> see tutorial_mnist_simple.py
4660
>>> tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_,
4761
... acc=acc, batch_size=500, n_epoch=200, print_freq=5,
4862
... X_val=X_val, y_val=y_val, eval_train=False)
63+
>>> tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_,
64+
... acc=acc, batch_size=500, n_epoch=200, print_freq=5,
65+
... X_val=X_val, y_val=y_val, eval_train=False,
66+
... tensorboard=True, tensorboard_weight_histograms=True, tensorboard_graph_vis=True)
67+
68+
Note
69+
--------
70+
If tensorboard=True, the global_variables_initializer will be run inside the fit function
71+
in order to initalize the automatically generated summary nodes used for tensorboard visualization,
72+
thus tf.global_variables_initializer().run() before the fit() call will be undefined.
4973
"""
5074
assert X_train.shape[0] >= batch_size, "Number of training examples should be bigger than the batch size"
75+
76+
if(tensorboard):
77+
print("Setting up tensorboard ...")
78+
#Set up tensorboard summaries and saver
79+
tl.files.exists_or_mkdir('logs/')
80+
81+
#Only write summaries for more recent TensorFlow versions
82+
if hasattr(tf, 'summary') and hasattr(tf.summary, 'FileWriter'):
83+
if tensorboard_graph_vis:
84+
train_writer = tf.summary.FileWriter('logs/train',sess.graph)
85+
val_writer = tf.summary.FileWriter('logs/validation',sess.graph)
86+
else:
87+
train_writer = tf.summary.FileWriter('logs/train')
88+
val_writer = tf.summary.FileWriter('logs/validation')
89+
90+
#Set up summary nodes
91+
if(tensorboard_weight_histograms):
92+
for param in network.all_params:
93+
if hasattr(tf, 'summary') and hasattr(tf.summary, 'histogram'):
94+
print('Param name ', param.name)
95+
tf.summary.histogram(param.name, param)
96+
97+
if hasattr(tf, 'summary') and hasattr(tf.summary, 'histogram'):
98+
tf.summary.scalar('cost', cost)
99+
100+
merged = tf.summary.merge_all()
101+
102+
#Initalize all variables and summaries
103+
tl.layers.initialize_global_variables(sess)
104+
print("Finished! use $tensorboard --logdir=logs/ to start server")
105+
51106
print("Start training the network ...")
52107
start_time_begin = time.time()
108+
tensorboard_train_index, tensorboard_val_index = 0, 0
53109
for epoch in range(n_epoch):
54110
start_time = time.time()
55111
loss_ep = 0; n_step = 0
@@ -62,6 +118,26 @@ def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_
62118
n_step += 1
63119
loss_ep = loss_ep/ n_step
64120

121+
if tensorboard and hasattr(tf, 'summary'):
122+
if epoch+1 == 1 or (epoch+1) % tensorboard_epoch_freq == 0:
123+
for X_train_a, y_train_a in iterate.minibatches(
124+
X_train, y_train, batch_size, shuffle=True):
125+
dp_dict = dict_to_one( network.all_drop ) # disable noise layers
126+
feed_dict = {x: X_train_a, y_: y_train_a}
127+
feed_dict.update(dp_dict)
128+
result = sess.run(merged, feed_dict=feed_dict)
129+
train_writer.add_summary(result, tensorboard_train_index)
130+
tensorboard_train_index += 1
131+
132+
for X_val_a, y_val_a in iterate.minibatches(
133+
X_val, y_val, batch_size, shuffle=True):
134+
dp_dict = dict_to_one( network.all_drop ) # disable noise layers
135+
feed_dict = {x: X_val_a, y_: y_val_a}
136+
feed_dict.update(dp_dict)
137+
result = sess.run(merged, feed_dict=feed_dict)
138+
val_writer.add_summary(result, tensorboard_val_index)
139+
tensorboard_val_index += 1
140+
65141
if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
66142
if (X_val is not None) and (y_val is not None):
67143
print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))

0 commit comments

Comments
 (0)