Skip to content

Commit 95b3d0c

Browse files
committed
DEBUG BatchNormLayer simplify
1 parent 68bea03 commit 95b3d0c

File tree

1 file changed

+176
-9
lines changed

1 file changed

+176
-9
lines changed

tensorlayer/layers.py

Lines changed: 176 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,13 +1762,9 @@ def __init__(
17621762
from tensorflow.python.ops import control_flow_ops
17631763

17641764
with tf.variable_scope(name) as vs:
1765-
# if use_bias:
1766-
# bias = _get_variable('bias', params_shape,
1767-
# initializer=tf.zeros_initializer)
1768-
# return self.inputs + bias
1769-
17701765
axis = list(range(len(x_shape) - 1))
17711766

1767+
## 1. beta, gamma
17721768
# beta = _get_variable('beta',
17731769
# params_shape,
17741770
# initializer=beta_init)
@@ -1789,6 +1785,7 @@ def __init__(
17891785
initializer=gamma_init, trainable=is_train,
17901786
)#restore=restore)
17911787

1788+
## 2. moving variables during training (not update by gradient!)
17921789
# trainable=False means : it prevent TF from updating this variable
17931790
# from the gradient, we have to update this from the mean computed
17941791
# from each batch during training
@@ -1816,6 +1813,7 @@ def __init__(
18161813
initializer=tf.constant_initializer(1.),
18171814
trainable=False,)# restore=restore)
18181815

1816+
## 3.
18191817
# These ops will only be preformed when training.
18201818
mean, variance = tf.nn.moments(self.inputs, axis)
18211819
try: # TF12
@@ -1831,13 +1829,17 @@ def __init__(
18311829
moving_variance, variance, decay)
18321830
# print("TF11 moving")
18331831

1834-
# tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
1835-
# tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
1836-
18371832
def mean_var_with_update():
18381833
with tf.control_dependencies([update_moving_mean, update_moving_variance]):
18391834
return tf.identity(mean), tf.identity(variance)
18401835

1836+
# ema = tf.train.ExponentialMovingAverage(decay=decay) # Akara
1837+
# def mean_var_with_update():
1838+
# ema_apply_op = ema.apply([moving_mean, moving_variance])
1839+
# with tf.control_dependencies([ema_apply_op]):
1840+
# return tf.identity(mean), tf.identity(variance)
1841+
1842+
## 4. behaviour for training and testing
18411843
# if not is_train: # test : mean=0, std=1
18421844
# # if is_train: # train : mean=0, std=1
18431845
# is_train = tf.cast(tf.ones([]), tf.bool)
@@ -1855,8 +1857,173 @@ def mean_var_with_update():
18551857
mean, var = mean_var_with_update()
18561858
self.outputs = act( tf.nn.batch_normalization(self.inputs, mean, var, beta, gamma, epsilon) )
18571859
else:
1860+
# self.outputs = act( tf.nn.batch_normalization(self.inputs, ema.average(mean), ema.average(variance), beta, gamma, epsilon) ) # Akara
18581861
self.outputs = act( tf.nn.batch_normalization(self.inputs, moving_mean, moving_variance, beta, gamma, epsilon) )
1859-
1862+
1863+
# variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) # 8 params in TF12 if zero_debias=True
1864+
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name) # 2 params beta, gamma
1865+
# variables = [beta, gamma, moving_mean, moving_variance]
1866+
1867+
# print(len(variables))
1868+
# for idx, v in enumerate(variables):
1869+
# print(" var {:3}: {:15} {}".format(idx, str(v.get_shape()), v))
1870+
# exit()
1871+
1872+
self.all_layers = list(layer.all_layers)
1873+
self.all_params = list(layer.all_params)
1874+
self.all_drop = dict(layer.all_drop)
1875+
self.all_layers.extend( [self.outputs] )
1876+
self.all_params.extend( variables )
1877+
# self.all_params.extend( [beta, gamma] )
1878+
1879+
1880+
class BatchNormLayer5(Layer): #
1881+
"""
1882+
The :class:`BatchNormLayer` class is a normalization layer, see ``tf.nn.batch_normalization`` and ``tf.nn.moments``.
1883+
1884+
Batch normalization on fully-connected or convolutional maps.
1885+
1886+
Parameters
1887+
-----------
1888+
layer : a :class:`Layer` instance
1889+
The `Layer` class feeding into this layer.
1890+
decay : float
1891+
A decay factor for ExponentialMovingAverage.
1892+
epsilon : float
1893+
A small float number to avoid dividing by 0.
1894+
act : activation function.
1895+
is_train : boolean
1896+
Whether train or inference.
1897+
beta_init : beta initializer
1898+
The initializer for initializing beta
1899+
gamma_init : gamma initializer
1900+
The initializer for initializing gamma
1901+
name : a string or None
1902+
An optional name to attach to this layer.
1903+
1904+
References
1905+
----------
1906+
- `Source <https://github.com/ry/tensorflow-resnet/blob/master/resnet.py>`_
1907+
- `stackoverflow <http://stackoverflow.com/questions/38312668/how-does-one-do-inference-with-batch-normalization-with-tensor-flow>`_
1908+
"""
1909+
def __init__(
1910+
self,
1911+
layer = None,
1912+
decay = 0.999,
1913+
epsilon = 0.00001,
1914+
act = tf.identity,
1915+
is_train = False,
1916+
beta_init = tf.zeros_initializer,
1917+
# gamma_init = tf.ones_initializer,
1918+
gamma_init = tf.random_normal_initializer(mean=1.0, stddev=0.002),
1919+
name ='batchnorm_layer',
1920+
):
1921+
Layer.__init__(self, name=name)
1922+
self.inputs = layer.outputs
1923+
print(" tensorlayer:Instantiate BatchNormLayer %s: decay: %f, epsilon: %f, act: %s, is_train: %s" %
1924+
(self.name, decay, epsilon, act.__name__, is_train))
1925+
x_shape = self.inputs.get_shape()
1926+
params_shape = x_shape[-1:]
1927+
1928+
from tensorflow.python.training import moving_averages
1929+
from tensorflow.python.ops import control_flow_ops
1930+
1931+
with tf.variable_scope(name) as vs:
1932+
axis = list(range(len(x_shape) - 1))
1933+
1934+
## 1. beta, gamma
1935+
beta = tf.get_variable('beta', shape=params_shape,
1936+
initializer=beta_init,
1937+
trainable=is_train)#, restore=restore)
1938+
1939+
gamma = tf.get_variable('gamma', shape=params_shape,
1940+
initializer=gamma_init, trainable=is_train,
1941+
)#restore=restore)
1942+
1943+
## 2. moving variables during training (not update by gradient!)
1944+
moving_mean = tf.get_variable('moving_mean',
1945+
params_shape,
1946+
initializer=tf.zeros_initializer,
1947+
trainable=False,)# restore=restore)
1948+
moving_variance = tf.get_variable('moving_variance',
1949+
params_shape,
1950+
initializer=tf.constant_initializer(1.),
1951+
trainable=False,)# restore=restore)
1952+
1953+
## 3.
1954+
# These ops will only be preformed when training.
1955+
def mean_var_with_update():
1956+
batch_mean, batch_var = tf.nn.moments(self.inputs, axis)
1957+
try: # TF12
1958+
update_moving_mean = moving_averages.assign_moving_average(
1959+
moving_mean, batch_mean, decay, zero_debias=False) # if zero_debias=True, has bias
1960+
update_moving_variance = moving_averages.assign_moving_average(
1961+
moving_variance, batch_var, decay, zero_debias=False) # if zero_debias=True, has bias
1962+
# print("TF12 moving")
1963+
except Exception as e: # TF11
1964+
update_moving_mean = moving_averages.assign_moving_average(
1965+
moving_mean, batch_mean, decay)
1966+
update_moving_variance = moving_averages.assign_moving_average(
1967+
moving_variance, batch_var, decay)
1968+
# print("TF11 moving")
1969+
1970+
# def mean_var_with_update():
1971+
with tf.control_dependencies([update_moving_mean, update_moving_variance]):
1972+
# return tf.identity(update_moving_mean), tf.identity(update_moving_variance)
1973+
return tf.identity(batch_mean), tf.identity(batch_var)
1974+
1975+
# ema = tf.train.ExponentialMovingAverage(decay=decay) # Akara
1976+
# def mean_var_with_update():
1977+
# ema_apply_op = ema.apply([batch_mean, batch_var])
1978+
# with tf.control_dependencies([ema_apply_op]):
1979+
# return tf.identity(batch_mean), tf.identity(batch_var)
1980+
1981+
## 4. behaviour for training and testing
1982+
# if not is_train: # test : mean=0, std=1
1983+
# # if is_train: # train : mean=0, std=1
1984+
# is_train = tf.cast(tf.ones([]), tf.bool)
1985+
# else:
1986+
# is_train = tf.cast(tf.zeros([]), tf.bool)
1987+
#
1988+
# # mean, var = control_flow_ops.cond(
1989+
# mean, var = tf.cond(
1990+
# # is_train, lambda: (mean, variance), # when training, (x-mean(x))/var(x)
1991+
# is_train, mean_var_with_update,
1992+
# lambda: (moving_mean, moving_variance)) # when inferencing, (x-0)/1
1993+
#
1994+
# self.outputs = act( tf.nn.batch_normalization(self.inputs, mean, var, beta, gamma, epsilon) )
1995+
# if not is_train:
1996+
# mean, var = mean_var_with_update()
1997+
# self.outputs = act( tf.nn.batch_normalization(self.inputs, mean, var, beta, gamma, epsilon) )
1998+
# else:
1999+
# # self.outputs = act( tf.nn.batch_normalization(self.inputs, ema.average(mean), ema.average(variance), beta, gamma, epsilon) ) # Akara
2000+
# self.outputs = act( tf.nn.batch_normalization(self.inputs, moving_mean, moving_variance, beta, gamma, epsilon) )
2001+
2002+
# if not is_train:
2003+
# is_train = tf.cast(tf.ones([]), tf.bool)
2004+
# else:
2005+
# is_train = tf.cast(tf.zeros([]), tf.bool)
2006+
#
2007+
# mean, var = tf.cond(
2008+
# is_train,
2009+
# mean_var_with_update,
2010+
# lambda: (moving_mean, moving_variance))
2011+
2012+
if not is_train:
2013+
mean, var = mean_var_with_update()#(update_moving_mean, update_moving_variance)
2014+
else:
2015+
mean, var = (moving_mean, moving_variance)
2016+
2017+
normed = tf.nn.batch_normalization(
2018+
x=self.inputs,
2019+
mean=mean,
2020+
variance=var,
2021+
offset=beta,
2022+
scale=gamma,
2023+
variance_epsilon=epsilon,
2024+
name="tf_bn"
2025+
)
2026+
self.outputs = act( normed )
18602027
# variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) # 8 params in TF12 if zero_debias=True
18612028
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name) # 2 params beta, gamma
18622029
# variables = [beta, gamma, moving_mean, moving_variance]

0 commit comments

Comments
 (0)