Skip to content

Commit badd3d5

Browse files
ndiyJonathan DEKHTIAR
authored andcommitted
Update utils.py (#819)
* Update utils.py update fit function replace tensorboard with tensorboard_dir parameter to enable customized log folder * Update utils.py * Update utils.py * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md
1 parent f8ec986 commit badd3d5

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ To release a new version, please update the changelog as followed:
7272
### Added
7373

7474
### Changed
75+
- remove 'tensorboard' param, replaced by 'tensorboard_dir' in `tensorlayer/utils.py` with customizable tensorboard directory (PR #819)
7576

7677
### Deprecated
7778

@@ -86,6 +87,7 @@ To release a new version, please update the changelog as followed:
8687
- pytest-cov>=2.5,<2.6 => pytest-cov>=2.5,<2.7 (PR #820)
8788

8889
### Contributors
90+
- @ndiy: #819
8991

9092
- @DEKHTIARJonathan: #815 #820
9193

tensorlayer/utils.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
def fit(
4343
sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_size=100, n_epoch=100, print_freq=5,
44-
X_val=None, y_val=None, eval_train=True, tensorboard=False, tensorboard_epoch_freq=5,
44+
X_val=None, y_val=None, eval_train=True, tensorboard_dir=None, tensorboard_epoch_freq=5,
4545
tensorboard_weight_histograms=True, tensorboard_graph_vis=True
4646
):
4747
"""Training a given non time-series network by the given cost function, training data, batch_size, n_epoch etc.
@@ -80,9 +80,8 @@ def fit(
8080
eval_train : boolean
8181
Whether to evaluate the model during training.
8282
If X_val and y_val are not None, it reflects whether to evaluate the model on training data.
83-
tensorboard : boolean
84-
If True, summary data will be stored to the log/ directory for visualization with tensorboard.
85-
See also detailed tensorboard_X settings for specific configurations of features. (default False)
83+
tensorboard_dir : string
84+
path to log dir, if set, summary data will be stored to the tensorboard_dir/ directory for visualization with tensorboard. (default None)
8685
Also runs `tl.layers.initialize_global_variables(sess)` internally in fit() to setup the summary nodes.
8786
tensorboard_epoch_freq : int
8887
How many epochs between storing tensorboard checkpoint for visualization to log/ directory (default 5).
@@ -106,27 +105,27 @@ def fit(
106105
107106
Notes
108107
--------
109-
If tensorboard=True, the `global_variables_initializer` will be run inside the fit function
108+
If tensorboard_dir not None, the `global_variables_initializer` will be run inside the fit function
110109
in order to initialize the automatically generated summary nodes used for tensorboard visualization,
111110
thus `tf.global_variables_initializer().run()` before the `fit()` call will be undefined.
112111
113112
"""
114113
if X_train.shape[0] < batch_size:
115114
raise AssertionError("Number of training examples should be bigger than the batch size")
116115

117-
if (tensorboard):
116+
if tensorboard_dir is not None:
118117
tl.logging.info("Setting up tensorboard ...")
119118
#Set up tensorboard summaries and saver
120-
tl.files.exists_or_mkdir('logs/')
119+
tl.files.exists_or_mkdir(tensorboard_dir)
121120

122121
#Only write summaries for more recent TensorFlow versions
123122
if hasattr(tf, 'summary') and hasattr(tf.summary, 'FileWriter'):
124123
if tensorboard_graph_vis:
125-
train_writer = tf.summary.FileWriter('logs/train', sess.graph)
126-
val_writer = tf.summary.FileWriter('logs/validation', sess.graph)
124+
train_writer = tf.summary.FileWriter(tensorboard_dir + '/train', sess.graph)
125+
val_writer = tf.summary.FileWriter(tensorboard_dir + '/validation', sess.graph)
127126
else:
128-
train_writer = tf.summary.FileWriter('logs/train')
129-
val_writer = tf.summary.FileWriter('logs/validation')
127+
train_writer = tf.summary.FileWriter(tensorboard_dir + '/train')
128+
val_writer = tf.summary.FileWriter(tensorboard_dir + '/validation')
130129

131130
#Set up summary nodes
132131
if (tensorboard_weight_histograms):
@@ -142,7 +141,7 @@ def fit(
142141

143142
#Initalize all variables and summaries
144143
tl.layers.initialize_global_variables(sess)
145-
tl.logging.info("Finished! use $tensorboard --logdir=logs/ to start server")
144+
tl.logging.info("Finished! use `tensorboard --logdir=%s/` to start tensorboard" % tensorboard_dir)
146145

147146
tl.logging.info("Start training the network ...")
148147
start_time_begin = time.time()
@@ -159,7 +158,7 @@ def fit(
159158
n_step += 1
160159
loss_ep = loss_ep / n_step
161160

162-
if tensorboard and hasattr(tf, 'summary'):
161+
if tensorboard_dir is not None and hasattr(tf, 'summary'):
163162
if epoch + 1 == 1 or (epoch + 1) % tensorboard_epoch_freq == 0:
164163
for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True):
165164
dp_dict = dict_to_one(network.all_drop) # disable noise layers

0 commit comments

Comments
 (0)