4141
4242def fit (
4343 sess , network , train_op , cost , X_train , y_train , x , y_ , acc = None , batch_size = 100 , n_epoch = 100 , print_freq = 5 ,
44- X_val = None , y_val = None , eval_train = True , tensorboard = False , tensorboard_epoch_freq = 5 ,
44+ X_val = None , y_val = None , eval_train = True , tensorboard_dir = None , tensorboard_epoch_freq = 5 ,
4545 tensorboard_weight_histograms = True , tensorboard_graph_vis = True
4646):
4747 """Training a given non time-series network by the given cost function, training data, batch_size, n_epoch etc.
@@ -80,9 +80,8 @@ def fit(
8080 eval_train : boolean
8181 Whether to evaluate the model during training.
8282 If X_val and y_val are not None, it reflects whether to evaluate the model on training data.
83- tensorboard : boolean
84- If True, summary data will be stored to the log/ directory for visualization with tensorboard.
85- See also detailed tensorboard_X settings for specific configurations of features. (default False)
83+ tensorboard_dir : string
84+ path to log dir, if set, summary data will be stored to the tensorboard_dir/ directory for visualization with tensorboard. (default None)
8685 Also runs `tl.layers.initialize_global_variables(sess)` internally in fit() to setup the summary nodes.
8786 tensorboard_epoch_freq : int
8887 How many epochs between storing tensorboard checkpoint for visualization to log/ directory (default 5).
@@ -106,27 +105,27 @@ def fit(
106105
107106 Notes
108107 --------
109- If tensorboard=True , the `global_variables_initializer` will be run inside the fit function
108+ If tensorboard_dir not None , the `global_variables_initializer` will be run inside the fit function
110109 in order to initialize the automatically generated summary nodes used for tensorboard visualization,
111110 thus `tf.global_variables_initializer().run()` before the `fit()` call will be undefined.
112111
113112 """
114113 if X_train .shape [0 ] < batch_size :
115114 raise AssertionError ("Number of training examples should be bigger than the batch size" )
116115
117- if ( tensorboard ) :
116+ if tensorboard_dir is not None :
118117 tl .logging .info ("Setting up tensorboard ..." )
119118 #Set up tensorboard summaries and saver
120- tl .files .exists_or_mkdir ('logs/' )
119+ tl .files .exists_or_mkdir (tensorboard_dir )
121120
122121 #Only write summaries for more recent TensorFlow versions
123122 if hasattr (tf , 'summary' ) and hasattr (tf .summary , 'FileWriter' ):
124123 if tensorboard_graph_vis :
125- train_writer = tf .summary .FileWriter ('logs /train' , sess .graph )
126- val_writer = tf .summary .FileWriter ('logs /validation' , sess .graph )
124+ train_writer = tf .summary .FileWriter (tensorboard_dir + ' /train' , sess .graph )
125+ val_writer = tf .summary .FileWriter (tensorboard_dir + ' /validation' , sess .graph )
127126 else :
128- train_writer = tf .summary .FileWriter ('logs /train' )
129- val_writer = tf .summary .FileWriter ('logs /validation' )
127+ train_writer = tf .summary .FileWriter (tensorboard_dir + ' /train' )
128+ val_writer = tf .summary .FileWriter (tensorboard_dir + ' /validation' )
130129
131130 #Set up summary nodes
132131 if (tensorboard_weight_histograms ):
@@ -142,7 +141,7 @@ def fit(
142141
143142 #Initalize all variables and summaries
144143 tl .layers .initialize_global_variables (sess )
145- tl .logging .info ("Finished! use $ tensorboard --logdir=logs/ to start server" )
144+ tl .logging .info ("Finished! use ` tensorboard --logdir=%s/` to start tensorboard" % tensorboard_dir )
146145
147146 tl .logging .info ("Start training the network ..." )
148147 start_time_begin = time .time ()
@@ -159,7 +158,7 @@ def fit(
159158 n_step += 1
160159 loss_ep = loss_ep / n_step
161160
162- if tensorboard and hasattr (tf , 'summary' ):
161+ if tensorboard_dir is not None and hasattr (tf , 'summary' ):
163162 if epoch + 1 == 1 or (epoch + 1 ) % tensorboard_epoch_freq == 0 :
164163 for X_train_a , y_train_a in tl .iterate .minibatches (X_train , y_train , batch_size , shuffle = True ):
165164 dp_dict = dict_to_one (network .all_drop ) # disable noise layers
0 commit comments