@@ -1724,7 +1724,7 @@ def __init__(
17241724 decay = 0.999 ,
17251725 epsilon = 0.00001 ,
17261726 act = tf .identity ,
1727- is_train = None ,
1727+ is_train = False ,
17281728 beta_init = tf .zeros_initializer ,
17291729 # gamma_init = tf.ones_initializer,
17301730 gamma_init = tf .random_normal_initializer (mean = 1.0 , stddev = 0.002 ),
@@ -1838,20 +1838,25 @@ def mean_var_with_update():
18381838 with tf .control_dependencies ([update_moving_mean , update_moving_variance ]):
18391839 return tf .identity (mean ), tf .identity (variance )
18401840
1841- if not is_train : # test : mean=0, std=1
1842- # if is_train: # train : mean=0, std=1
1843- is_train = tf .cast (tf .ones ([]), tf .bool )
1841+ # if not is_train: # test : mean=0, std=1
1842+ # # if is_train: # train : mean=0, std=1
1843+ # is_train = tf.cast(tf.ones([]), tf.bool)
1844+ # else:
1845+ # is_train = tf.cast(tf.zeros([]), tf.bool)
1846+ #
1847+ # # mean, var = control_flow_ops.cond(
1848+ # mean, var = tf.cond(
1849+ # # is_train, lambda: (mean, variance), # when training, (x-mean(x))/var(x)
1850+ # is_train, mean_var_with_update,
1851+ # lambda: (moving_mean, moving_variance)) # when inferencing, (x-0)/1
1852+ #
1853+ # self.outputs = act( tf.nn.batch_normalization(self.inputs, mean, var, beta, gamma, epsilon) )
1854+ if not is_train :
1855+ mean , var = mean_var_with_update ()
1856+ self .outputs = act ( tf .nn .batch_normalization (self .inputs , mean , var , beta , gamma , epsilon ) )
18441857 else :
1845- is_train = tf .cast (tf .zeros ([]), tf .bool )
1846-
1847- # mean, var = control_flow_ops.cond(
1848- mean , var = tf .cond (
1849- # is_train, lambda: (mean, variance), # when training, (x-mean(x))/var(x)
1850- is_train , mean_var_with_update ,
1851- lambda : (moving_mean , moving_variance )) # when inferencing, (x-0)/1
1852-
1853- self .outputs = act ( tf .nn .batch_normalization (self .inputs , mean , var , beta , gamma , epsilon ) )
1854- #x.set_shape(inputs.get_shape()) ??
1858+ self .outputs = act ( tf .nn .batch_normalization (self .inputs , moving_mean , moving_variance , beta , gamma , epsilon ) )
1859+
18551860 # variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) # 8 params in TF12 if zero_debias=True
18561861 variables = tf .get_collection (tf .GraphKeys .TRAINABLE_VARIABLES , scope = vs .name ) # 2 params beta, gamma
18571862 # variables = [beta, gamma, moving_mean, moving_variance]
0 commit comments