|
1 | 1 | import functools |
2 | 2 | import time |
3 | | -from multiprocessing import get_context, Barrier |
| 3 | +from multiprocessing import get_context, Barrier, Process, Queue |
4 | 4 | from typing import Iterable, List, Optional, Tuple |
5 | 5 | from itertools import islice |
6 | 6 |
|
@@ -75,42 +75,52 @@ def search_all( |
75 | 75 |
|
76 | 76 | search_one = functools.partial(self.__class__._search_one, top=top) |
77 | 77 |
|
| 78 | + # Initialize the start time |
| 79 | + start = time.perf_counter() |
| 80 | + |
78 | 81 | if parallel == 1: |
79 | | - start = time.perf_counter() |
| 82 | + # Single-threaded execution |
80 | 83 | precisions, latencies = list( |
81 | 84 | zip(*[search_one(query) for query in tqdm.tqdm(queries)]) |
82 | 85 | ) |
83 | 86 | else: |
84 | | - ctx = get_context(self.get_mp_start_method()) |
85 | | - |
86 | | - # Create a Barrier to synchronize processes |
87 | | - barrier = Barrier(parallel) |
| 87 | + # Dynamically calculate chunk size |
| 88 | + chunk_size = max(1, len(queries) // parallel) |
| 89 | + query_chunks = list(chunked_iterable(queries, chunk_size)) |
88 | 90 |
|
89 | | - def process_initializer(): |
90 | | - """Initialize each process before starting the search.""" |
| 91 | + # Function to be executed by each worker process |
| 92 | + def worker_function(chunk, result_queue): |
91 | 93 | self.__class__.init_client( |
92 | 94 | self.host, |
93 | 95 | distance, |
94 | 96 | self.connection_params, |
95 | 97 | self.search_params, |
96 | 98 | ) |
97 | 99 | self.setup_search() |
98 | | - barrier.wait() # Wait for all processes to be ready |
| 100 | + results = process_chunk(chunk, search_one) |
| 101 | + result_queue.put(results) |
99 | 102 |
|
100 | | - # Dynamically calculate chunk size |
101 | | - chunk_size = max(1, len(queries) // parallel) |
102 | | - query_chunks = list(chunked_iterable(queries, chunk_size)) |
| 103 | + # Create a queue to collect results |
| 104 | + result_queue = Queue() |
103 | 105 |
|
104 | | - with ctx.Pool( |
105 | | - processes=parallel, |
106 | | - initializer=process_initializer, |
107 | | - ) as pool: |
108 | | - start = time.perf_counter() |
109 | | - results = pool.starmap( |
110 | | - process_chunk, |
111 | | - [(chunk, search_one) for chunk in query_chunks], |
112 | | - ) |
113 | | - precisions, latencies = zip(*[result for chunk in results for result in chunk]) |
| 106 | + # Create and start worker processes |
| 107 | + processes = [] |
| 108 | + for chunk in query_chunks: |
| 109 | + process = Process(target=worker_function, args=(chunk, result_queue)) |
| 110 | + processes.append(process) |
| 111 | + process.start() |
| 112 | + |
| 113 | + # Collect results from all worker processes |
| 114 | + results = [] |
| 115 | + for _ in processes: |
| 116 | + results.extend(result_queue.get()) |
| 117 | + |
| 118 | + # Wait for all worker processes to finish |
| 119 | + for process in processes: |
| 120 | + process.join() |
| 121 | + |
| 122 | + # Extract precisions and latencies |
| 123 | + precisions, latencies = zip(*results) |
114 | 124 |
|
115 | 125 | total_time = time.perf_counter() - start |
116 | 126 |
|
@@ -151,3 +161,8 @@ def chunked_iterable(iterable, size): |
151 | 161 | def process_chunk(chunk, search_one): |
152 | 162 | """Process a chunk of queries using the search_one function.""" |
153 | 163 | return [search_one(query) for query in chunk] |
| 164 | + |
| 165 | + |
| 166 | +def process_chunk_wrapper(chunk, search_one): |
| 167 | + """Wrapper to process a chunk of queries.""" |
| 168 | + return process_chunk(chunk, search_one) |
0 commit comments