Skip to content

Commit 2e1e42c

Browse files
committed
2 parents c388aca + bc82c3f commit 2e1e42c

File tree

3 files changed

+31
-34
lines changed

3 files changed

+31
-34
lines changed

example/tutorial_atari_pong.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,17 @@ def prepro(I):
9393

9494
prob = sess.run(
9595
sampling_prob,
96-
feed_dict={t_states: x}
97-
)
96+
feed_dict={t_states: x})
97+
9898
# action. 1: STOP 2: UP 3: DOWN
9999
# action = np.random.choice([1,2,3], p=prob.flatten())
100100
action = tl.rein.choice_action_by_probs(prob.flatten(), [1,2,3])
101101

102102
observation, reward, done, _ = env.step(action)
103103
reward_sum += reward
104-
xs.append(x) # all observations in a episode
105-
ys.append(action - 1) # all fake labels in a episode (action begins from 1, so minus 1)
106-
rs.append(reward) # all rewards in a episode
104+
xs.append(x) # all observations in an episode
105+
ys.append(action - 1) # all fake labels in an episode (action begins from 1, so minus 1)
106+
rs.append(reward) # all rewards in an episode
107107

108108
if done:
109109
episode_number += 1
@@ -125,9 +125,7 @@ def prepro(I):
125125
feed_dict={
126126
t_states: epx,
127127
t_actions: epy,
128-
t_discount_rewards: disR
129-
}
130-
)
128+
t_discount_rewards: disR})
131129

132130
if episode_number % (batch_size * 100) == 0:
133131
tl.files.save_npz(network.all_params, name=model_file_name+'.npz')

example/tutorial_imdb_fasttext.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
# in addition to unigrams.
4141
N_GRAM = 2
4242

43-
# Size of vocabulary; less frequent works will be treated as "unknown"
43+
# Size of vocabulary; less frequent words will be treated as "unknown"
4444
VOCAB_SIZE = 100000
4545

4646
# Number of buckets used for hashing n-grams
@@ -71,7 +71,7 @@ def __init__(self, vocab_size, embedding_size, n_labels):
7171
tf.int32, shape=[None], name='labels')
7272

7373
# Network structure
74-
network = AverageEmbeddingInputlayer(
74+
network = AverageEmbeddingInputLayer(
7575
self.inputs, self.vocab_size, self.embedding_size)
7676
self.network = DenseLayer(network, self.n_labels)
7777

tensorlayer/layers.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -651,30 +651,34 @@ def __init__(
651651
self.all_drop = {}
652652

653653

654-
class AverageEmbeddingInputlayer(Layer):
655-
"""The :class:`AverageEmbeddingInputlayer` class is for FastText Embedding for sentence classification, see `[1] <http://arxiv.org/abs/1607.01759>`_.
654+
class AverageEmbeddingInputLayer(Layer):
655+
""":class:`AverageEmbeddingInputlayer` averages over embeddings of inputs.
656+
657+
:class:`AverageEmbeddingInputlayer` can be used as the input layer
658+
for models like DAN[1] and FastText[2].
656659
657660
Parameters
658661
------------
659-
inputs : input placeholder or tensor; zeros are paddings
662+
inputs : input placeholder or tensor
660663
vocabulary_size : an integer, the size of vocabulary
661664
embedding_size : an integer, the dimension of embedding vectors
665+
pad_value : an integer, the scalar pad value used in inputs
662666
name : a string, the name of the layer
663667
embeddings_initializer : the initializer of the embedding matrix
664668
embeddings_kwargs : kwargs to get embedding matrix variable
665669
666670
References
667671
------------
668-
- [1] Joulin, A., Grave, E., Bojanowski, P., & Mikolov, T. (2016). `Bag of Tricks for Efficient Text Classification. <http://arxiv.org/abs/1607.01759>`_
669-
- [2] Recht, B., Re, C., Wright, S., & Niu, F. (2011). `Hogwild: A Lock-Free Approach to Parallelizing Stochastic Gradient Descent. <https://arxiv.org/abs/1106.5730>`_ In NPIS 2011 (pp. 693–701).
670-
- [3] `TensorFlow Candidate Sampling <https://www.tensorflow.org/api_guides/python/nn#Candidate_Sampling>`_
672+
- [1] Iyyer, M., Manjunatha, V., Boyd-Graber, J., & Daum’e III, H. (2015). Deep Unordered Composition Rivals Syntactic Methods for Text Classification. In Association for Computational Linguistics.
673+
- [2] Joulin, A., Grave, E., Bojanowski, P., & Mikolov, T. (2016).`Bag of Tricks for Efficient Text Classification. <http://arxiv.org/abs/1607.01759>`_
671674
"""
672675
def __init__(
673676
self, inputs, vocabulary_size, embedding_size,
674-
name='fasttext_layer',
677+
pad_value=0,
678+
name='average_embedding_layer',
675679
embeddings_initializer=tf.random_uniform_initializer(-0.1, 0.1),
676-
embeddings_kwargs={}
677-
):#None):
680+
embeddings_kwargs=None,
681+
):
678682
super().__init__(name=name)
679683

680684
if inputs.get_shape().ndims != 2:
@@ -690,29 +694,24 @@ def __init__(
690694
name='embeddings',
691695
shape=(vocabulary_size, embedding_size),
692696
initializer=embeddings_initializer,
693-
# **(embeddings_kwargs or {}),
694-
**embeddings_kwargs)
697+
**(embeddings_kwargs or {}),
698+
)
695699

696700
word_embeddings = tf.nn.embedding_lookup(
697701
self.embeddings, self.inputs,
698702
name='word_embeddings',
699703
)
700-
701-
# Masks used to ignore padding words
702-
masks = tf.expand_dims(
703-
tf.sign(self.inputs),
704-
axis=-1,
705-
name='masks',
706-
)
707-
sum_word_embeddings = tf.reduce_sum(
708-
word_embeddings * tf.cast(masks, tf.float32),
709-
axis=1,
704+
# Zero out embeddings of pad value
705+
masks = tf.not_equal(self.inputs, pad_value, name='masks')
706+
word_embeddings *= tf.cast(
707+
tf.expand_dims(masks, axis=-1),
708+
tf.float32,
710709
)
710+
sum_word_embeddings = tf.reduce_sum(word_embeddings, axis=1)
711711

712712
# Count number of non-padding words in each sentence
713-
# Used to commute average word embeddings in sentences
714713
sentence_lengths = tf.count_nonzero(
715-
self.inputs,
714+
masks,
716715
axis=1,
717716
keep_dims=True,
718717
dtype=tf.float32,
@@ -721,7 +720,7 @@ def __init__(
721720

722721
sentence_embeddings = tf.divide(
723722
sum_word_embeddings,
724-
sentence_lengths,
723+
sentence_lengths + 1e-8, # Add epsilon to avoid dividing by 0
725724
name='sentence_embeddings'
726725
)
727726

0 commit comments

Comments
 (0)