File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -81,6 +81,8 @@ class ElementwiseLayer(Layer):
8181 combine_fn : a TensorFlow element-wise combine function
8282 e.g. AND is ``tf.minimum`` ; OR is ``tf.maximum`` ; ADD is ``tf.add`` ; MUL is ``tf.multiply`` and so on.
8383 See `TensorFlow Math API <https://www.tensorflow.org/versions/master/api_docs/python/math_ops.html#math>`__ .
84+ act : activation function
85+ The activation function of this layer.
8486 name : str
8587 A unique layer name.
8688
@@ -102,19 +104,21 @@ def __init__(
102104 self ,
103105 layers ,
104106 combine_fn = tf .minimum ,
107+ act = None ,
105108 name = 'elementwise_layer' ,
106109 ):
107110 Layer .__init__ (self , name = name )
108111
109112 logging .info ("ElementwiseLayer %s: size:%s fn:%s" % (self .name , layers [0 ].outputs .get_shape (), combine_fn .__name__ ))
110113
111114 self .outputs = layers [0 ].outputs
112- # logging.info(self.outputs._shape, type(self.outputs._shape))
115+
113116 for l in layers [1 :]:
114- if str (self .outputs .get_shape ()) != str (l .outputs .get_shape ()):
115- raise Exception ("Hint: the input shapes should be the same. %s != %s" % (self .outputs .get_shape (), str (l .outputs .get_shape ())))
116117 self .outputs = combine_fn (self .outputs , l .outputs , name = name )
117118
119+ if act :
120+ self .outputs = act (self .outputs )
121+
118122 self .all_layers = list (layers [0 ].all_layers )
119123 self .all_params = list (layers [0 ].all_params )
120124 self .all_drop = dict (layers [0 ].all_drop )
You can’t perform that action at this time.
0 commit comments