Skip to content

Commit 60c2e62

Browse files
committed
[layers] Release TimeDistributed
1 parent 4768d55 commit 60c2e62

File tree

3 files changed

+85
-1
lines changed

3 files changed

+85
-1
lines changed

docs/modules/layers.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ Layer list
295295
BatchNormLayer
296296
LocalResponseNormLayer
297297

298+
TimeDistributedLayer
299+
298300
RNNLayer
299301
BiRNNLayer
300302
advanced_indexing_op
@@ -497,6 +499,13 @@ Local Response Normalization
497499
.. autoclass:: LocalResponseNormLayer
498500

499501

502+
Time distributed layer
503+
------------------------
504+
505+
.. autoclass:: TimeDistributedLayer
506+
507+
508+
500509
Fixed Length Recurrent layer
501510
-------------------------------
502511
All recurrent layers can implement any type of RNN cell by feeding different cell function (LSTM, GRU etc).

tensorlayer/layers.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from . import files
1111
from . import cost
1212
from . import iterate
13+
from . import ops
1314
import numpy as np
1415
from six.moves import xrange
1516
import random
@@ -2788,6 +2789,81 @@ def __init__(
27882789
self.all_drop = dict(layer.all_drop)
27892790
self.all_layers.extend( [self.outputs] )
27902791

2792+
## TimeDistributedLayer
2793+
class TimeDistributedLayer(Layer):
2794+
"""
2795+
The :class:`TimeDistributedLayer` class that applies a function to every timestep of the input tensor.
2796+
For example, if using :class:`DenseLayer` as the ``layer_class``, inputs [batch_size , length, dim]
2797+
outputs [batch_size , length, new_dim].
2798+
2799+
Parameters
2800+
----------
2801+
layer : a :class:`Layer` instance
2802+
The `Layer` class feeding into this layer, [batch_size , length, dim]
2803+
layer_class : a :class:`Layer` class
2804+
args : dictionary
2805+
The arguments for the ``layer_class``.
2806+
name : a string or None
2807+
An optional name to attach to this layer.
2808+
2809+
Examples
2810+
--------
2811+
>>> batch_size = 32
2812+
>>> timestep = 20
2813+
>>> input_dim = 100
2814+
>>> x = tf.placeholder(dtype=tf.float32, shape=[batch_size, timestep, input_dim], name="encode_seqs")
2815+
>>> net = InputLayer(x, name='input')
2816+
>>> net = TimeDistributedLayer(net, layer_class=DenseLayer, args={'n_units':50, 'name':'dense'}, name='time_dense')
2817+
... [TL] InputLayer input: (32, 20, 100)
2818+
... [TL] TimeDistributedLayer time_dense: layer_class:DenseLayer
2819+
>>> print(net.outputs._shape)
2820+
... (32, 20, 50)
2821+
>>> net.print_params(False)
2822+
... param 0: (100, 50) time_dense/dense/W:0
2823+
... param 1: (50,) time_dense/dense/b:0
2824+
... num of params: 5050
2825+
"""
2826+
def __init__(
2827+
self,
2828+
layer = None,
2829+
layer_class = None,
2830+
args = {},
2831+
name ='time_distributed',
2832+
):
2833+
Layer.__init__(self, name=name)
2834+
self.inputs = layer.outputs
2835+
print(" [TL] TimeDistributedLayer %s: layer_class:%s args:%s" %
2836+
(self.name, layer_class.__name__, args))
2837+
2838+
if not args: args = dict()
2839+
assert isinstance(args, dict), "'args' must be a dict."
2840+
2841+
if not isinstance(self.inputs, tf.Tensor):
2842+
self.inputs = tf.transpose(tf.stack(self.inputs), [1, 0, 2])
2843+
2844+
input_shape = self.inputs.get_shape()
2845+
2846+
timestep = input_shape[1]
2847+
x = tf.unstack(self.inputs, axis=1)
2848+
2849+
with ops.suppress_stdout():
2850+
for i in range(0, timestep):
2851+
with tf.variable_scope(name, reuse=(False if i==0 else True)) as vs:
2852+
set_name_reuse((False if i==0 else True))
2853+
net = layer_class(InputLayer(x[i], name=args['name']+str(i)), **args)
2854+
x[i] = net.outputs
2855+
variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
2856+
2857+
self.outputs = tf.stack(x, axis=1, name=name)
2858+
2859+
self.all_layers = list(layer.all_layers)
2860+
self.all_params = list(layer.all_params)
2861+
self.all_drop = dict(layer.all_drop)
2862+
self.all_layers.extend( [self.outputs] )
2863+
self.all_params.extend( variables )
2864+
2865+
2866+
27912867
## Recurrent layer
27922868
class RNNLayer(Layer):
27932869
"""

tensorlayer/ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import os
99
import sys
1010
from sys import platform as _platform
11-
from .layers import set_keep
1211

1312

1413
def exit_tf(sess=None):

0 commit comments

Comments
 (0)