Skip to content

Commit c99a6a2

Browse files
committed
batch norm support float16
1 parent fc7f62a commit c99a6a2

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

tensorlayer/layers.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1900,7 +1900,7 @@ def __init__(
19001900
):
19011901
if tf.__version__ < "1.4":
19021902
raise Exception("Deformable CNN layer requires tensrflow 1.4 or higher version")
1903-
1903+
19041904
Layer.__init__(self, name=name)
19051905
self.inputs = layer.outputs
19061906
self.offset_layer = offset_layer
@@ -3099,6 +3099,7 @@ class BatchNormLayer(Layer):
30993099
The initializer for initializing beta
31003100
gamma_init : gamma initializer
31013101
The initializer for initializing gamma
3102+
dtype : tf.float32 (default) or tf.float16
31023103
name : a string or None
31033104
An optional name to attach to this layer.
31043105
@@ -3116,6 +3117,7 @@ def __init__(
31163117
is_train = False,
31173118
beta_init = tf.zeros_initializer,
31183119
gamma_init = tf.random_normal_initializer(mean=1.0, stddev=0.002), # tf.ones_initializer,
3120+
dtype = tf.float32,
31193121
name ='batchnorm_layer',
31203122
):
31213123
Layer.__init__(self, name=name)
@@ -3136,10 +3138,13 @@ def __init__(
31363138
beta_init = beta_init()
31373139
beta = tf.get_variable('beta', shape=params_shape,
31383140
initializer=beta_init,
3141+
dtype=dtype,
31393142
trainable=is_train)#, restore=restore)
31403143

31413144
gamma = tf.get_variable('gamma', shape=params_shape,
3142-
initializer=gamma_init, trainable=is_train,
3145+
initializer=gamma_init,
3146+
dtype=dtype,
3147+
trainable=is_train,
31433148
)#restore=restore)
31443149

31453150
## 2.
@@ -3150,10 +3155,12 @@ def __init__(
31503155
moving_mean = tf.get_variable('moving_mean',
31513156
params_shape,
31523157
initializer=moving_mean_init,
3153-
trainable=False,)# restore=restore)
3158+
dtype=dtype,
3159+
trainable=False)# restore=restore)
31543160
moving_variance = tf.get_variable('moving_variance',
31553161
params_shape,
31563162
initializer=tf.constant_initializer(1.),
3163+
dtype=dtype,
31573164
trainable=False,)# restore=restore)
31583165

31593166
## 3.

tensorlayer/prepro.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def crop_multi(x, wrg, hrg, is_random=False, row_index=0, col_index=1, channel_i
281281
return np.asarray(results)
282282

283283
# flip
284-
def flip_axis(x, axis, is_random=False):
284+
def flip_axis(x, axis=1, is_random=False):
285285
"""Flip the axis of an image, such as flip left and right, up and down, randomly or non-randomly,
286286
287287
Parameters

0 commit comments

Comments
 (0)