Skip to content

Commit 0c394c1

Browse files
committed
[LAYER] lambda layer
1 parent e98b30b commit 0c394c1

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

docs/modules/layers.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ Layer list
279279
FlattenLayer
280280
ConcatLayer
281281
ReshapeLayer
282+
LambdaLayer
282283
ElementwiseLayer
283284
SlimNetsLayer
284285
PReluLayer
@@ -433,6 +434,11 @@ Reshape layer
433434

434435
.. autoclass:: ReshapeLayer
435436

437+
Lambda layer
438+
^^^^^^^^^^^^^^^
439+
440+
.. autoclass:: LambdaLayer
441+
436442
Logic layer
437443
-------------
438444
.. autoclass:: ElementwiseLayer

tensorlayer/layers.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2542,6 +2542,49 @@ def __init__(
25422542
self.all_drop = dict(layer.all_drop)
25432543
self.all_layers.extend( [self.outputs] )
25442544

2545+
2546+
2547+
class LambdaLayer(Layer):
2548+
"""
2549+
The :class:`LambdaLayer` class is a layer which is able to use the provided function.
2550+
2551+
Parameters
2552+
----------
2553+
layer : a :class:`Layer` instance
2554+
The `Layer` class feeding into this layer.
2555+
fn : a function
2556+
The function that applies to the outputs of previous layer.
2557+
name : a string or None
2558+
An optional name to attach to this layer.
2559+
2560+
Examples
2561+
---------
2562+
>>> x = tf.placeholder(tf.float32, shape=[None, 1], name='x')
2563+
>>> network = tl.layers.InputLayer(x, name='input_layer')
2564+
>>> network = LambdaLayer(network, lambda x: 2*x, name='lambda_layer')
2565+
>>> y = network.outputs
2566+
>>> sess = tf.InteractiveSession()
2567+
>>> out = sess.run(y, feed_dict={x : [[1],[2]]})
2568+
... [[2],[4]]
2569+
"""
2570+
def __init__(
2571+
self,
2572+
layer = None,
2573+
fn = None,
2574+
name = 'lambda_layer',
2575+
):
2576+
Layer.__init__(self, name=name)
2577+
self.inputs = layer.outputs
2578+
2579+
print(" tensorlayer:Instantiate LambdaLayer %s" % self.name)
2580+
with tf.variable_scope(name) as vs:
2581+
self.outputs = fn(self.inputs)
2582+
2583+
self.all_layers = list(layer.all_layers)
2584+
self.all_params = list(layer.all_params)
2585+
self.all_drop = dict(layer.all_drop)
2586+
self.all_layers.extend( [self.outputs] )
2587+
25452588
## Logic layer
25462589
class ElementwiseLayer(Layer):
25472590
"""

0 commit comments

Comments
 (0)