Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit 4b5929e

Browse files
committed
+ batch_norm_ops more params and available for non-convnet
1 parent 5ce9294 commit 4b5929e

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

skflow/ops/batch_norm_ops.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,32 +19,41 @@
1919
from tensorflow.python import control_flow_ops
2020

2121

22-
def batch_normalize(X, epsilon=1e-5, scale_after_normalization=True):
22+
def batch_normalize(tensor_in, epsilon=1e-5, convnet=True, decay=0.9,
23+
scale_after_normalization=True):
2324
"""Batch Normalization
2425
2526
Args:
26-
X: Input Tensor
27+
tensor_in: input Tensor, 4D shape:
28+
[batch, in_height, in_width, in_depth].
2729
epsilon : A float number to avoid being divided by 0.
30+
decay: decay rate for exponential moving average.
31+
convnet: Whether this is for convolutional net use. If this is True,
32+
moments will sum across axis [0, 1, 2]. Otherwise, only [0].
2833
scale_after_normalization: Whether to scale after normalization.
2934
"""
30-
shape = X.get_shape().as_list()
35+
shape = tensor_in.get_shape().as_list()
3136

3237
with tf.variable_scope("batch_norm"):
3338
gamma = tf.get_variable("gamma", [shape[-1]],
3439
initializer=tf.random_normal_initializer(1., 0.02))
3540
beta = tf.get_variable("beta", [shape[-1]],
3641
initializer=tf.constant_initializer(0.))
37-
ema = tf.train.ExponentialMovingAverage(decay=0.9)
38-
assign_mean, assign_var = tf.nn.moments(X, [0, 1, 2])
42+
ema = tf.train.ExponentialMovingAverage(decay=decay)
43+
if convnet:
44+
assign_mean, assign_var = tf.nn.moments(tensor_in, [0, 1, 2])
45+
else:
46+
assign_mean, assign_var = tf.nn.moments(tensor_in, [0])
3947
ema_assign_op = ema.apply([assign_mean, assign_var])
4048
ema_mean, ema_var = ema.average(assign_mean), ema.average(assign_var)
4149
def update_mean_var():
50+
"""Internal function that updates mean and variance during training"""
4251
with tf.control_dependencies([ema_assign_op]):
4352
return tf.identity(assign_mean), tf.identity(assign_var)
44-
IS_TRAINING = tf.get_collection("IS_TRAINING") # TODO: this is always empty
53+
IS_TRAINING = tf.get_collection("IS_TRAINING")[-1]
4554
mean, variance = control_flow_ops.cond(IS_TRAINING,
4655
update_mean_var,
4756
lambda: (ema_mean, ema_var))
4857
return tf.nn.batch_norm_with_global_normalization(
49-
X, mean, variance, beta, gamma, epsilon,
58+
tensor_in, mean, variance, beta, gamma, epsilon,
5059
scale_after_normalization=scale_after_normalization)

skflow/ops/conv_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from __future__ import division, print_function, absolute_import
1717

1818
import tensorflow as tf
19-
from skflow.ops.batch_norm_ops import *
19+
from skflow.ops.batch_norm_ops import batch_normalize
2020

2121

2222
def conv2d(tensor_in, n_filters, filter_shape, strides=None, padding='SAME',

0 commit comments

Comments
 (0)