Skip to content

Commit d785da7

Browse files
zsdonghaoluomai
authored andcommitted
Remove unnecessary check in ElementwiseLayer (#377)
* update ElementwiseLayer for #376 * elementwise supports activation * Update merge.py * Update merge.py
1 parent 3a87b6c commit d785da7

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

tensorlayer/layers/merge.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ class ElementwiseLayer(Layer):
8181
combine_fn : a TensorFlow element-wise combine function
8282
e.g. AND is ``tf.minimum`` ; OR is ``tf.maximum`` ; ADD is ``tf.add`` ; MUL is ``tf.multiply`` and so on.
8383
See `TensorFlow Math API <https://www.tensorflow.org/versions/master/api_docs/python/math_ops.html#math>`__ .
84+
act : activation function
85+
The activation function of this layer.
8486
name : str
8587
A unique layer name.
8688
@@ -102,19 +104,21 @@ def __init__(
102104
self,
103105
layers,
104106
combine_fn=tf.minimum,
107+
act=None,
105108
name='elementwise_layer',
106109
):
107110
Layer.__init__(self, name=name)
108111

109112
logging.info("ElementwiseLayer %s: size:%s fn:%s" % (self.name, layers[0].outputs.get_shape(), combine_fn.__name__))
110113

111114
self.outputs = layers[0].outputs
112-
# logging.info(self.outputs._shape, type(self.outputs._shape))
115+
113116
for l in layers[1:]:
114-
if str(self.outputs.get_shape()) != str(l.outputs.get_shape()):
115-
raise Exception("Hint: the input shapes should be the same. %s != %s" % (self.outputs.get_shape(), str(l.outputs.get_shape())))
116117
self.outputs = combine_fn(self.outputs, l.outputs, name=name)
117118

119+
if act:
120+
self.outputs = act(self.outputs)
121+
118122
self.all_layers = list(layers[0].all_layers)
119123
self.all_params = list(layers[0].all_params)
120124
self.all_drop = dict(layers[0].all_drop)

0 commit comments

Comments
 (0)