Skip to content

Commit 9e85f24

Browse files
committed
[DEBUG] Update BatchNormLayer for TF12, --> tf.cond
1 parent 1024ee1 commit 9e85f24

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tensorlayer/layers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1844,7 +1844,8 @@ def mean_var_with_update():
18441844
else:
18451845
is_train = tf.cast(tf.zeros([]), tf.bool)
18461846

1847-
mean, var = control_flow_ops.cond(
1847+
# mean, var = control_flow_ops.cond(
1848+
mean, var = tf.cond(
18481849
# is_train, lambda: (mean, variance), # when training, (x-mean(x))/var(x)
18491850
is_train, mean_var_with_update,
18501851
lambda: (moving_mean, moving_variance)) # when inferencing, (x-0)/1

0 commit comments

Comments
 (0)