Skip to content

Commit 3dec543

Browse files
committed
Fixed Batch Norm Layer
1 parent 6bc750a commit 3dec543

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tensorlayer/layers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,8 +1924,9 @@ def mean_var_with_update():
19241924
self.outputs = act( tf.nn.batch_normalization(self.inputs, mean, variance, beta, gamma, epsilon) )
19251925

19261926
# variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) # 8 params in TF12 if zero_debias=True
1927-
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name) # 2 params beta, gamma
1927+
# variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name) # 2 params beta, gamma
19281928
# variables = [beta, gamma, moving_mean, moving_variance]
1929+
variables = [beta, gamma]
19291930

19301931
# print(len(variables))
19311932
# for idx, v in enumerate(variables):
@@ -2091,8 +2092,9 @@ def mean_var_with_update():
20912092
)
20922093
self.outputs = act( normed )
20932094
# variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) # 8 params in TF12 if zero_debias=True
2094-
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name) # 2 params beta, gamma
2095+
# variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name) # 2 params beta, gamma
20952096
# variables = [beta, gamma, moving_mean, moving_variance]
2097+
variables = [beta, gamma]
20962098

20972099
# print(len(variables))
20982100
# for idx, v in enumerate(variables):

0 commit comments

Comments
 (0)