11import functools
22import time
3- from multiprocessing import get_context
3+ from multiprocessing import get_context , Barrier , Process , Queue
44from typing import Iterable , List , Optional , Tuple
55import itertools
66
@@ -65,6 +65,10 @@ def search_all(
6565 ):
6666 parallel = self .search_params .get ("parallel" , 1 )
6767 top = self .search_params .get ("top" , None )
68+
69+ # Convert queries to a list to calculate its length
70+ queries = list (queries ) # This allows us to calculate len(queries)
71+
6872 # setup_search may require initialized client
6973 self .init_client (
7074 self .host , distance , self .connection_params , self .search_params
@@ -80,31 +84,56 @@ def search_all(
8084 print (f"Limiting queries to [0:{ MAX_QUERIES - 1 } ]" )
8185
8286 if parallel == 1 :
87+ # Single-threaded execution
8388 start = time .perf_counter ()
84- precisions , latencies = list (
85- zip (* [search_one (query ) for query in tqdm .tqdm (used_queries )])
86- )
89+
90+ results = [search_one (query ) for query in tqdm .tqdm (queries )]
91+ total_time = time .perf_counter () - start
92+
8793 else :
88- ctx = get_context (self .get_mp_start_method ())
94+ # Dynamically calculate chunk size
95+ chunk_size = max (1 , len (queries ) // parallel )
96+ query_chunks = list (chunked_iterable (queries , chunk_size ))
8997
90- with ctx .Pool (
91- processes = parallel ,
92- initializer = self .__class__ .init_client ,
93- initargs = (
98+ # Function to be executed by each worker process
99+ def worker_function (chunk , result_queue ):
100+ self .__class__ .init_client (
94101 self .host ,
95102 distance ,
96103 self .connection_params ,
97104 self .search_params ,
98- ),
99- ) as pool :
100- if parallel > 10 :
101- time .sleep (15 ) # Wait for all processes to start
102- start = time .perf_counter ()
103- precisions , latencies = list (
104- zip (* pool .imap_unordered (search_one , iterable = tqdm .tqdm (used_queries )))
105105 )
106+ self .setup_search ()
107+ results = process_chunk (chunk , search_one )
108+ result_queue .put (results )
109+
110+ # Create a queue to collect results
111+ result_queue = Queue ()
112+
113+ # Create and start worker processes
114+ processes = []
115+ for chunk in query_chunks :
116+ process = Process (target = worker_function , args = (chunk , result_queue ))
117+ processes .append (process )
118+ process .start ()
119+
120+ # Start measuring time for the critical work
121+ start = time .perf_counter ()
106122
107- total_time = time .perf_counter () - start
123+ # Collect results from all worker processes
124+ results = []
125+ for _ in processes :
126+ results .extend (result_queue .get ())
127+
128+ # Wait for all worker processes to finish
129+ for process in processes :
130+ process .join ()
131+
132+ # Stop measuring time for the critical work
133+ total_time = time .perf_counter () - start
134+
135+ # Extract precisions and latencies (outside the timed section)
136+ precisions , latencies = zip (* results )
108137
109138 self .__class__ .delete_client ()
110139
@@ -132,3 +161,20 @@ def post_search(self):
132161 @classmethod
133162 def delete_client (cls ):
134163 pass
164+
165+
166+ def chunked_iterable (iterable , size ):
167+ """Yield successive chunks of a given size from an iterable."""
168+ it = iter (iterable )
169+ while chunk := list (itertools .islice (it , size )):
170+ yield chunk
171+
172+
173+ def process_chunk (chunk , search_one ):
174+ """Process a chunk of queries using the search_one function."""
175+ return [search_one (query ) for query in chunk ]
176+
177+
178+ def process_chunk_wrapper (chunk , search_one ):
179+ """Wrapper to process a chunk of queries."""
180+ return process_chunk (chunk , search_one )
0 commit comments