From d1f34af51d4618c9e793e27e710653246f6195f3 Mon Sep 17 00:00:00 2001 From: cpegeric Date: Wed, 3 Sep 2025 10:17:33 +0100 Subject: [PATCH 01/10] Add: New GoLang APIs for more types --- golang/lib.go | 98 +++++++++++-- golang/lib_test.go | 355 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 439 insertions(+), 14 deletions(-) diff --git a/golang/lib.go b/golang/lib.go index 778247ac6..c77aae3e9 100644 --- a/golang/lib.go +++ b/golang/lib.go @@ -111,6 +111,25 @@ func (a Quantization) String() string { } } +func (a Quantization) CValue() C.usearch_scalar_kind_t { + switch a { + case F16: + return C.usearch_scalar_f16_k + case F32: + return C.usearch_scalar_f32_k + case F64: + return C.usearch_scalar_f64_k + case I8: + return C.usearch_scalar_i8_k + case B1: + return C.usearch_scalar_b1_k + case BF16: + return C.usearch_scalar_bf16_k + default: + return C.usearch_scalar_unknown_k + } +} + // IndexConfig represents the configuration options for initializing a USearch index. type IndexConfig struct { Quantization Quantization // The scalar kind used for quantization of vector data during indexing. @@ -161,20 +180,7 @@ func NewIndex(conf IndexConfig) (index *Index, err error) { options.metric_kind = conf.Metric.CValue() // Map the quantization method - switch conf.Quantization { - case F16: - options.quantization = C.usearch_scalar_f16_k - case F32: - options.quantization = C.usearch_scalar_f32_k - case F64: - options.quantization = C.usearch_scalar_f64_k - case I8: - options.quantization = C.usearch_scalar_i8_k - case B1: - options.quantization = C.usearch_scalar_b1_k - default: - options.quantization = C.usearch_scalar_unknown_k - } + options.quantization = conf.Quantization.CValue() var errorMessage *C.char ptr := C.usearch_init(&options, (*C.usearch_error_t)(&errorMessage)) @@ -361,6 +367,20 @@ func (index *Index) Add(key Key, vec []float32) error { return nil } +// Add adds a vector with a specified key to the index. +func (index *Index) AddWithPointer(key Key, vec unsafe.Pointer) error { + if index.opaque_handle == nil { + panic("Index is uninitialized") + } + + var errorMessage *C.char + C.usearch_add((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), vec, index.config.Quantization.CValue(), (*C.usearch_error_t)(&errorMessage)) + if errorMessage != nil { + return errors.New(C.GoString(errorMessage)) + } + return nil +} + // Remove removes the vector associated with the given key from the index. func (index *Index) Remove(key Key) error { if index.opaque_handle == nil { @@ -428,6 +448,17 @@ func Distance(vec1 []float32, vec2 []float32, dims uint, metric Metric) (float32 return float32(dist), nil } +// Distance computes the distance between two vectors +func DistanceWithPointer(vec1 unsafe.Pointer, vec2 unsafe.Pointer, dims uint, metric Metric, quantization Quantization) (float32, error) { + + var errorMessage *C.char + dist := C.usearch_distance(vec1, vec2, quantization.CValue(), C.size_t(dims), metric.CValue(), (*C.usearch_error_t)(&errorMessage)) + if errorMessage != nil { + return 0, errors.New(C.GoString(errorMessage)) + } + return float32(dist), nil +} + // Search performs k-Approximate Nearest Neighbors Search for the closest vectors to the query vector. func (index *Index) Search(query []float32, limit uint) (keys []Key, distances []float32, err error) { if index.opaque_handle == nil { @@ -450,6 +481,25 @@ func (index *Index) Search(query []float32, limit uint) (keys []Key, distances [ return keys, distances, nil } +// Search performs k-Approximate Nearest Neighbors Search for the closest vectors to the query vector. +func (index *Index) SearchWithPointer(query unsafe.Pointer, limit uint) (keys []Key, distances []float32, err error) { + if index.opaque_handle == nil { + panic("Index is uninitialized") + } + + keys = make([]Key, limit) + distances = make([]float32, limit) + var errorMessage *C.char + count := uint(C.usearch_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), query, index.config.Quantization.CValue(), (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) + if errorMessage != nil { + return nil, nil, errors.New(C.GoString(errorMessage)) + } + + keys = keys[:count] + distances = distances[:count] + return keys, distances, nil +} + // ExactSearch is a multithreaded exact nearest neighbors search func ExactSearch(dataset []float32, queries []float32, dataset_size uint, queries_size uint, dataset_stride uint, queries_stride uint, dims uint, metric Metric, @@ -476,6 +526,26 @@ func ExactSearch(dataset []float32, queries []float32, dataset_size uint, querie return keys, distances, nil } +// ExactSearch is a multithreaded exact nearest neighbors search +func ExactSearchWithPointer(dataset unsafe.Pointer, queries unsafe.Pointer, dataset_size uint, queries_size uint, + dataset_stride uint, queries_stride uint, dims uint, metric Metric, quantization Quantization, + count uint, threads uint, keys_stride uint, distances_stride uint) (keys []Key, distances []float32, err error) { + + keys = make([]Key, count) + distances = make([]float32, count) + var errorMessage *C.char + C.usearch_exact_search(dataset, C.size_t(dataset_size), C.size_t(dataset_stride), queries, C.size_t(queries_size), C.size_t(queries_stride), + quantization.CValue(), C.size_t(dims), metric.CValue(), C.size_t(count), C.size_t(threads), + (*C.usearch_key_t)(&keys[0]), C.size_t(keys_stride), (*C.usearch_distance_t)(&distances[0]), C.size_t(distances_stride), (*C.usearch_error_t)(&errorMessage)) + if errorMessage != nil { + return nil, nil, errors.New(C.GoString(errorMessage)) + } + + keys = keys[:count] + distances = distances[:count] + return keys, distances, nil +} + // Save saves the index to a specified buffer. func (index *Index) SaveBuffer(buf []byte, buffer_size uint) error { if index.opaque_handle == nil { diff --git a/golang/lib_test.go b/golang/lib_test.go index fb30a9016..653eb0a48 100644 --- a/golang/lib_test.go +++ b/golang/lib_test.go @@ -4,6 +4,7 @@ import ( "math" "runtime" "testing" + "unsafe" ) func TestUSearch(t *testing.T) { @@ -102,6 +103,43 @@ func TestUSearch(t *testing.T) { } }) + t.Run("Test Insertion with pointer", func(t *testing.T) { + dim := uint(128) + conf := DefaultConfig(dim) + ind, err := NewIndex(conf) + if err != nil { + t.Fatalf("Failed to construct the index: %s", err) + } + defer ind.Destroy() + + err = ind.Reserve(100) + if err != nil { + t.Fatalf("Failed to reserve capacity: %s", err) + } + + err = ind.ChangeThreadsAdd(10) + if err != nil { + t.Fatalf("Failed to change threads add: %s", err) + } + + vec := make([]float32, dim) + vec[0] = 40.0 + vec[1] = 2.0 + + err = ind.AddWithPointer(42, unsafe.Pointer(&vec[0])) + if err != nil { + t.Fatalf("Failed to insert: %s", err) + } + + found_len, err := ind.Len() + if err != nil { + t.Fatalf("Failed to retrieve size after insertion: %s", err) + } + if found_len != 1 { + t.Fatalf("Expected size to be 1, got %d", found_len) + } + }) + t.Run("Test Search", func(t *testing.T) { dim := uint(128) conf := DefaultConfig(dim) @@ -143,6 +181,48 @@ func TestUSearch(t *testing.T) { // TODO: Add exact search }) + t.Run("Test Search With Pointer", func(t *testing.T) { + dim := uint(128) + conf := DefaultConfig(dim) + ind, err := NewIndex(conf) + if err != nil { + t.Fatalf("Failed to construct the index: %s", err) + } + defer ind.Destroy() + + err = ind.Reserve(100) + if err != nil { + t.Fatalf("Failed to reserve capacity: %s", err) + } + + err = ind.ChangeThreadsSearch(10) + if err != nil { + t.Fatalf("Failed to change threads search: %s", err) + } + + vec := make([]float32, dim) + vec[0] = 40.0 + vec[1] = 2.0 + + ptr := unsafe.Pointer(&vec[0]) + err = ind.AddWithPointer(42, ptr) + if err != nil { + t.Fatalf("Failed to insert: %s", err) + } + + keys, distances, err := ind.SearchWithPointer(ptr, 10) + if err != nil { + t.Fatalf("Failed to search: %s", err) + } + + const tolerance = 1e-2 // For example, this sets the tolerance to 0.01 + if keys[0] != 42 || math.Abs(float64(distances[0])) > tolerance { + t.Fatalf("Expected result 42 with distance 0, got key %d with distance %f", keys[0], distances[0]) + } + + // TODO: Add exact search + }) + t.Run("Test Save and Load", func(t *testing.T) { dim := uint(128) conf := DefaultConfig(dim) @@ -226,4 +306,279 @@ func TestUSearch(t *testing.T) { // TODO: Check file save/load/metadata }) + + t.Run("Test Save and Load With Pointer", func(t *testing.T) { + dim := uint(128) + conf := DefaultConfig(dim) + ind, err := NewIndex(conf) + if err != nil { + t.Fatalf("Failed to construct the index: %s", err) + } + defer ind.Destroy() + ind2, err := NewIndex(conf) + if err != nil { + t.Fatalf("Failed to construct the index: %s", err) + } + defer ind2.Destroy() + indView, err := NewIndex(conf) + if err != nil { + t.Fatalf("Failed to construct the index: %s", err) + } + defer indView.Destroy() + + err = ind.Reserve(100) + if err != nil { + t.Fatalf("Failed to reserve capacity: %s", err) + } + + vec := make([]float32, dim) + for i := uint(0); i < dim; i++ { + vec[i] = float32(i) + 0.2 + err = ind.AddWithPointer(uint64(i), unsafe.Pointer(&vec[0])) + if err != nil { + t.Fatalf("Failed to insert: %s", err) + } + } + + ind_length, err := ind.Len() + if err != nil { + t.Fatalf("Failed to retrieve size: %s", err) + } + + // TODO: Add invalid save and loads? + buffer_size := uint(1 * 1024 * 1024) + buf := make([]byte, buffer_size) + err = ind.SaveBuffer(buf, buffer_size) + if err != nil { + t.Fatalf("Failed to save the index to a buffer: %s", err) + } + + err = ind2.LoadBuffer(buf, buffer_size) + if err != nil { + t.Fatalf("Failed to load the index from a buffer: %s", err) + } + + ind2_length, err := ind2.Len() + if err != nil { + t.Fatalf("Failed to retrieve size: %s", err) + } + if ind_length != ind2_length { + t.Fatalf("Loaded index length %d doesn't match original of %d ", ind2_length, ind_length) + } + // TODO: Check some values + + err = indView.ViewBuffer(buf, buffer_size) + if err != nil { + t.Fatalf("Failed to load the view from a buffer: %s", err) + } + + indView_length, err := indView.Len() + if err != nil { + t.Fatalf("Failed to retrieve size: %s", err) + } + if ind_length != indView_length { + t.Fatalf("Loaded view length %d doesn't match original of %d ", indView_length, ind_length) + } + + conf, err = MetadataBuffer(buf, buffer_size) + if err != nil { + t.Fatalf("Failed to load the metadata from a buffer: %s", err) + } + if conf != ind.config { + t.Fatalf("Loaded metadata doesn't match the index metadata") + } + + // TODO: Check file save/load/metadata + }) +} + +func TestUSearchDataType(t *testing.T) { + + t.Run("Test Save and Load With F64", func(t *testing.T) { + dim := uint(128) + conf := DefaultConfig(dim) + conf.Quantization = F64 + ind, err := NewIndex(conf) + if err != nil { + t.Fatalf("Failed to construct the index: %s", err) + } + defer ind.Destroy() + ind2, err := NewIndex(conf) + if err != nil { + t.Fatalf("Failed to construct the index: %s", err) + } + defer ind2.Destroy() + indView, err := NewIndex(conf) + if err != nil { + t.Fatalf("Failed to construct the index: %s", err) + } + defer indView.Destroy() + + err = ind.Reserve(100) + if err != nil { + t.Fatalf("Failed to reserve capacity: %s", err) + } + + vec := make([]float64, dim) + for i := uint(0); i < dim; i++ { + vec[i] = float64(i) + 0.2 + err = ind.AddWithPointer(uint64(i), unsafe.Pointer(&vec[0])) + if err != nil { + t.Fatalf("Failed to insert: %s", err) + } + } + + ind_length, err := ind.Len() + if err != nil { + t.Fatalf("Failed to retrieve size: %s", err) + } + + // TODO: Add invalid save and loads? + buffer_size := uint(1 * 1024 * 1024) + buf := make([]byte, buffer_size) + err = ind.SaveBuffer(buf, buffer_size) + if err != nil { + t.Fatalf("Failed to save the index to a buffer: %s", err) + } + + err = ind2.LoadBuffer(buf, buffer_size) + if err != nil { + t.Fatalf("Failed to load the index from a buffer: %s", err) + } + + ind2_length, err := ind2.Len() + if err != nil { + t.Fatalf("Failed to retrieve size: %s", err) + } + if ind_length != ind2_length { + t.Fatalf("Loaded index length %d doesn't match original of %d ", ind2_length, ind_length) + } + // TODO: Check some values + + err = indView.ViewBuffer(buf, buffer_size) + if err != nil { + t.Fatalf("Failed to load the view from a buffer: %s", err) + } + + indView_length, err := indView.Len() + if err != nil { + t.Fatalf("Failed to retrieve size: %s", err) + } + if ind_length != indView_length { + t.Fatalf("Loaded view length %d doesn't match original of %d ", indView_length, ind_length) + } + + conf, err = MetadataBuffer(buf, buffer_size) + if err != nil { + t.Fatalf("Failed to load the metadata from a buffer: %s", err) + } + if conf != ind.config { + t.Fatalf("Loaded metadata doesn't match the index metadata") + } + + // TODO: Check file save/load/metadata + keys, distances, err := ind.SearchWithPointer(unsafe.Pointer(&vec[0]), dim) + if err != nil { + t.Fatalf("Failed to search: %s", err) + } + + const tolerance = 1e-2 // For example, this sets the tolerance to 0.01 + if keys[0] != 127 || math.Abs(float64(distances[0])) > tolerance { + t.Fatalf("Expected result 42 with distance 0, got key %d with distance %f", keys[0], distances[0]) + } + }) + + t.Run("Test Save and Load With F64", func(t *testing.T) { + dim := uint(128) + conf := DefaultConfig(dim) + conf.Quantization = F64 + ind, err := NewIndex(conf) + if err != nil { + t.Fatalf("Failed to construct the index: %s", err) + } + defer ind.Destroy() + ind2, err := NewIndex(conf) + if err != nil { + t.Fatalf("Failed to construct the index: %s", err) + } + defer ind2.Destroy() + indView, err := NewIndex(conf) + if err != nil { + t.Fatalf("Failed to construct the index: %s", err) + } + defer indView.Destroy() + + err = ind.Reserve(100) + if err != nil { + t.Fatalf("Failed to reserve capacity: %s", err) + } + + vec := make([]float64, dim) + for i := uint(0); i < dim; i++ { + vec[i] = float64(i) + 0.2 + err = ind.AddWithPointer(uint64(i), unsafe.Pointer(&vec[0])) + if err != nil { + t.Fatalf("Failed to insert: %s", err) + } + } + + ind_length, err := ind.Len() + if err != nil { + t.Fatalf("Failed to retrieve size: %s", err) + } + + // TODO: Add invalid save and loads? + buffer_size := uint(1 * 1024 * 1024) + buf := make([]byte, buffer_size) + err = ind.SaveBuffer(buf, buffer_size) + if err != nil { + t.Fatalf("Failed to save the index to a buffer: %s", err) + } + + err = ind2.LoadBuffer(buf, buffer_size) + if err != nil { + t.Fatalf("Failed to load the index from a buffer: %s", err) + } + + ind2_length, err := ind2.Len() + if err != nil { + t.Fatalf("Failed to retrieve size: %s", err) + } + if ind_length != ind2_length { + t.Fatalf("Loaded index length %d doesn't match original of %d ", ind2_length, ind_length) + } + // TODO: Check some values + + err = indView.ViewBuffer(buf, buffer_size) + if err != nil { + t.Fatalf("Failed to load the view from a buffer: %s", err) + } + + indView_length, err := indView.Len() + if err != nil { + t.Fatalf("Failed to retrieve size: %s", err) + } + if ind_length != indView_length { + t.Fatalf("Loaded view length %d doesn't match original of %d ", indView_length, ind_length) + } + + conf, err = MetadataBuffer(buf, buffer_size) + if err != nil { + t.Fatalf("Failed to load the metadata from a buffer: %s", err) + } + if conf != ind.config { + t.Fatalf("Loaded metadata doesn't match the index metadata") + } + + // TODO: Check file save/load/metadata + keys, distances, err := ind.SearchWithPointer(unsafe.Pointer(&vec[0]), dim) + if err != nil { + t.Fatalf("Failed to search: %s", err) + } + + const tolerance = 1e-2 // For example, this sets the tolerance to 0.01 + if keys[0] != 127 || math.Abs(float64(distances[0])) > tolerance { + t.Fatalf("Expected result 42 with distance 0, got key %d with distance %f", keys[0], distances[0]) + } + }) } From f222b5f8fae4fb7682515ba6594c35b2dc7c67d4 Mon Sep 17 00:00:00 2001 From: cpegeric Date: Wed, 3 Sep 2025 10:48:03 +0100 Subject: [PATCH 02/10] Improve: Test new GoLang APIs --- golang/lib_test.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/golang/lib_test.go b/golang/lib_test.go index 653eb0a48..13131c9c7 100644 --- a/golang/lib_test.go +++ b/golang/lib_test.go @@ -394,10 +394,9 @@ func TestUSearch(t *testing.T) { func TestUSearchDataType(t *testing.T) { - t.Run("Test Save and Load With F64", func(t *testing.T) { + t.Run("Test Save and Load With F32", func(t *testing.T) { dim := uint(128) conf := DefaultConfig(dim) - conf.Quantization = F64 ind, err := NewIndex(conf) if err != nil { t.Fatalf("Failed to construct the index: %s", err) @@ -419,9 +418,9 @@ func TestUSearchDataType(t *testing.T) { t.Fatalf("Failed to reserve capacity: %s", err) } - vec := make([]float64, dim) + vec := make([]float32, dim) for i := uint(0); i < dim; i++ { - vec[i] = float64(i) + 0.2 + vec[i] = float32(i) + 0.2 err = ind.AddWithPointer(uint64(i), unsafe.Pointer(&vec[0])) if err != nil { t.Fatalf("Failed to insert: %s", err) From 12ee3528e5a27286d898b4f188a1eb0b8e36c3f5 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 3 Sep 2025 11:20:30 +0000 Subject: [PATCH 03/10] Improve: Rename *WithPointer to *Unsafe --- golang/lib.go | 98 ++++++++++++++++++++++++++++++++++++++++++---- golang/lib_test.go | 16 ++++---- 2 files changed, 98 insertions(+), 16 deletions(-) diff --git a/golang/lib.go b/golang/lib.go index c77aae3e9..cb9f0d2c1 100644 --- a/golang/lib.go +++ b/golang/lib.go @@ -367,8 +367,12 @@ func (index *Index) Add(key Key, vec []float32) error { return nil } -// Add adds a vector with a specified key to the index. -func (index *Index) AddWithPointer(key Key, vec unsafe.Pointer) error { +// AddUnsafe adds a vector with a specified key using an unsafe pointer. +// +// The memory behind `vec` must contain exactly `Dimensions()` scalars of the +// type matching the index quantization (e.g., F32/F64/I8/B1/F16/BF16). +// Passing a pointer to data of a different type is undefined behavior. +func (index *Index) AddUnsafe(key Key, vec unsafe.Pointer) error { if index.opaque_handle == nil { panic("Index is uninitialized") } @@ -448,8 +452,11 @@ func Distance(vec1 []float32, vec2 []float32, dims uint, metric Metric) (float32 return float32(dist), nil } -// Distance computes the distance between two vectors -func DistanceWithPointer(vec1 unsafe.Pointer, vec2 unsafe.Pointer, dims uint, metric Metric, quantization Quantization) (float32, error) { +// DistanceUnsafe computes the distance between two vectors using unsafe pointers. +// +// The memory behind `vec1` and `vec2` must contain exactly `dims` scalars of the +// type specified by `quantization`. Passing mismatched types is undefined behavior. +func DistanceUnsafe(vec1 unsafe.Pointer, vec2 unsafe.Pointer, dims uint, metric Metric, quantization Quantization) (float32, error) { var errorMessage *C.char dist := C.usearch_distance(vec1, vec2, quantization.CValue(), C.size_t(dims), metric.CValue(), (*C.usearch_error_t)(&errorMessage)) @@ -481,8 +488,11 @@ func (index *Index) Search(query []float32, limit uint) (keys []Key, distances [ return keys, distances, nil } -// Search performs k-Approximate Nearest Neighbors Search for the closest vectors to the query vector. -func (index *Index) SearchWithPointer(query unsafe.Pointer, limit uint) (keys []Key, distances []float32, err error) { +// SearchUnsafe performs k-Approximate Nearest Neighbors Search using an unsafe pointer. +// +// The memory behind `query` must contain exactly `Dimensions()` scalars of the +// type matching the index quantization. Passing mismatched types is undefined behavior. +func (index *Index) SearchUnsafe(query unsafe.Pointer, limit uint) (keys []Key, distances []float32, err error) { if index.opaque_handle == nil { panic("Index is uninitialized") } @@ -526,8 +536,11 @@ func ExactSearch(dataset []float32, queries []float32, dataset_size uint, querie return keys, distances, nil } -// ExactSearch is a multithreaded exact nearest neighbors search -func ExactSearchWithPointer(dataset unsafe.Pointer, queries unsafe.Pointer, dataset_size uint, queries_size uint, +// ExactSearchUnsafe is a multithreaded exact nearest neighbors search using unsafe pointers. +// +// `dataset` and `queries` must point to contiguous memory with the element type specified by `quantization`. +// Stride and sizes are in bytes and elements, respectively. +func ExactSearchUnsafe(dataset unsafe.Pointer, queries unsafe.Pointer, dataset_size uint, queries_size uint, dataset_stride uint, queries_stride uint, dims uint, metric Metric, quantization Quantization, count uint, threads uint, keys_stride uint, distances_stride uint) (keys []Key, distances []float32, err error) { @@ -546,6 +559,75 @@ func ExactSearchWithPointer(dataset unsafe.Pointer, queries unsafe.Pointer, data return keys, distances, nil } +// Convenience I8 helpers + +// AddI8 adds a vector provided as int8 slice. The slice length must equal Dimensions(). +func (index *Index) AddI8(key Key, vec []int8) error { + if index.opaque_handle == nil { + panic("Index is uninitialized") + } + if len(vec) == 0 { + return errors.New("empty vector") + } + var errorMessage *C.char + C.usearch_add((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), unsafe.Pointer(&vec[0]), C.usearch_scalar_i8_k, (*C.usearch_error_t)(&errorMessage)) + if errorMessage != nil { + return errors.New(C.GoString(errorMessage)) + } + return nil +} + +// SearchI8 searches with a query provided as int8 slice. The slice length must equal Dimensions(). +func (index *Index) SearchI8(query []int8, limit uint) (keys []Key, distances []float32, err error) { + if index.opaque_handle == nil { + panic("Index is uninitialized") + } + if len(query) == 0 { + return nil, nil, errors.New("empty query") + } + keys = make([]Key, limit) + distances = make([]float32, limit) + var errorMessage *C.char + count := uint(C.usearch_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&query[0]), C.usearch_scalar_i8_k, (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) + if errorMessage != nil { + return nil, nil, errors.New(C.GoString(errorMessage)) + } + keys = keys[:count] + distances = distances[:count] + return keys, distances, nil +} + +// DistanceI8 computes distance between two int8 vectors. +func DistanceI8(vec1 []int8, vec2 []int8, dims uint, metric Metric) (float32, error) { + if len(vec1) == 0 || len(vec2) == 0 { + return 0, errors.New("empty vectors") + } + var errorMessage *C.char + dist := C.usearch_distance(unsafe.Pointer(&vec1[0]), unsafe.Pointer(&vec2[0]), C.usearch_scalar_i8_k, C.size_t(dims), metric.CValue(), (*C.usearch_error_t)(&errorMessage)) + if errorMessage != nil { + return 0, errors.New(C.GoString(errorMessage)) + } + return float32(dist), nil +} + +// ExactSearchI8 performs exact search on int8 matrices. +func ExactSearchI8(dataset []int8, queries []int8, dataset_size uint, queries_size uint, + dataset_stride uint, queries_stride uint, dims uint, metric Metric, + count uint, threads uint, keys_stride uint, distances_stride uint) (keys []Key, distances []float32, err error) { + keys = make([]Key, count) + distances = make([]float32, count) + var errorMessage *C.char + C.usearch_exact_search(unsafe.Pointer(&dataset[0]), C.size_t(dataset_size), C.size_t(dataset_stride), unsafe.Pointer(&queries[0]), C.size_t(queries_size), C.size_t(queries_stride), + C.usearch_scalar_i8_k, C.size_t(dims), metric.CValue(), C.size_t(count), C.size_t(threads), + (*C.usearch_key_t)(&keys[0]), C.size_t(keys_stride), (*C.usearch_distance_t)(&distances[0]), C.size_t(distances_stride), (*C.usearch_error_t)(&errorMessage)) + if errorMessage != nil { + return nil, nil, errors.New(C.GoString(errorMessage)) + } + keys = keys[:count] + distances = distances[:count] + return keys, distances, nil +} + // Save saves the index to a specified buffer. func (index *Index) SaveBuffer(buf []byte, buffer_size uint) error { if index.opaque_handle == nil { diff --git a/golang/lib_test.go b/golang/lib_test.go index 13131c9c7..1a7692693 100644 --- a/golang/lib_test.go +++ b/golang/lib_test.go @@ -126,7 +126,7 @@ func TestUSearch(t *testing.T) { vec[0] = 40.0 vec[1] = 2.0 - err = ind.AddWithPointer(42, unsafe.Pointer(&vec[0])) + err = ind.AddUnsafe(42, unsafe.Pointer(&vec[0])) if err != nil { t.Fatalf("Failed to insert: %s", err) } @@ -205,12 +205,12 @@ func TestUSearch(t *testing.T) { vec[1] = 2.0 ptr := unsafe.Pointer(&vec[0]) - err = ind.AddWithPointer(42, ptr) + err = ind.AddUnsafe(42, ptr) if err != nil { t.Fatalf("Failed to insert: %s", err) } - keys, distances, err := ind.SearchWithPointer(ptr, 10) + keys, distances, err := ind.SearchUnsafe(ptr, 10) if err != nil { t.Fatalf("Failed to search: %s", err) } @@ -334,7 +334,7 @@ func TestUSearch(t *testing.T) { vec := make([]float32, dim) for i := uint(0); i < dim; i++ { vec[i] = float32(i) + 0.2 - err = ind.AddWithPointer(uint64(i), unsafe.Pointer(&vec[0])) + err = ind.AddUnsafe(uint64(i), unsafe.Pointer(&vec[0])) if err != nil { t.Fatalf("Failed to insert: %s", err) } @@ -421,7 +421,7 @@ func TestUSearchDataType(t *testing.T) { vec := make([]float32, dim) for i := uint(0); i < dim; i++ { vec[i] = float32(i) + 0.2 - err = ind.AddWithPointer(uint64(i), unsafe.Pointer(&vec[0])) + err = ind.AddUnsafe(uint64(i), unsafe.Pointer(&vec[0])) if err != nil { t.Fatalf("Failed to insert: %s", err) } @@ -476,7 +476,7 @@ func TestUSearchDataType(t *testing.T) { } // TODO: Check file save/load/metadata - keys, distances, err := ind.SearchWithPointer(unsafe.Pointer(&vec[0]), dim) + keys, distances, err := ind.SearchUnsafe(unsafe.Pointer(&vec[0]), dim) if err != nil { t.Fatalf("Failed to search: %s", err) } @@ -515,7 +515,7 @@ func TestUSearchDataType(t *testing.T) { vec := make([]float64, dim) for i := uint(0); i < dim; i++ { vec[i] = float64(i) + 0.2 - err = ind.AddWithPointer(uint64(i), unsafe.Pointer(&vec[0])) + err = ind.AddUnsafe(uint64(i), unsafe.Pointer(&vec[0])) if err != nil { t.Fatalf("Failed to insert: %s", err) } @@ -570,7 +570,7 @@ func TestUSearchDataType(t *testing.T) { } // TODO: Check file save/load/metadata - keys, distances, err := ind.SearchWithPointer(unsafe.Pointer(&vec[0]), dim) + keys, distances, err := ind.SearchUnsafe(unsafe.Pointer(&vec[0]), dim) if err != nil { t.Fatalf("Failed to search: %s", err) } From e70b57f9394b75f9e98636a39dff5fcc7f6ab508 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 3 Sep 2025 11:35:03 +0000 Subject: [PATCH 04/10] Make: `reinstall cmake` on newer macOS runners --- .github/workflows/prerelease.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/prerelease.yml b/.github/workflows/prerelease.yml index 86e04e8a2..167ce3da4 100644 --- a/.github/workflows/prerelease.yml +++ b/.github/workflows/prerelease.yml @@ -244,7 +244,7 @@ jobs: - name: Build C/C++ run: | brew update - brew install cmake + brew reinstall cmake cmake -B build_artifacts -D CMAKE_BUILD_TYPE=RelWithDebInfo -D USEARCH_BUILD_TEST_CPP=1 -D USEARCH_BUILD_TEST_C=1 -D USEARCH_BUILD_LIB_C=1 -D USEARCH_BUILD_SQLITE=1 cmake --build build_artifacts --config RelWithDebInfo - name: Test C++ From 07c890d0d3aa7892db66481b6bacc4a821ef8c63 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 3 Sep 2025 11:35:50 +0000 Subject: [PATCH 05/10] Fix: Avoid default `=0` outside of the implementation --- javascript/usearch.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/javascript/usearch.ts b/javascript/usearch.ts index 63dba1ca9..109a07bd3 100644 --- a/javascript/usearch.ts +++ b/javascript/usearch.ts @@ -386,8 +386,8 @@ export class Index { * @throws Will throw an error if `k` is not a positive integer or if the size of the vectors is not a multiple of dimensions. * @throws Will throw an error if `vectors` is not a valid input type (TypedArray or an array of TypedArray) or if its flattened size is not a multiple of dimensions. */ - search(vectors: Vector, k: number, threads: number = 0): Matches; - search(vectors: Matrix, k: number, threads: number = 0): BatchMatches; + search(vectors: Vector, k: number, threads: number): Matches; + search(vectors: Matrix, k: number, threads: number): BatchMatches; search(vectors: VectorOrMatrix, k: number, threads: number = 0): Matches | BatchMatches { if ((!Number.isNaN(k) && typeof k !== "number") || k <= 0) { throw new Error( From 584e2015d5616197234a397215a3b5695b82c00f Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 3 Sep 2025 11:45:32 +0000 Subject: [PATCH 06/10] Improve: Pass Go static checks --- CONTRIBUTING.md | 9 + golang/lib.go | 504 +++++++++++++++++++++++++++++++++++++----------- 2 files changed, 400 insertions(+), 113 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index abbd85873..c4cfc937f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -397,6 +397,15 @@ cp c/usearch.h golang/ # to make the header available to Go cd golang && LD_LIBRARY_PATH=. go test -v ; cd .. ``` +For static checks: + +```sh +cd golang +go vet ./... +staticcheck ./... # if installed +golangci-lint run # if installed +``` + ## Java USearch provides Java bindings as a fat-JAR published with prebuilt JNI libraries via GitHub releases. Installation via Maven Central is deprecated; prefer downloading the fat-JAR from the latest GitHub release. The compilation settings are controlled by `build.gradle` and are independent from CMake used for C/C++ builds. diff --git a/golang/lib.go b/golang/lib.go index cb9f0d2c1..e7b1ddcfe 100644 --- a/golang/lib.go +++ b/golang/lib.go @@ -1,7 +1,27 @@ +// Package usearch provides Go bindings for the USearch library, a high-performance +// approximate nearest neighbor search implementation. +// +// Basic usage: +// +// conf := usearch.DefaultConfig(128) // 128-dimensional vectors +// index, err := usearch.NewIndex(conf) +// if err != nil { +// log.Fatal(err) +// } +// defer index.Destroy() +// +// // Add vectors +// vec := make([]float32, 128) +// err = index.Add(42, vec) +// +// // Search +// keys, distances, err := index.Search(vec, 10) package usearch import ( "errors" + "fmt" + "runtime" "unsafe" ) @@ -12,22 +32,37 @@ import ( */ import "C" -// Key represents the type for keys used in the USearch index. +// Key represents a unique identifier for vectors in the index. +// Keys must be unique within an index; adding a vector with an existing +// key will update the associated vector. type Key = uint64 -// Metric represents the type for different metrics used in distance calculations. +// Metric defines the distance calculation method used for comparing vectors. +// Different metrics are suitable for different use cases: +// - Cosine: Normalized dot product, ideal for text embeddings +// - L2sq: Squared Euclidean distance, for spatial data +// - InnerProduct: Dot product, for recommendation systems type Metric uint8 // Different metric kinds supported by the USearch library. const ( + // InnerProduct computes the dot product between vectors InnerProduct Metric = iota + // Cosine computes cosine similarity (normalized dot product) Cosine + // L2sq computes squared Euclidean distance L2sq + // Haversine computes great-circle distance for geographic coordinates Haversine + // Divergence computes Jensen-Shannon divergence Divergence + // Pearson computes Pearson correlation coefficient Pearson + // Hamming computes Hamming distance for binary data Hamming + // Tanimoto computes Tanimoto/Jaccard coefficient Tanimoto + // Sorensen computes Sørensen-Dice coefficient Sorensen ) @@ -53,7 +88,7 @@ func (m Metric) String() string { case Sorensen: return "sorensen" default: - panic("Unknown metric") + panic("unknown metric") } } func (m Metric) CValue() C.usearch_metric_kind_t { @@ -78,16 +113,23 @@ func (m Metric) CValue() C.usearch_metric_kind_t { return C.usearch_metric_l2sq_k } -// Quantization represents the type for different scalar kinds used in quantization. +// Quantization represents the scalar type used for storing vectors in the index. +// Different quantization types offer different trade-offs between memory usage and precision. type Quantization uint8 // Different quantization kinds supported by the USearch library. const ( + // F32 uses 32-bit floating point (standard precision) F32 Quantization = iota + // BF16 uses brain floating-point format (16-bit) BF16 + // F16 uses half-precision floating point (16-bit) F16 + // F64 uses 64-bit double precision floating point F64 + // I8 uses 8-bit signed integers (quantized) I8 + // B1 uses binary representation (1-bit per dimension) B1 ) @@ -107,7 +149,7 @@ func (a Quantization) String() string { case B1: return "B1" default: - panic("Unknown quantization") + panic("unknown quantization") } } @@ -131,21 +173,31 @@ func (a Quantization) CValue() C.usearch_scalar_kind_t { } // IndexConfig represents the configuration options for initializing a USearch index. +// +// Zero values for optional parameters (Connectivity, ExpansionAdd, ExpansionSearch) +// will be replaced with optimal defaults by the C library. type IndexConfig struct { Quantization Quantization // The scalar kind used for quantization of vector data during indexing. Metric Metric // The metric kind used for distance calculation between vectors. Dimensions uint // The number of dimensions in the vectors to be indexed. - Connectivity uint // The optional connectivity parameter that limits connections-per-node in the graph. - ExpansionAdd uint // The optional expansion factor used for index construction when adding vectors. - ExpansionSearch uint // The optional expansion factor used for index construction during search operations. + Connectivity uint // The optional connectivity parameter that limits connections-per-node in the graph (0 for default). + ExpansionAdd uint // The optional expansion factor used for index construction when adding vectors (0 for default). + ExpansionSearch uint // The optional expansion factor used for index construction during search operations (0 for default). Multi bool // Indicates whether multiple vectors can map to the same key. } // DefaultConfig returns an IndexConfig with default values for the specified number of dimensions. +// Uses Cosine metric and F32 quantization by default. +// +// Example: +// +// config := usearch.DefaultConfig(128) // 128-dimensional vectors +// index, err := usearch.NewIndex(config) func DefaultConfig(dimensions uint) IndexConfig { c := IndexConfig{} c.Dimensions = dimensions c.Metric = Cosine + c.Quantization = F32 // Zeros will be replaced by the underlying C implementation c.Connectivity = 0 c.ExpansionAdd = 0 @@ -154,14 +206,32 @@ func DefaultConfig(dimensions uint) IndexConfig { return c } -// Index represents a USearch index. +// Index represents a USearch approximate nearest neighbor index. +// +// The index must be properly initialized with NewIndex() and destroyed +// with Destroy() when no longer needed to free resources. type Index struct { opaque_handle *C.void config IndexConfig } -// NewIndex initializes a new instance of the index with the specified configuration. +// NewIndex creates a new approximate nearest neighbor index with the specified configuration. +// +// The index must be destroyed with Destroy() when no longer needed. +// +// Example: +// +// config := usearch.DefaultConfig(128) +// config.Metric = usearch.L2sq +// index, err := usearch.NewIndex(config) +// if err != nil { +// log.Fatal(err) +// } +// defer index.Destroy() func NewIndex(conf IndexConfig) (index *Index, err error) { + if conf.Dimensions == 0 { + return nil, errors.New("dimensions must be greater than 0") + } index = &Index{config: conf} conf = index.config @@ -326,7 +396,7 @@ func (index *Index) HardwareAcceleration() (string, error) { // Destroy frees the resources associated with the index. func (index *Index) Destroy() error { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") } var errorMessage *C.char @@ -342,7 +412,7 @@ func (index *Index) Destroy() error { // Reserve reserves memory for a specified number of incoming vectors. func (index *Index) Reserve(capacity uint) error { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") } var errorMessage *C.char @@ -353,28 +423,53 @@ func (index *Index) Reserve(capacity uint) error { return nil } -// Add adds a vector with a specified key to the index. +// Add inserts or updates a vector in the index with the specified key. +// The vector must have exactly Dimensions() elements. +// If a vector with this key already exists, it will be replaced. +// +// Returns an error if: +// - The index is not initialized +// - The vector is empty or has wrong dimensions +// - The underlying C library reports an error func (index *Index) Add(key Key, vec []float32) error { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") + } + + if len(vec) == 0 { + return errors.New("vector cannot be empty") + } + if uint(len(vec)) != index.config.Dimensions { + return fmt.Errorf("vector dimension mismatch: got %d, expected %d", len(vec), index.config.Dimensions) } var errorMessage *C.char C.usearch_add((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), unsafe.Pointer(&vec[0]), C.usearch_scalar_f32_k, (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(vec) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } return nil } -// AddUnsafe adds a vector with a specified key using an unsafe pointer. +// AddUnsafe adds a vector using a raw pointer, bypassing Go's type safety. +// +// SAFETY REQUIREMENTS: +// - vec must not be nil +// - Memory at vec must contain exactly Dimensions() scalars +// - Scalar type must match index.config.Quantization +// - Memory must remain valid for the duration of the call +// - Caller is responsible for ensuring correct data layout // -// The memory behind `vec` must contain exactly `Dimensions()` scalars of the -// type matching the index quantization (e.g., F32/F64/I8/B1/F16/BF16). -// Passing a pointer to data of a different type is undefined behavior. +// Use Add() or AddI8() instead unless you need maximum performance +// and understand the safety implications. func (index *Index) AddUnsafe(key Key, vec unsafe.Pointer) error { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") + } + + if vec == nil { + return errors.New("vector pointer cannot be nil") } var errorMessage *C.char @@ -388,7 +483,7 @@ func (index *Index) AddUnsafe(key Key, vec unsafe.Pointer) error { // Remove removes the vector associated with the given key from the index. func (index *Index) Remove(key Key) error { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") } var errorMessage *C.char @@ -402,7 +497,7 @@ func (index *Index) Remove(key Key) error { // Contains checks if the index contains a vector with a specific key. func (index *Index) Contains(key Key) (found bool, err error) { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") } var errorMessage *C.char @@ -414,14 +509,20 @@ func (index *Index) Contains(key Key) (found bool, err error) { } // Get retrieves the vectors associated with the given key from the index. -func (index *Index) Get(key Key, count uint) (vectors []float32, err error) { +// Returns nil if the key is not found. +func (index *Index) Get(key Key, maxCount uint) (vectors []float32, err error) { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") + } + + if maxCount == 0 { + return nil, nil } - vectors = make([]float32, index.config.Dimensions*count) + vectors = make([]float32, index.config.Dimensions*maxCount) var errorMessage *C.char - found := uint(C.usearch_get((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), (C.size_t)(count), unsafe.Pointer(&vectors[0]), C.usearch_scalar_f32_k, (*C.usearch_error_t)(&errorMessage))) + found := uint(C.usearch_get((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), (C.size_t)(maxCount), unsafe.Pointer(&vectors[0]), C.usearch_scalar_f32_k, (*C.usearch_error_t)(&errorMessage))) + runtime.KeepAlive(vectors) if errorMessage != nil { return nil, errors.New(C.GoString(errorMessage)) } @@ -441,11 +542,20 @@ func (index *Index) Rename(from Key, to Key) error { return nil } -// Distance computes the distance between two vectors -func Distance(vec1 []float32, vec2 []float32, dims uint, metric Metric) (float32, error) { +// Distance computes the distance between two float32 vectors using the specified metric. +// Both vectors must have exactly 'dims' elements. +func Distance(vec1 []float32, vec2 []float32, vectorDimensions uint, metric Metric) (float32, error) { + if len(vec1) == 0 || len(vec2) == 0 { + return 0, errors.New("vectors cannot be empty") + } + if uint(len(vec1)) < vectorDimensions || uint(len(vec2)) < vectorDimensions { + return 0, fmt.Errorf("vectors too short for specified dimensions: need %d elements", vectorDimensions) + } var errorMessage *C.char - dist := C.usearch_distance(unsafe.Pointer(&vec1[0]), unsafe.Pointer(&vec2[0]), C.usearch_scalar_f32_k, C.size_t(dims), metric.CValue(), (*C.usearch_error_t)(&errorMessage)) + dist := C.usearch_distance(unsafe.Pointer(&vec1[0]), unsafe.Pointer(&vec2[0]), C.usearch_scalar_f32_k, C.size_t(vectorDimensions), metric.CValue(), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(vec1) + runtime.KeepAlive(vec2) if errorMessage != nil { return 0, errors.New(C.GoString(errorMessage)) } @@ -454,226 +564,389 @@ func Distance(vec1 []float32, vec2 []float32, dims uint, metric Metric) (float32 // DistanceUnsafe computes the distance between two vectors using unsafe pointers. // -// The memory behind `vec1` and `vec2` must contain exactly `dims` scalars of the -// type specified by `quantization`. Passing mismatched types is undefined behavior. -func DistanceUnsafe(vec1 unsafe.Pointer, vec2 unsafe.Pointer, dims uint, metric Metric, quantization Quantization) (float32, error) { +// SAFETY REQUIREMENTS: +// - vec1 and vec2 must not be nil +// - Memory at both pointers must contain exactly 'dims' scalars +// - Scalar type must match the specified quantization +// - Memory must remain valid for the duration of the call +func DistanceUnsafe(vec1 unsafe.Pointer, vec2 unsafe.Pointer, vectorDimensions uint, metric Metric, quantization Quantization) (float32, error) { + if vec1 == nil || vec2 == nil { + return 0, errors.New("vector pointers cannot be nil") + } + if vectorDimensions == 0 { + return 0, errors.New("dimensions must be greater than zero") + } var errorMessage *C.char - dist := C.usearch_distance(vec1, vec2, quantization.CValue(), C.size_t(dims), metric.CValue(), (*C.usearch_error_t)(&errorMessage)) + dist := C.usearch_distance(vec1, vec2, quantization.CValue(), C.size_t(vectorDimensions), metric.CValue(), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return 0, errors.New(C.GoString(errorMessage)) } return float32(dist), nil } -// Search performs k-Approximate Nearest Neighbors Search for the closest vectors to the query vector. +// Search finds the k nearest neighbors to the query vector. +// +// Parameters: +// - query: Must have exactly Dimensions() elements +// - limit: Maximum number of results to return +// +// Returns: +// - keys: IDs of the nearest vectors (up to limit) +// - distances: Distance to each result (same length as keys) +// - err: Error if query is invalid or search fails +// +// The actual number of results may be less than limit if the index +// contains fewer vectors. func (index *Index) Search(query []float32, limit uint) (keys []Key, distances []float32, err error) { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") } - if len(query) != int(index.config.Dimensions) { - return nil, nil, errors.New("Number of dimensions doesn't match!") + + if len(query) == 0 { + return nil, nil, errors.New("query vector cannot be empty") + } + if uint(len(query)) != index.config.Dimensions { + return nil, nil, fmt.Errorf("query dimension mismatch: got %d, expected %d", len(query), index.config.Dimensions) + } + if limit == 0 { + return []Key{}, []float32{}, nil } keys = make([]Key, limit) distances = make([]float32, limit) var errorMessage *C.char - count := uint(C.usearch_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&query[0]), C.usearch_scalar_f32_k, (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) + resultCount := uint(C.usearch_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&query[0]), C.usearch_scalar_f32_k, (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) + runtime.KeepAlive(query) + runtime.KeepAlive(keys) + runtime.KeepAlive(distances) if errorMessage != nil { return nil, nil, errors.New(C.GoString(errorMessage)) } - keys = keys[:count] - distances = distances[:count] + keys = keys[:resultCount] + distances = distances[:resultCount] return keys, distances, nil } // SearchUnsafe performs k-Approximate Nearest Neighbors Search using an unsafe pointer. // -// The memory behind `query` must contain exactly `Dimensions()` scalars of the -// type matching the index quantization. Passing mismatched types is undefined behavior. +// SAFETY REQUIREMENTS: +// - query must not be nil +// - Memory at query must contain exactly Dimensions() scalars +// - Scalar type must match index.config.Quantization +// - Memory must remain valid for the duration of the call +// +// Use Search() or SearchI8() instead unless you need maximum performance +// and understand the safety implications. func (index *Index) SearchUnsafe(query unsafe.Pointer, limit uint) (keys []Key, distances []float32, err error) { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") + } + + if query == nil { + return nil, nil, errors.New("query pointer cannot be nil") + } + if limit == 0 { + return []Key{}, []float32{}, nil } keys = make([]Key, limit) distances = make([]float32, limit) var errorMessage *C.char - count := uint(C.usearch_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), query, index.config.Quantization.CValue(), (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) + resultCount := uint(C.usearch_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), query, index.config.Quantization.CValue(), (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) + runtime.KeepAlive(keys) + runtime.KeepAlive(distances) if errorMessage != nil { return nil, nil, errors.New(C.GoString(errorMessage)) } - keys = keys[:count] - distances = distances[:count] + keys = keys[:resultCount] + distances = distances[:resultCount] return keys, distances, nil } -// ExactSearch is a multithreaded exact nearest neighbors search -func ExactSearch(dataset []float32, queries []float32, dataset_size uint, queries_size uint, - dataset_stride uint, queries_stride uint, dims uint, metric Metric, - count uint, threads uint, keys_stride uint, distances_stride uint) (keys []Key, distances []float32, err error) { - if (len(dataset) % int(dims)) != 0 { - return nil, nil, errors.New("Dataset length must be a multiple of the dimensions") - } - if (len(queries) % int(dims)) != 0 { - return nil, nil, errors.New("Queries length must be a multiple of the dimensions") - } - - keys = make([]Key, count) - distances = make([]float32, count) +// ExactSearch performs multithreaded exact nearest neighbors search. +// Unlike the index-based search, this computes distances to all vectors in the dataset. +// +// Parameters: +// - dataset: Flattened array of vectors (datasetSize x vectorDimensions) +// - queries: Flattened array of query vectors (queryCount x vectorDimensions) +// - datasetSize, queryCount: Number of vectors in dataset and queries +// - datasetStride, queryStride: Memory stride in bytes between consecutive vectors (use vectorDimensions * sizeof(float32) for contiguous data) +// - vectorDimensions: Number of dimensions per vector +// - metric: Distance metric to use +// - maxResults: Maximum results per query +// - numThreads: Number of threads to use (0 for auto-detection) +func ExactSearch(dataset []float32, queries []float32, datasetSize uint, queryCount uint, + datasetStride uint, queryStride uint, vectorDimensions uint, metric Metric, + maxResults uint, numThreads uint, resultKeysStride uint, resultDistancesStride uint) (keys []Key, distances []float32, err error) { + + if len(dataset) == 0 || len(queries) == 0 { + return nil, nil, errors.New("dataset and queries cannot be empty") + } + if vectorDimensions == 0 { + return nil, nil, errors.New("dimensions must be greater than zero") + } + if (len(dataset) % int(vectorDimensions)) != 0 { + return nil, nil, errors.New("dataset length must be a multiple of the dimensions") + } + if (len(queries) % int(vectorDimensions)) != 0 { + return nil, nil, errors.New("queries length must be a multiple of the dimensions") + } + + keys = make([]Key, maxResults) + distances = make([]float32, maxResults) var errorMessage *C.char - C.usearch_exact_search(unsafe.Pointer(&dataset[0]), C.size_t(dataset_size), C.size_t(dataset_stride), unsafe.Pointer(&queries[0]), C.size_t(queries_size), C.size_t(queries_stride), - C.usearch_scalar_f32_k, C.size_t(dims), metric.CValue(), C.size_t(count), C.size_t(threads), - (*C.usearch_key_t)(&keys[0]), C.size_t(keys_stride), (*C.usearch_distance_t)(&distances[0]), C.size_t(distances_stride), (*C.usearch_error_t)(&errorMessage)) + C.usearch_exact_search(unsafe.Pointer(&dataset[0]), C.size_t(datasetSize), C.size_t(datasetStride), unsafe.Pointer(&queries[0]), C.size_t(queryCount), C.size_t(queryStride), + C.usearch_scalar_f32_k, C.size_t(vectorDimensions), metric.CValue(), C.size_t(maxResults), C.size_t(numThreads), + (*C.usearch_key_t)(&keys[0]), C.size_t(resultKeysStride), (*C.usearch_distance_t)(&distances[0]), C.size_t(resultDistancesStride), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(dataset) + runtime.KeepAlive(queries) + runtime.KeepAlive(keys) + runtime.KeepAlive(distances) if errorMessage != nil { return nil, nil, errors.New(C.GoString(errorMessage)) } - keys = keys[:count] - distances = distances[:count] + keys = keys[:maxResults] + distances = distances[:maxResults] return keys, distances, nil } -// ExactSearchUnsafe is a multithreaded exact nearest neighbors search using unsafe pointers. +// ExactSearchUnsafe performs multithreaded exact nearest neighbors search using unsafe pointers. // -// `dataset` and `queries` must point to contiguous memory with the element type specified by `quantization`. -// Stride and sizes are in bytes and elements, respectively. -func ExactSearchUnsafe(dataset unsafe.Pointer, queries unsafe.Pointer, dataset_size uint, queries_size uint, - dataset_stride uint, queries_stride uint, dims uint, metric Metric, quantization Quantization, - count uint, threads uint, keys_stride uint, distances_stride uint) (keys []Key, distances []float32, err error) { +// SAFETY REQUIREMENTS: +// - dataset and queries must not be nil +// - Memory must contain contiguous vectors of the specified quantization type +// - dataset must contain datasetSize vectors of vectorDimensions elements each +// - queries must contain queryCount vectors of vectorDimensions elements each +// - Memory must remain valid for the duration of the call +// +// Stride parameters specify memory offset in bytes between consecutive vectors. +// For contiguous data, use vectorDimensions * sizeof(element_type). +func ExactSearchUnsafe(dataset unsafe.Pointer, queries unsafe.Pointer, datasetSize uint, queryCount uint, + datasetStride uint, queryStride uint, vectorDimensions uint, metric Metric, quantization Quantization, + maxResults uint, numThreads uint, resultKeysStride uint, resultDistancesStride uint) (keys []Key, distances []float32, err error) { + + if dataset == nil || queries == nil { + return nil, nil, errors.New("dataset and queries pointers cannot be nil") + } + if vectorDimensions == 0 || datasetSize == 0 || queryCount == 0 { + return nil, nil, errors.New("dimensions and sizes must be greater than zero") + } - keys = make([]Key, count) - distances = make([]float32, count) + keys = make([]Key, maxResults) + distances = make([]float32, maxResults) var errorMessage *C.char - C.usearch_exact_search(dataset, C.size_t(dataset_size), C.size_t(dataset_stride), queries, C.size_t(queries_size), C.size_t(queries_stride), - quantization.CValue(), C.size_t(dims), metric.CValue(), C.size_t(count), C.size_t(threads), - (*C.usearch_key_t)(&keys[0]), C.size_t(keys_stride), (*C.usearch_distance_t)(&distances[0]), C.size_t(distances_stride), (*C.usearch_error_t)(&errorMessage)) + C.usearch_exact_search(dataset, C.size_t(datasetSize), C.size_t(datasetStride), queries, C.size_t(queryCount), C.size_t(queryStride), + quantization.CValue(), C.size_t(vectorDimensions), metric.CValue(), C.size_t(maxResults), C.size_t(numThreads), + (*C.usearch_key_t)(&keys[0]), C.size_t(resultKeysStride), (*C.usearch_distance_t)(&distances[0]), C.size_t(resultDistancesStride), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(keys) + runtime.KeepAlive(distances) if errorMessage != nil { return nil, nil, errors.New(C.GoString(errorMessage)) } - keys = keys[:count] - distances = distances[:count] + keys = keys[:maxResults] + distances = distances[:maxResults] return keys, distances, nil } // Convenience I8 helpers -// AddI8 adds a vector provided as int8 slice. The slice length must equal Dimensions(). +// AddI8 adds an int8 vector to the index. +// The vector must have exactly Dimensions() elements. +// +// This is a convenience method for indexes using I8 quantization. func (index *Index) AddI8(key Key, vec []int8) error { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") } if len(vec) == 0 { - return errors.New("empty vector") + return errors.New("vector cannot be empty") + } + if uint(len(vec)) != index.config.Dimensions { + return fmt.Errorf("vector dimension mismatch: got %d, expected %d", len(vec), index.config.Dimensions) } var errorMessage *C.char C.usearch_add((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), unsafe.Pointer(&vec[0]), C.usearch_scalar_i8_k, (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(vec) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } return nil } -// SearchI8 searches with a query provided as int8 slice. The slice length must equal Dimensions(). +// SearchI8 searches for nearest neighbors using an int8 query vector. +// The query must have exactly Dimensions() elements. +// +// This is a convenience method for indexes using I8 quantization. func (index *Index) SearchI8(query []int8, limit uint) (keys []Key, distances []float32, err error) { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") } if len(query) == 0 { - return nil, nil, errors.New("empty query") + return nil, nil, errors.New("query vector cannot be empty") + } + if uint(len(query)) != index.config.Dimensions { + return nil, nil, fmt.Errorf("query dimension mismatch: got %d, expected %d", len(query), index.config.Dimensions) + } + if limit == 0 { + return []Key{}, []float32{}, nil } keys = make([]Key, limit) distances = make([]float32, limit) var errorMessage *C.char - count := uint(C.usearch_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&query[0]), C.usearch_scalar_i8_k, (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) + resultCount := uint(C.usearch_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&query[0]), C.usearch_scalar_i8_k, (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) + runtime.KeepAlive(query) + runtime.KeepAlive(keys) + runtime.KeepAlive(distances) if errorMessage != nil { return nil, nil, errors.New(C.GoString(errorMessage)) } - keys = keys[:count] - distances = distances[:count] + keys = keys[:resultCount] + distances = distances[:resultCount] return keys, distances, nil } -// DistanceI8 computes distance between two int8 vectors. -func DistanceI8(vec1 []int8, vec2 []int8, dims uint, metric Metric) (float32, error) { +// DistanceI8 computes the distance between two int8 vectors. +// +// Example: +// +// vec1 := []int8{1, 2, 3, 4} +// vec2 := []int8{5, 6, 7, 8} +// dist, err := usearch.DistanceI8(vec1, vec2, 4, usearch.L2sq) +func DistanceI8(vec1 []int8, vec2 []int8, vectorDimensions uint, metric Metric) (float32, error) { if len(vec1) == 0 || len(vec2) == 0 { - return 0, errors.New("empty vectors") + return 0, errors.New("vectors cannot be empty") + } + if uint(len(vec1)) < vectorDimensions || uint(len(vec2)) < vectorDimensions { + return 0, fmt.Errorf("vectors too short for specified dimensions: need %d elements", vectorDimensions) } var errorMessage *C.char - dist := C.usearch_distance(unsafe.Pointer(&vec1[0]), unsafe.Pointer(&vec2[0]), C.usearch_scalar_i8_k, C.size_t(dims), metric.CValue(), (*C.usearch_error_t)(&errorMessage)) + dist := C.usearch_distance(unsafe.Pointer(&vec1[0]), unsafe.Pointer(&vec2[0]), C.usearch_scalar_i8_k, C.size_t(vectorDimensions), metric.CValue(), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(vec1) + runtime.KeepAlive(vec2) if errorMessage != nil { return 0, errors.New(C.GoString(errorMessage)) } return float32(dist), nil } -// ExactSearchI8 performs exact search on int8 matrices. -func ExactSearchI8(dataset []int8, queries []int8, dataset_size uint, queries_size uint, - dataset_stride uint, queries_stride uint, dims uint, metric Metric, - count uint, threads uint, keys_stride uint, distances_stride uint) (keys []Key, distances []float32, err error) { - keys = make([]Key, count) - distances = make([]float32, count) +// ExactSearchI8 performs exact nearest neighbors search on int8 vectors. +// This computes distances to all vectors in the dataset without using an index. +// +// Stride parameters specify memory offset in bytes between consecutive vectors. +// For contiguous int8 data, use vectorDimensions * 1 byte. +func ExactSearchI8(dataset []int8, queries []int8, datasetSize uint, queryCount uint, + datasetStride uint, queryStride uint, vectorDimensions uint, metric Metric, + maxResults uint, numThreads uint, resultKeysStride uint, resultDistancesStride uint) (keys []Key, distances []float32, err error) { + + if len(dataset) == 0 || len(queries) == 0 { + return nil, nil, errors.New("dataset and queries cannot be empty") + } + if vectorDimensions == 0 { + return nil, nil, errors.New("dimensions must be greater than zero") + } + + keys = make([]Key, maxResults) + distances = make([]float32, maxResults) var errorMessage *C.char - C.usearch_exact_search(unsafe.Pointer(&dataset[0]), C.size_t(dataset_size), C.size_t(dataset_stride), unsafe.Pointer(&queries[0]), C.size_t(queries_size), C.size_t(queries_stride), - C.usearch_scalar_i8_k, C.size_t(dims), metric.CValue(), C.size_t(count), C.size_t(threads), - (*C.usearch_key_t)(&keys[0]), C.size_t(keys_stride), (*C.usearch_distance_t)(&distances[0]), C.size_t(distances_stride), (*C.usearch_error_t)(&errorMessage)) + C.usearch_exact_search(unsafe.Pointer(&dataset[0]), C.size_t(datasetSize), C.size_t(datasetStride), unsafe.Pointer(&queries[0]), C.size_t(queryCount), C.size_t(queryStride), + C.usearch_scalar_i8_k, C.size_t(vectorDimensions), metric.CValue(), C.size_t(maxResults), C.size_t(numThreads), + (*C.usearch_key_t)(&keys[0]), C.size_t(resultKeysStride), (*C.usearch_distance_t)(&distances[0]), C.size_t(resultDistancesStride), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(dataset) + runtime.KeepAlive(queries) + runtime.KeepAlive(keys) + runtime.KeepAlive(distances) if errorMessage != nil { return nil, nil, errors.New(C.GoString(errorMessage)) } - keys = keys[:count] - distances = distances[:count] + keys = keys[:maxResults] + distances = distances[:maxResults] return keys, distances, nil } -// Save saves the index to a specified buffer. +// SaveBuffer serializes the index into a byte buffer. +// The buffer must be large enough to hold the serialized index. +// Use SerializedLength() to determine the required buffer size. func (index *Index) SaveBuffer(buf []byte, buffer_size uint) error { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") + } + + if len(buf) == 0 { + return errors.New("buffer cannot be empty") + } + if uint(len(buf)) < buffer_size { + return fmt.Errorf("buffer too small: has %d bytes, need %d", len(buf), buffer_size) } var errorMessage *C.char C.usearch_save_buffer((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&buf[0]), C.size_t(buffer_size), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(buf) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } return nil } -// Loads the index from a specified buffer. +// LoadBuffer loads a serialized index from a byte buffer. +// The buffer must contain a valid serialized index. func (index *Index) LoadBuffer(buf []byte, buffer_size uint) error { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") + } + + if len(buf) == 0 { + return errors.New("buffer cannot be empty") + } + if uint(len(buf)) < buffer_size { + return fmt.Errorf("buffer too small: has %d bytes, need %d", len(buf), buffer_size) } var errorMessage *C.char C.usearch_load_buffer((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&buf[0]), C.size_t(buffer_size), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(buf) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } return nil } -// Loads the index from a specified buffer without copying the data. +// ViewBuffer creates a view of a serialized index without copying the data. +// The buffer must remain valid for the lifetime of the index. +// Changes to the buffer will affect the index. func (index *Index) ViewBuffer(buf []byte, buffer_size uint) error { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") + } + + if len(buf) == 0 { + return errors.New("buffer cannot be empty") + } + if uint(len(buf)) < buffer_size { + return fmt.Errorf("buffer too small: has %d bytes, need %d", len(buf), buffer_size) } var errorMessage *C.char C.usearch_view_buffer((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&buf[0]), C.size_t(buffer_size), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(buf) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } return nil } -// Loads the metadata from a specified buffer. +// MetadataBuffer extracts index configuration metadata from a serialized buffer. +// This can be used to inspect an index before loading it. func MetadataBuffer(buf []byte, buffer_size uint) (c IndexConfig, err error) { - if buf == nil { - panic("Buffer is uninitialized") + if len(buf) == 0 { + return c, errors.New("buffer cannot be empty") + } + if uint(len(buf)) < buffer_size { + return c, fmt.Errorf("buffer too small: has %d bytes, need %d", len(buf), buffer_size) } c = IndexConfig{} @@ -681,6 +954,7 @@ func MetadataBuffer(buf []byte, buffer_size uint) (c IndexConfig, err error) { var errorMessage *C.char C.usearch_metadata_buffer(unsafe.Pointer(&buf[0]), C.size_t(buffer_size), &options, (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(buf) if errorMessage != nil { return c, errors.New(C.GoString(errorMessage)) } @@ -728,8 +1002,12 @@ func MetadataBuffer(buf []byte, buffer_size uint) (c IndexConfig, err error) { return c, nil } -// Metadata loads the metadata from a specified file. +// Metadata loads the index configuration metadata from a file. +// This can be used to inspect an index file before loading it. func Metadata(path string) (c IndexConfig, err error) { + if path == "" { + return c, errors.New("path cannot be empty") + } c_path := C.CString(path) defer C.free(unsafe.Pointer(c_path)) @@ -788,7 +1066,7 @@ func Metadata(path string) (c IndexConfig, err error) { // Save saves the index to a specified file. func (index *Index) Save(path string) error { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") } c_path := C.CString(path) @@ -805,7 +1083,7 @@ func (index *Index) Save(path string) error { // Load loads the index from a specified file. func (index *Index) Load(path string) error { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") } c_path := C.CString(path) @@ -822,7 +1100,7 @@ func (index *Index) Load(path string) error { // View creates a view of the index from a specified file without loading it into memory. func (index *Index) View(path string) error { if index.opaque_handle == nil { - panic("Index is uninitialized") + panic("index is uninitialized") } c_path := C.CString(path) From fd79bf7e5090d90eb87e8ec9bb5dc39e3b0c4271 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 3 Sep 2025 12:17:45 +0000 Subject: [PATCH 07/10] Add: `io.Closer` interface support --- golang/lib.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/golang/lib.go b/golang/lib.go index e7b1ddcfe..3b6d399fb 100644 --- a/golang/lib.go +++ b/golang/lib.go @@ -207,9 +207,10 @@ func DefaultConfig(dimensions uint) IndexConfig { } // Index represents a USearch approximate nearest neighbor index. +// It implements io.Closer for idiomatic resource cleanup. // // The index must be properly initialized with NewIndex() and destroyed -// with Destroy() when no longer needed to free resources. +// with Destroy() or Close() when no longer needed to free resources. type Index struct { opaque_handle *C.void config IndexConfig @@ -409,6 +410,12 @@ func (index *Index) Destroy() error { return nil } +// Close implements io.Closer interface and calls Destroy() to free resources. +// This provides idiomatic Go resource cleanup that can be used with defer statements. +func (index *Index) Close() error { + return index.Destroy() +} + // Reserve reserves memory for a specified number of incoming vectors. func (index *Index) Reserve(capacity uint) error { if index.opaque_handle == nil { From 1f190b47810e0a2cdd12298d71aa41ca1ca1eb5f Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 3 Sep 2025 12:30:48 +0000 Subject: [PATCH 08/10] Add: All new test suite for Go --- golang/lib.go | 106 ++--- golang/lib_test.go | 1044 ++++++++++++++++++++++++++++---------------- 2 files changed, 732 insertions(+), 418 deletions(-) diff --git a/golang/lib.go b/golang/lib.go index 3b6d399fb..02d81d5cd 100644 --- a/golang/lib.go +++ b/golang/lib.go @@ -22,6 +22,7 @@ import ( "errors" "fmt" "runtime" + "sync" "unsafe" ) @@ -212,8 +213,9 @@ func DefaultConfig(dimensions uint) IndexConfig { // The index must be properly initialized with NewIndex() and destroyed // with Destroy() or Close() when no longer needed to free resources. type Index struct { - opaque_handle *C.void - config IndexConfig + handle C.usearch_index_t + config IndexConfig + mu sync.Mutex } // NewIndex creates a new approximate nearest neighbor index with the specified configuration. @@ -259,14 +261,14 @@ func NewIndex(conf IndexConfig) (index *Index, err error) { return nil, errors.New(C.GoString(errorMessage)) } - index.opaque_handle = (*C.void)(unsafe.Pointer(ptr)) + index.handle = ptr return index, nil } // Len returns the number of vectors in the index. func (index *Index) Len() (len uint, err error) { var errorMessage *C.char - len = uint(C.usearch_size((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + len = uint(C.usearch_size(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -276,7 +278,7 @@ func (index *Index) Len() (len uint, err error) { // SerializedLength reports the expected file size after serialization. func (index *Index) SerializedLength() (len uint, err error) { var errorMessage *C.char - len = uint(C.usearch_serialized_length((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + len = uint(C.usearch_serialized_length(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -286,7 +288,7 @@ func (index *Index) SerializedLength() (len uint, err error) { // MemoryUsage reports the memory usage of the index func (index *Index) MemoryUsage() (len uint, err error) { var errorMessage *C.char - len = uint(C.usearch_memory_usage((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + len = uint(C.usearch_memory_usage(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -296,7 +298,7 @@ func (index *Index) MemoryUsage() (len uint, err error) { // ExpansionAdd returns the expansion value used during index creation func (index *Index) ExpansionAdd() (val uint, err error) { var errorMessage *C.char - val = uint(C.usearch_expansion_add((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + val = uint(C.usearch_expansion_add(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -306,7 +308,7 @@ func (index *Index) ExpansionAdd() (val uint, err error) { // ExpansionSearch returns the expansion value used during search func (index *Index) ExpansionSearch() (val uint, err error) { var errorMessage *C.char - val = uint(C.usearch_expansion_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + val = uint(C.usearch_expansion_search(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -316,7 +318,7 @@ func (index *Index) ExpansionSearch() (val uint, err error) { // ChangeExpansionAdd sets the expansion value used during index creation func (index *Index) ChangeExpansionAdd(val uint) error { var errorMessage *C.char - C.usearch_change_expansion_add((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), C.size_t(val), (*C.usearch_error_t)(&errorMessage)) + C.usearch_change_expansion_add(index.handle, C.size_t(val), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -326,7 +328,7 @@ func (index *Index) ChangeExpansionAdd(val uint) error { // ChangeExpansionSearch sets the expansion value used during search func (index *Index) ChangeExpansionSearch(val uint) error { var errorMessage *C.char - C.usearch_change_expansion_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), C.size_t(val), (*C.usearch_error_t)(&errorMessage)) + C.usearch_change_expansion_search(index.handle, C.size_t(val), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -336,7 +338,7 @@ func (index *Index) ChangeExpansionSearch(val uint) error { // ChangeThreadsAdd sets the threads limit for add func (index *Index) ChangeThreadsAdd(val uint) error { var errorMessage *C.char - C.usearch_change_threads_add((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), C.size_t(val), (*C.usearch_error_t)(&errorMessage)) + C.usearch_change_threads_add(index.handle, C.size_t(val), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -346,7 +348,7 @@ func (index *Index) ChangeThreadsAdd(val uint) error { // ChangeThreadsSearch sets the threads limit for search func (index *Index) ChangeThreadsSearch(val uint) error { var errorMessage *C.char - C.usearch_change_threads_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), C.size_t(val), (*C.usearch_error_t)(&errorMessage)) + C.usearch_change_threads_search(index.handle, C.size_t(val), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -356,7 +358,7 @@ func (index *Index) ChangeThreadsSearch(val uint) error { // Connectivity returns the connectivity parameter of the index. func (index *Index) Connectivity() (con uint, err error) { var errorMessage *C.char - con = uint(C.usearch_connectivity((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + con = uint(C.usearch_connectivity(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -366,7 +368,7 @@ func (index *Index) Connectivity() (con uint, err error) { // Dimensions returns the number of dimensions of the vectors in the index. func (index *Index) Dimensions() (dim uint, err error) { var errorMessage *C.char - dim = uint(C.usearch_dimensions((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + dim = uint(C.usearch_dimensions(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -376,7 +378,7 @@ func (index *Index) Dimensions() (dim uint, err error) { // Capacity returns the capacity (maximum number of vectors) of the index. func (index *Index) Capacity() (cap uint, err error) { var errorMessage *C.char - cap = uint(C.usearch_capacity((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + cap = uint(C.usearch_capacity(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -387,7 +389,7 @@ func (index *Index) Capacity() (cap uint, err error) { func (index *Index) HardwareAcceleration() (string, error) { var str *C.char var errorMessage *C.char - str = C.usearch_hardware_acceleration((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage)) + str = C.usearch_hardware_acceleration(index.handle, (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return C.GoString(nil), errors.New(C.GoString(errorMessage)) } @@ -396,16 +398,16 @@ func (index *Index) HardwareAcceleration() (string, error) { // Destroy frees the resources associated with the index. func (index *Index) Destroy() error { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } var errorMessage *C.char - C.usearch_free((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage)) + C.usearch_free(index.handle, (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } - index.opaque_handle = nil + index.handle = nil index.config = IndexConfig{} return nil } @@ -418,12 +420,12 @@ func (index *Index) Close() error { // Reserve reserves memory for a specified number of incoming vectors. func (index *Index) Reserve(capacity uint) error { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } var errorMessage *C.char - C.usearch_reserve((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.size_t)(capacity), (*C.usearch_error_t)(&errorMessage)) + C.usearch_reserve(index.handle, (C.size_t)(capacity), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -439,7 +441,7 @@ func (index *Index) Reserve(capacity uint) error { // - The vector is empty or has wrong dimensions // - The underlying C library reports an error func (index *Index) Add(key Key, vec []float32) error { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } @@ -451,7 +453,7 @@ func (index *Index) Add(key Key, vec []float32) error { } var errorMessage *C.char - C.usearch_add((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), unsafe.Pointer(&vec[0]), C.usearch_scalar_f32_k, (*C.usearch_error_t)(&errorMessage)) + C.usearch_add(index.handle, (C.usearch_key_t)(key), unsafe.Pointer(&vec[0]), C.usearch_scalar_f32_k, (*C.usearch_error_t)(&errorMessage)) runtime.KeepAlive(vec) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) @@ -471,7 +473,7 @@ func (index *Index) Add(key Key, vec []float32) error { // Use Add() or AddI8() instead unless you need maximum performance // and understand the safety implications. func (index *Index) AddUnsafe(key Key, vec unsafe.Pointer) error { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } @@ -480,7 +482,7 @@ func (index *Index) AddUnsafe(key Key, vec unsafe.Pointer) error { } var errorMessage *C.char - C.usearch_add((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), vec, index.config.Quantization.CValue(), (*C.usearch_error_t)(&errorMessage)) + C.usearch_add(index.handle, (C.usearch_key_t)(key), vec, index.config.Quantization.CValue(), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -489,12 +491,12 @@ func (index *Index) AddUnsafe(key Key, vec unsafe.Pointer) error { // Remove removes the vector associated with the given key from the index. func (index *Index) Remove(key Key) error { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } var errorMessage *C.char - C.usearch_remove((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), (*C.usearch_error_t)(&errorMessage)) + C.usearch_remove(index.handle, (C.usearch_key_t)(key), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -503,12 +505,12 @@ func (index *Index) Remove(key Key) error { // Contains checks if the index contains a vector with a specific key. func (index *Index) Contains(key Key) (found bool, err error) { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } var errorMessage *C.char - found = bool(C.usearch_contains((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), (*C.usearch_error_t)(&errorMessage))) + found = bool(C.usearch_contains(index.handle, (C.usearch_key_t)(key), (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { return found, errors.New(C.GoString(errorMessage)) } @@ -518,7 +520,7 @@ func (index *Index) Contains(key Key) (found bool, err error) { // Get retrieves the vectors associated with the given key from the index. // Returns nil if the key is not found. func (index *Index) Get(key Key, maxCount uint) (vectors []float32, err error) { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } @@ -528,7 +530,7 @@ func (index *Index) Get(key Key, maxCount uint) (vectors []float32, err error) { vectors = make([]float32, index.config.Dimensions*maxCount) var errorMessage *C.char - found := uint(C.usearch_get((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), (C.size_t)(maxCount), unsafe.Pointer(&vectors[0]), C.usearch_scalar_f32_k, (*C.usearch_error_t)(&errorMessage))) + found := uint(C.usearch_get(index.handle, (C.usearch_key_t)(key), (C.size_t)(maxCount), unsafe.Pointer(&vectors[0]), C.usearch_scalar_f32_k, (*C.usearch_error_t)(&errorMessage))) runtime.KeepAlive(vectors) if errorMessage != nil { return nil, errors.New(C.GoString(errorMessage)) @@ -542,7 +544,7 @@ func (index *Index) Get(key Key, maxCount uint) (vectors []float32, err error) { // Rename the vector at key from to key to func (index *Index) Rename(from Key, to Key) error { var errorMessage *C.char - C.usearch_rename((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), C.usearch_key_t(from), C.usearch_key_t(to), (*C.usearch_error_t)(&errorMessage)) + C.usearch_rename(index.handle, C.usearch_key_t(from), C.usearch_key_t(to), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -606,7 +608,7 @@ func DistanceUnsafe(vec1 unsafe.Pointer, vec2 unsafe.Pointer, vectorDimensions u // The actual number of results may be less than limit if the index // contains fewer vectors. func (index *Index) Search(query []float32, limit uint) (keys []Key, distances []float32, err error) { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } @@ -623,7 +625,7 @@ func (index *Index) Search(query []float32, limit uint) (keys []Key, distances [ keys = make([]Key, limit) distances = make([]float32, limit) var errorMessage *C.char - resultCount := uint(C.usearch_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&query[0]), C.usearch_scalar_f32_k, (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) + resultCount := uint(C.usearch_search(index.handle, unsafe.Pointer(&query[0]), C.usearch_scalar_f32_k, (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) runtime.KeepAlive(query) runtime.KeepAlive(keys) runtime.KeepAlive(distances) @@ -647,7 +649,7 @@ func (index *Index) Search(query []float32, limit uint) (keys []Key, distances [ // Use Search() or SearchI8() instead unless you need maximum performance // and understand the safety implications. func (index *Index) SearchUnsafe(query unsafe.Pointer, limit uint) (keys []Key, distances []float32, err error) { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } @@ -661,7 +663,7 @@ func (index *Index) SearchUnsafe(query unsafe.Pointer, limit uint) (keys []Key, keys = make([]Key, limit) distances = make([]float32, limit) var errorMessage *C.char - resultCount := uint(C.usearch_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), query, index.config.Quantization.CValue(), (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) + resultCount := uint(C.usearch_search(index.handle, query, index.config.Quantization.CValue(), (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) runtime.KeepAlive(keys) runtime.KeepAlive(distances) if errorMessage != nil { @@ -767,7 +769,7 @@ func ExactSearchUnsafe(dataset unsafe.Pointer, queries unsafe.Pointer, datasetSi // // This is a convenience method for indexes using I8 quantization. func (index *Index) AddI8(key Key, vec []int8) error { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } if len(vec) == 0 { @@ -777,7 +779,7 @@ func (index *Index) AddI8(key Key, vec []int8) error { return fmt.Errorf("vector dimension mismatch: got %d, expected %d", len(vec), index.config.Dimensions) } var errorMessage *C.char - C.usearch_add((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), unsafe.Pointer(&vec[0]), C.usearch_scalar_i8_k, (*C.usearch_error_t)(&errorMessage)) + C.usearch_add(index.handle, (C.usearch_key_t)(key), unsafe.Pointer(&vec[0]), C.usearch_scalar_i8_k, (*C.usearch_error_t)(&errorMessage)) runtime.KeepAlive(vec) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) @@ -790,7 +792,7 @@ func (index *Index) AddI8(key Key, vec []int8) error { // // This is a convenience method for indexes using I8 quantization. func (index *Index) SearchI8(query []int8, limit uint) (keys []Key, distances []float32, err error) { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } if len(query) == 0 { @@ -805,7 +807,7 @@ func (index *Index) SearchI8(query []int8, limit uint) (keys []Key, distances [] keys = make([]Key, limit) distances = make([]float32, limit) var errorMessage *C.char - resultCount := uint(C.usearch_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&query[0]), C.usearch_scalar_i8_k, (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) + resultCount := uint(C.usearch_search(index.handle, unsafe.Pointer(&query[0]), C.usearch_scalar_i8_k, (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) runtime.KeepAlive(query) runtime.KeepAlive(keys) runtime.KeepAlive(distances) @@ -879,7 +881,7 @@ func ExactSearchI8(dataset []int8, queries []int8, datasetSize uint, queryCount // The buffer must be large enough to hold the serialized index. // Use SerializedLength() to determine the required buffer size. func (index *Index) SaveBuffer(buf []byte, buffer_size uint) error { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } @@ -891,7 +893,7 @@ func (index *Index) SaveBuffer(buf []byte, buffer_size uint) error { } var errorMessage *C.char - C.usearch_save_buffer((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&buf[0]), C.size_t(buffer_size), (*C.usearch_error_t)(&errorMessage)) + C.usearch_save_buffer(index.handle, unsafe.Pointer(&buf[0]), C.size_t(buffer_size), (*C.usearch_error_t)(&errorMessage)) runtime.KeepAlive(buf) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) @@ -902,7 +904,7 @@ func (index *Index) SaveBuffer(buf []byte, buffer_size uint) error { // LoadBuffer loads a serialized index from a byte buffer. // The buffer must contain a valid serialized index. func (index *Index) LoadBuffer(buf []byte, buffer_size uint) error { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } @@ -914,7 +916,7 @@ func (index *Index) LoadBuffer(buf []byte, buffer_size uint) error { } var errorMessage *C.char - C.usearch_load_buffer((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&buf[0]), C.size_t(buffer_size), (*C.usearch_error_t)(&errorMessage)) + C.usearch_load_buffer(index.handle, unsafe.Pointer(&buf[0]), C.size_t(buffer_size), (*C.usearch_error_t)(&errorMessage)) runtime.KeepAlive(buf) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) @@ -926,7 +928,7 @@ func (index *Index) LoadBuffer(buf []byte, buffer_size uint) error { // The buffer must remain valid for the lifetime of the index. // Changes to the buffer will affect the index. func (index *Index) ViewBuffer(buf []byte, buffer_size uint) error { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } @@ -938,7 +940,7 @@ func (index *Index) ViewBuffer(buf []byte, buffer_size uint) error { } var errorMessage *C.char - C.usearch_view_buffer((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&buf[0]), C.size_t(buffer_size), (*C.usearch_error_t)(&errorMessage)) + C.usearch_view_buffer(index.handle, unsafe.Pointer(&buf[0]), C.size_t(buffer_size), (*C.usearch_error_t)(&errorMessage)) runtime.KeepAlive(buf) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) @@ -1072,7 +1074,7 @@ func Metadata(path string) (c IndexConfig, err error) { // Save saves the index to a specified file. func (index *Index) Save(path string) error { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } @@ -1080,7 +1082,7 @@ func (index *Index) Save(path string) error { defer C.free(unsafe.Pointer(c_path)) var errorMessage *C.char - C.usearch_save((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), c_path, (*C.usearch_error_t)(&errorMessage)) + C.usearch_save((C.usearch_index_t)(unsafe.Pointer(index.handle)), c_path, (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -1089,7 +1091,7 @@ func (index *Index) Save(path string) error { // Load loads the index from a specified file. func (index *Index) Load(path string) error { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } @@ -1097,7 +1099,7 @@ func (index *Index) Load(path string) error { defer C.free(unsafe.Pointer(c_path)) var errorMessage *C.char - C.usearch_load((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), c_path, (*C.usearch_error_t)(&errorMessage)) + C.usearch_load((C.usearch_index_t)(unsafe.Pointer(index.handle)), c_path, (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -1106,7 +1108,7 @@ func (index *Index) Load(path string) error { // View creates a view of the index from a specified file without loading it into memory. func (index *Index) View(path string) error { - if index.opaque_handle == nil { + if index.handle == nil { panic("index is uninitialized") } @@ -1114,7 +1116,7 @@ func (index *Index) View(path string) error { defer C.free(unsafe.Pointer(c_path)) var errorMessage *C.char - C.usearch_view((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), c_path, (*C.usearch_error_t)(&errorMessage)) + C.usearch_view((C.usearch_index_t)(unsafe.Pointer(index.handle)), c_path, (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } diff --git a/golang/lib_test.go b/golang/lib_test.go index 1a7692693..6d026bbf9 100644 --- a/golang/lib_test.go +++ b/golang/lib_test.go @@ -1,583 +1,895 @@ package usearch import ( + "fmt" + "io" "math" "runtime" + "sync" "testing" "unsafe" ) -func TestUSearch(t *testing.T) { - runtime.LockOSThread() +// Test constants +const ( + defaultTestDimensions = 128 + distanceTolerance = 1e-2 + bufferSize = 1024 * 1024 +) - t.Run("Test Index Initialization", func(t *testing.T) { - dim := uint(128) - conf := DefaultConfig(dim) - ind, err := NewIndex(conf) - if err != nil { - t.Fatalf("Failed to construct the index: %s", err) - } - defer ind.Destroy() +// Helper functions to reduce code duplication + +func createTestIndex(t *testing.T, dimensions uint, quantization Quantization) *Index { + conf := DefaultConfig(dimensions) + conf.Quantization = quantization + index, err := NewIndex(conf) + if err != nil { + t.Fatalf("Failed to create test index: %v", err) + } + return index +} + +func generateTestVector(dimensions uint) []float32 { + vec := make([]float32, dimensions) + for i := uint(0); i < dimensions; i++ { + vec[i] = float32(i) + 0.1 + } + return vec +} + +func generateTestVectorI8(dimensions uint) []int8 { + vec := make([]int8, dimensions) + for i := uint(0); i < dimensions; i++ { + vec[i] = int8((i % 127) + 1) + } + return vec +} + +func populateIndex(t *testing.T, index *Index, vectorCount int) [][]float32 { + vectors := make([][]float32, vectorCount) + err := index.Reserve(uint(vectorCount)) + if err != nil { + t.Fatalf("Failed to reserve capacity: %v", err) + } + + dimensions, err := index.Dimensions() + if err != nil { + t.Fatalf("Failed to get dimensions: %v", err) + } + + for i := 0; i < vectorCount; i++ { + vec := generateTestVector(dimensions) + vec[0] = float32(i) // Make each vector unique + vectors[i] = vec - found_dims, err := ind.Dimensions() + err = index.Add(Key(i), vec) if err != nil { - t.Fatalf("Failed to retrieve dimensions: %s", err) - } - if found_dims != dim { - t.Fatalf("Expected %d dimensions, got %d", dim, found_dims) + t.Fatalf("Failed to add vector %d: %v", i, err) } + } + return vectors +} + +// Core functionality tests (improved versions of existing) + +func TestIndexLifecycle(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + t.Run("Index creation and configuration", func(t *testing.T) { + dimensions := uint(64) + index := createTestIndex(t, dimensions, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() - found_len, err := ind.Len() + // Verify dimensions + actualDimensions, err := index.Dimensions() if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) + t.Fatalf("Failed to retrieve dimensions: %v", err) } - if found_len != 0 { - t.Fatalf("Expected size to be 0, got %d", found_len) + if actualDimensions != dimensions { + t.Fatalf("Expected %d dimensions, got %d", dimensions, actualDimensions) } - found_len, err = ind.SerializedLength() + // Verify empty index + size, err := index.Len() if err != nil { - t.Fatalf("Failed to retrieve serialized length: %s", err) + t.Fatalf("Failed to retrieve size: %v", err) } - if found_len != 112 { - t.Fatalf("Expected serialized length to be 112, got %d", found_len) + if size != 0 { + t.Fatalf("Expected empty index, got size %d", size) } - err = ind.Reserve(100) + // Capacity may be zero before any reservation; ensure Reserve works + if err := index.Reserve(10); err != nil { + t.Fatalf("Failed to reserve capacity: %v", err) + } + capacity, err := index.Capacity() if err != nil { - t.Fatalf("Failed to reserve capacity: %s", err) + t.Fatalf("Failed to retrieve capacity: %v", err) + } + if capacity < 10 { + t.Fatalf("Expected capacity >= 10 after reserve, got %d", capacity) } - mem, err := ind.MemoryUsage() + // Verify memory usage + memUsage, err := index.MemoryUsage() if err != nil { - t.Fatalf("Failed to retrieve serialized length: %s", err) + t.Fatalf("Failed to retrieve memory usage: %v", err) } - if mem == 0 { - t.Fatalf("Expected the empty index memory usage to be positive, got zero") + if memUsage == 0 { + t.Fatalf("Expected positive memory usage") } - s, err := ind.HardwareAcceleration() + // Verify hardware acceleration info + hwAccel, err := index.HardwareAcceleration() if err != nil { - t.Fatalf("Failed to retrieve hardware acceleration: %s", err) + t.Fatalf("Failed to retrieve hardware acceleration: %v", err) } - if s == "" { - t.Fatalf("An empty string was returned from HardwareAcceleration") + if hwAccel == "" { + t.Fatalf("Expected non-empty hardware acceleration string") } - }) - t.Run("Test Insertion", func(t *testing.T) { - dim := uint(128) - conf := DefaultConfig(dim) - ind, err := NewIndex(conf) - if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + t.Run("Index configuration validation", func(t *testing.T) { + // Test different configurations + configs := []struct { + name string + dimensions uint + quantization Quantization + metric Metric + }{ + {"F32-Cosine", 128, F32, Cosine}, + {"F64-L2sq", 64, F64, L2sq}, + {"I8-InnerProduct", 32, I8, InnerProduct}, + } + + for _, config := range configs { + t.Run(config.name, func(t *testing.T) { + conf := DefaultConfig(config.dimensions) + conf.Quantization = config.quantization + conf.Metric = config.metric + + index, err := NewIndex(conf) + if err != nil { + t.Fatalf("Failed to create index with config %s: %v", config.name, err) + } + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + actualDims, err := index.Dimensions() + if err != nil || actualDims != config.dimensions { + t.Fatalf("Configuration mismatch for %s", config.name) + } + }) } - defer ind.Destroy() + }) +} - err = ind.Reserve(100) - if err != nil { - t.Fatalf("Failed to reserve capacity: %s", err) - } +func TestBasicOperations(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() - err = ind.ChangeThreadsAdd(10) - if err != nil { - t.Fatalf("Failed to change threads add: %s", err) + t.Run("Add and retrieve", func(t *testing.T) { + index := createTestIndex(t, defaultTestDimensions, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + // Ensure capacity before first add + if err := index.Reserve(1); err != nil { + t.Fatalf("Failed to reserve capacity: %v", err) } - vec := make([]float32, dim) - vec[0] = 40.0 - vec[1] = 2.0 + // Add a vector + vec := generateTestVector(defaultTestDimensions) + vec[0] = 42.0 + vec[1] = 24.0 - err = ind.Add(42, vec) + err := index.Add(100, vec) if err != nil { - t.Fatalf("Failed to insert: %s", err) + t.Fatalf("Failed to add vector: %v", err) } - found_len, err := ind.Len() + // Verify index size + size, err := index.Len() if err != nil { - t.Fatalf("Failed to retrieve size after insertion: %s", err) + t.Fatalf("Failed to get index size: %v", err) } - if found_len != 1 { - t.Fatalf("Expected size to be 1, got %d", found_len) + if size != 1 { + t.Fatalf("Expected size 1, got %d", size) } - }) - t.Run("Test Insertion with pointer", func(t *testing.T) { - dim := uint(128) - conf := DefaultConfig(dim) - ind, err := NewIndex(conf) + // Test Contains + found, err := index.Contains(100) if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + t.Fatalf("Contains check failed: %v", err) } - defer ind.Destroy() - - err = ind.Reserve(100) - if err != nil { - t.Fatalf("Failed to reserve capacity: %s", err) + if !found { + t.Fatalf("Expected to find key 100") } - err = ind.ChangeThreadsAdd(10) + // Test Get + retrieved, err := index.Get(100, 1) if err != nil { - t.Fatalf("Failed to change threads add: %s", err) + t.Fatalf("Failed to retrieve vector: %v", err) } + if retrieved == nil || len(retrieved) != int(defaultTestDimensions) { + t.Fatalf("Retrieved vector has wrong dimensions") + } + }) + + t.Run("Search functionality", func(t *testing.T) { + index := createTestIndex(t, defaultTestDimensions, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() - vec := make([]float32, dim) - vec[0] = 40.0 - vec[1] = 2.0 + // Add test data + testVectors := populateIndex(t, index, 10) - err = ind.AddUnsafe(42, unsafe.Pointer(&vec[0])) + // Search with first vector (should find itself) + keys, distances, err := index.Search(testVectors[0], 5) if err != nil { - t.Fatalf("Failed to insert: %s", err) + t.Fatalf("Search failed: %v", err) } - found_len, err := ind.Len() - if err != nil { - t.Fatalf("Failed to retrieve size after insertion: %s", err) + if len(keys) == 0 || len(distances) == 0 { + t.Fatalf("Search returned no results") + } + + // First result should be the exact match with near-zero distance + if keys[0] != 0 { + t.Fatalf("Expected first result to be key 0, got %d", keys[0]) } - if found_len != 1 { - t.Fatalf("Expected size to be 1, got %d", found_len) + + if math.Abs(float64(distances[0])) > distanceTolerance { + t.Fatalf("Expected near-zero distance for exact match, got %f", distances[0]) } }) - t.Run("Test Search", func(t *testing.T) { - dim := uint(128) - conf := DefaultConfig(dim) - ind, err := NewIndex(conf) + t.Run("Remove operations", func(t *testing.T) { + index := createTestIndex(t, defaultTestDimensions, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + // Add vectors + populateIndex(t, index, 5) + + // Remove one vector + err := index.Remove(2) if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + t.Fatalf("Failed to remove vector: %v", err) } - defer ind.Destroy() - err = ind.Reserve(100) + // Verify it's gone + found, err := index.Contains(2) if err != nil { - t.Fatalf("Failed to reserve capacity: %s", err) + t.Fatalf("Contains check failed after removal: %v", err) + } + if found { + t.Fatalf("Key 2 should have been removed") } - err = ind.ChangeThreadsSearch(10) + // Verify size decreased + size, err := index.Len() if err != nil { - t.Fatalf("Failed to change threads search: %s", err) + t.Fatalf("Failed to get size after removal: %v", err) } + if size != 4 { + t.Fatalf("Expected size 4 after removal, got %d", size) + } + }) +} - vec := make([]float32, dim) - vec[0] = 40.0 - vec[1] = 2.0 +func TestIOCloser(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() - err = ind.Add(42, vec) - if err != nil { - t.Fatalf("Failed to insert: %s", err) - } + t.Run("io.Closer interface compliance", func(t *testing.T) { + index := createTestIndex(t, 32, F32) - keys, distances, err := ind.Search(vec, 10) + // Verify that Index can be used as io.Closer + var closer io.Closer = index + + // Test Close method works like Destroy + err := closer.Close() if err != nil { - t.Fatalf("Failed to search: %s", err) + t.Fatalf("Close failed: %v", err) } + }) +} - const tolerance = 1e-2 // For example, this sets the tolerance to 0.01 - if keys[0] != 42 || math.Abs(float64(distances[0])) > tolerance { - t.Fatalf("Expected result 42 with distance 0, got key %d with distance %f", keys[0], distances[0]) - } +func TestSerialization(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + t.Run("Buffer save/load/view operations", func(t *testing.T) { + // Create and populate original index + originalIndex := createTestIndex(t, defaultTestDimensions, F32) + defer func() { + if err := originalIndex.Destroy(); err != nil { + t.Errorf("Failed to destroy original index: %v", err) + } + }() - // TODO: Add exact search - }) + testVectors := populateIndex(t, originalIndex, 50) - t.Run("Test Search With Pointer", func(t *testing.T) { - dim := uint(128) - conf := DefaultConfig(dim) - ind, err := NewIndex(conf) + originalSize, err := originalIndex.Len() if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + t.Fatalf("Failed to get original index size: %v", err) } - defer ind.Destroy() - err = ind.Reserve(100) + // Save to buffer + buf := make([]byte, bufferSize) + err = originalIndex.SaveBuffer(buf, bufferSize) if err != nil { - t.Fatalf("Failed to reserve capacity: %s", err) + t.Fatalf("Failed to save index to buffer: %v", err) } - err = ind.ChangeThreadsSearch(10) + // Test metadata extraction + metadata, err := MetadataBuffer(buf, bufferSize) if err != nil { - t.Fatalf("Failed to change threads search: %s", err) + t.Fatalf("Failed to extract metadata: %v", err) } - vec := make([]float32, dim) - vec[0] = 40.0 - vec[1] = 2.0 + if metadata.Dimensions != defaultTestDimensions { + t.Fatalf("Metadata dimensions mismatch: expected %d, got %d", + defaultTestDimensions, metadata.Dimensions) + } - ptr := unsafe.Pointer(&vec[0]) - err = ind.AddUnsafe(42, ptr) + // Test LoadBuffer + loadedIndex := createTestIndex(t, defaultTestDimensions, F32) + defer func() { + if err := loadedIndex.Destroy(); err != nil { + t.Errorf("Failed to destroy loaded index: %v", err) + } + }() + + err = loadedIndex.LoadBuffer(buf, bufferSize) if err != nil { - t.Fatalf("Failed to insert: %s", err) + t.Fatalf("Failed to load index from buffer: %v", err) } - keys, distances, err := ind.SearchUnsafe(ptr, 10) + loadedSize, err := loadedIndex.Len() if err != nil { - t.Fatalf("Failed to search: %s", err) + t.Fatalf("Failed to get loaded index size: %v", err) } - const tolerance = 1e-2 // For example, this sets the tolerance to 0.01 - if keys[0] != 42 || math.Abs(float64(distances[0])) > tolerance { - t.Fatalf("Expected result 42 with distance 0, got key %d with distance %f", keys[0], distances[0]) + if loadedSize != originalSize { + t.Fatalf("Loaded index size mismatch: expected %d, got %d", + originalSize, loadedSize) } - // TODO: Add exact search - }) - - t.Run("Test Save and Load", func(t *testing.T) { - dim := uint(128) - conf := DefaultConfig(dim) - ind, err := NewIndex(conf) - if err != nil { - t.Fatalf("Failed to construct the index: %s", err) - } - defer ind.Destroy() - ind2, err := NewIndex(conf) + // Verify search results are consistent + keys, distances, err := loadedIndex.Search(testVectors[0], 3) if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + t.Fatalf("Search failed on loaded index: %v", err) } - defer ind2.Destroy() - indView, err := NewIndex(conf) - if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + + if len(keys) == 0 || keys[0] != 0 { + t.Fatalf("Loaded index search results inconsistent") } - defer indView.Destroy() - err = ind.Reserve(100) - if err != nil { - t.Fatalf("Failed to reserve capacity: %s", err) + // Verify distance is near zero for exact match + if math.Abs(float64(distances[0])) > distanceTolerance { + t.Fatalf("Expected near-zero distance for exact match, got %f", distances[0]) } - vec := make([]float32, dim) - for i := uint(0); i < dim; i++ { - vec[i] = float32(i) + 0.2 - err = ind.Add(uint64(i), vec) - if err != nil { - t.Fatalf("Failed to insert: %s", err) + // Test ViewBuffer + viewIndex := createTestIndex(t, defaultTestDimensions, F32) + defer func() { + if err := viewIndex.Destroy(); err != nil { + t.Errorf("Failed to destroy view index: %v", err) } - } + }() - ind_length, err := ind.Len() + err = viewIndex.ViewBuffer(buf, bufferSize) if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) + t.Fatalf("Failed to create view from buffer: %v", err) } - // TODO: Add invalid save and loads? - buffer_size := uint(1 * 1024 * 1024) - buf := make([]byte, buffer_size) - err = ind.SaveBuffer(buf, buffer_size) + viewSize, err := viewIndex.Len() if err != nil { - t.Fatalf("Failed to save the index to a buffer: %s", err) + t.Fatalf("Failed to get view index size: %v", err) } - err = ind2.LoadBuffer(buf, buffer_size) - if err != nil { - t.Fatalf("Failed to load the index from a buffer: %s", err) + if viewSize != originalSize { + t.Fatalf("View index size mismatch: expected %d, got %d", + originalSize, viewSize) } + }) +} - ind2_length, err := ind2.Len() - if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) +func TestInputValidation(t *testing.T) { + t.Run("Zero dimensions", func(t *testing.T) { + conf := DefaultConfig(0) + _, err := NewIndex(conf) + if err == nil { + t.Fatalf("Expected error for zero dimensions") } - if ind_length != ind2_length { - t.Fatalf("Loaded index length %d doesn't match original of %d ", ind2_length, ind_length) - } - // TODO: Check some values + }) - err = indView.ViewBuffer(buf, buffer_size) - if err != nil { - t.Fatalf("Failed to load the view from a buffer: %s", err) - } + t.Run("Empty vectors", func(t *testing.T) { + index := createTestIndex(t, 64, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() - indView_length, err := indView.Len() - if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) - } - if ind_length != indView_length { - t.Fatalf("Loaded view length %d doesn't match original of %d ", indView_length, ind_length) + // Test Add with empty vector + err := index.Add(1, []float32{}) + if err == nil { + t.Fatalf("Expected error for empty vector in Add") } - conf, err = MetadataBuffer(buf, buffer_size) - if err != nil { - t.Fatalf("Failed to load the metadata from a buffer: %s", err) + // Test Search with empty vector + _, _, err = index.Search([]float32{}, 10) + if err == nil { + t.Fatalf("Expected error for empty vector in Search") } - if conf != ind.config { - t.Fatalf("Loaded metadata doesn't match the index metadata") + }) + + t.Run("Dimension mismatches", func(t *testing.T) { + index := createTestIndex(t, 64, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + // Test Add with wrong dimensions + wrongVec := make([]float32, 32) // Should be 64 + err := index.Add(1, wrongVec) + if err == nil { + t.Fatalf("Expected error for dimension mismatch in Add") } - // TODO: Check file save/load/metadata + // Test Search with wrong dimensions + _, _, err = index.Search(wrongVec, 10) + if err == nil { + t.Fatalf("Expected error for dimension mismatch in Search") + } }) - t.Run("Test Save and Load With Pointer", func(t *testing.T) { - dim := uint(128) - conf := DefaultConfig(dim) - ind, err := NewIndex(conf) - if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + t.Run("Nil pointers", func(t *testing.T) { + index := createTestIndex(t, 64, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + // Test AddUnsafe with nil pointer + err := index.AddUnsafe(1, nil) + if err == nil { + t.Fatalf("Expected error for nil pointer in AddUnsafe") } - defer ind.Destroy() - ind2, err := NewIndex(conf) - if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + + // Test SearchUnsafe with nil pointer + _, _, err = index.SearchUnsafe(nil, 10) + if err == nil { + t.Fatalf("Expected error for nil pointer in SearchUnsafe") } - defer ind2.Destroy() - indView, err := NewIndex(conf) - if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + }) + + t.Run("Buffer validation", func(t *testing.T) { + index := createTestIndex(t, 64, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + // Test SaveBuffer with empty buffer + err := index.SaveBuffer([]byte{}, 100) + if err == nil { + t.Fatalf("Expected error for empty buffer in SaveBuffer") } - defer indView.Destroy() - err = ind.Reserve(100) - if err != nil { - t.Fatalf("Failed to reserve capacity: %s", err) + // Test LoadBuffer with empty buffer + err = index.LoadBuffer([]byte{}, 100) + if err == nil { + t.Fatalf("Expected error for empty buffer in LoadBuffer") } + }) +} - vec := make([]float32, dim) - for i := uint(0); i < dim; i++ { - vec[i] = float32(i) + 0.2 - err = ind.AddUnsafe(uint64(i), unsafe.Pointer(&vec[0])) - if err != nil { - t.Fatalf("Failed to insert: %s", err) +func TestQuantizationTypes(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + t.Run("F32 operations", func(t *testing.T) { + index := createTestIndex(t, 32, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) } - } + }() - ind_length, err := ind.Len() + if err := index.Reserve(1); err != nil { + t.Fatalf("Failed to reserve capacity: %v", err) + } + vec := generateTestVector(32) + err := index.Add(1, vec) if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) + t.Fatalf("F32 Add failed: %v", err) } - // TODO: Add invalid save and loads? - buffer_size := uint(1 * 1024 * 1024) - buf := make([]byte, buffer_size) - err = ind.SaveBuffer(buf, buffer_size) + keys, _, err := index.Search(vec, 1) if err != nil { - t.Fatalf("Failed to save the index to a buffer: %s", err) + t.Fatalf("F32 Search failed: %v", err) } - err = ind2.LoadBuffer(buf, buffer_size) - if err != nil { - t.Fatalf("Failed to load the index from a buffer: %s", err) + if len(keys) == 0 || keys[0] != 1 { + t.Fatalf("F32 search results incorrect") } + }) - ind2_length, err := ind2.Len() - if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) + t.Run("F64 operations", func(t *testing.T) { + index := createTestIndex(t, 32, F64) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + if err := index.Reserve(1); err != nil { + t.Fatalf("Failed to reserve capacity: %v", err) } - if ind_length != ind2_length { - t.Fatalf("Loaded index length %d doesn't match original of %d ", ind2_length, ind_length) + vec := make([]float64, 32) + for i := range vec { + vec[i] = float64(i) + 0.5 } - // TODO: Check some values - err = indView.ViewBuffer(buf, buffer_size) + err := index.AddUnsafe(1, unsafe.Pointer(&vec[0])) if err != nil { - t.Fatalf("Failed to load the view from a buffer: %s", err) + t.Fatalf("F64 AddUnsafe failed: %v", err) } - indView_length, err := indView.Len() + keys, _, err := index.SearchUnsafe(unsafe.Pointer(&vec[0]), 1) if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) - } - if ind_length != indView_length { - t.Fatalf("Loaded view length %d doesn't match original of %d ", indView_length, ind_length) + t.Fatalf("F64 SearchUnsafe failed: %v", err) } - conf, err = MetadataBuffer(buf, buffer_size) - if err != nil { - t.Fatalf("Failed to load the metadata from a buffer: %s", err) - } - if conf != ind.config { - t.Fatalf("Loaded metadata doesn't match the index metadata") + if len(keys) == 0 || keys[0] != 1 { + t.Fatalf("F64 search results incorrect") } - - // TODO: Check file save/load/metadata }) -} -func TestUSearchDataType(t *testing.T) { + t.Run("I8 operations", func(t *testing.T) { + index := createTestIndex(t, 32, I8) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() - t.Run("Test Save and Load With F32", func(t *testing.T) { - dim := uint(128) - conf := DefaultConfig(dim) - ind, err := NewIndex(conf) - if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + if err := index.Reserve(1); err != nil { + t.Fatalf("Failed to reserve capacity: %v", err) } - defer ind.Destroy() - ind2, err := NewIndex(conf) + vec := generateTestVectorI8(32) + err := index.AddI8(1, vec) if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + t.Fatalf("I8 Add failed: %v", err) } - defer ind2.Destroy() - indView, err := NewIndex(conf) + + keys, _, err := index.SearchI8(vec, 1) if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + t.Fatalf("I8 Search failed: %v", err) } - defer indView.Destroy() - err = ind.Reserve(100) - if err != nil { - t.Fatalf("Failed to reserve capacity: %s", err) + if len(keys) == 0 || keys[0] != 1 { + t.Fatalf("I8 search results incorrect") } + }) +} - vec := make([]float32, dim) - for i := uint(0); i < dim; i++ { - vec[i] = float32(i) + 0.2 - err = ind.AddUnsafe(uint64(i), unsafe.Pointer(&vec[0])) - if err != nil { - t.Fatalf("Failed to insert: %s", err) +func TestUnsafeOperations(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + t.Run("Unsafe pointer operations", func(t *testing.T) { + index := createTestIndex(t, 64, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) } - } + }() - ind_length, err := ind.Len() - if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) + if err := index.Reserve(1); err != nil { + t.Fatalf("Failed to reserve capacity: %v", err) } + vec := generateTestVector(64) + ptr := unsafe.Pointer(&vec[0]) - // TODO: Add invalid save and loads? - buffer_size := uint(1 * 1024 * 1024) - buf := make([]byte, buffer_size) - err = ind.SaveBuffer(buf, buffer_size) + // Test AddUnsafe + err := index.AddUnsafe(100, ptr) if err != nil { - t.Fatalf("Failed to save the index to a buffer: %s", err) + t.Fatalf("AddUnsafe failed: %v", err) } - err = ind2.LoadBuffer(buf, buffer_size) + // Verify vector was added + size, err := index.Len() if err != nil { - t.Fatalf("Failed to load the index from a buffer: %s", err) + t.Fatalf("Failed to get size after AddUnsafe: %v", err) + } + if size != 1 { + t.Fatalf("Expected size 1 after AddUnsafe, got %d", size) } - ind2_length, err := ind2.Len() + // Test SearchUnsafe + keys, distances, err := index.SearchUnsafe(ptr, 5) if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) + t.Fatalf("SearchUnsafe failed: %v", err) } - if ind_length != ind2_length { - t.Fatalf("Loaded index length %d doesn't match original of %d ", ind2_length, ind_length) + + if len(keys) == 0 || keys[0] != 100 { + t.Fatalf("SearchUnsafe returned incorrect results") } - // TODO: Check some values - err = indView.ViewBuffer(buf, buffer_size) - if err != nil { - t.Fatalf("Failed to load the view from a buffer: %s", err) + if math.Abs(float64(distances[0])) > distanceTolerance { + t.Fatalf("Expected near-zero distance for exact match, got %f", distances[0]) } + }) +} + +func TestConcurrentInsertions(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + t.Run("Multiple concurrent insertions", func(t *testing.T) { + index := createTestIndex(t, 64, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + const numGoroutines = 50 + const vectorsPerGoroutine = 20 + const totalVectors = numGoroutines * vectorsPerGoroutine - indView_length, err := indView.Len() + err := index.Reserve(totalVectors) if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) - } - if ind_length != indView_length { - t.Fatalf("Loaded view length %d doesn't match original of %d ", indView_length, ind_length) + t.Fatalf("Failed to reserve capacity: %v", err) } - conf, err = MetadataBuffer(buf, buffer_size) - if err != nil { - t.Fatalf("Failed to load the metadata from a buffer: %s", err) + var wg sync.WaitGroup + errorChan := make(chan error, numGoroutines) + + // Only concurrent insertions - no mixed operations + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(startID int) { + defer wg.Done() + + for j := 0; j < vectorsPerGoroutine; j++ { + vec := generateTestVector(64) + vec[0] = float32(startID*vectorsPerGoroutine + j) // Unique identifier + + err := index.Add(Key(startID*vectorsPerGoroutine+j), vec) + if err != nil { + errorChan <- err + return + } + } + }(i) } - if conf != ind.config { - t.Fatalf("Loaded metadata doesn't match the index metadata") + + wg.Wait() + close(errorChan) + + // Check for any errors + for err := range errorChan { + t.Fatalf("Concurrent insertion failed: %v", err) } - // TODO: Check file save/load/metadata - keys, distances, err := ind.SearchUnsafe(unsafe.Pointer(&vec[0]), dim) + // Verify final count + finalSize, err := index.Len() if err != nil { - t.Fatalf("Failed to search: %s", err) + t.Fatalf("Failed to get final size: %v", err) } - const tolerance = 1e-2 // For example, this sets the tolerance to 0.01 - if keys[0] != 127 || math.Abs(float64(distances[0])) > tolerance { - t.Fatalf("Expected result 42 with distance 0, got key %d with distance %f", keys[0], distances[0]) + if finalSize != totalVectors { + t.Fatalf("Expected %d vectors after concurrent insertions, got %d", + totalVectors, finalSize) } }) +} - t.Run("Test Save and Load With F64", func(t *testing.T) { - dim := uint(128) - conf := DefaultConfig(dim) - conf.Quantization = F64 - ind, err := NewIndex(conf) - if err != nil { - t.Fatalf("Failed to construct the index: %s", err) - } - defer ind.Destroy() - ind2, err := NewIndex(conf) - if err != nil { - t.Fatalf("Failed to construct the index: %s", err) - } - defer ind2.Destroy() - indView, err := NewIndex(conf) - if err != nil { - t.Fatalf("Failed to construct the index: %s", err) - } - defer indView.Destroy() +func TestConcurrentSearches(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() - err = ind.Reserve(100) - if err != nil { - t.Fatalf("Failed to reserve capacity: %s", err) + t.Run("Multiple concurrent searches", func(t *testing.T) { + index := createTestIndex(t, 64, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + // Pre-populate with data + testVectors := populateIndex(t, index, 200) + + const numGoroutines = 30 + const searchesPerGoroutine = 50 + + var wg sync.WaitGroup + errorChan := make(chan error, numGoroutines) + + // Only concurrent searches - no mixed operations + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(goroutineID int) { + defer wg.Done() + + for j := 0; j < searchesPerGoroutine; j++ { + // Use different query vectors + queryIndex := (goroutineID*searchesPerGoroutine + j) % len(testVectors) + query := testVectors[queryIndex] + + keys, distances, err := index.Search(query, 10) + if err != nil { + errorChan <- err + return + } + + // Basic validation - should find at least the exact match + if len(keys) == 0 || len(distances) == 0 { + errorChan <- fmt.Errorf("search returned empty results") + return + } + + // First result should be the exact match + if keys[0] != Key(queryIndex) || math.Abs(float64(distances[0])) > distanceTolerance { + errorChan <- fmt.Errorf("search results inconsistent: expected key %d, got %d", queryIndex, keys[0]) + return + } + } + }(i) } - vec := make([]float64, dim) - for i := uint(0); i < dim; i++ { - vec[i] = float64(i) + 0.2 - err = ind.AddUnsafe(uint64(i), unsafe.Pointer(&vec[0])) - if err != nil { - t.Fatalf("Failed to insert: %s", err) - } + wg.Wait() + close(errorChan) + + // Check for any errors + for err := range errorChan { + t.Fatalf("Concurrent search failed: %v", err) } + }) +} - ind_length, err := ind.Len() - if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) +func TestExactSearch(t *testing.T) { + t.Run("Float32 exact search", func(t *testing.T) { + // Create dataset and queries + const datasetSize = 100 + const querySize = 10 + const vectorDims = 32 + + dataset := make([]float32, datasetSize*vectorDims) + queries := make([]float32, querySize*vectorDims) + + // Fill with test data + for i := 0; i < len(dataset); i++ { + dataset[i] = float32(i%100) + 0.1 } - // TODO: Add invalid save and loads? - buffer_size := uint(1 * 1024 * 1024) - buf := make([]byte, buffer_size) - err = ind.SaveBuffer(buf, buffer_size) - if err != nil { - t.Fatalf("Failed to save the index to a buffer: %s", err) + for i := 0; i < len(queries); i++ { + queries[i] = float32(i%50) + 0.1 } - err = ind2.LoadBuffer(buf, buffer_size) + keys, distances, err := ExactSearch( + dataset, queries, + datasetSize, querySize, + vectorDims*4, vectorDims*4, // Stride in bytes for float32 + vectorDims, Cosine, + 5, 0, // maxResults=5, numThreads=0 (auto) + 8, 4, // resultKeysStride, resultDistancesStride + ) + if err != nil { - t.Fatalf("Failed to load the index from a buffer: %s", err) + t.Fatalf("ExactSearch failed: %v", err) } - ind2_length, err := ind2.Len() - if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) + if len(keys) != 5 || len(distances) != 5 { + t.Fatalf("Expected 5 results from ExactSearch, got %d keys and %d distances", + len(keys), len(distances)) } - if ind_length != ind2_length { - t.Fatalf("Loaded index length %d doesn't match original of %d ", ind2_length, ind_length) + }) + + t.Run("I8 exact search", func(t *testing.T) { + const datasetSize = 50 + const querySize = 5 + const vectorDims = 16 + + dataset := make([]int8, datasetSize*vectorDims) + queries := make([]int8, querySize*vectorDims) + + // Fill with test data + for i := 0; i < len(dataset); i++ { + dataset[i] = int8((i % 100) + 1) } - // TODO: Check some values - err = indView.ViewBuffer(buf, buffer_size) - if err != nil { - t.Fatalf("Failed to load the view from a buffer: %s", err) + for i := 0; i < len(queries); i++ { + queries[i] = int8((i % 50) + 1) } - indView_length, err := indView.Len() + keys, distances, err := ExactSearchI8( + dataset, queries, + datasetSize, querySize, + vectorDims, vectorDims, // Stride in bytes for int8 + vectorDims, L2sq, + 3, 0, // maxResults=3, numThreads=0 (auto) + 8, 4, // resultKeysStride, resultDistancesStride + ) + if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) + t.Fatalf("ExactSearchI8 failed: %v", err) } - if ind_length != indView_length { - t.Fatalf("Loaded view length %d doesn't match original of %d ", indView_length, ind_length) + + if len(keys) != 3 || len(distances) != 3 { + t.Fatalf("Expected 3 results from ExactSearchI8, got %d keys and %d distances", + len(keys), len(distances)) } + }) +} - conf, err = MetadataBuffer(buf, buffer_size) - if err != nil { - t.Fatalf("Failed to load the metadata from a buffer: %s", err) +func TestDistanceCalculations(t *testing.T) { + t.Run("Float32 distance calculations", func(t *testing.T) { + vec1 := []float32{1.0, 0.0, 0.0} + vec2 := []float32{0.0, 1.0, 0.0} + + // Test different metrics + metrics := []struct { + metric Metric + expected float32 + tolerance float32 + }{ + {Cosine, 1.0, 0.01}, // Perpendicular vectors + {L2sq, 2.0, 0.01}, // Squared Euclidean distance } - if conf != ind.config { - t.Fatalf("Loaded metadata doesn't match the index metadata") + + for _, test := range metrics { + distance, err := Distance(vec1, vec2, 3, test.metric) + if err != nil { + t.Fatalf("Distance calculation failed for %v: %v", test.metric, err) + } + + if math.Abs(float64(distance-test.expected)) > float64(test.tolerance) { + t.Fatalf("Distance mismatch for %v: expected %f, got %f", + test.metric, test.expected, distance) + } } + }) + + t.Run("I8 distance calculations", func(t *testing.T) { + vec1 := []int8{10, 0, 0} + vec2 := []int8{0, 10, 0} - // TODO: Check file save/load/metadata - keys, distances, err := ind.SearchUnsafe(unsafe.Pointer(&vec[0]), dim) + distance, err := DistanceI8(vec1, vec2, 3, L2sq) if err != nil { - t.Fatalf("Failed to search: %s", err) + t.Fatalf("DistanceI8 failed: %v", err) } - const tolerance = 1e-2 // For example, this sets the tolerance to 0.01 - if keys[0] != 127 || math.Abs(float64(distances[0])) > tolerance { - t.Fatalf("Expected result 42 with distance 0, got key %d with distance %f", keys[0], distances[0]) + expected := float32(200.0) // 10^2 + 10^2 = 200 + if math.Abs(float64(distance-expected)) > 0.1 { + t.Fatalf("I8 distance mismatch: expected %f, got %f", expected, distance) } }) } From 1d6ae880e775880611f8951965eb98d6d2ba7659 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 3 Sep 2025 12:42:17 +0000 Subject: [PATCH 09/10] Docs: Cleaner code & reserve semantics --- golang/README.md | 8 ++- golang/lib.go | 8 +-- golang/lib_test.go | 122 +++++++++++++++++++-------------------------- 3 files changed, 63 insertions(+), 75 deletions(-) diff --git a/golang/README.md b/golang/README.md index 9b820e782..e52c0b35f 100644 --- a/golang/README.md +++ b/golang/README.md @@ -64,8 +64,10 @@ func main() { } defer index.Destroy() - // Add to Index + // Reserve capacity and configure internal threading err = index.Reserve(uint(vectorsCount)) + _ = index.ChangeThreadsAdd(uint(runtime.NumCPU())) + _ = index.ChangeThreadsSearch(uint(runtime.NumCPU())) for i := 0; i < vectorsCount; i++ { err = index.Add(usearch.Key(i), []float32{float32(i), float32(i + 1), float32(i + 2)}) if err != nil { @@ -82,6 +84,10 @@ func main() { } ``` +Notes: +- Always call `Reserve(capacity)` before the first write. +- Prefer single-caller writes with internal parallelism via `ChangeThreadsAdd` and internal parallel searches via `ChangeThreadsSearch`, instead of calling `Add` concurrently. + 3. Get USearch: ```sh diff --git a/golang/lib.go b/golang/lib.go index 02d81d5cd..8fa39a230 100644 --- a/golang/lib.go +++ b/golang/lib.go @@ -22,7 +22,6 @@ import ( "errors" "fmt" "runtime" - "sync" "unsafe" ) @@ -215,7 +214,6 @@ func DefaultConfig(dimensions uint) IndexConfig { type Index struct { handle C.usearch_index_t config IndexConfig - mu sync.Mutex } // NewIndex creates a new approximate nearest neighbor index with the specified configuration. @@ -335,7 +333,8 @@ func (index *Index) ChangeExpansionSearch(val uint) error { return nil } -// ChangeThreadsAdd sets the threads limit for add +// ChangeThreadsAdd sets the maximum number of CPU threads used by the index +// during add/build operations. This controls internal parallelism for indexing. func (index *Index) ChangeThreadsAdd(val uint) error { var errorMessage *C.char C.usearch_change_threads_add(index.handle, C.size_t(val), (*C.usearch_error_t)(&errorMessage)) @@ -345,7 +344,8 @@ func (index *Index) ChangeThreadsAdd(val uint) error { return nil } -// ChangeThreadsSearch sets the threads limit for search +// ChangeThreadsSearch sets the maximum number of CPU threads used by the index +// during search operations. This controls internal parallelism for queries. func (index *Index) ChangeThreadsSearch(val uint) error { var errorMessage *C.char C.usearch_change_threads_search(index.handle, C.size_t(val), (*C.usearch_error_t)(&errorMessage)) diff --git a/golang/lib_test.go b/golang/lib_test.go index 6d026bbf9..6409dce6c 100644 --- a/golang/lib_test.go +++ b/golang/lib_test.go @@ -30,19 +30,19 @@ func createTestIndex(t *testing.T, dimensions uint, quantization Quantization) * } func generateTestVector(dimensions uint) []float32 { - vec := make([]float32, dimensions) - for i := uint(0); i < dimensions; i++ { - vec[i] = float32(i) + 0.1 - } - return vec + vector := make([]float32, dimensions) + for i := uint(0); i < dimensions; i++ { + vector[i] = float32(i) + 0.1 + } + return vector } func generateTestVectorI8(dimensions uint) []int8 { - vec := make([]int8, dimensions) - for i := uint(0); i < dimensions; i++ { - vec[i] = int8((i % 127) + 1) - } - return vec + vector := make([]int8, dimensions) + for i := uint(0); i < dimensions; i++ { + vector[i] = int8((i % 127) + 1) + } + return vector } func populateIndex(t *testing.T, index *Index, vectorCount int) [][]float32 { @@ -57,16 +57,16 @@ func populateIndex(t *testing.T, index *Index, vectorCount int) [][]float32 { t.Fatalf("Failed to get dimensions: %v", err) } - for i := 0; i < vectorCount; i++ { - vec := generateTestVector(dimensions) - vec[0] = float32(i) // Make each vector unique - vectors[i] = vec + for i := 0; i < vectorCount; i++ { + vector := generateTestVector(dimensions) + vector[0] = float32(i) // Make each vector unique + vectors[i] = vector - err = index.Add(Key(i), vec) - if err != nil { - t.Fatalf("Failed to add vector %d: %v", i, err) - } - } + err = index.Add(Key(i), vector) + if err != nil { + t.Fatalf("Failed to add vector %d: %v", i, err) + } + } return vectors } @@ -189,12 +189,12 @@ func TestBasicOperations(t *testing.T) { t.Fatalf("Failed to reserve capacity: %v", err) } - // Add a vector - vec := generateTestVector(defaultTestDimensions) - vec[0] = 42.0 - vec[1] = 24.0 + // Add a vector + vector := generateTestVector(defaultTestDimensions) + vector[0] = 42.0 + vector[1] = 24.0 - err := index.Add(100, vec) + err := index.Add(100, vector) if err != nil { t.Fatalf("Failed to add vector: %v", err) } @@ -524,13 +524,13 @@ func TestQuantizationTypes(t *testing.T) { if err := index.Reserve(1); err != nil { t.Fatalf("Failed to reserve capacity: %v", err) } - vec := generateTestVector(32) - err := index.Add(1, vec) + vector := generateTestVector(32) + err := index.Add(1, vector) if err != nil { t.Fatalf("F32 Add failed: %v", err) } - keys, _, err := index.Search(vec, 1) + keys, _, err := index.Search(vector, 1) if err != nil { t.Fatalf("F32 Search failed: %v", err) } @@ -551,17 +551,17 @@ func TestQuantizationTypes(t *testing.T) { if err := index.Reserve(1); err != nil { t.Fatalf("Failed to reserve capacity: %v", err) } - vec := make([]float64, 32) - for i := range vec { - vec[i] = float64(i) + 0.5 - } + vector := make([]float64, 32) + for i := range vector { + vector[i] = float64(i) + 0.5 + } - err := index.AddUnsafe(1, unsafe.Pointer(&vec[0])) + err := index.AddUnsafe(1, unsafe.Pointer(&vector[0])) if err != nil { t.Fatalf("F64 AddUnsafe failed: %v", err) } - keys, _, err := index.SearchUnsafe(unsafe.Pointer(&vec[0]), 1) + keys, _, err := index.SearchUnsafe(unsafe.Pointer(&vector[0]), 1) if err != nil { t.Fatalf("F64 SearchUnsafe failed: %v", err) } @@ -582,13 +582,13 @@ func TestQuantizationTypes(t *testing.T) { if err := index.Reserve(1); err != nil { t.Fatalf("Failed to reserve capacity: %v", err) } - vec := generateTestVectorI8(32) - err := index.AddI8(1, vec) + vector := generateTestVectorI8(32) + err := index.AddI8(1, vector) if err != nil { t.Fatalf("I8 Add failed: %v", err) } - keys, _, err := index.SearchI8(vec, 1) + keys, _, err := index.SearchI8(vector, 1) if err != nil { t.Fatalf("I8 Search failed: %v", err) } @@ -614,8 +614,8 @@ func TestUnsafeOperations(t *testing.T) { if err := index.Reserve(1); err != nil { t.Fatalf("Failed to reserve capacity: %v", err) } - vec := generateTestVector(64) - ptr := unsafe.Pointer(&vec[0]) + vector := generateTestVector(64) + ptr := unsafe.Pointer(&vector[0]) // Test AddUnsafe err := index.AddUnsafe(100, ptr) @@ -652,7 +652,7 @@ func TestConcurrentInsertions(t *testing.T) { runtime.LockOSThread() defer runtime.UnlockOSThread() - t.Run("Multiple concurrent insertions", func(t *testing.T) { + t.Run("Parallelized insertions via internal threads", func(t *testing.T) { index := createTestIndex(t, 64, F32) defer func() { if err := index.Destroy(); err != nil { @@ -660,44 +660,23 @@ func TestConcurrentInsertions(t *testing.T) { } }() - const numGoroutines = 50 - const vectorsPerGoroutine = 20 - const totalVectors = numGoroutines * vectorsPerGoroutine + const totalVectors = 1000 err := index.Reserve(totalVectors) if err != nil { t.Fatalf("Failed to reserve capacity: %v", err) } - var wg sync.WaitGroup - errorChan := make(chan error, numGoroutines) - - // Only concurrent insertions - no mixed operations - for i := 0; i < numGoroutines; i++ { - wg.Add(1) - go func(startID int) { - defer wg.Done() - - for j := 0; j < vectorsPerGoroutine; j++ { - vec := generateTestVector(64) - vec[0] = float32(startID*vectorsPerGoroutine + j) // Unique identifier - - err := index.Add(Key(startID*vectorsPerGoroutine+j), vec) - if err != nil { - errorChan <- err - return - } - } - }(i) - } - - wg.Wait() - close(errorChan) + // Let the library parallelize inserts internally + _ = index.ChangeThreadsAdd(uint(runtime.NumCPU())) - // Check for any errors - for err := range errorChan { - t.Fatalf("Concurrent insertion failed: %v", err) - } + for i := 0; i < totalVectors; i++ { + vector := generateTestVector(64) + vector[0] = float32(i) + if err := index.Add(Key(i), vector); err != nil { + t.Fatalf("Insertion failed at %d: %v", i, err) + } + } // Verify final count finalSize, err := index.Len() @@ -727,6 +706,9 @@ func TestConcurrentSearches(t *testing.T) { // Pre-populate with data testVectors := populateIndex(t, index, 200) + // Let the library parallelize search internally as well + _ = index.ChangeThreadsSearch(uint(runtime.NumCPU())) + const numGoroutines = 30 const searchesPerGoroutine = 50 From 55b5f76552276d6f0e5fd28e0746b3ffe694d42c Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 3 Sep 2025 14:11:29 +0000 Subject: [PATCH 10/10] Fix: Allow reserving threads in Java --- java/cloud/unum/usearch/Index.java | 31 ++++++- .../unum/usearch/cloud_unum_usearch_Index.cpp | 18 ++++- .../unum/usearch/cloud_unum_usearch_Index.h | 4 +- java/test/IndexTest.java | 80 ++++--------------- 4 files changed, 63 insertions(+), 70 deletions(-) diff --git a/java/cloud/unum/usearch/Index.java b/java/cloud/unum/usearch/Index.java index f942aaf56..c15958437 100644 --- a/java/cloud/unum/usearch/Index.java +++ b/java/cloud/unum/usearch/Index.java @@ -274,7 +274,34 @@ public void reserve(long capacity) { if (c_ptr == 0) { throw new IllegalStateException("Index already closed"); } - c_reserve(c_ptr, capacity); + // Pass zeros to use all available contexts on the device + c_reserve(c_ptr, capacity, 0, 0); + } + + /** + * Reserves memory and configures thread contexts. + * Use this to explicitly control concurrent add/search capacities. + * + * @param capacity desired total capacity + * @param threadsAdd maximum concurrent add contexts + * @param threadsSearch maximum concurrent search contexts + */ + public void reserve(long capacity, long threadsAdd, long threadsSearch) { + if (c_ptr == 0) { + throw new IllegalStateException("Index already closed"); + } + c_reserve(c_ptr, capacity, threadsAdd, threadsSearch); + } + + /** + * Reserves memory and sets the same number of contexts + * for both add and search operations. + * + * @param capacity desired total capacity + * @param threads number of contexts for both add and search + */ + public void reserve(long capacity, long threads) { + reserve(capacity, threads, threads); } /** @@ -1021,7 +1048,7 @@ private static native long c_create( private static native long c_capacity(long ptr); - private static native void c_reserve(long ptr, long capacity); + private static native void c_reserve(long ptr, long capacity, long threadsAdd, long threadsSearch); private static native void c_save(long ptr, String path); diff --git a/java/cloud/unum/usearch/cloud_unum_usearch_Index.cpp b/java/cloud/unum/usearch/cloud_unum_usearch_Index.cpp index bd6a8d42e..685bdc8d8 100644 --- a/java/cloud/unum/usearch/cloud_unum_usearch_Index.cpp +++ b/java/cloud/unum/usearch/cloud_unum_usearch_Index.cpp @@ -143,8 +143,22 @@ JNIEXPORT jlong JNICALL Java_cloud_unum_usearch_Index_c_1capacity(JNIEnv*, jclas return reinterpret_cast(c_ptr)->capacity(); } -JNIEXPORT void JNICALL Java_cloud_unum_usearch_Index_c_1reserve(JNIEnv* env, jclass, jlong c_ptr, jlong capacity) { - if (!reinterpret_cast(c_ptr)->try_reserve(static_cast(capacity))) { +JNIEXPORT void JNICALL Java_cloud_unum_usearch_Index_c_1reserve(JNIEnv* env, jclass, jlong c_ptr, jlong capacity, + jlong threads_add, jlong threads_search) { + std::size_t t_add = static_cast(threads_add); + std::size_t t_search = static_cast(threads_search); + if (t_add == 0 || t_search == 0) { + std::size_t hc = std::thread::hardware_concurrency(); + if (hc == 0) + hc = 1; // fallback to 1 if the runtime can't report + if (t_add == 0) + t_add = hc; + if (t_search == 0) + t_search = hc; + } + index_limits_t limits(static_cast(capacity), t_add); + limits.threads_search = t_search; + if (!reinterpret_cast(c_ptr)->try_reserve(limits)) { jclass jc = (*env).FindClass("java/lang/Error"); if (jc) (*env).ThrowNew(jc, "Failed to grow vector index!"); diff --git a/java/cloud/unum/usearch/cloud_unum_usearch_Index.h b/java/cloud/unum/usearch/cloud_unum_usearch_Index.h index b2494c29f..16cbc2110 100644 --- a/java/cloud/unum/usearch/cloud_unum_usearch_Index.h +++ b/java/cloud/unum/usearch/cloud_unum_usearch_Index.h @@ -66,10 +66,10 @@ JNIEXPORT jlong JNICALL Java_cloud_unum_usearch_Index_c_1capacity /* * Class: cloud_unum_usearch_Index * Method: c_reserve - * Signature: (JJ)V + * Signature: (JJJJ)V */ JNIEXPORT void JNICALL Java_cloud_unum_usearch_Index_c_1reserve - (JNIEnv *, jclass, jlong, jlong); + (JNIEnv *, jclass, jlong, jlong, jlong, jlong); /* * Class: cloud_unum_usearch_Index diff --git a/java/test/IndexTest.java b/java/test/IndexTest.java index 455c2cd2f..817ed6334 100644 --- a/java/test/IndexTest.java +++ b/java/test/IndexTest.java @@ -267,13 +267,14 @@ public void testGetIntoBufferMethods() { @Test public void testConcurrentAdd() throws Exception { try (Index index = new Index.Config().metric("cos").dimensions(4).build()) { - index.reserve(1000); + final int threadsCount = 10; + index.reserve(1000, threadsCount); - ExecutorService executor = Executors.newFixedThreadPool(10); + ExecutorService executor = Executors.newFixedThreadPool(threadsCount); @SuppressWarnings("unchecked") - CompletableFuture[] futures = new CompletableFuture[10]; + CompletableFuture[] futures = new CompletableFuture[threadsCount]; - for (int t = 0; t < 10; t++) { + for (int t = 0; t < threadsCount; t++) { final int threadId = t; futures[t] = CompletableFuture.runAsync( @@ -290,25 +291,26 @@ public void testConcurrentAdd() throws Exception { CompletableFuture.allOf(futures).get(10, TimeUnit.SECONDS); executor.shutdown(); - assertEquals(500, index.size()); + assertEquals(50L * threadsCount, index.size()); } } @Test public void testConcurrentSearch() throws Exception { try (Index index = new Index.Config().metric("cos").dimensions(4).build()) { - index.reserve(100); + final int threadsCount = 5; + index.reserve(100, threadsCount); // Add some vectors first for (int i = 0; i < 100; i++) { index.add(i, randomVector(4)); } - ExecutorService executor = Executors.newFixedThreadPool(5); + ExecutorService executor = Executors.newFixedThreadPool(threadsCount); @SuppressWarnings("unchecked") - CompletableFuture[] futures = new CompletableFuture[5]; + CompletableFuture[] futures = new CompletableFuture[threadsCount]; - for (int t = 0; t < 5; t++) { + for (int t = 0; t < threadsCount; t++) { futures[t] = CompletableFuture.supplyAsync( () -> { @@ -328,56 +330,6 @@ public void testConcurrentSearch() throws Exception { } } - @Test - public void testMixedConcurrency() throws Exception { - try (Index index = new Index.Config().metric("cos").dimensions(3).build()) { - index.reserve(200); - - ExecutorService executor = Executors.newFixedThreadPool(8); - @SuppressWarnings("unchecked") - CompletableFuture[] addFutures = new CompletableFuture[4]; - @SuppressWarnings("unchecked") - CompletableFuture[] searchFutures = new CompletableFuture[4]; - - // Add operations - for (int t = 0; t < 4; t++) { - final int threadId = t; - addFutures[t] - = CompletableFuture.runAsync( - () -> { - for (int i = 0; i < 30; i++) { - long key = threadId * 30L + i; - index.add(key, randomVector(3)); - } - }, - executor); - } - - // Wait for some adds to complete, then start searches - Thread.sleep(100); - - // Search operations - for (int t = 0; t < 4; t++) { - searchFutures[t] - = CompletableFuture.runAsync( - () -> { - for (int i = 0; i < 10; i++) { - float[] queryVector = randomVector(3); - long[] results = index.search(queryVector, 5); - assertTrue(results.length >= 0); - } - }, - executor); - } - - CompletableFuture.allOf(addFutures).get(15, TimeUnit.SECONDS); - CompletableFuture.allOf(searchFutures).get(15, TimeUnit.SECONDS); - executor.shutdown(); - - assertEquals(120, index.size()); - } - } - @Test public void testBatchAdd() { try (Index index = new Index.Config().metric("cos").dimensions(2).build()) { @@ -751,30 +703,30 @@ public void testPlatformCapabilities() { String[] available = Index.hardwareAccelerationAvailable(); assertNotEquals("Available capabilities should not be null", null, available); assertTrue("Platform should have at least serial capability", available.length > 0); - + // Test compile-time capabilities String[] compiled = Index.hardwareAccelerationCompiled(); assertNotEquals("Compiled capabilities should not be null", null, compiled); assertTrue("Should have at least serial compiled", compiled.length > 0); - + // Should always include serial as baseline in both boolean hasAvailableSerial = false; boolean hasCompiledSerial = false; - + for (String cap : available) { if ("serial".equals(cap)) { hasAvailableSerial = true; break; } } - + for (String cap : compiled) { if ("serial".equals(cap)) { hasCompiledSerial = true; break; } } - + assertTrue("Platform should always support serial capability", hasAvailableSerial); assertTrue("Serial should always be compiled", hasCompiledSerial);