11#! /usr/bin/python
22# -*- coding: utf8 -*-
33import tensorflow as tf
4+ import tensorlayer as tl
45from . import iterate
56import numpy as np
67import time
8+ import math
79
810
9- def fit (sess , network , train_op , cost , X_train , y_train , x , y_ , acc = None , batch_size = 100 , n_epoch = 100 , print_freq = 5 , X_val = None , y_val = None , eval_train = True ):
11+ def fit (sess , network , train_op , cost , X_train , y_train , x , y_ , acc = None , batch_size = 100 ,
12+ n_epoch = 100 , print_freq = 5 , X_val = None , y_val = None , eval_train = True ,
13+ tensorboard = False , tensorboard_epoch_freq = 5 , tensorboard_weight_histograms = True , tensorboard_graph_vis = True ):
1014 """Traing a given non time-series network by the given cost function, training data, batch_size, n_epoch etc.
1115
1216 Parameters
@@ -39,17 +43,69 @@ def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_
3943 the target of validation data
4044 eval_train : boolean
4145 if X_val and y_val are not None, it refects whether to evaluate the training data
42-
46+ tensorboard : boolean
47+ if True summary data will be stored to the log/ direcory for visualization with tensorboard.
48+ See also detailed tensorboard_X settings for specific configurations of features. (default False)
49+ Also runs tl.layers.initialize_global_variables(sess) internally in fit() to setup the summary nodes, see Note:
50+ tensorboard_epoch_freq : int
51+ how many epochs between storing tensorboard checkpoint for visualization to log/ directory (default 5)
52+ tensorboard_weight_histograms : boolean
53+ if True updates tensorboard data in the logs/ directory for visulaization
54+ of the weight histograms every tensorboard_epoch_freq epoch (default True)
55+ tensorboard_graph_vis : boolean
56+ if True stores the graph in the tensorboard summaries saved to log/ (default True)
4357 Examples
4458 --------
4559 >>> see tutorial_mnist_simple.py
4660 >>> tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_,
4761 ... acc=acc, batch_size=500, n_epoch=200, print_freq=5,
4862 ... X_val=X_val, y_val=y_val, eval_train=False)
63+ >>> tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_,
64+ ... acc=acc, batch_size=500, n_epoch=200, print_freq=5,
65+ ... X_val=X_val, y_val=y_val, eval_train=False,
66+ ... tensorboard=True, tensorboard_weight_histograms=True, tensorboard_graph_vis=True)
67+
68+ Note
69+ --------
70+ If tensorboard=True, the global_variables_initializer will be run inside the fit function
71+ in order to initalize the automatically generated summary nodes used for tensorboard visualization,
72+ thus tf.global_variables_initializer().run() before the fit() call will be undefined.
4973 """
5074 assert X_train .shape [0 ] >= batch_size , "Number of training examples should be bigger than the batch size"
75+
76+ if (tensorboard ):
77+ print ("Setting up tensorboard ..." )
78+ #Set up tensorboard summaries and saver
79+ tl .files .exists_or_mkdir ('logs/' )
80+
81+ #Only write summaries for more recent TensorFlow versions
82+ if hasattr (tf , 'summary' ) and hasattr (tf .summary , 'FileWriter' ):
83+ if tensorboard_graph_vis :
84+ train_writer = tf .summary .FileWriter ('logs/train' ,sess .graph )
85+ val_writer = tf .summary .FileWriter ('logs/validation' ,sess .graph )
86+ else :
87+ train_writer = tf .summary .FileWriter ('logs/train' )
88+ val_writer = tf .summary .FileWriter ('logs/validation' )
89+
90+ #Set up summary nodes
91+ if (tensorboard_weight_histograms ):
92+ for param in network .all_params :
93+ if hasattr (tf , 'summary' ) and hasattr (tf .summary , 'histogram' ):
94+ print ('Param name ' , param .name )
95+ tf .summary .histogram (param .name , param )
96+
97+ if hasattr (tf , 'summary' ) and hasattr (tf .summary , 'histogram' ):
98+ tf .summary .scalar ('cost' , cost )
99+
100+ merged = tf .summary .merge_all ()
101+
102+ #Initalize all variables and summaries
103+ tl .layers .initialize_global_variables (sess )
104+ print ("Finished! use $tensorboard --logdir=logs/ to start server" )
105+
51106 print ("Start training the network ..." )
52107 start_time_begin = time .time ()
108+ tensorboard_train_index , tensorboard_val_index = 0 , 0
53109 for epoch in range (n_epoch ):
54110 start_time = time .time ()
55111 loss_ep = 0 ; n_step = 0
@@ -62,6 +118,26 @@ def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_
62118 n_step += 1
63119 loss_ep = loss_ep / n_step
64120
121+ if tensorboard and hasattr (tf , 'summary' ):
122+ if epoch + 1 == 1 or (epoch + 1 ) % tensorboard_epoch_freq == 0 :
123+ for X_train_a , y_train_a in iterate .minibatches (
124+ X_train , y_train , batch_size , shuffle = True ):
125+ dp_dict = dict_to_one ( network .all_drop ) # disable noise layers
126+ feed_dict = {x : X_train_a , y_ : y_train_a }
127+ feed_dict .update (dp_dict )
128+ result = sess .run (merged , feed_dict = feed_dict )
129+ train_writer .add_summary (result , tensorboard_train_index )
130+ tensorboard_train_index += 1
131+
132+ for X_val_a , y_val_a in iterate .minibatches (
133+ X_val , y_val , batch_size , shuffle = True ):
134+ dp_dict = dict_to_one ( network .all_drop ) # disable noise layers
135+ feed_dict = {x : X_val_a , y_ : y_val_a }
136+ feed_dict .update (dp_dict )
137+ result = sess .run (merged , feed_dict = feed_dict )
138+ val_writer .add_summary (result , tensorboard_val_index )
139+ tensorboard_val_index += 1
140+
65141 if epoch + 1 == 1 or (epoch + 1 ) % print_freq == 0 :
66142 if (X_val is not None ) and (y_val is not None ):
67143 print ("Epoch %d of %d took %fs" % (epoch + 1 , n_epoch , time .time () - start_time ))
0 commit comments