|
19 | 19 | from tensorflow.python import control_flow_ops |
20 | 20 |
|
21 | 21 |
|
22 | | -def batch_normalize(X, epsilon=1e-5, scale_after_normalization=True): |
| 22 | +def batch_normalize(tensor_in, epsilon=1e-5, convnet=True, decay=0.9, |
| 23 | + scale_after_normalization=True): |
23 | 24 | """Batch Normalization |
24 | 25 |
|
25 | 26 | Args: |
26 | | - X: Input Tensor |
| 27 | + tensor_in: input Tensor, 4D shape: |
| 28 | + [batch, in_height, in_width, in_depth]. |
27 | 29 | epsilon : A float number to avoid being divided by 0. |
| 30 | + decay: decay rate for exponential moving average. |
| 31 | + convnet: Whether this is for convolutional net use. If this is True, |
| 32 | + moments will sum across axis [0, 1, 2]. Otherwise, only [0]. |
28 | 33 | scale_after_normalization: Whether to scale after normalization. |
29 | 34 | """ |
30 | | - shape = X.get_shape().as_list() |
| 35 | + shape = tensor_in.get_shape().as_list() |
31 | 36 |
|
32 | 37 | with tf.variable_scope("batch_norm"): |
33 | 38 | gamma = tf.get_variable("gamma", [shape[-1]], |
34 | 39 | initializer=tf.random_normal_initializer(1., 0.02)) |
35 | 40 | beta = tf.get_variable("beta", [shape[-1]], |
36 | 41 | initializer=tf.constant_initializer(0.)) |
37 | | - ema = tf.train.ExponentialMovingAverage(decay=0.9) |
38 | | - assign_mean, assign_var = tf.nn.moments(X, [0, 1, 2]) |
| 42 | + ema = tf.train.ExponentialMovingAverage(decay=decay) |
| 43 | + if convnet: |
| 44 | + assign_mean, assign_var = tf.nn.moments(tensor_in, [0, 1, 2]) |
| 45 | + else: |
| 46 | + assign_mean, assign_var = tf.nn.moments(tensor_in, [0]) |
39 | 47 | ema_assign_op = ema.apply([assign_mean, assign_var]) |
40 | 48 | ema_mean, ema_var = ema.average(assign_mean), ema.average(assign_var) |
41 | 49 | def update_mean_var(): |
| 50 | + """Internal function that updates mean and variance during training""" |
42 | 51 | with tf.control_dependencies([ema_assign_op]): |
43 | 52 | return tf.identity(assign_mean), tf.identity(assign_var) |
44 | | - IS_TRAINING = tf.get_collection("IS_TRAINING") # TODO: this is always empty |
| 53 | + IS_TRAINING = tf.get_collection("IS_TRAINING")[-1] |
45 | 54 | mean, variance = control_flow_ops.cond(IS_TRAINING, |
46 | 55 | update_mean_var, |
47 | 56 | lambda: (ema_mean, ema_var)) |
48 | 57 | return tf.nn.batch_norm_with_global_normalization( |
49 | | - X, mean, variance, beta, gamma, epsilon, |
| 58 | + tensor_in, mean, variance, beta, gamma, epsilon, |
50 | 59 | scale_after_normalization=scale_after_normalization) |
0 commit comments