File tree Expand file tree Collapse file tree 1 file changed +12
-4
lines changed Expand file tree Collapse file tree 1 file changed +12
-4
lines changed Original file line number Diff line number Diff line change @@ -148,14 +148,22 @@ def sample_top(a=[], top_k=10):
148148 top_k : int
149149 Number of candidates to be considered.
150150 """
151- a = np .array (a )
152- idx = np .argsort (a )[::- 1 ]
153- idx = idx [:top_k ]
154- # a = a[idx]
151+ idx = np .argpartition (a , - top_k )[- top_k :]
155152 probs = a [idx ]
153+ # print("new", probs)
156154 probs = probs / np .sum (probs )
157155 choice = np .random .choice (idx , p = probs )
158156 return choice
157+ ## old implementation
158+ # a = np.array(a)
159+ # idx = np.argsort(a)[::-1]
160+ # idx = idx[:top_k]
161+ # # a = a[idx]
162+ # probs = a[idx]
163+ # print("prev", probs)
164+ # # probs = probs / np.sum(probs)
165+ # # choice = np.random.choice(idx, p=probs)
166+ # # return choice
159167
160168
161169## Vector representations of words (Advanced) UNDOCUMENT
You can’t perform that action at this time.
0 commit comments