@@ -1900,7 +1900,7 @@ def __init__(
19001900 ):
19011901 if tf .__version__ < "1.4" :
19021902 raise Exception ("Deformable CNN layer requires tensrflow 1.4 or higher version" )
1903-
1903+
19041904 Layer .__init__ (self , name = name )
19051905 self .inputs = layer .outputs
19061906 self .offset_layer = offset_layer
@@ -3099,6 +3099,7 @@ class BatchNormLayer(Layer):
30993099 The initializer for initializing beta
31003100 gamma_init : gamma initializer
31013101 The initializer for initializing gamma
3102+ dtype : tf.float32 (default) or tf.float16
31023103 name : a string or None
31033104 An optional name to attach to this layer.
31043105
@@ -3116,6 +3117,7 @@ def __init__(
31163117 is_train = False ,
31173118 beta_init = tf .zeros_initializer ,
31183119 gamma_init = tf .random_normal_initializer (mean = 1.0 , stddev = 0.002 ), # tf.ones_initializer,
3120+ dtype = tf .float32 ,
31193121 name = 'batchnorm_layer' ,
31203122 ):
31213123 Layer .__init__ (self , name = name )
@@ -3136,10 +3138,13 @@ def __init__(
31363138 beta_init = beta_init ()
31373139 beta = tf .get_variable ('beta' , shape = params_shape ,
31383140 initializer = beta_init ,
3141+ dtype = dtype ,
31393142 trainable = is_train )#, restore=restore)
31403143
31413144 gamma = tf .get_variable ('gamma' , shape = params_shape ,
3142- initializer = gamma_init , trainable = is_train ,
3145+ initializer = gamma_init ,
3146+ dtype = dtype ,
3147+ trainable = is_train ,
31433148 )#restore=restore)
31443149
31453150 ## 2.
@@ -3150,10 +3155,12 @@ def __init__(
31503155 moving_mean = tf .get_variable ('moving_mean' ,
31513156 params_shape ,
31523157 initializer = moving_mean_init ,
3153- trainable = False ,)# restore=restore)
3158+ dtype = dtype ,
3159+ trainable = False )# restore=restore)
31543160 moving_variance = tf .get_variable ('moving_variance' ,
31553161 params_shape ,
31563162 initializer = tf .constant_initializer (1. ),
3163+ dtype = dtype ,
31573164 trainable = False ,)# restore=restore)
31583165
31593166 ## 3.
0 commit comments