Skip to content

Commit 55b5f76

Browse files
committed
Fix: Allow reserving threads in Java
1 parent f662d31 commit 55b5f76

File tree

4 files changed

+63
-70
lines changed

4 files changed

+63
-70
lines changed

java/cloud/unum/usearch/Index.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,34 @@ public void reserve(long capacity) {
274274
if (c_ptr == 0) {
275275
throw new IllegalStateException("Index already closed");
276276
}
277-
c_reserve(c_ptr, capacity);
277+
// Pass zeros to use all available contexts on the device
278+
c_reserve(c_ptr, capacity, 0, 0);
279+
}
280+
281+
/**
282+
* Reserves memory and configures thread contexts.
283+
* Use this to explicitly control concurrent add/search capacities.
284+
*
285+
* @param capacity desired total capacity
286+
* @param threadsAdd maximum concurrent add contexts
287+
* @param threadsSearch maximum concurrent search contexts
288+
*/
289+
public void reserve(long capacity, long threadsAdd, long threadsSearch) {
290+
if (c_ptr == 0) {
291+
throw new IllegalStateException("Index already closed");
292+
}
293+
c_reserve(c_ptr, capacity, threadsAdd, threadsSearch);
294+
}
295+
296+
/**
297+
* Reserves memory and sets the same number of contexts
298+
* for both add and search operations.
299+
*
300+
* @param capacity desired total capacity
301+
* @param threads number of contexts for both add and search
302+
*/
303+
public void reserve(long capacity, long threads) {
304+
reserve(capacity, threads, threads);
278305
}
279306

280307
/**
@@ -1021,7 +1048,7 @@ private static native long c_create(
10211048

10221049
private static native long c_capacity(long ptr);
10231050

1024-
private static native void c_reserve(long ptr, long capacity);
1051+
private static native void c_reserve(long ptr, long capacity, long threadsAdd, long threadsSearch);
10251052

10261053
private static native void c_save(long ptr, String path);
10271054

java/cloud/unum/usearch/cloud_unum_usearch_Index.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,22 @@ JNIEXPORT jlong JNICALL Java_cloud_unum_usearch_Index_c_1capacity(JNIEnv*, jclas
143143
return reinterpret_cast<index_dense_t*>(c_ptr)->capacity();
144144
}
145145

146-
JNIEXPORT void JNICALL Java_cloud_unum_usearch_Index_c_1reserve(JNIEnv* env, jclass, jlong c_ptr, jlong capacity) {
147-
if (!reinterpret_cast<index_dense_t*>(c_ptr)->try_reserve(static_cast<std::size_t>(capacity))) {
146+
JNIEXPORT void JNICALL Java_cloud_unum_usearch_Index_c_1reserve(JNIEnv* env, jclass, jlong c_ptr, jlong capacity,
147+
jlong threads_add, jlong threads_search) {
148+
std::size_t t_add = static_cast<std::size_t>(threads_add);
149+
std::size_t t_search = static_cast<std::size_t>(threads_search);
150+
if (t_add == 0 || t_search == 0) {
151+
std::size_t hc = std::thread::hardware_concurrency();
152+
if (hc == 0)
153+
hc = 1; // fallback to 1 if the runtime can't report
154+
if (t_add == 0)
155+
t_add = hc;
156+
if (t_search == 0)
157+
t_search = hc;
158+
}
159+
index_limits_t limits(static_cast<std::size_t>(capacity), t_add);
160+
limits.threads_search = t_search;
161+
if (!reinterpret_cast<index_dense_t*>(c_ptr)->try_reserve(limits)) {
148162
jclass jc = (*env).FindClass("java/lang/Error");
149163
if (jc)
150164
(*env).ThrowNew(jc, "Failed to grow vector index!");

java/cloud/unum/usearch/cloud_unum_usearch_Index.h

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

java/test/IndexTest.java

Lines changed: 16 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -267,13 +267,14 @@ public void testGetIntoBufferMethods() {
267267
@Test
268268
public void testConcurrentAdd() throws Exception {
269269
try (Index index = new Index.Config().metric("cos").dimensions(4).build()) {
270-
index.reserve(1000);
270+
final int threadsCount = 10;
271+
index.reserve(1000, threadsCount);
271272

272-
ExecutorService executor = Executors.newFixedThreadPool(10);
273+
ExecutorService executor = Executors.newFixedThreadPool(threadsCount);
273274
@SuppressWarnings("unchecked")
274-
CompletableFuture<Void>[] futures = new CompletableFuture[10];
275+
CompletableFuture<Void>[] futures = new CompletableFuture[threadsCount];
275276

276-
for (int t = 0; t < 10; t++) {
277+
for (int t = 0; t < threadsCount; t++) {
277278
final int threadId = t;
278279
futures[t]
279280
= CompletableFuture.runAsync(
@@ -290,25 +291,26 @@ public void testConcurrentAdd() throws Exception {
290291
CompletableFuture.allOf(futures).get(10, TimeUnit.SECONDS);
291292
executor.shutdown();
292293

293-
assertEquals(500, index.size());
294+
assertEquals(50L * threadsCount, index.size());
294295
}
295296
}
296297

297298
@Test
298299
public void testConcurrentSearch() throws Exception {
299300
try (Index index = new Index.Config().metric("cos").dimensions(4).build()) {
300-
index.reserve(100);
301+
final int threadsCount = 5;
302+
index.reserve(100, threadsCount);
301303

302304
// Add some vectors first
303305
for (int i = 0; i < 100; i++) {
304306
index.add(i, randomVector(4));
305307
}
306308

307-
ExecutorService executor = Executors.newFixedThreadPool(5);
309+
ExecutorService executor = Executors.newFixedThreadPool(threadsCount);
308310
@SuppressWarnings("unchecked")
309-
CompletableFuture<long[]>[] futures = new CompletableFuture[5];
311+
CompletableFuture<long[]>[] futures = new CompletableFuture[threadsCount];
310312

311-
for (int t = 0; t < 5; t++) {
313+
for (int t = 0; t < threadsCount; t++) {
312314
futures[t]
313315
= CompletableFuture.supplyAsync(
314316
() -> {
@@ -328,56 +330,6 @@ public void testConcurrentSearch() throws Exception {
328330
}
329331
}
330332

331-
@Test
332-
public void testMixedConcurrency() throws Exception {
333-
try (Index index = new Index.Config().metric("cos").dimensions(3).build()) {
334-
index.reserve(200);
335-
336-
ExecutorService executor = Executors.newFixedThreadPool(8);
337-
@SuppressWarnings("unchecked")
338-
CompletableFuture<Void>[] addFutures = new CompletableFuture[4];
339-
@SuppressWarnings("unchecked")
340-
CompletableFuture<Void>[] searchFutures = new CompletableFuture[4];
341-
342-
// Add operations
343-
for (int t = 0; t < 4; t++) {
344-
final int threadId = t;
345-
addFutures[t]
346-
= CompletableFuture.runAsync(
347-
() -> {
348-
for (int i = 0; i < 30; i++) {
349-
long key = threadId * 30L + i;
350-
index.add(key, randomVector(3));
351-
}
352-
},
353-
executor);
354-
}
355-
356-
// Wait for some adds to complete, then start searches
357-
Thread.sleep(100);
358-
359-
// Search operations
360-
for (int t = 0; t < 4; t++) {
361-
searchFutures[t]
362-
= CompletableFuture.runAsync(
363-
() -> {
364-
for (int i = 0; i < 10; i++) {
365-
float[] queryVector = randomVector(3);
366-
long[] results = index.search(queryVector, 5);
367-
assertTrue(results.length >= 0);
368-
}
369-
},
370-
executor);
371-
}
372-
373-
CompletableFuture.allOf(addFutures).get(15, TimeUnit.SECONDS);
374-
CompletableFuture.allOf(searchFutures).get(15, TimeUnit.SECONDS);
375-
executor.shutdown();
376-
377-
assertEquals(120, index.size());
378-
}
379-
}
380-
381333
@Test
382334
public void testBatchAdd() {
383335
try (Index index = new Index.Config().metric("cos").dimensions(2).build()) {
@@ -751,30 +703,30 @@ public void testPlatformCapabilities() {
751703
String[] available = Index.hardwareAccelerationAvailable();
752704
assertNotEquals("Available capabilities should not be null", null, available);
753705
assertTrue("Platform should have at least serial capability", available.length > 0);
754-
706+
755707
// Test compile-time capabilities
756708
String[] compiled = Index.hardwareAccelerationCompiled();
757709
assertNotEquals("Compiled capabilities should not be null", null, compiled);
758710
assertTrue("Should have at least serial compiled", compiled.length > 0);
759-
711+
760712
// Should always include serial as baseline in both
761713
boolean hasAvailableSerial = false;
762714
boolean hasCompiledSerial = false;
763-
715+
764716
for (String cap : available) {
765717
if ("serial".equals(cap)) {
766718
hasAvailableSerial = true;
767719
break;
768720
}
769721
}
770-
722+
771723
for (String cap : compiled) {
772724
if ("serial".equals(cap)) {
773725
hasCompiledSerial = true;
774726
break;
775727
}
776728
}
777-
729+
778730
assertTrue("Platform should always support serial capability", hasAvailableSerial);
779731
assertTrue("Serial should always be compiled", hasCompiledSerial);
780732

0 commit comments

Comments
 (0)