Skip to content

Commit 25e65ab

Browse files
committed
Update the example using tensorboard.
1 parent 5d5b511 commit 25e65ab

File tree

4 files changed

+138
-8
lines changed

4 files changed

+138
-8
lines changed

examples/basic_tutorials/mnist_mlp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ class CustomModel(Module):
3535

3636
def __init__(self):
3737
super(CustomModel, self).__init__()
38-
self.dropout1 = Dropout(p=0.8)
38+
self.dropout1 = Dropout(p=0.2)
3939
self.linear1 = Linear(out_features=800, act=tlx.ReLU, in_features=784)
40-
self.dropout2 = Dropout(p=0.8)
40+
self.dropout2 = Dropout(p=0.2)
4141
self.linear2 = Linear(out_features=800, act=tlx.ReLU, in_features=800)
42-
self.dropout3 = Dropout(p=0.8)
42+
self.dropout3 = Dropout(p=0.2)
4343
self.linear3 = Linear(out_features=10, act=tlx.ReLU, in_features=800)
4444

4545
def forward(self, x, foo=None):
@@ -90,12 +90,12 @@ def forward(self, x, foo=None):
9090
#
9191
# def __init__(self):
9292
# super(CustomModel, self).__init__()
93-
# self.dropout1 = Dropout(keep=0.8)
93+
# self.dropout1 = Dropout(p=0.2)
9494
# self.linear1 = Linear(out_features=800, in_features=784)
9595
# self.batchnorm = BatchNorm1d(act=tlx.ReLU, num_features=800)
96-
# self.dropout2 = Dropout(keep=0.8)
96+
# self.dropout2 = Dropout(p=0.2)
9797
# self.linear2 = Linear(out_features=800, act=tlx.ReLU, in_features=800)
98-
# self.dropout3 = Dropout(keep=0.8)
98+
# self.dropout3 = Dropout(p=0.2)
9999
# self.linear3 = Linear(out_features=10, act=tlx.ReLU, in_features=800)
100100
#
101101
# def forward(self, x, foo=None):
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
# This code describes how TensoLayerX uses TensorBoard to monitor training.
4+
# TensorLayerX uses tensorboardX to monitor the training situation, so it is necessary to install tensorboardX, The version of tensorboardX installed needs to match with TensorFlow.
5+
# tensorboardX repo: https://github.com/lanpa/tensorboardX/blob/master/README.md
6+
7+
# Use the steps Description
8+
9+
# Step 1: install tensorboardX.
10+
# pip install tensorboardX
11+
# or build from source:
12+
# pip install 'git+https://github.com/lanpa/tensorboardX'
13+
# You can optionally install crc32c to speed up.
14+
# pip install crc32c
15+
# Starting from tensorboardX 2.1, You need to install soundfile for the add_audio() function (200x speedup).
16+
# pip install soundfile
17+
18+
# Step 2: Creates writer1 object.The log will be saved in 'runs/mlp'
19+
# writer = SummaryWriter('runs/mlp')
20+
21+
# Step 3:Use the add_scalar to record numeric constants.
22+
# writer.add_scalar('train acc', train_acc / n_iter, train_batch)
23+
24+
# Step 4:start tensorboard on the command line
25+
# tensorboard --logdir=<your_log_dir>
26+
# eg. tensorboard --logdir=runs/mlp
27+
28+
# Step 5:viewing the content in a browser.
29+
# Enter the http://localhost:6006 in the browser.
30+
31+
32+
import os
33+
os.environ['TL_BACKEND'] = 'tensorflow'
34+
35+
import numpy as np
36+
import time
37+
38+
import tensorflow as tf
39+
import tensorlayerx as tlx
40+
from tensorlayerx.nn import Module
41+
from tensorlayerx.nn import Linear, Dropout, BatchNorm1d
42+
from tensorboardX import SummaryWriter
43+
44+
45+
X_train, y_train, X_val, y_val, X_test, y_test = tlx.files.load_mnist_dataset(shape=(-1, 784))
46+
47+
48+
class CustomModel(Module):
49+
50+
def __init__(self):
51+
super(CustomModel, self).__init__()
52+
self.dropout1 = Dropout(p=0.2)
53+
self.linear1 = Linear(out_features=800, in_features=784)
54+
self.batchnorm = BatchNorm1d(act=tlx.ReLU, num_features=800)
55+
self.dropout2 = Dropout(p=0.2)
56+
self.linear2 = Linear(out_features=800, act=tlx.ReLU, in_features=800)
57+
self.dropout3 = Dropout(p=0.2)
58+
self.linear3 = Linear(out_features=10, act=tlx.ReLU, in_features=800)
59+
60+
def forward(self, x, foo=None):
61+
z = self.dropout1(x)
62+
z = self.linear1(z)
63+
z = self.batchnorm(z)
64+
z = self.dropout2(z)
65+
z = self.linear2(z)
66+
z = self.dropout3(z)
67+
out = self.linear3(z)
68+
if foo is not None:
69+
out = tlx.relu(out)
70+
return out
71+
72+
73+
MLP = CustomModel()
74+
n_epoch = 50
75+
batch_size = 500
76+
print_freq = 1
77+
train_weights = MLP.trainable_weights
78+
optimizer = tlx.optimizers.Adam(learning_rate=0.0001)
79+
train_batch = 0
80+
test_batch = 0
81+
82+
writer = SummaryWriter('runs/mlp')
83+
84+
for epoch in range(n_epoch): ## iterate the dataset n_epoch times
85+
start_time = time.time()
86+
## iterate over the entire training set once (shuffle the data via training)
87+
for X_batch, y_batch in tlx.utils.iterate.minibatches(X_train, y_train, batch_size, shuffle=True):
88+
MLP.set_train() # enable dropout
89+
with tf.GradientTape() as tape:
90+
## compute outputs
91+
_logits = MLP(X_batch)
92+
## compute loss and update model
93+
_loss = tlx.losses.softmax_cross_entropy_with_logits(_logits, y_batch)
94+
grad = tape.gradient(_loss, train_weights)
95+
optimizer.apply_gradients(zip(grad, train_weights))
96+
97+
## use training and evaluation sets to evaluate the model every print_freq epoch
98+
if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
99+
print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time))
100+
train_loss, train_acc, n_iter = 0, 0, 0
101+
for X_batch, y_batch in tlx.utils.iterate.minibatches(X_train, y_train, batch_size, shuffle=False):
102+
train_batch += 1
103+
_logits = MLP(X_batch)
104+
train_loss += tlx.losses.softmax_cross_entropy_with_logits(_logits, y_batch)
105+
train_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch))
106+
n_iter += 1
107+
108+
print(" train loss: {}".format(train_loss / n_iter))
109+
print(" train acc: {}".format(train_acc / n_iter))
110+
111+
writer.add_scalar('train loss', tlx.ops.convert_to_numpy(train_loss / n_iter), train_batch)
112+
writer.add_scalar('train acc', train_acc / n_iter, train_batch)
113+
114+
val_loss, val_acc, n_iter = 0, 0, 0
115+
for X_batch, y_batch in tlx.utils.iterate.minibatches(X_val, y_val, batch_size, shuffle=False):
116+
test_batch += 1
117+
_logits = MLP(X_batch) # is_train=False, disable dropout
118+
val_loss += tlx.losses.softmax_cross_entropy_with_logits(_logits, y_batch)
119+
val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch))
120+
n_iter += 1
121+
122+
print(" val loss: {}".format(val_loss / n_iter))
123+
print(" val acc: {}".format(val_acc / n_iter))
124+
125+
writer.add_scalar('val loss', tlx.ops.convert_to_numpy(val_loss / n_iter), test_batch)
126+
writer.add_scalar('val acc', val_acc / n_iter, test_batch)
127+
128+
writer.export_scalars_to_json("./all_scalars.json")
129+
writer.close()
130+

tensorlayerx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
# distributed = LazyImport("tensorlayerx.utils.distributed")
4040
# nlp = LazyImport("tensorlayerx.text.nlp")
4141
prepro = LazyImport("tensorlayerx.utils.prepro")
42-
# utils = LazyImport("tensorlayerx.utils")
42+
utils = LazyImport("tensorlayerx.utils")
4343
visualize = LazyImport("tensorlayerx.utils.visualize")
4444

4545
# alias

tensorlayerx/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# from .array_ops import *
55
# from .db import *
66
# from .distributed import *
7-
# from .iterate import *
7+
from .iterate import *
88
from .prepro import *
99
# from .rein import *
1010
# from .utils import *

0 commit comments

Comments
 (0)