Skip to content

Commit 68bea03

Browse files
committed
DEBUG BatchNormLayer -- simplify
1 parent 9e85f24 commit 68bea03

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

tensorlayer/layers.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1724,7 +1724,7 @@ def __init__(
17241724
decay = 0.999,
17251725
epsilon = 0.00001,
17261726
act = tf.identity,
1727-
is_train = None,
1727+
is_train = False,
17281728
beta_init = tf.zeros_initializer,
17291729
# gamma_init = tf.ones_initializer,
17301730
gamma_init = tf.random_normal_initializer(mean=1.0, stddev=0.002),
@@ -1838,20 +1838,25 @@ def mean_var_with_update():
18381838
with tf.control_dependencies([update_moving_mean, update_moving_variance]):
18391839
return tf.identity(mean), tf.identity(variance)
18401840

1841-
if not is_train: # test : mean=0, std=1
1842-
# if is_train: # train : mean=0, std=1
1843-
is_train = tf.cast(tf.ones([]), tf.bool)
1841+
# if not is_train: # test : mean=0, std=1
1842+
# # if is_train: # train : mean=0, std=1
1843+
# is_train = tf.cast(tf.ones([]), tf.bool)
1844+
# else:
1845+
# is_train = tf.cast(tf.zeros([]), tf.bool)
1846+
#
1847+
# # mean, var = control_flow_ops.cond(
1848+
# mean, var = tf.cond(
1849+
# # is_train, lambda: (mean, variance), # when training, (x-mean(x))/var(x)
1850+
# is_train, mean_var_with_update,
1851+
# lambda: (moving_mean, moving_variance)) # when inferencing, (x-0)/1
1852+
#
1853+
# self.outputs = act( tf.nn.batch_normalization(self.inputs, mean, var, beta, gamma, epsilon) )
1854+
if not is_train:
1855+
mean, var = mean_var_with_update()
1856+
self.outputs = act( tf.nn.batch_normalization(self.inputs, mean, var, beta, gamma, epsilon) )
18441857
else:
1845-
is_train = tf.cast(tf.zeros([]), tf.bool)
1846-
1847-
# mean, var = control_flow_ops.cond(
1848-
mean, var = tf.cond(
1849-
# is_train, lambda: (mean, variance), # when training, (x-mean(x))/var(x)
1850-
is_train, mean_var_with_update,
1851-
lambda: (moving_mean, moving_variance)) # when inferencing, (x-0)/1
1852-
1853-
self.outputs = act( tf.nn.batch_normalization(self.inputs, mean, var, beta, gamma, epsilon) )
1854-
#x.set_shape(inputs.get_shape()) ??
1858+
self.outputs = act( tf.nn.batch_normalization(self.inputs, moving_mean, moving_variance, beta, gamma, epsilon) )
1859+
18551860
# variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) # 8 params in TF12 if zero_debias=True
18561861
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name) # 2 params beta, gamma
18571862
# variables = [beta, gamma, moving_mean, moving_variance]

0 commit comments

Comments
 (0)