Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit e27aa15

Browse files
committed
Add usage of batch norm in conv test and fix usage of is_training collection
1 parent 5db4eb1 commit e27aa15

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

skflow/ops/batch_norm_ops.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from __future__ import division, print_function, absolute_import
1717

1818
import tensorflow as tf
19-
from tensorflow.python import control_flow_ops
2019

2120

2221
def batch_normalize(tensor_in, epsilon=1e-5, convnet=True, decay=0.9,
@@ -50,10 +49,9 @@ def update_mean_var():
5049
"""Internal function that updates mean and variance during training"""
5150
with tf.control_dependencies([ema_assign_op]):
5251
return tf.identity(assign_mean), tf.identity(assign_var)
53-
IS_TRAINING = tf.get_collection("IS_TRAINING")[-1]
54-
mean, variance = control_flow_ops.cond(IS_TRAINING,
55-
update_mean_var,
56-
lambda: (ema_mean, ema_var))
52+
is_training = tf.squeeze(tf.get_collection("IS_TRAINING"))
53+
mean, variance = tf.python.control_flow_ops.cond(
54+
is_training, update_mean_var, lambda: (ema_mean, ema_var))
5755
return tf.nn.batch_norm_with_global_normalization(
5856
tensor_in, mean, variance, beta, gamma, epsilon,
5957
scale_after_normalization=scale_after_normalization)

skflow/ops/tests/test_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,11 @@ def test_conv2d(self):
6868
filter_shape = (5, 5)
6969
vals = np.random.randn(batch_size, input_shape[0], input_shape[1], 1)
7070
with self.test_session() as sess:
71+
tf.add_to_collection("IS_TRAINING", True)
7172
tensor_in = tf.placeholder(tf.float32, [batch_size, input_shape[0],
7273
input_shape[1], 1])
73-
res = ops.conv2d(tensor_in, n_filters, filter_shape)
74+
res = ops.conv2d(
75+
tensor_in, n_filters, filter_shape, batch_norm=True)
7476
sess.run(tf.initialize_all_variables())
7577
conv = sess.run(res, feed_dict={tensor_in.name: vals})
7678
self.assertEqual(conv.shape, (batch_size, input_shape[0],

0 commit comments

Comments
 (0)