Skip to content

Commit 603a209

Browse files
committed
[nlp] better implementation of sample_top_k
1 parent 761a387 commit 603a209

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

tensorlayer/nlp.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,22 @@ def sample_top(a=[], top_k=10):
148148
top_k : int
149149
Number of candidates to be considered.
150150
"""
151-
a = np.array(a)
152-
idx = np.argsort(a)[::-1]
153-
idx = idx[:top_k]
154-
# a = a[idx]
151+
idx = np.argpartition(a, -top_k)[-top_k:]
155152
probs = a[idx]
153+
# print("new", probs)
156154
probs = probs / np.sum(probs)
157155
choice = np.random.choice(idx, p=probs)
158156
return choice
157+
## old implementation
158+
# a = np.array(a)
159+
# idx = np.argsort(a)[::-1]
160+
# idx = idx[:top_k]
161+
# # a = a[idx]
162+
# probs = a[idx]
163+
# print("prev", probs)
164+
# # probs = probs / np.sum(probs)
165+
# # choice = np.random.choice(idx, p=probs)
166+
# # return choice
159167

160168

161169
## Vector representations of words (Advanced) UNDOCUMENT

0 commit comments

Comments
 (0)