Skip to content

Commit 88b06ba

Browse files
committed
[layer] fix LayerNormLayer with variables
1 parent 0ede665 commit 88b06ba

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

tensorlayer/layers.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3571,23 +3571,26 @@ def __init__(self,
35713571
self.inputs = layer.outputs
35723572
print(" [TL] LayerNormLayer %s: act:%s" %
35733573
(self.name, act.__name__))
3574-
self.outputs = tf.contrib.layers.layer_norm(self.inputs,
3575-
center=center,
3576-
scale=scale,
3577-
activation_fn=act,
3578-
reuse=reuse,
3579-
variables_collections=variables_collections,
3580-
outputs_collections=outputs_collections,
3581-
trainable=trainable,
3582-
begin_norm_axis=begin_norm_axis,
3583-
begin_params_axis=begin_params_axis,
3584-
scope=name
3585-
)
3574+
with tf.variable_scope(name) as vs:
3575+
self.outputs = tf.contrib.layers.layer_norm(self.inputs,
3576+
center=center,
3577+
scale=scale,
3578+
activation_fn=act,
3579+
reuse=reuse,
3580+
variables_collections=variables_collections,
3581+
outputs_collections=outputs_collections,
3582+
trainable=trainable,
3583+
begin_norm_axis=begin_norm_axis,
3584+
begin_params_axis=begin_params_axis,
3585+
scope=None,
3586+
)
3587+
variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
3588+
35863589
self.all_layers = list(layer.all_layers)
35873590
self.all_params = list(layer.all_params)
35883591
self.all_drop = dict(layer.all_drop)
35893592
self.all_layers.extend( [self.outputs] )
3590-
3593+
self.all_params.extend( variables )
35913594

35923595
## Pooling layer
35933596
class PoolLayer(Layer):

0 commit comments

Comments
 (0)