Skip to content

Commit 62c57bb

Browse files
committed
1) release merge_network 2) disable BiRNN return_last
1 parent 66f81f2 commit 62c57bb

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

docs/modules/layers.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ Layer list
340340
clear_layers_name
341341
initialize_rnn_state
342342
list_remove_repeat
343+
merge_networks
343344

344345

345346
Name Scope and Sharing Parameters
@@ -768,3 +769,7 @@ Initialize RNN state
768769
Remove repeated items in a list
769770
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
770771
.. autofunction:: list_remove_repeat
772+
773+
Merge networks attributes
774+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
775+
.. autofunction:: merge_networks

tensorlayer/layers.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,36 @@ def list_remove_repeat(l=None):
222222
[l2.append(i) for i in l if not i in l2]
223223
return l2
224224

225+
def merge_networks(layers=[]):
226+
"""Merge all parameters, layers and dropout probabilities to a :class:`Layer`.
227+
228+
Parameters
229+
----------
230+
layer : list of :class:`Layer` instance
231+
Merge all parameters, layers and dropout probabilities to the first layer in the list.
232+
233+
Examples
234+
---------
235+
>>> n1 = ...
236+
>>> n2 = ...
237+
>>> n = merge_networks([n1, n2])
238+
"""
239+
layer = layers[0]
240+
241+
all_params = []
242+
all_layers = []
243+
all_drop = {}
244+
for l in layers:
245+
all_params.extend(l.all_params)
246+
all_layers.extend(l.all_layers)
247+
all_drop.update(l.all_drop)
248+
249+
layer.all_params = list(all_params)
250+
layer.all_layers = list(all_layers)
251+
layer.all_drop = dict(all_drop)
252+
253+
return layer
254+
225255
def initialize_global_variables(sess=None):
226256
"""Excute ``sess.run(tf.global_variables_initializer())`` for TF12+ or
227257
sess.run(tf.initialize_all_variables()) for TF11.
@@ -4242,6 +4272,7 @@ def __init__(
42424272
)
42434273

42444274
if return_last:
4275+
raise Exception("Do not support return_last at the moment.")
42454276
self.outputs = outputs[-1]
42464277
else:
42474278
self.outputs = outputs
@@ -4880,6 +4911,7 @@ def __init__(
48804911
outputs = tf.concat(2, outputs)
48814912
if return_last:
48824913
# [batch_size, 2 * n_hidden]
4914+
raise Exception("Do not support return_last at the moment")
48834915
self.outputs = advanced_indexing_op(outputs, sequence_length)
48844916
else:
48854917
# [batch_size, n_step(max), 2 * n_hidden]

0 commit comments

Comments
 (0)