Skip to content

Commit 2617cb4

Browse files
committed
Added MAX_QUERIES feature
1 parent 32b3843 commit 2617cb4

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

engine/base_client/search.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55

66
import numpy as np
77
import tqdm
8+
import os
89

910
from dataset_reader.base_reader import Query
1011

1112
DEFAULT_TOP = 10
13+
MAX_QUERIES = int(os.getenv("MAX_QUERIES", -1))
14+
1215

1316

1417
class BaseSearcher:
@@ -68,11 +71,15 @@ def search_all(
6871
self.setup_search()
6972

7073
search_one = functools.partial(self.__class__._search_one, top=top)
74+
used_queries = queries
75+
if MAX_QUERIES > 0:
76+
used_queries = queries[0:MAX_QUERIES-1]
77+
print(f"limitting queries to [0:{MAX_QUERIES}]")
7178

7279
if parallel == 1:
7380
start = time.perf_counter()
7481
precisions, latencies = list(
75-
zip(*[search_one(query) for query in tqdm.tqdm(queries)])
82+
zip(*[search_one(query) for query in tqdm.tqdm(used_queries)])
7683
)
7784
else:
7885
ctx = get_context(self.get_mp_start_method())
@@ -91,7 +98,7 @@ def search_all(
9198
time.sleep(15) # Wait for all processes to start
9299
start = time.perf_counter()
93100
precisions, latencies = list(
94-
zip(*pool.imap_unordered(search_one, iterable=tqdm.tqdm(queries)))
101+
zip(*pool.imap_unordered(search_one, iterable=tqdm.tqdm(used_queries)))
95102
)
96103

97104
total_time = time.perf_counter() - start

0 commit comments

Comments
 (0)