Skip to content

Commit 427322e

Browse files
committed
new prepro for sequences/ update Seq2Seq for Chatbot example
1 parent 7c62857 commit 427322e

File tree

3 files changed

+127
-14
lines changed

3 files changed

+127
-14
lines changed

docs/modules/prepro.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,11 @@ Some of the code in this package are borrowed from Keras.
5959
erosion
6060

6161
pad_sequences
62+
remove_pad_sequences
6263
process_sequences
6364
sequences_add_start_id
65+
sequences_add_end_id
66+
sequences_add_end_id_after_pad
6467
sequences_get_mask
6568

6669
distorted_images
@@ -185,6 +188,11 @@ Padding
185188
^^^^^^^^^
186189
.. autofunction:: pad_sequences
187190

191+
Remove Padding
192+
^^^^^^^^^^^^^^^^^
193+
.. autofunction:: remove_pad_sequences
194+
195+
188196
Process
189197
^^^^^^^^^
190198
.. autofunction:: process_sequences
@@ -193,6 +201,15 @@ Add Start ID
193201
^^^^^^^^^^^^^^^
194202
.. autofunction:: sequences_add_start_id
195203

204+
205+
Add End ID
206+
^^^^^^^^^^^^^^^
207+
.. autofunction:: sequences_add_end_id
208+
209+
Add End ID after pad
210+
^^^^^^^^^^^^^^^^^^^^^^^
211+
.. autofunction:: sequences_add_end_id_after_pad
212+
196213
Get Mask
197214
^^^^^^^^^
198215
.. autofunction:: sequences_get_mask

tensorlayer/layers.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4864,10 +4864,12 @@ def __init__(
48644864
# Seq2seq
48654865
class Seq2Seq(Layer):
48664866
"""
4867-
The :class:`Seq2Seq` class is a simple :class:`DynamicRNNLayer` based Seq2seq layer,
4868-
both encoder and decoder are :class:`DynamicRNNLayer`, network details
4869-
see `Model <https://camo.githubusercontent.com/242210d7d0151cae91107ee63bff364a860db5dd/687474703a2f2f6936342e74696e797069632e636f6d2f333031333674652e706e67>`_
4870-
and `Sequence to Sequence Learning with Neural Networks <https://arxiv.org/abs/1409.3215>`_ .
4867+
The :class:`Seq2Seq` class is a Simple :class:`DynamicRNNLayer` based Seq2seq layer without using `tl.contrib.seq2seq <https://www.tensorflow.org/api_guides/python/contrib.seq2seq>`_.
4868+
See `Model <https://camo.githubusercontent.com/242210d7d0151cae91107ee63bff364a860db5dd/687474703a2f2f6936342e74696e797069632e636f6d2f333031333674652e706e67>`_
4869+
and `Sequence to Sequence Learning with Neural Networks <https://arxiv.org/abs/1409.3215>`_.
4870+
4871+
- Please check the example `Twitter Chatbot <>`_.
4872+
- The Author recommends users to read the source code of :class:`DynamicRNNLayer` and :class:`Seq2Seq`.
48714873
48724874
Parameters
48734875
----------
@@ -4904,9 +4906,23 @@ class Seq2Seq(Layer):
49044906
------------
49054907
outputs : a tensor
49064908
The output of RNN decoder.
4909+
initial_state_encode : a tensor or StateTuple
4910+
Initial state of RNN encoder.
4911+
initial_state_decode : a tensor or StateTuple
4912+
Initial state of RNN decoder.
4913+
final_state_encode : a tensor or StateTuple
4914+
Final state of RNN encoder.
4915+
final_state_decode : a tensor or StateTuple
4916+
Final state of RNN decoder.
49074917
4908-
final_state : a tensor or StateTuple
4909-
Final state of decoder, see :class:`DynamicRNNLayer` .
4918+
Notes
4919+
--------
4920+
- How to feed data: `Sequence to Sequence Learning with Neural Networks <https://arxiv.org/pdf/1409.3215v3.pdf>`_
4921+
- input_seqs : ``['how', 'are', 'you', '<PAD_ID'>]``
4922+
- decode_seqs : ``['<START_ID>', 'I', 'am', 'fine', '<PAD_ID'>]``
4923+
- target_seqs : ``['I', 'am', 'fine', '<END_ID', '<PAD_ID'>]``
4924+
- target_mask : ``[1, 1, 1, 1, 0]``
4925+
- related functions : tl.prepro <pad_sequences, precess_sequences, sequences_add_start_id, sequences_get_mask>
49104926
49114927
Examples
49124928
----------
@@ -4948,14 +4964,7 @@ class Seq2Seq(Layer):
49484964
>>> y = tf.nn.softmax(net_out.outputs)
49494965
>>> net_out.print_params(False)
49504966
4951-
Notes
4952-
--------
4953-
- How to feed data: `Sequence to Sequence Learning with Neural Networks <https://arxiv.org/pdf/1409.3215v3.pdf>`_
4954-
- input_seqs : ``['how', 'are', 'you', '<PAD_ID'>]``
4955-
- decode_seqs : ``['<START_ID>', 'I', 'am', 'fine', '<PAD_ID'>]``
4956-
- target_seqs : ``['I', 'am', 'fine', '<END_ID']``
4957-
- target_mask : ``[1, 1, 1, 1, 0]``
4958-
- related functions : tl.prepro <pad_sequences, precess_sequences, sequences_add_start_id, sequences_get_mask>
4967+
49594968
"""
49604969
def __init__(
49614970
self,
@@ -5018,6 +5027,10 @@ def __init__(
50185027

50195028
rnn_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
50205029

5030+
# Initial state
5031+
self.initial_state_encode = network_encode.initial_state
5032+
self.initial_state_decode = network_decode.initial_state
5033+
50215034
# Final state
50225035
self.final_state_encode = network_encode.final_state
50235036
self.final_state_decode = network_decode.final_state

tensorlayer/prepro.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,33 @@ def pad_sequences(sequences, maxlen=None, dtype='int32', padding='post', truncat
13861386
raise ValueError('Padding type "%s" not understood' % padding)
13871387
return x
13881388

1389+
def remove_pad_sequences(sequences, pad_id=0):
1390+
"""Remove padding.
1391+
1392+
Parameters
1393+
-----------
1394+
sequences : list of list.
1395+
pad_id : int.
1396+
1397+
Examples
1398+
----------
1399+
>>> sequences = [[2,3,4,0,0], [5,1,2,3,4,0,0,0], [4,5,0,2,4,0,0,0]]
1400+
>>> print(remove_pad_sequences(sequences, pad_id=0))
1401+
... [[2, 3, 4], [5, 1, 2, 3, 4], [4, 5, 0, 2, 4]]
1402+
"""
1403+
import copy
1404+
sequences_out = copy.deepcopy(sequences)
1405+
for i in range(len(sequences)):
1406+
# for j in range(len(sequences[i])):
1407+
# if sequences[i][j] == pad_id:
1408+
# sequences_out[i] = sequences_out[i][:j]
1409+
# break
1410+
for j in range(1, len(sequences[i])):
1411+
if sequences[i][-j] != pad_id:
1412+
sequences_out[i] = sequences_out[i][0:-j+1]
1413+
break
1414+
return sequences_out
1415+
13891416
def process_sequences(sequences, end_id=0, pad_val=0, is_shorten=True, remain_end_id=False):
13901417
"""Set all tokens(ids) after END token to the padding value, and then shorten (option) it to the maximum sequence length in this batch.
13911418
@@ -1451,6 +1478,62 @@ def sequences_add_start_id(sequences, start_id=0, remove_last=False):
14511478
sequences_out[i] = [start_id] + sequences[i]
14521479
return sequences_out
14531480

1481+
def sequences_add_end_id(sequences, end_id=888):
1482+
"""Add special end token(id) in the end of each sequence.
1483+
1484+
Parameters
1485+
-----------
1486+
sequences : list of list.
1487+
end_id : int.
1488+
1489+
Examples
1490+
---------
1491+
>>> sequences = [[1,2,3],[4,5,6,7]]
1492+
>>> print(sequences_add_end_id(sequences, end_id=999))
1493+
... [[1, 2, 3, 999], [4, 5, 6, 999]]
1494+
"""
1495+
sequences_out = [[] for _ in range(len(sequences))]#[[]] * len(sequences)
1496+
for i in range(len(sequences)):
1497+
sequences_out[i] = sequences[i] + [end_id]
1498+
return sequences_out
1499+
1500+
def sequences_add_end_id_after_pad(sequences, end_id=888, pad_id=0):
1501+
"""Add special end token(id) in the end of each sequence.
1502+
1503+
Parameters
1504+
-----------
1505+
sequences : list of list.
1506+
end_id : int.
1507+
pad_id : int.
1508+
1509+
Examples
1510+
---------
1511+
>>> sequences = [[1,2,0,0], [1,2,3,0], [1,2,3,4]]
1512+
>>> print(sequences_add_end_id_after_pad(sequences, end_id=99, pad_id=0))
1513+
... [[1, 2, 99, 0], [1, 2, 3, 99], [1, 2, 3, 4]]
1514+
"""
1515+
# sequences_out = [[] for _ in range(len(sequences))]#[[]] * len(sequences)
1516+
import copy
1517+
sequences_out = copy.deepcopy(sequences)
1518+
# # add a pad to all
1519+
# for i in range(len(sequences)):
1520+
# for j in range(len(sequences[i])):
1521+
# sequences_out[i].append(pad_id)
1522+
# # pad -- > end
1523+
# max_len = 0
1524+
for i in range(len(sequences)):
1525+
for j in range(len(sequences[i])):
1526+
if sequences[i][j] == pad_id:
1527+
sequences_out[i][j] = end_id
1528+
# if j > max_len:
1529+
# max_len = j
1530+
break
1531+
# # remove pad if too long
1532+
# for i in range(len(sequences)):
1533+
# for j in range(len(sequences[i])):
1534+
# sequences_out[i] = sequences_out[i][:max_len+1]
1535+
return sequences_out
1536+
14541537
def sequences_get_mask(sequences, pad_val=0):
14551538
"""Return mask for sequences.
14561539

0 commit comments

Comments
 (0)