Skip to content

Commit cfc29a4

Browse files
author
Russell Stewart
committed
Simplified python3/python compatibility code.
1 parent 2afd981 commit cfc29a4

File tree

1 file changed

+5
-26
lines changed

1 file changed

+5
-26
lines changed

eval/python/evaluate.py

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import argparse
22
import numpy as np
3-
import sys
43

54
def main():
65
parser = argparse.ArgumentParser()
@@ -14,22 +13,15 @@ def main():
1413
vectors = {}
1514
for line in f:
1615
vals = line.rstrip().split(' ')
17-
if sys.version_info >= (3, 0):
18-
vectors[vals[0]] = list(map(float, vals[1:]))
19-
else:
20-
vectors[vals[0]] = map(float, vals[1:])
16+
vectors[vals[0]] = [float(x) for x in vals[1:]]
2117

2218
vocab_size = len(words)
2319
vocab = {w: idx for idx, w in enumerate(words)}
2420
ivocab = {idx: w for idx, w in enumerate(words)}
2521

2622
vector_dim = len(vectors[ivocab[0]])
2723
W = np.zeros((vocab_size, vector_dim))
28-
if sys.version_info >= (3, 0):
29-
vectors_to_iterate = vectors.items()
30-
else:
31-
vectors_to_iterate = vectors.iteritems()
32-
for word, v in vectors_to_iterate:
24+
for word, v in vectors.items():
3325
if word == '<unk>':
3426
continue
3527
W[vocab[word], :] = v
@@ -64,11 +56,7 @@ def evaluate_vectors(W, vocab, ivocab):
6456
count_tot = 0 # count all questions
6557
full_count = 0 # count all questions, including those with unknown words
6658

67-
if sys.version_info >= (3, 0):
68-
file_iterator = range(len(filenames))
69-
else:
70-
file_iterator = xrange(len(filenames))
71-
for i in file_iterator:
59+
for i in range(len(filenames)):
7260
with open('%s/%s' % (prefix, filenames[i]), 'r') as f:
7361
full_data = [line.rstrip().split(' ') for line in f]
7462
full_count += len(full_data)
@@ -79,24 +67,15 @@ def evaluate_vectors(W, vocab, ivocab):
7967

8068
predictions = np.zeros((len(indices),))
8169
num_iter = int(np.ceil(len(indices) / float(split_size)))
82-
83-
if sys.version_info >= (3, 0):
84-
number_iterator = range(num_iter)
85-
else:
86-
number_iterator = xrange(num_iter)
87-
for j in number_iterator:
70+
for j in range(num_iter):
8871
subset = np.arange(j*split_size, min((j + 1)*split_size, len(ind1)))
8972

9073
pred_vec = (W[ind2[subset], :] - W[ind1[subset], :]
9174
+ W[ind3[subset], :])
9275
#cosine similarity if input W has been normalized
9376
dist = np.dot(W, pred_vec.T)
9477

95-
if sys.version_info >= (3, 0):
96-
subset_iterator = range(len(subset))
97-
else:
98-
subset_iterator = xrange(len(subset))
99-
for k in subset_iterator:
78+
for k in range(len(subset)):
10079
dist[ind1[subset[k]], k] = -np.Inf
10180
dist[ind2[subset[k]], k] = -np.Inf
10281
dist[ind3[subset[k]], k] = -np.Inf

0 commit comments

Comments
 (0)