Skip to content

Commit 4c1d080

Browse files
committed
implemented custom process management , instead of using pool
1 parent deaf5ab commit 4c1d080

File tree

3 files changed

+51
-25
lines changed

3 files changed

+51
-25
lines changed

datasets/datasets.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,14 @@
330330
"type": "h5",
331331
"path": "laion-img-emb-512/laion-img-emb-512-1M-cosine.hdf5",
332332
"link": "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion400m/laion-img-emb-512-100M-cosine.hdf5"
333+
},
334+
{
335+
"name": "laion-img-emb-512-1M-100ktrain-cosine",
336+
"vector_size": 512,
337+
"distance": "cosine",
338+
"type": "h5",
339+
"path": "laion-img-emb-512/laion-img-emb-512-1M-100ktrain-cosine.hdf5",
340+
"link": "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion400m/laion-img-emb-512-100M-cosine.hdf5"
333341
}
334342

335343
]

engine/base_client/client.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ def save_search_results(
3636
):
3737
now = datetime.now()
3838
timestamp = now.strftime("%Y-%m-%d-%H-%M-%S")
39+
pid = os.getpid() # Get the current process ID
3940
experiments_file = (
40-
f"{self.name}-{dataset_name}-search-{search_id}-{timestamp}.json"
41+
f"{self.name}-{dataset_name}-search-{search_id}-{pid}-{timestamp}.json"
4142
)
4243
result_path = RESULTS_DIR / experiments_file
4344
with open(result_path, "w") as out:
@@ -89,7 +90,8 @@ def run_experiment(
8990
reader = dataset.get_reader(execution_params.get("normalize", False))
9091

9192
if skip_if_exists:
92-
glob_pattern = f"{self.name}-{dataset.config.name}-search-*-*.json"
93+
pid = os.getpid() # Get the current process ID
94+
glob_pattern = f"{self.name}-{dataset.config.name}-search-{pid}-*-*.json"
9395
existing_results = list(RESULTS_DIR.glob(glob_pattern))
9496
if len(existing_results) == len(self.searchers):
9597
print(
@@ -124,8 +126,9 @@ def run_experiment(
124126
for search_id, searcher in enumerate(self.searchers):
125127

126128
if skip_if_exists:
129+
pid = os.getpid() # Get the current process ID
127130
glob_pattern = (
128-
f"{self.name}-{dataset.config.name}-search-{search_id}-*.json"
131+
f"{self.name}-{dataset.config.name}-search-{search_id}-{pid}-*.json"
129132
)
130133
existing_results = list(RESULTS_DIR.glob(glob_pattern))
131134
print("Pattern", glob_pattern, "Results:", existing_results)

engine/base_client/search.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import functools
22
import time
3-
from multiprocessing import get_context, Barrier
3+
from multiprocessing import get_context, Barrier, Process, Queue
44
from typing import Iterable, List, Optional, Tuple
55
from itertools import islice
66

@@ -75,42 +75,52 @@ def search_all(
7575

7676
search_one = functools.partial(self.__class__._search_one, top=top)
7777

78+
# Initialize the start time
79+
start = time.perf_counter()
80+
7881
if parallel == 1:
79-
start = time.perf_counter()
82+
# Single-threaded execution
8083
precisions, latencies = list(
8184
zip(*[search_one(query) for query in tqdm.tqdm(queries)])
8285
)
8386
else:
84-
ctx = get_context(self.get_mp_start_method())
85-
86-
# Create a Barrier to synchronize processes
87-
barrier = Barrier(parallel)
87+
# Dynamically calculate chunk size
88+
chunk_size = max(1, len(queries) // parallel)
89+
query_chunks = list(chunked_iterable(queries, chunk_size))
8890

89-
def process_initializer():
90-
"""Initialize each process before starting the search."""
91+
# Function to be executed by each worker process
92+
def worker_function(chunk, result_queue):
9193
self.__class__.init_client(
9294
self.host,
9395
distance,
9496
self.connection_params,
9597
self.search_params,
9698
)
9799
self.setup_search()
98-
barrier.wait() # Wait for all processes to be ready
100+
results = process_chunk(chunk, search_one)
101+
result_queue.put(results)
99102

100-
# Dynamically calculate chunk size
101-
chunk_size = max(1, len(queries) // parallel)
102-
query_chunks = list(chunked_iterable(queries, chunk_size))
103+
# Create a queue to collect results
104+
result_queue = Queue()
103105

104-
with ctx.Pool(
105-
processes=parallel,
106-
initializer=process_initializer,
107-
) as pool:
108-
start = time.perf_counter()
109-
results = pool.starmap(
110-
process_chunk,
111-
[(chunk, search_one) for chunk in query_chunks],
112-
)
113-
precisions, latencies = zip(*[result for chunk in results for result in chunk])
106+
# Create and start worker processes
107+
processes = []
108+
for chunk in query_chunks:
109+
process = Process(target=worker_function, args=(chunk, result_queue))
110+
processes.append(process)
111+
process.start()
112+
113+
# Collect results from all worker processes
114+
results = []
115+
for _ in processes:
116+
results.extend(result_queue.get())
117+
118+
# Wait for all worker processes to finish
119+
for process in processes:
120+
process.join()
121+
122+
# Extract precisions and latencies
123+
precisions, latencies = zip(*results)
114124

115125
total_time = time.perf_counter() - start
116126

@@ -151,3 +161,8 @@ def chunked_iterable(iterable, size):
151161
def process_chunk(chunk, search_one):
152162
"""Process a chunk of queries using the search_one function."""
153163
return [search_one(query) for query in chunk]
164+
165+
166+
def process_chunk_wrapper(chunk, search_one):
167+
"""Wrapper to process a chunk of queries."""
168+
return process_chunk(chunk, search_one)

0 commit comments

Comments
 (0)