diff --git a/.github/workflows/prerelease.yml b/.github/workflows/prerelease.yml index 86e04e8a2..167ce3da4 100644 --- a/.github/workflows/prerelease.yml +++ b/.github/workflows/prerelease.yml @@ -244,7 +244,7 @@ jobs: - name: Build C/C++ run: | brew update - brew install cmake + brew reinstall cmake cmake -B build_artifacts -D CMAKE_BUILD_TYPE=RelWithDebInfo -D USEARCH_BUILD_TEST_CPP=1 -D USEARCH_BUILD_TEST_C=1 -D USEARCH_BUILD_LIB_C=1 -D USEARCH_BUILD_SQLITE=1 cmake --build build_artifacts --config RelWithDebInfo - name: Test C++ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index abbd85873..c4cfc937f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -397,6 +397,15 @@ cp c/usearch.h golang/ # to make the header available to Go cd golang && LD_LIBRARY_PATH=. go test -v ; cd .. ``` +For static checks: + +```sh +cd golang +go vet ./... +staticcheck ./... # if installed +golangci-lint run # if installed +``` + ## Java USearch provides Java bindings as a fat-JAR published with prebuilt JNI libraries via GitHub releases. Installation via Maven Central is deprecated; prefer downloading the fat-JAR from the latest GitHub release. The compilation settings are controlled by `build.gradle` and are independent from CMake used for C/C++ builds. diff --git a/golang/README.md b/golang/README.md index 9b820e782..e52c0b35f 100644 --- a/golang/README.md +++ b/golang/README.md @@ -64,8 +64,10 @@ func main() { } defer index.Destroy() - // Add to Index + // Reserve capacity and configure internal threading err = index.Reserve(uint(vectorsCount)) + _ = index.ChangeThreadsAdd(uint(runtime.NumCPU())) + _ = index.ChangeThreadsSearch(uint(runtime.NumCPU())) for i := 0; i < vectorsCount; i++ { err = index.Add(usearch.Key(i), []float32{float32(i), float32(i + 1), float32(i + 2)}) if err != nil { @@ -82,6 +84,10 @@ func main() { } ``` +Notes: +- Always call `Reserve(capacity)` before the first write. +- Prefer single-caller writes with internal parallelism via `ChangeThreadsAdd` and internal parallel searches via `ChangeThreadsSearch`, instead of calling `Add` concurrently. + 3. Get USearch: ```sh diff --git a/golang/lib.go b/golang/lib.go index 778247ac6..8fa39a230 100644 --- a/golang/lib.go +++ b/golang/lib.go @@ -1,7 +1,27 @@ +// Package usearch provides Go bindings for the USearch library, a high-performance +// approximate nearest neighbor search implementation. +// +// Basic usage: +// +// conf := usearch.DefaultConfig(128) // 128-dimensional vectors +// index, err := usearch.NewIndex(conf) +// if err != nil { +// log.Fatal(err) +// } +// defer index.Destroy() +// +// // Add vectors +// vec := make([]float32, 128) +// err = index.Add(42, vec) +// +// // Search +// keys, distances, err := index.Search(vec, 10) package usearch import ( "errors" + "fmt" + "runtime" "unsafe" ) @@ -12,22 +32,37 @@ import ( */ import "C" -// Key represents the type for keys used in the USearch index. +// Key represents a unique identifier for vectors in the index. +// Keys must be unique within an index; adding a vector with an existing +// key will update the associated vector. type Key = uint64 -// Metric represents the type for different metrics used in distance calculations. +// Metric defines the distance calculation method used for comparing vectors. +// Different metrics are suitable for different use cases: +// - Cosine: Normalized dot product, ideal for text embeddings +// - L2sq: Squared Euclidean distance, for spatial data +// - InnerProduct: Dot product, for recommendation systems type Metric uint8 // Different metric kinds supported by the USearch library. const ( + // InnerProduct computes the dot product between vectors InnerProduct Metric = iota + // Cosine computes cosine similarity (normalized dot product) Cosine + // L2sq computes squared Euclidean distance L2sq + // Haversine computes great-circle distance for geographic coordinates Haversine + // Divergence computes Jensen-Shannon divergence Divergence + // Pearson computes Pearson correlation coefficient Pearson + // Hamming computes Hamming distance for binary data Hamming + // Tanimoto computes Tanimoto/Jaccard coefficient Tanimoto + // Sorensen computes Sørensen-Dice coefficient Sorensen ) @@ -53,7 +88,7 @@ func (m Metric) String() string { case Sorensen: return "sorensen" default: - panic("Unknown metric") + panic("unknown metric") } } func (m Metric) CValue() C.usearch_metric_kind_t { @@ -78,16 +113,23 @@ func (m Metric) CValue() C.usearch_metric_kind_t { return C.usearch_metric_l2sq_k } -// Quantization represents the type for different scalar kinds used in quantization. +// Quantization represents the scalar type used for storing vectors in the index. +// Different quantization types offer different trade-offs between memory usage and precision. type Quantization uint8 // Different quantization kinds supported by the USearch library. const ( + // F32 uses 32-bit floating point (standard precision) F32 Quantization = iota + // BF16 uses brain floating-point format (16-bit) BF16 + // F16 uses half-precision floating point (16-bit) F16 + // F64 uses 64-bit double precision floating point F64 + // I8 uses 8-bit signed integers (quantized) I8 + // B1 uses binary representation (1-bit per dimension) B1 ) @@ -107,26 +149,55 @@ func (a Quantization) String() string { case B1: return "B1" default: - panic("Unknown quantization") + panic("unknown quantization") + } +} + +func (a Quantization) CValue() C.usearch_scalar_kind_t { + switch a { + case F16: + return C.usearch_scalar_f16_k + case F32: + return C.usearch_scalar_f32_k + case F64: + return C.usearch_scalar_f64_k + case I8: + return C.usearch_scalar_i8_k + case B1: + return C.usearch_scalar_b1_k + case BF16: + return C.usearch_scalar_bf16_k + default: + return C.usearch_scalar_unknown_k } } // IndexConfig represents the configuration options for initializing a USearch index. +// +// Zero values for optional parameters (Connectivity, ExpansionAdd, ExpansionSearch) +// will be replaced with optimal defaults by the C library. type IndexConfig struct { Quantization Quantization // The scalar kind used for quantization of vector data during indexing. Metric Metric // The metric kind used for distance calculation between vectors. Dimensions uint // The number of dimensions in the vectors to be indexed. - Connectivity uint // The optional connectivity parameter that limits connections-per-node in the graph. - ExpansionAdd uint // The optional expansion factor used for index construction when adding vectors. - ExpansionSearch uint // The optional expansion factor used for index construction during search operations. + Connectivity uint // The optional connectivity parameter that limits connections-per-node in the graph (0 for default). + ExpansionAdd uint // The optional expansion factor used for index construction when adding vectors (0 for default). + ExpansionSearch uint // The optional expansion factor used for index construction during search operations (0 for default). Multi bool // Indicates whether multiple vectors can map to the same key. } // DefaultConfig returns an IndexConfig with default values for the specified number of dimensions. +// Uses Cosine metric and F32 quantization by default. +// +// Example: +// +// config := usearch.DefaultConfig(128) // 128-dimensional vectors +// index, err := usearch.NewIndex(config) func DefaultConfig(dimensions uint) IndexConfig { c := IndexConfig{} c.Dimensions = dimensions c.Metric = Cosine + c.Quantization = F32 // Zeros will be replaced by the underlying C implementation c.Connectivity = 0 c.ExpansionAdd = 0 @@ -135,14 +206,33 @@ func DefaultConfig(dimensions uint) IndexConfig { return c } -// Index represents a USearch index. +// Index represents a USearch approximate nearest neighbor index. +// It implements io.Closer for idiomatic resource cleanup. +// +// The index must be properly initialized with NewIndex() and destroyed +// with Destroy() or Close() when no longer needed to free resources. type Index struct { - opaque_handle *C.void - config IndexConfig + handle C.usearch_index_t + config IndexConfig } -// NewIndex initializes a new instance of the index with the specified configuration. +// NewIndex creates a new approximate nearest neighbor index with the specified configuration. +// +// The index must be destroyed with Destroy() when no longer needed. +// +// Example: +// +// config := usearch.DefaultConfig(128) +// config.Metric = usearch.L2sq +// index, err := usearch.NewIndex(config) +// if err != nil { +// log.Fatal(err) +// } +// defer index.Destroy() func NewIndex(conf IndexConfig) (index *Index, err error) { + if conf.Dimensions == 0 { + return nil, errors.New("dimensions must be greater than 0") + } index = &Index{config: conf} conf = index.config @@ -161,20 +251,7 @@ func NewIndex(conf IndexConfig) (index *Index, err error) { options.metric_kind = conf.Metric.CValue() // Map the quantization method - switch conf.Quantization { - case F16: - options.quantization = C.usearch_scalar_f16_k - case F32: - options.quantization = C.usearch_scalar_f32_k - case F64: - options.quantization = C.usearch_scalar_f64_k - case I8: - options.quantization = C.usearch_scalar_i8_k - case B1: - options.quantization = C.usearch_scalar_b1_k - default: - options.quantization = C.usearch_scalar_unknown_k - } + options.quantization = conf.Quantization.CValue() var errorMessage *C.char ptr := C.usearch_init(&options, (*C.usearch_error_t)(&errorMessage)) @@ -182,14 +259,14 @@ func NewIndex(conf IndexConfig) (index *Index, err error) { return nil, errors.New(C.GoString(errorMessage)) } - index.opaque_handle = (*C.void)(unsafe.Pointer(ptr)) + index.handle = ptr return index, nil } // Len returns the number of vectors in the index. func (index *Index) Len() (len uint, err error) { var errorMessage *C.char - len = uint(C.usearch_size((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + len = uint(C.usearch_size(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -199,7 +276,7 @@ func (index *Index) Len() (len uint, err error) { // SerializedLength reports the expected file size after serialization. func (index *Index) SerializedLength() (len uint, err error) { var errorMessage *C.char - len = uint(C.usearch_serialized_length((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + len = uint(C.usearch_serialized_length(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -209,7 +286,7 @@ func (index *Index) SerializedLength() (len uint, err error) { // MemoryUsage reports the memory usage of the index func (index *Index) MemoryUsage() (len uint, err error) { var errorMessage *C.char - len = uint(C.usearch_memory_usage((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + len = uint(C.usearch_memory_usage(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -219,7 +296,7 @@ func (index *Index) MemoryUsage() (len uint, err error) { // ExpansionAdd returns the expansion value used during index creation func (index *Index) ExpansionAdd() (val uint, err error) { var errorMessage *C.char - val = uint(C.usearch_expansion_add((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + val = uint(C.usearch_expansion_add(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -229,7 +306,7 @@ func (index *Index) ExpansionAdd() (val uint, err error) { // ExpansionSearch returns the expansion value used during search func (index *Index) ExpansionSearch() (val uint, err error) { var errorMessage *C.char - val = uint(C.usearch_expansion_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + val = uint(C.usearch_expansion_search(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -239,7 +316,7 @@ func (index *Index) ExpansionSearch() (val uint, err error) { // ChangeExpansionAdd sets the expansion value used during index creation func (index *Index) ChangeExpansionAdd(val uint) error { var errorMessage *C.char - C.usearch_change_expansion_add((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), C.size_t(val), (*C.usearch_error_t)(&errorMessage)) + C.usearch_change_expansion_add(index.handle, C.size_t(val), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -249,27 +326,29 @@ func (index *Index) ChangeExpansionAdd(val uint) error { // ChangeExpansionSearch sets the expansion value used during search func (index *Index) ChangeExpansionSearch(val uint) error { var errorMessage *C.char - C.usearch_change_expansion_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), C.size_t(val), (*C.usearch_error_t)(&errorMessage)) + C.usearch_change_expansion_search(index.handle, C.size_t(val), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } return nil } -// ChangeThreadsAdd sets the threads limit for add +// ChangeThreadsAdd sets the maximum number of CPU threads used by the index +// during add/build operations. This controls internal parallelism for indexing. func (index *Index) ChangeThreadsAdd(val uint) error { var errorMessage *C.char - C.usearch_change_threads_add((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), C.size_t(val), (*C.usearch_error_t)(&errorMessage)) + C.usearch_change_threads_add(index.handle, C.size_t(val), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } return nil } -// ChangeThreadsSearch sets the threads limit for search +// ChangeThreadsSearch sets the maximum number of CPU threads used by the index +// during search operations. This controls internal parallelism for queries. func (index *Index) ChangeThreadsSearch(val uint) error { var errorMessage *C.char - C.usearch_change_threads_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), C.size_t(val), (*C.usearch_error_t)(&errorMessage)) + C.usearch_change_threads_search(index.handle, C.size_t(val), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -279,7 +358,7 @@ func (index *Index) ChangeThreadsSearch(val uint) error { // Connectivity returns the connectivity parameter of the index. func (index *Index) Connectivity() (con uint, err error) { var errorMessage *C.char - con = uint(C.usearch_connectivity((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + con = uint(C.usearch_connectivity(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -289,7 +368,7 @@ func (index *Index) Connectivity() (con uint, err error) { // Dimensions returns the number of dimensions of the vectors in the index. func (index *Index) Dimensions() (dim uint, err error) { var errorMessage *C.char - dim = uint(C.usearch_dimensions((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + dim = uint(C.usearch_dimensions(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -299,7 +378,7 @@ func (index *Index) Dimensions() (dim uint, err error) { // Capacity returns the capacity (maximum number of vectors) of the index. func (index *Index) Capacity() (cap uint, err error) { var errorMessage *C.char - cap = uint(C.usearch_capacity((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage))) + cap = uint(C.usearch_capacity(index.handle, (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { err = errors.New(C.GoString(errorMessage)) } @@ -310,7 +389,7 @@ func (index *Index) Capacity() (cap uint, err error) { func (index *Index) HardwareAcceleration() (string, error) { var str *C.char var errorMessage *C.char - str = C.usearch_hardware_acceleration((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage)) + str = C.usearch_hardware_acceleration(index.handle, (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return C.GoString(nil), errors.New(C.GoString(errorMessage)) } @@ -319,42 +398,91 @@ func (index *Index) HardwareAcceleration() (string, error) { // Destroy frees the resources associated with the index. func (index *Index) Destroy() error { - if index.opaque_handle == nil { - panic("Index is uninitialized") + if index.handle == nil { + panic("index is uninitialized") } var errorMessage *C.char - C.usearch_free((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (*C.usearch_error_t)(&errorMessage)) + C.usearch_free(index.handle, (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } - index.opaque_handle = nil + index.handle = nil index.config = IndexConfig{} return nil } +// Close implements io.Closer interface and calls Destroy() to free resources. +// This provides idiomatic Go resource cleanup that can be used with defer statements. +func (index *Index) Close() error { + return index.Destroy() +} + // Reserve reserves memory for a specified number of incoming vectors. func (index *Index) Reserve(capacity uint) error { - if index.opaque_handle == nil { - panic("Index is uninitialized") + if index.handle == nil { + panic("index is uninitialized") } var errorMessage *C.char - C.usearch_reserve((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.size_t)(capacity), (*C.usearch_error_t)(&errorMessage)) + C.usearch_reserve(index.handle, (C.size_t)(capacity), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } return nil } -// Add adds a vector with a specified key to the index. +// Add inserts or updates a vector in the index with the specified key. +// The vector must have exactly Dimensions() elements. +// If a vector with this key already exists, it will be replaced. +// +// Returns an error if: +// - The index is not initialized +// - The vector is empty or has wrong dimensions +// - The underlying C library reports an error func (index *Index) Add(key Key, vec []float32) error { - if index.opaque_handle == nil { - panic("Index is uninitialized") + if index.handle == nil { + panic("index is uninitialized") + } + + if len(vec) == 0 { + return errors.New("vector cannot be empty") + } + if uint(len(vec)) != index.config.Dimensions { + return fmt.Errorf("vector dimension mismatch: got %d, expected %d", len(vec), index.config.Dimensions) + } + + var errorMessage *C.char + C.usearch_add(index.handle, (C.usearch_key_t)(key), unsafe.Pointer(&vec[0]), C.usearch_scalar_f32_k, (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(vec) + if errorMessage != nil { + return errors.New(C.GoString(errorMessage)) + } + return nil +} + +// AddUnsafe adds a vector using a raw pointer, bypassing Go's type safety. +// +// SAFETY REQUIREMENTS: +// - vec must not be nil +// - Memory at vec must contain exactly Dimensions() scalars +// - Scalar type must match index.config.Quantization +// - Memory must remain valid for the duration of the call +// - Caller is responsible for ensuring correct data layout +// +// Use Add() or AddI8() instead unless you need maximum performance +// and understand the safety implications. +func (index *Index) AddUnsafe(key Key, vec unsafe.Pointer) error { + if index.handle == nil { + panic("index is uninitialized") + } + + if vec == nil { + return errors.New("vector pointer cannot be nil") } var errorMessage *C.char - C.usearch_add((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), unsafe.Pointer(&vec[0]), C.usearch_scalar_f32_k, (*C.usearch_error_t)(&errorMessage)) + C.usearch_add(index.handle, (C.usearch_key_t)(key), vec, index.config.Quantization.CValue(), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -363,12 +491,12 @@ func (index *Index) Add(key Key, vec []float32) error { // Remove removes the vector associated with the given key from the index. func (index *Index) Remove(key Key) error { - if index.opaque_handle == nil { - panic("Index is uninitialized") + if index.handle == nil { + panic("index is uninitialized") } var errorMessage *C.char - C.usearch_remove((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), (*C.usearch_error_t)(&errorMessage)) + C.usearch_remove(index.handle, (C.usearch_key_t)(key), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -377,12 +505,12 @@ func (index *Index) Remove(key Key) error { // Contains checks if the index contains a vector with a specific key. func (index *Index) Contains(key Key) (found bool, err error) { - if index.opaque_handle == nil { - panic("Index is uninitialized") + if index.handle == nil { + panic("index is uninitialized") } var errorMessage *C.char - found = bool(C.usearch_contains((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), (*C.usearch_error_t)(&errorMessage))) + found = bool(C.usearch_contains(index.handle, (C.usearch_key_t)(key), (*C.usearch_error_t)(&errorMessage))) if errorMessage != nil { return found, errors.New(C.GoString(errorMessage)) } @@ -390,14 +518,20 @@ func (index *Index) Contains(key Key) (found bool, err error) { } // Get retrieves the vectors associated with the given key from the index. -func (index *Index) Get(key Key, count uint) (vectors []float32, err error) { - if index.opaque_handle == nil { - panic("Index is uninitialized") +// Returns nil if the key is not found. +func (index *Index) Get(key Key, maxCount uint) (vectors []float32, err error) { + if index.handle == nil { + panic("index is uninitialized") } - vectors = make([]float32, index.config.Dimensions*count) + if maxCount == 0 { + return nil, nil + } + + vectors = make([]float32, index.config.Dimensions*maxCount) var errorMessage *C.char - found := uint(C.usearch_get((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), (C.size_t)(count), unsafe.Pointer(&vectors[0]), C.usearch_scalar_f32_k, (*C.usearch_error_t)(&errorMessage))) + found := uint(C.usearch_get(index.handle, (C.usearch_key_t)(key), (C.size_t)(maxCount), unsafe.Pointer(&vectors[0]), C.usearch_scalar_f32_k, (*C.usearch_error_t)(&errorMessage))) + runtime.KeepAlive(vectors) if errorMessage != nil { return nil, errors.New(C.GoString(errorMessage)) } @@ -410,118 +544,418 @@ func (index *Index) Get(key Key, count uint) (vectors []float32, err error) { // Rename the vector at key from to key to func (index *Index) Rename(from Key, to Key) error { var errorMessage *C.char - C.usearch_rename((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), C.usearch_key_t(from), C.usearch_key_t(to), (*C.usearch_error_t)(&errorMessage)) + C.usearch_rename(index.handle, C.usearch_key_t(from), C.usearch_key_t(to), (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } return nil } -// Distance computes the distance between two vectors -func Distance(vec1 []float32, vec2 []float32, dims uint, metric Metric) (float32, error) { +// Distance computes the distance between two float32 vectors using the specified metric. +// Both vectors must have exactly 'dims' elements. +func Distance(vec1 []float32, vec2 []float32, vectorDimensions uint, metric Metric) (float32, error) { + if len(vec1) == 0 || len(vec2) == 0 { + return 0, errors.New("vectors cannot be empty") + } + if uint(len(vec1)) < vectorDimensions || uint(len(vec2)) < vectorDimensions { + return 0, fmt.Errorf("vectors too short for specified dimensions: need %d elements", vectorDimensions) + } var errorMessage *C.char - dist := C.usearch_distance(unsafe.Pointer(&vec1[0]), unsafe.Pointer(&vec2[0]), C.usearch_scalar_f32_k, C.size_t(dims), metric.CValue(), (*C.usearch_error_t)(&errorMessage)) + dist := C.usearch_distance(unsafe.Pointer(&vec1[0]), unsafe.Pointer(&vec2[0]), C.usearch_scalar_f32_k, C.size_t(vectorDimensions), metric.CValue(), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(vec1) + runtime.KeepAlive(vec2) if errorMessage != nil { return 0, errors.New(C.GoString(errorMessage)) } return float32(dist), nil } -// Search performs k-Approximate Nearest Neighbors Search for the closest vectors to the query vector. +// DistanceUnsafe computes the distance between two vectors using unsafe pointers. +// +// SAFETY REQUIREMENTS: +// - vec1 and vec2 must not be nil +// - Memory at both pointers must contain exactly 'dims' scalars +// - Scalar type must match the specified quantization +// - Memory must remain valid for the duration of the call +func DistanceUnsafe(vec1 unsafe.Pointer, vec2 unsafe.Pointer, vectorDimensions uint, metric Metric, quantization Quantization) (float32, error) { + if vec1 == nil || vec2 == nil { + return 0, errors.New("vector pointers cannot be nil") + } + if vectorDimensions == 0 { + return 0, errors.New("dimensions must be greater than zero") + } + + var errorMessage *C.char + dist := C.usearch_distance(vec1, vec2, quantization.CValue(), C.size_t(vectorDimensions), metric.CValue(), (*C.usearch_error_t)(&errorMessage)) + if errorMessage != nil { + return 0, errors.New(C.GoString(errorMessage)) + } + return float32(dist), nil +} + +// Search finds the k nearest neighbors to the query vector. +// +// Parameters: +// - query: Must have exactly Dimensions() elements +// - limit: Maximum number of results to return +// +// Returns: +// - keys: IDs of the nearest vectors (up to limit) +// - distances: Distance to each result (same length as keys) +// - err: Error if query is invalid or search fails +// +// The actual number of results may be less than limit if the index +// contains fewer vectors. func (index *Index) Search(query []float32, limit uint) (keys []Key, distances []float32, err error) { - if index.opaque_handle == nil { - panic("Index is uninitialized") + if index.handle == nil { + panic("index is uninitialized") + } + + if len(query) == 0 { + return nil, nil, errors.New("query vector cannot be empty") + } + if uint(len(query)) != index.config.Dimensions { + return nil, nil, fmt.Errorf("query dimension mismatch: got %d, expected %d", len(query), index.config.Dimensions) } - if len(query) != int(index.config.Dimensions) { - return nil, nil, errors.New("Number of dimensions doesn't match!") + if limit == 0 { + return []Key{}, []float32{}, nil } keys = make([]Key, limit) distances = make([]float32, limit) var errorMessage *C.char - count := uint(C.usearch_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&query[0]), C.usearch_scalar_f32_k, (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) + resultCount := uint(C.usearch_search(index.handle, unsafe.Pointer(&query[0]), C.usearch_scalar_f32_k, (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) + runtime.KeepAlive(query) + runtime.KeepAlive(keys) + runtime.KeepAlive(distances) if errorMessage != nil { return nil, nil, errors.New(C.GoString(errorMessage)) } - keys = keys[:count] - distances = distances[:count] + keys = keys[:resultCount] + distances = distances[:resultCount] return keys, distances, nil } -// ExactSearch is a multithreaded exact nearest neighbors search -func ExactSearch(dataset []float32, queries []float32, dataset_size uint, queries_size uint, - dataset_stride uint, queries_stride uint, dims uint, metric Metric, - count uint, threads uint, keys_stride uint, distances_stride uint) (keys []Key, distances []float32, err error) { - if (len(dataset) % int(dims)) != 0 { - return nil, nil, errors.New("Dataset length must be a multiple of the dimensions") +// SearchUnsafe performs k-Approximate Nearest Neighbors Search using an unsafe pointer. +// +// SAFETY REQUIREMENTS: +// - query must not be nil +// - Memory at query must contain exactly Dimensions() scalars +// - Scalar type must match index.config.Quantization +// - Memory must remain valid for the duration of the call +// +// Use Search() or SearchI8() instead unless you need maximum performance +// and understand the safety implications. +func (index *Index) SearchUnsafe(query unsafe.Pointer, limit uint) (keys []Key, distances []float32, err error) { + if index.handle == nil { + panic("index is uninitialized") } - if (len(queries) % int(dims)) != 0 { - return nil, nil, errors.New("Queries length must be a multiple of the dimensions") + + if query == nil { + return nil, nil, errors.New("query pointer cannot be nil") + } + if limit == 0 { + return []Key{}, []float32{}, nil } - keys = make([]Key, count) - distances = make([]float32, count) + keys = make([]Key, limit) + distances = make([]float32, limit) var errorMessage *C.char - C.usearch_exact_search(unsafe.Pointer(&dataset[0]), C.size_t(dataset_size), C.size_t(dataset_stride), unsafe.Pointer(&queries[0]), C.size_t(queries_size), C.size_t(queries_stride), - C.usearch_scalar_f32_k, C.size_t(dims), metric.CValue(), C.size_t(count), C.size_t(threads), - (*C.usearch_key_t)(&keys[0]), C.size_t(keys_stride), (*C.usearch_distance_t)(&distances[0]), C.size_t(distances_stride), (*C.usearch_error_t)(&errorMessage)) + resultCount := uint(C.usearch_search(index.handle, query, index.config.Quantization.CValue(), (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) + runtime.KeepAlive(keys) + runtime.KeepAlive(distances) if errorMessage != nil { return nil, nil, errors.New(C.GoString(errorMessage)) } - keys = keys[:count] - distances = distances[:count] + keys = keys[:resultCount] + distances = distances[:resultCount] return keys, distances, nil } -// Save saves the index to a specified buffer. +// ExactSearch performs multithreaded exact nearest neighbors search. +// Unlike the index-based search, this computes distances to all vectors in the dataset. +// +// Parameters: +// - dataset: Flattened array of vectors (datasetSize x vectorDimensions) +// - queries: Flattened array of query vectors (queryCount x vectorDimensions) +// - datasetSize, queryCount: Number of vectors in dataset and queries +// - datasetStride, queryStride: Memory stride in bytes between consecutive vectors (use vectorDimensions * sizeof(float32) for contiguous data) +// - vectorDimensions: Number of dimensions per vector +// - metric: Distance metric to use +// - maxResults: Maximum results per query +// - numThreads: Number of threads to use (0 for auto-detection) +func ExactSearch(dataset []float32, queries []float32, datasetSize uint, queryCount uint, + datasetStride uint, queryStride uint, vectorDimensions uint, metric Metric, + maxResults uint, numThreads uint, resultKeysStride uint, resultDistancesStride uint) (keys []Key, distances []float32, err error) { + + if len(dataset) == 0 || len(queries) == 0 { + return nil, nil, errors.New("dataset and queries cannot be empty") + } + if vectorDimensions == 0 { + return nil, nil, errors.New("dimensions must be greater than zero") + } + if (len(dataset) % int(vectorDimensions)) != 0 { + return nil, nil, errors.New("dataset length must be a multiple of the dimensions") + } + if (len(queries) % int(vectorDimensions)) != 0 { + return nil, nil, errors.New("queries length must be a multiple of the dimensions") + } + + keys = make([]Key, maxResults) + distances = make([]float32, maxResults) + var errorMessage *C.char + C.usearch_exact_search(unsafe.Pointer(&dataset[0]), C.size_t(datasetSize), C.size_t(datasetStride), unsafe.Pointer(&queries[0]), C.size_t(queryCount), C.size_t(queryStride), + C.usearch_scalar_f32_k, C.size_t(vectorDimensions), metric.CValue(), C.size_t(maxResults), C.size_t(numThreads), + (*C.usearch_key_t)(&keys[0]), C.size_t(resultKeysStride), (*C.usearch_distance_t)(&distances[0]), C.size_t(resultDistancesStride), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(dataset) + runtime.KeepAlive(queries) + runtime.KeepAlive(keys) + runtime.KeepAlive(distances) + if errorMessage != nil { + return nil, nil, errors.New(C.GoString(errorMessage)) + } + + keys = keys[:maxResults] + distances = distances[:maxResults] + return keys, distances, nil +} + +// ExactSearchUnsafe performs multithreaded exact nearest neighbors search using unsafe pointers. +// +// SAFETY REQUIREMENTS: +// - dataset and queries must not be nil +// - Memory must contain contiguous vectors of the specified quantization type +// - dataset must contain datasetSize vectors of vectorDimensions elements each +// - queries must contain queryCount vectors of vectorDimensions elements each +// - Memory must remain valid for the duration of the call +// +// Stride parameters specify memory offset in bytes between consecutive vectors. +// For contiguous data, use vectorDimensions * sizeof(element_type). +func ExactSearchUnsafe(dataset unsafe.Pointer, queries unsafe.Pointer, datasetSize uint, queryCount uint, + datasetStride uint, queryStride uint, vectorDimensions uint, metric Metric, quantization Quantization, + maxResults uint, numThreads uint, resultKeysStride uint, resultDistancesStride uint) (keys []Key, distances []float32, err error) { + + if dataset == nil || queries == nil { + return nil, nil, errors.New("dataset and queries pointers cannot be nil") + } + if vectorDimensions == 0 || datasetSize == 0 || queryCount == 0 { + return nil, nil, errors.New("dimensions and sizes must be greater than zero") + } + + keys = make([]Key, maxResults) + distances = make([]float32, maxResults) + var errorMessage *C.char + C.usearch_exact_search(dataset, C.size_t(datasetSize), C.size_t(datasetStride), queries, C.size_t(queryCount), C.size_t(queryStride), + quantization.CValue(), C.size_t(vectorDimensions), metric.CValue(), C.size_t(maxResults), C.size_t(numThreads), + (*C.usearch_key_t)(&keys[0]), C.size_t(resultKeysStride), (*C.usearch_distance_t)(&distances[0]), C.size_t(resultDistancesStride), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(keys) + runtime.KeepAlive(distances) + if errorMessage != nil { + return nil, nil, errors.New(C.GoString(errorMessage)) + } + + keys = keys[:maxResults] + distances = distances[:maxResults] + return keys, distances, nil +} + +// Convenience I8 helpers + +// AddI8 adds an int8 vector to the index. +// The vector must have exactly Dimensions() elements. +// +// This is a convenience method for indexes using I8 quantization. +func (index *Index) AddI8(key Key, vec []int8) error { + if index.handle == nil { + panic("index is uninitialized") + } + if len(vec) == 0 { + return errors.New("vector cannot be empty") + } + if uint(len(vec)) != index.config.Dimensions { + return fmt.Errorf("vector dimension mismatch: got %d, expected %d", len(vec), index.config.Dimensions) + } + var errorMessage *C.char + C.usearch_add(index.handle, (C.usearch_key_t)(key), unsafe.Pointer(&vec[0]), C.usearch_scalar_i8_k, (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(vec) + if errorMessage != nil { + return errors.New(C.GoString(errorMessage)) + } + return nil +} + +// SearchI8 searches for nearest neighbors using an int8 query vector. +// The query must have exactly Dimensions() elements. +// +// This is a convenience method for indexes using I8 quantization. +func (index *Index) SearchI8(query []int8, limit uint) (keys []Key, distances []float32, err error) { + if index.handle == nil { + panic("index is uninitialized") + } + if len(query) == 0 { + return nil, nil, errors.New("query vector cannot be empty") + } + if uint(len(query)) != index.config.Dimensions { + return nil, nil, fmt.Errorf("query dimension mismatch: got %d, expected %d", len(query), index.config.Dimensions) + } + if limit == 0 { + return []Key{}, []float32{}, nil + } + keys = make([]Key, limit) + distances = make([]float32, limit) + var errorMessage *C.char + resultCount := uint(C.usearch_search(index.handle, unsafe.Pointer(&query[0]), C.usearch_scalar_i8_k, (C.size_t)(limit), (*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage))) + runtime.KeepAlive(query) + runtime.KeepAlive(keys) + runtime.KeepAlive(distances) + if errorMessage != nil { + return nil, nil, errors.New(C.GoString(errorMessage)) + } + keys = keys[:resultCount] + distances = distances[:resultCount] + return keys, distances, nil +} + +// DistanceI8 computes the distance between two int8 vectors. +// +// Example: +// +// vec1 := []int8{1, 2, 3, 4} +// vec2 := []int8{5, 6, 7, 8} +// dist, err := usearch.DistanceI8(vec1, vec2, 4, usearch.L2sq) +func DistanceI8(vec1 []int8, vec2 []int8, vectorDimensions uint, metric Metric) (float32, error) { + if len(vec1) == 0 || len(vec2) == 0 { + return 0, errors.New("vectors cannot be empty") + } + if uint(len(vec1)) < vectorDimensions || uint(len(vec2)) < vectorDimensions { + return 0, fmt.Errorf("vectors too short for specified dimensions: need %d elements", vectorDimensions) + } + var errorMessage *C.char + dist := C.usearch_distance(unsafe.Pointer(&vec1[0]), unsafe.Pointer(&vec2[0]), C.usearch_scalar_i8_k, C.size_t(vectorDimensions), metric.CValue(), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(vec1) + runtime.KeepAlive(vec2) + if errorMessage != nil { + return 0, errors.New(C.GoString(errorMessage)) + } + return float32(dist), nil +} + +// ExactSearchI8 performs exact nearest neighbors search on int8 vectors. +// This computes distances to all vectors in the dataset without using an index. +// +// Stride parameters specify memory offset in bytes between consecutive vectors. +// For contiguous int8 data, use vectorDimensions * 1 byte. +func ExactSearchI8(dataset []int8, queries []int8, datasetSize uint, queryCount uint, + datasetStride uint, queryStride uint, vectorDimensions uint, metric Metric, + maxResults uint, numThreads uint, resultKeysStride uint, resultDistancesStride uint) (keys []Key, distances []float32, err error) { + + if len(dataset) == 0 || len(queries) == 0 { + return nil, nil, errors.New("dataset and queries cannot be empty") + } + if vectorDimensions == 0 { + return nil, nil, errors.New("dimensions must be greater than zero") + } + + keys = make([]Key, maxResults) + distances = make([]float32, maxResults) + var errorMessage *C.char + C.usearch_exact_search(unsafe.Pointer(&dataset[0]), C.size_t(datasetSize), C.size_t(datasetStride), unsafe.Pointer(&queries[0]), C.size_t(queryCount), C.size_t(queryStride), + C.usearch_scalar_i8_k, C.size_t(vectorDimensions), metric.CValue(), C.size_t(maxResults), C.size_t(numThreads), + (*C.usearch_key_t)(&keys[0]), C.size_t(resultKeysStride), (*C.usearch_distance_t)(&distances[0]), C.size_t(resultDistancesStride), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(dataset) + runtime.KeepAlive(queries) + runtime.KeepAlive(keys) + runtime.KeepAlive(distances) + if errorMessage != nil { + return nil, nil, errors.New(C.GoString(errorMessage)) + } + keys = keys[:maxResults] + distances = distances[:maxResults] + return keys, distances, nil +} + +// SaveBuffer serializes the index into a byte buffer. +// The buffer must be large enough to hold the serialized index. +// Use SerializedLength() to determine the required buffer size. func (index *Index) SaveBuffer(buf []byte, buffer_size uint) error { - if index.opaque_handle == nil { - panic("Index is uninitialized") + if index.handle == nil { + panic("index is uninitialized") + } + + if len(buf) == 0 { + return errors.New("buffer cannot be empty") + } + if uint(len(buf)) < buffer_size { + return fmt.Errorf("buffer too small: has %d bytes, need %d", len(buf), buffer_size) } var errorMessage *C.char - C.usearch_save_buffer((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&buf[0]), C.size_t(buffer_size), (*C.usearch_error_t)(&errorMessage)) + C.usearch_save_buffer(index.handle, unsafe.Pointer(&buf[0]), C.size_t(buffer_size), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(buf) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } return nil } -// Loads the index from a specified buffer. +// LoadBuffer loads a serialized index from a byte buffer. +// The buffer must contain a valid serialized index. func (index *Index) LoadBuffer(buf []byte, buffer_size uint) error { - if index.opaque_handle == nil { - panic("Index is uninitialized") + if index.handle == nil { + panic("index is uninitialized") + } + + if len(buf) == 0 { + return errors.New("buffer cannot be empty") + } + if uint(len(buf)) < buffer_size { + return fmt.Errorf("buffer too small: has %d bytes, need %d", len(buf), buffer_size) } var errorMessage *C.char - C.usearch_load_buffer((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&buf[0]), C.size_t(buffer_size), (*C.usearch_error_t)(&errorMessage)) + C.usearch_load_buffer(index.handle, unsafe.Pointer(&buf[0]), C.size_t(buffer_size), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(buf) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } return nil } -// Loads the index from a specified buffer without copying the data. +// ViewBuffer creates a view of a serialized index without copying the data. +// The buffer must remain valid for the lifetime of the index. +// Changes to the buffer will affect the index. func (index *Index) ViewBuffer(buf []byte, buffer_size uint) error { - if index.opaque_handle == nil { - panic("Index is uninitialized") + if index.handle == nil { + panic("index is uninitialized") + } + + if len(buf) == 0 { + return errors.New("buffer cannot be empty") + } + if uint(len(buf)) < buffer_size { + return fmt.Errorf("buffer too small: has %d bytes, need %d", len(buf), buffer_size) } var errorMessage *C.char - C.usearch_view_buffer((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), unsafe.Pointer(&buf[0]), C.size_t(buffer_size), (*C.usearch_error_t)(&errorMessage)) + C.usearch_view_buffer(index.handle, unsafe.Pointer(&buf[0]), C.size_t(buffer_size), (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(buf) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } return nil } -// Loads the metadata from a specified buffer. +// MetadataBuffer extracts index configuration metadata from a serialized buffer. +// This can be used to inspect an index before loading it. func MetadataBuffer(buf []byte, buffer_size uint) (c IndexConfig, err error) { - if buf == nil { - panic("Buffer is uninitialized") + if len(buf) == 0 { + return c, errors.New("buffer cannot be empty") + } + if uint(len(buf)) < buffer_size { + return c, fmt.Errorf("buffer too small: has %d bytes, need %d", len(buf), buffer_size) } c = IndexConfig{} @@ -529,6 +963,7 @@ func MetadataBuffer(buf []byte, buffer_size uint) (c IndexConfig, err error) { var errorMessage *C.char C.usearch_metadata_buffer(unsafe.Pointer(&buf[0]), C.size_t(buffer_size), &options, (*C.usearch_error_t)(&errorMessage)) + runtime.KeepAlive(buf) if errorMessage != nil { return c, errors.New(C.GoString(errorMessage)) } @@ -576,8 +1011,12 @@ func MetadataBuffer(buf []byte, buffer_size uint) (c IndexConfig, err error) { return c, nil } -// Metadata loads the metadata from a specified file. +// Metadata loads the index configuration metadata from a file. +// This can be used to inspect an index file before loading it. func Metadata(path string) (c IndexConfig, err error) { + if path == "" { + return c, errors.New("path cannot be empty") + } c_path := C.CString(path) defer C.free(unsafe.Pointer(c_path)) @@ -635,15 +1074,15 @@ func Metadata(path string) (c IndexConfig, err error) { // Save saves the index to a specified file. func (index *Index) Save(path string) error { - if index.opaque_handle == nil { - panic("Index is uninitialized") + if index.handle == nil { + panic("index is uninitialized") } c_path := C.CString(path) defer C.free(unsafe.Pointer(c_path)) var errorMessage *C.char - C.usearch_save((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), c_path, (*C.usearch_error_t)(&errorMessage)) + C.usearch_save((C.usearch_index_t)(unsafe.Pointer(index.handle)), c_path, (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -652,15 +1091,15 @@ func (index *Index) Save(path string) error { // Load loads the index from a specified file. func (index *Index) Load(path string) error { - if index.opaque_handle == nil { - panic("Index is uninitialized") + if index.handle == nil { + panic("index is uninitialized") } c_path := C.CString(path) defer C.free(unsafe.Pointer(c_path)) var errorMessage *C.char - C.usearch_load((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), c_path, (*C.usearch_error_t)(&errorMessage)) + C.usearch_load((C.usearch_index_t)(unsafe.Pointer(index.handle)), c_path, (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } @@ -669,15 +1108,15 @@ func (index *Index) Load(path string) error { // View creates a view of the index from a specified file without loading it into memory. func (index *Index) View(path string) error { - if index.opaque_handle == nil { - panic("Index is uninitialized") + if index.handle == nil { + panic("index is uninitialized") } c_path := C.CString(path) defer C.free(unsafe.Pointer(c_path)) var errorMessage *C.char - C.usearch_view((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), c_path, (*C.usearch_error_t)(&errorMessage)) + C.usearch_view((C.usearch_index_t)(unsafe.Pointer(index.handle)), c_path, (*C.usearch_error_t)(&errorMessage)) if errorMessage != nil { return errors.New(C.GoString(errorMessage)) } diff --git a/golang/lib_test.go b/golang/lib_test.go index fb30a9016..6409dce6c 100644 --- a/golang/lib_test.go +++ b/golang/lib_test.go @@ -1,229 +1,877 @@ package usearch import ( + "fmt" + "io" "math" "runtime" + "sync" "testing" + "unsafe" ) -func TestUSearch(t *testing.T) { +// Test constants +const ( + defaultTestDimensions = 128 + distanceTolerance = 1e-2 + bufferSize = 1024 * 1024 +) + +// Helper functions to reduce code duplication + +func createTestIndex(t *testing.T, dimensions uint, quantization Quantization) *Index { + conf := DefaultConfig(dimensions) + conf.Quantization = quantization + index, err := NewIndex(conf) + if err != nil { + t.Fatalf("Failed to create test index: %v", err) + } + return index +} + +func generateTestVector(dimensions uint) []float32 { + vector := make([]float32, dimensions) + for i := uint(0); i < dimensions; i++ { + vector[i] = float32(i) + 0.1 + } + return vector +} + +func generateTestVectorI8(dimensions uint) []int8 { + vector := make([]int8, dimensions) + for i := uint(0); i < dimensions; i++ { + vector[i] = int8((i % 127) + 1) + } + return vector +} + +func populateIndex(t *testing.T, index *Index, vectorCount int) [][]float32 { + vectors := make([][]float32, vectorCount) + err := index.Reserve(uint(vectorCount)) + if err != nil { + t.Fatalf("Failed to reserve capacity: %v", err) + } + + dimensions, err := index.Dimensions() + if err != nil { + t.Fatalf("Failed to get dimensions: %v", err) + } + + for i := 0; i < vectorCount; i++ { + vector := generateTestVector(dimensions) + vector[0] = float32(i) // Make each vector unique + vectors[i] = vector + + err = index.Add(Key(i), vector) + if err != nil { + t.Fatalf("Failed to add vector %d: %v", i, err) + } + } + return vectors +} + +// Core functionality tests (improved versions of existing) + +func TestIndexLifecycle(t *testing.T) { runtime.LockOSThread() + defer runtime.UnlockOSThread() + + t.Run("Index creation and configuration", func(t *testing.T) { + dimensions := uint(64) + index := createTestIndex(t, dimensions, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() - t.Run("Test Index Initialization", func(t *testing.T) { - dim := uint(128) - conf := DefaultConfig(dim) - ind, err := NewIndex(conf) + // Verify dimensions + actualDimensions, err := index.Dimensions() if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + t.Fatalf("Failed to retrieve dimensions: %v", err) + } + if actualDimensions != dimensions { + t.Fatalf("Expected %d dimensions, got %d", dimensions, actualDimensions) } - defer ind.Destroy() - found_dims, err := ind.Dimensions() + // Verify empty index + size, err := index.Len() if err != nil { - t.Fatalf("Failed to retrieve dimensions: %s", err) + t.Fatalf("Failed to retrieve size: %v", err) } - if found_dims != dim { - t.Fatalf("Expected %d dimensions, got %d", dim, found_dims) + if size != 0 { + t.Fatalf("Expected empty index, got size %d", size) } - found_len, err := ind.Len() + // Capacity may be zero before any reservation; ensure Reserve works + if err := index.Reserve(10); err != nil { + t.Fatalf("Failed to reserve capacity: %v", err) + } + capacity, err := index.Capacity() if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) + t.Fatalf("Failed to retrieve capacity: %v", err) + } + if capacity < 10 { + t.Fatalf("Expected capacity >= 10 after reserve, got %d", capacity) + } + + // Verify memory usage + memUsage, err := index.MemoryUsage() + if err != nil { + t.Fatalf("Failed to retrieve memory usage: %v", err) + } + if memUsage == 0 { + t.Fatalf("Expected positive memory usage") + } + + // Verify hardware acceleration info + hwAccel, err := index.HardwareAcceleration() + if err != nil { + t.Fatalf("Failed to retrieve hardware acceleration: %v", err) + } + if hwAccel == "" { + t.Fatalf("Expected non-empty hardware acceleration string") + } + }) + + t.Run("Index configuration validation", func(t *testing.T) { + // Test different configurations + configs := []struct { + name string + dimensions uint + quantization Quantization + metric Metric + }{ + {"F32-Cosine", 128, F32, Cosine}, + {"F64-L2sq", 64, F64, L2sq}, + {"I8-InnerProduct", 32, I8, InnerProduct}, + } + + for _, config := range configs { + t.Run(config.name, func(t *testing.T) { + conf := DefaultConfig(config.dimensions) + conf.Quantization = config.quantization + conf.Metric = config.metric + + index, err := NewIndex(conf) + if err != nil { + t.Fatalf("Failed to create index with config %s: %v", config.name, err) + } + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + actualDims, err := index.Dimensions() + if err != nil || actualDims != config.dimensions { + t.Fatalf("Configuration mismatch for %s", config.name) + } + }) } - if found_len != 0 { - t.Fatalf("Expected size to be 0, got %d", found_len) + }) +} + +func TestBasicOperations(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + t.Run("Add and retrieve", func(t *testing.T) { + index := createTestIndex(t, defaultTestDimensions, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + // Ensure capacity before first add + if err := index.Reserve(1); err != nil { + t.Fatalf("Failed to reserve capacity: %v", err) } - found_len, err = ind.SerializedLength() + // Add a vector + vector := generateTestVector(defaultTestDimensions) + vector[0] = 42.0 + vector[1] = 24.0 + + err := index.Add(100, vector) + if err != nil { + t.Fatalf("Failed to add vector: %v", err) + } + + // Verify index size + size, err := index.Len() if err != nil { - t.Fatalf("Failed to retrieve serialized length: %s", err) + t.Fatalf("Failed to get index size: %v", err) } - if found_len != 112 { - t.Fatalf("Expected serialized length to be 112, got %d", found_len) + if size != 1 { + t.Fatalf("Expected size 1, got %d", size) } - err = ind.Reserve(100) + // Test Contains + found, err := index.Contains(100) if err != nil { - t.Fatalf("Failed to reserve capacity: %s", err) + t.Fatalf("Contains check failed: %v", err) + } + if !found { + t.Fatalf("Expected to find key 100") } - mem, err := ind.MemoryUsage() + // Test Get + retrieved, err := index.Get(100, 1) if err != nil { - t.Fatalf("Failed to retrieve serialized length: %s", err) + t.Fatalf("Failed to retrieve vector: %v", err) } - if mem == 0 { - t.Fatalf("Expected the empty index memory usage to be positive, got zero") + if retrieved == nil || len(retrieved) != int(defaultTestDimensions) { + t.Fatalf("Retrieved vector has wrong dimensions") } + }) - s, err := ind.HardwareAcceleration() + t.Run("Search functionality", func(t *testing.T) { + index := createTestIndex(t, defaultTestDimensions, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + // Add test data + testVectors := populateIndex(t, index, 10) + + // Search with first vector (should find itself) + keys, distances, err := index.Search(testVectors[0], 5) if err != nil { - t.Fatalf("Failed to retrieve hardware acceleration: %s", err) + t.Fatalf("Search failed: %v", err) } - if s == "" { - t.Fatalf("An empty string was returned from HardwareAcceleration") + + if len(keys) == 0 || len(distances) == 0 { + t.Fatalf("Search returned no results") } + // First result should be the exact match with near-zero distance + if keys[0] != 0 { + t.Fatalf("Expected first result to be key 0, got %d", keys[0]) + } + + if math.Abs(float64(distances[0])) > distanceTolerance { + t.Fatalf("Expected near-zero distance for exact match, got %f", distances[0]) + } }) - t.Run("Test Insertion", func(t *testing.T) { - dim := uint(128) - conf := DefaultConfig(dim) - ind, err := NewIndex(conf) + t.Run("Remove operations", func(t *testing.T) { + index := createTestIndex(t, defaultTestDimensions, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + // Add vectors + populateIndex(t, index, 5) + + // Remove one vector + err := index.Remove(2) + if err != nil { + t.Fatalf("Failed to remove vector: %v", err) + } + + // Verify it's gone + found, err := index.Contains(2) if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + t.Fatalf("Contains check failed after removal: %v", err) + } + if found { + t.Fatalf("Key 2 should have been removed") } - defer ind.Destroy() - err = ind.Reserve(100) + // Verify size decreased + size, err := index.Len() if err != nil { - t.Fatalf("Failed to reserve capacity: %s", err) + t.Fatalf("Failed to get size after removal: %v", err) } + if size != 4 { + t.Fatalf("Expected size 4 after removal, got %d", size) + } + }) +} + +func TestIOCloser(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + t.Run("io.Closer interface compliance", func(t *testing.T) { + index := createTestIndex(t, 32, F32) + + // Verify that Index can be used as io.Closer + var closer io.Closer = index - err = ind.ChangeThreadsAdd(10) + // Test Close method works like Destroy + err := closer.Close() if err != nil { - t.Fatalf("Failed to change threads add: %s", err) + t.Fatalf("Close failed: %v", err) } + }) +} - vec := make([]float32, dim) - vec[0] = 40.0 - vec[1] = 2.0 +func TestSerialization(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + t.Run("Buffer save/load/view operations", func(t *testing.T) { + // Create and populate original index + originalIndex := createTestIndex(t, defaultTestDimensions, F32) + defer func() { + if err := originalIndex.Destroy(); err != nil { + t.Errorf("Failed to destroy original index: %v", err) + } + }() - err = ind.Add(42, vec) + testVectors := populateIndex(t, originalIndex, 50) + + originalSize, err := originalIndex.Len() if err != nil { - t.Fatalf("Failed to insert: %s", err) + t.Fatalf("Failed to get original index size: %v", err) } - found_len, err := ind.Len() + // Save to buffer + buf := make([]byte, bufferSize) + err = originalIndex.SaveBuffer(buf, bufferSize) if err != nil { - t.Fatalf("Failed to retrieve size after insertion: %s", err) + t.Fatalf("Failed to save index to buffer: %v", err) } - if found_len != 1 { - t.Fatalf("Expected size to be 1, got %d", found_len) + + // Test metadata extraction + metadata, err := MetadataBuffer(buf, bufferSize) + if err != nil { + t.Fatalf("Failed to extract metadata: %v", err) + } + + if metadata.Dimensions != defaultTestDimensions { + t.Fatalf("Metadata dimensions mismatch: expected %d, got %d", + defaultTestDimensions, metadata.Dimensions) } - }) - t.Run("Test Search", func(t *testing.T) { - dim := uint(128) - conf := DefaultConfig(dim) - ind, err := NewIndex(conf) + // Test LoadBuffer + loadedIndex := createTestIndex(t, defaultTestDimensions, F32) + defer func() { + if err := loadedIndex.Destroy(); err != nil { + t.Errorf("Failed to destroy loaded index: %v", err) + } + }() + + err = loadedIndex.LoadBuffer(buf, bufferSize) if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + t.Fatalf("Failed to load index from buffer: %v", err) } - defer ind.Destroy() - err = ind.Reserve(100) + loadedSize, err := loadedIndex.Len() if err != nil { - t.Fatalf("Failed to reserve capacity: %s", err) + t.Fatalf("Failed to get loaded index size: %v", err) + } + + if loadedSize != originalSize { + t.Fatalf("Loaded index size mismatch: expected %d, got %d", + originalSize, loadedSize) } - err = ind.ChangeThreadsSearch(10) + // Verify search results are consistent + keys, distances, err := loadedIndex.Search(testVectors[0], 3) if err != nil { - t.Fatalf("Failed to change threads search: %s", err) + t.Fatalf("Search failed on loaded index: %v", err) } - vec := make([]float32, dim) - vec[0] = 40.0 - vec[1] = 2.0 + if len(keys) == 0 || keys[0] != 0 { + t.Fatalf("Loaded index search results inconsistent") + } + + // Verify distance is near zero for exact match + if math.Abs(float64(distances[0])) > distanceTolerance { + t.Fatalf("Expected near-zero distance for exact match, got %f", distances[0]) + } - err = ind.Add(42, vec) + // Test ViewBuffer + viewIndex := createTestIndex(t, defaultTestDimensions, F32) + defer func() { + if err := viewIndex.Destroy(); err != nil { + t.Errorf("Failed to destroy view index: %v", err) + } + }() + + err = viewIndex.ViewBuffer(buf, bufferSize) if err != nil { - t.Fatalf("Failed to insert: %s", err) + t.Fatalf("Failed to create view from buffer: %v", err) } - keys, distances, err := ind.Search(vec, 10) + viewSize, err := viewIndex.Len() if err != nil { - t.Fatalf("Failed to search: %s", err) + t.Fatalf("Failed to get view index size: %v", err) + } + + if viewSize != originalSize { + t.Fatalf("View index size mismatch: expected %d, got %d", + originalSize, viewSize) + } + }) +} + +func TestInputValidation(t *testing.T) { + t.Run("Zero dimensions", func(t *testing.T) { + conf := DefaultConfig(0) + _, err := NewIndex(conf) + if err == nil { + t.Fatalf("Expected error for zero dimensions") + } + }) + + t.Run("Empty vectors", func(t *testing.T) { + index := createTestIndex(t, 64, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + // Test Add with empty vector + err := index.Add(1, []float32{}) + if err == nil { + t.Fatalf("Expected error for empty vector in Add") } - const tolerance = 1e-2 // For example, this sets the tolerance to 0.01 - if keys[0] != 42 || math.Abs(float64(distances[0])) > tolerance { - t.Fatalf("Expected result 42 with distance 0, got key %d with distance %f", keys[0], distances[0]) + // Test Search with empty vector + _, _, err = index.Search([]float32{}, 10) + if err == nil { + t.Fatalf("Expected error for empty vector in Search") + } + }) + + t.Run("Dimension mismatches", func(t *testing.T) { + index := createTestIndex(t, 64, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + // Test Add with wrong dimensions + wrongVec := make([]float32, 32) // Should be 64 + err := index.Add(1, wrongVec) + if err == nil { + t.Fatalf("Expected error for dimension mismatch in Add") + } + + // Test Search with wrong dimensions + _, _, err = index.Search(wrongVec, 10) + if err == nil { + t.Fatalf("Expected error for dimension mismatch in Search") + } + }) + + t.Run("Nil pointers", func(t *testing.T) { + index := createTestIndex(t, 64, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + // Test AddUnsafe with nil pointer + err := index.AddUnsafe(1, nil) + if err == nil { + t.Fatalf("Expected error for nil pointer in AddUnsafe") + } + + // Test SearchUnsafe with nil pointer + _, _, err = index.SearchUnsafe(nil, 10) + if err == nil { + t.Fatalf("Expected error for nil pointer in SearchUnsafe") + } + }) + + t.Run("Buffer validation", func(t *testing.T) { + index := createTestIndex(t, 64, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + // Test SaveBuffer with empty buffer + err := index.SaveBuffer([]byte{}, 100) + if err == nil { + t.Fatalf("Expected error for empty buffer in SaveBuffer") } - // TODO: Add exact search + // Test LoadBuffer with empty buffer + err = index.LoadBuffer([]byte{}, 100) + if err == nil { + t.Fatalf("Expected error for empty buffer in LoadBuffer") + } }) +} + +func TestQuantizationTypes(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + t.Run("F32 operations", func(t *testing.T) { + index := createTestIndex(t, 32, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() - t.Run("Test Save and Load", func(t *testing.T) { - dim := uint(128) - conf := DefaultConfig(dim) - ind, err := NewIndex(conf) + if err := index.Reserve(1); err != nil { + t.Fatalf("Failed to reserve capacity: %v", err) + } + vector := generateTestVector(32) + err := index.Add(1, vector) if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + t.Fatalf("F32 Add failed: %v", err) } - defer ind.Destroy() - ind2, err := NewIndex(conf) + + keys, _, err := index.Search(vector, 1) if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + t.Fatalf("F32 Search failed: %v", err) } - defer ind2.Destroy() - indView, err := NewIndex(conf) + + if len(keys) == 0 || keys[0] != 1 { + t.Fatalf("F32 search results incorrect") + } + }) + + t.Run("F64 operations", func(t *testing.T) { + index := createTestIndex(t, 32, F64) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + if err := index.Reserve(1); err != nil { + t.Fatalf("Failed to reserve capacity: %v", err) + } + vector := make([]float64, 32) + for i := range vector { + vector[i] = float64(i) + 0.5 + } + + err := index.AddUnsafe(1, unsafe.Pointer(&vector[0])) if err != nil { - t.Fatalf("Failed to construct the index: %s", err) + t.Fatalf("F64 AddUnsafe failed: %v", err) } - defer indView.Destroy() - err = ind.Reserve(100) + keys, _, err := index.SearchUnsafe(unsafe.Pointer(&vector[0]), 1) if err != nil { - t.Fatalf("Failed to reserve capacity: %s", err) + t.Fatalf("F64 SearchUnsafe failed: %v", err) } - vec := make([]float32, dim) - for i := uint(0); i < dim; i++ { - vec[i] = float32(i) + 0.2 - err = ind.Add(uint64(i), vec) - if err != nil { - t.Fatalf("Failed to insert: %s", err) + if len(keys) == 0 || keys[0] != 1 { + t.Fatalf("F64 search results incorrect") + } + }) + + t.Run("I8 operations", func(t *testing.T) { + index := createTestIndex(t, 32, I8) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) } + }() + + if err := index.Reserve(1); err != nil { + t.Fatalf("Failed to reserve capacity: %v", err) + } + vector := generateTestVectorI8(32) + err := index.AddI8(1, vector) + if err != nil { + t.Fatalf("I8 Add failed: %v", err) } - ind_length, err := ind.Len() + keys, _, err := index.SearchI8(vector, 1) if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) + t.Fatalf("I8 Search failed: %v", err) + } + + if len(keys) == 0 || keys[0] != 1 { + t.Fatalf("I8 search results incorrect") } + }) +} + +func TestUnsafeOperations(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + t.Run("Unsafe pointer operations", func(t *testing.T) { + index := createTestIndex(t, 64, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + if err := index.Reserve(1); err != nil { + t.Fatalf("Failed to reserve capacity: %v", err) + } + vector := generateTestVector(64) + ptr := unsafe.Pointer(&vector[0]) - // TODO: Add invalid save and loads? - buffer_size := uint(1 * 1024 * 1024) - buf := make([]byte, buffer_size) - err = ind.SaveBuffer(buf, buffer_size) + // Test AddUnsafe + err := index.AddUnsafe(100, ptr) if err != nil { - t.Fatalf("Failed to save the index to a buffer: %s", err) + t.Fatalf("AddUnsafe failed: %v", err) } - err = ind2.LoadBuffer(buf, buffer_size) + // Verify vector was added + size, err := index.Len() if err != nil { - t.Fatalf("Failed to load the index from a buffer: %s", err) + t.Fatalf("Failed to get size after AddUnsafe: %v", err) + } + if size != 1 { + t.Fatalf("Expected size 1 after AddUnsafe, got %d", size) } - ind2_length, err := ind2.Len() + // Test SearchUnsafe + keys, distances, err := index.SearchUnsafe(ptr, 5) if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) + t.Fatalf("SearchUnsafe failed: %v", err) } - if ind_length != ind2_length { - t.Fatalf("Loaded index length %d doesn't match original of %d ", ind2_length, ind_length) + + if len(keys) == 0 || keys[0] != 100 { + t.Fatalf("SearchUnsafe returned incorrect results") } - // TODO: Check some values - err = indView.ViewBuffer(buf, buffer_size) + if math.Abs(float64(distances[0])) > distanceTolerance { + t.Fatalf("Expected near-zero distance for exact match, got %f", distances[0]) + } + }) +} + +func TestConcurrentInsertions(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + t.Run("Parallelized insertions via internal threads", func(t *testing.T) { + index := createTestIndex(t, 64, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + const totalVectors = 1000 + + err := index.Reserve(totalVectors) if err != nil { - t.Fatalf("Failed to load the view from a buffer: %s", err) + t.Fatalf("Failed to reserve capacity: %v", err) + } + + // Let the library parallelize inserts internally + _ = index.ChangeThreadsAdd(uint(runtime.NumCPU())) + + for i := 0; i < totalVectors; i++ { + vector := generateTestVector(64) + vector[0] = float32(i) + if err := index.Add(Key(i), vector); err != nil { + t.Fatalf("Insertion failed at %d: %v", i, err) + } + } + + // Verify final count + finalSize, err := index.Len() + if err != nil { + t.Fatalf("Failed to get final size: %v", err) + } + + if finalSize != totalVectors { + t.Fatalf("Expected %d vectors after concurrent insertions, got %d", + totalVectors, finalSize) + } + }) +} + +func TestConcurrentSearches(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + t.Run("Multiple concurrent searches", func(t *testing.T) { + index := createTestIndex(t, 64, F32) + defer func() { + if err := index.Destroy(); err != nil { + t.Errorf("Failed to destroy index: %v", err) + } + }() + + // Pre-populate with data + testVectors := populateIndex(t, index, 200) + + // Let the library parallelize search internally as well + _ = index.ChangeThreadsSearch(uint(runtime.NumCPU())) + + const numGoroutines = 30 + const searchesPerGoroutine = 50 + + var wg sync.WaitGroup + errorChan := make(chan error, numGoroutines) + + // Only concurrent searches - no mixed operations + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(goroutineID int) { + defer wg.Done() + + for j := 0; j < searchesPerGoroutine; j++ { + // Use different query vectors + queryIndex := (goroutineID*searchesPerGoroutine + j) % len(testVectors) + query := testVectors[queryIndex] + + keys, distances, err := index.Search(query, 10) + if err != nil { + errorChan <- err + return + } + + // Basic validation - should find at least the exact match + if len(keys) == 0 || len(distances) == 0 { + errorChan <- fmt.Errorf("search returned empty results") + return + } + + // First result should be the exact match + if keys[0] != Key(queryIndex) || math.Abs(float64(distances[0])) > distanceTolerance { + errorChan <- fmt.Errorf("search results inconsistent: expected key %d, got %d", queryIndex, keys[0]) + return + } + } + }(i) + } + + wg.Wait() + close(errorChan) + + // Check for any errors + for err := range errorChan { + t.Fatalf("Concurrent search failed: %v", err) } + }) +} + +func TestExactSearch(t *testing.T) { + t.Run("Float32 exact search", func(t *testing.T) { + // Create dataset and queries + const datasetSize = 100 + const querySize = 10 + const vectorDims = 32 + + dataset := make([]float32, datasetSize*vectorDims) + queries := make([]float32, querySize*vectorDims) + + // Fill with test data + for i := 0; i < len(dataset); i++ { + dataset[i] = float32(i%100) + 0.1 + } + + for i := 0; i < len(queries); i++ { + queries[i] = float32(i%50) + 0.1 + } + + keys, distances, err := ExactSearch( + dataset, queries, + datasetSize, querySize, + vectorDims*4, vectorDims*4, // Stride in bytes for float32 + vectorDims, Cosine, + 5, 0, // maxResults=5, numThreads=0 (auto) + 8, 4, // resultKeysStride, resultDistancesStride + ) - indView_length, err := indView.Len() if err != nil { - t.Fatalf("Failed to retrieve size: %s", err) + t.Fatalf("ExactSearch failed: %v", err) } - if ind_length != indView_length { - t.Fatalf("Loaded view length %d doesn't match original of %d ", indView_length, ind_length) + + if len(keys) != 5 || len(distances) != 5 { + t.Fatalf("Expected 5 results from ExactSearch, got %d keys and %d distances", + len(keys), len(distances)) + } + }) + + t.Run("I8 exact search", func(t *testing.T) { + const datasetSize = 50 + const querySize = 5 + const vectorDims = 16 + + dataset := make([]int8, datasetSize*vectorDims) + queries := make([]int8, querySize*vectorDims) + + // Fill with test data + for i := 0; i < len(dataset); i++ { + dataset[i] = int8((i % 100) + 1) + } + + for i := 0; i < len(queries); i++ { + queries[i] = int8((i % 50) + 1) } - conf, err = MetadataBuffer(buf, buffer_size) + keys, distances, err := ExactSearchI8( + dataset, queries, + datasetSize, querySize, + vectorDims, vectorDims, // Stride in bytes for int8 + vectorDims, L2sq, + 3, 0, // maxResults=3, numThreads=0 (auto) + 8, 4, // resultKeysStride, resultDistancesStride + ) + if err != nil { - t.Fatalf("Failed to load the metadata from a buffer: %s", err) + t.Fatalf("ExactSearchI8 failed: %v", err) } - if conf != ind.config { - t.Fatalf("Loaded metadata doesn't match the index metadata") + + if len(keys) != 3 || len(distances) != 3 { + t.Fatalf("Expected 3 results from ExactSearchI8, got %d keys and %d distances", + len(keys), len(distances)) } + }) +} + +func TestDistanceCalculations(t *testing.T) { + t.Run("Float32 distance calculations", func(t *testing.T) { + vec1 := []float32{1.0, 0.0, 0.0} + vec2 := []float32{0.0, 1.0, 0.0} - // TODO: Check file save/load/metadata + // Test different metrics + metrics := []struct { + metric Metric + expected float32 + tolerance float32 + }{ + {Cosine, 1.0, 0.01}, // Perpendicular vectors + {L2sq, 2.0, 0.01}, // Squared Euclidean distance + } + + for _, test := range metrics { + distance, err := Distance(vec1, vec2, 3, test.metric) + if err != nil { + t.Fatalf("Distance calculation failed for %v: %v", test.metric, err) + } + + if math.Abs(float64(distance-test.expected)) > float64(test.tolerance) { + t.Fatalf("Distance mismatch for %v: expected %f, got %f", + test.metric, test.expected, distance) + } + } + }) + + t.Run("I8 distance calculations", func(t *testing.T) { + vec1 := []int8{10, 0, 0} + vec2 := []int8{0, 10, 0} + + distance, err := DistanceI8(vec1, vec2, 3, L2sq) + if err != nil { + t.Fatalf("DistanceI8 failed: %v", err) + } + + expected := float32(200.0) // 10^2 + 10^2 = 200 + if math.Abs(float64(distance-expected)) > 0.1 { + t.Fatalf("I8 distance mismatch: expected %f, got %f", expected, distance) + } }) } diff --git a/java/cloud/unum/usearch/Index.java b/java/cloud/unum/usearch/Index.java index f942aaf56..c15958437 100644 --- a/java/cloud/unum/usearch/Index.java +++ b/java/cloud/unum/usearch/Index.java @@ -274,7 +274,34 @@ public void reserve(long capacity) { if (c_ptr == 0) { throw new IllegalStateException("Index already closed"); } - c_reserve(c_ptr, capacity); + // Pass zeros to use all available contexts on the device + c_reserve(c_ptr, capacity, 0, 0); + } + + /** + * Reserves memory and configures thread contexts. + * Use this to explicitly control concurrent add/search capacities. + * + * @param capacity desired total capacity + * @param threadsAdd maximum concurrent add contexts + * @param threadsSearch maximum concurrent search contexts + */ + public void reserve(long capacity, long threadsAdd, long threadsSearch) { + if (c_ptr == 0) { + throw new IllegalStateException("Index already closed"); + } + c_reserve(c_ptr, capacity, threadsAdd, threadsSearch); + } + + /** + * Reserves memory and sets the same number of contexts + * for both add and search operations. + * + * @param capacity desired total capacity + * @param threads number of contexts for both add and search + */ + public void reserve(long capacity, long threads) { + reserve(capacity, threads, threads); } /** @@ -1021,7 +1048,7 @@ private static native long c_create( private static native long c_capacity(long ptr); - private static native void c_reserve(long ptr, long capacity); + private static native void c_reserve(long ptr, long capacity, long threadsAdd, long threadsSearch); private static native void c_save(long ptr, String path); diff --git a/java/cloud/unum/usearch/cloud_unum_usearch_Index.cpp b/java/cloud/unum/usearch/cloud_unum_usearch_Index.cpp index bd6a8d42e..685bdc8d8 100644 --- a/java/cloud/unum/usearch/cloud_unum_usearch_Index.cpp +++ b/java/cloud/unum/usearch/cloud_unum_usearch_Index.cpp @@ -143,8 +143,22 @@ JNIEXPORT jlong JNICALL Java_cloud_unum_usearch_Index_c_1capacity(JNIEnv*, jclas return reinterpret_cast(c_ptr)->capacity(); } -JNIEXPORT void JNICALL Java_cloud_unum_usearch_Index_c_1reserve(JNIEnv* env, jclass, jlong c_ptr, jlong capacity) { - if (!reinterpret_cast(c_ptr)->try_reserve(static_cast(capacity))) { +JNIEXPORT void JNICALL Java_cloud_unum_usearch_Index_c_1reserve(JNIEnv* env, jclass, jlong c_ptr, jlong capacity, + jlong threads_add, jlong threads_search) { + std::size_t t_add = static_cast(threads_add); + std::size_t t_search = static_cast(threads_search); + if (t_add == 0 || t_search == 0) { + std::size_t hc = std::thread::hardware_concurrency(); + if (hc == 0) + hc = 1; // fallback to 1 if the runtime can't report + if (t_add == 0) + t_add = hc; + if (t_search == 0) + t_search = hc; + } + index_limits_t limits(static_cast(capacity), t_add); + limits.threads_search = t_search; + if (!reinterpret_cast(c_ptr)->try_reserve(limits)) { jclass jc = (*env).FindClass("java/lang/Error"); if (jc) (*env).ThrowNew(jc, "Failed to grow vector index!"); diff --git a/java/cloud/unum/usearch/cloud_unum_usearch_Index.h b/java/cloud/unum/usearch/cloud_unum_usearch_Index.h index b2494c29f..16cbc2110 100644 --- a/java/cloud/unum/usearch/cloud_unum_usearch_Index.h +++ b/java/cloud/unum/usearch/cloud_unum_usearch_Index.h @@ -66,10 +66,10 @@ JNIEXPORT jlong JNICALL Java_cloud_unum_usearch_Index_c_1capacity /* * Class: cloud_unum_usearch_Index * Method: c_reserve - * Signature: (JJ)V + * Signature: (JJJJ)V */ JNIEXPORT void JNICALL Java_cloud_unum_usearch_Index_c_1reserve - (JNIEnv *, jclass, jlong, jlong); + (JNIEnv *, jclass, jlong, jlong, jlong, jlong); /* * Class: cloud_unum_usearch_Index diff --git a/java/test/IndexTest.java b/java/test/IndexTest.java index 455c2cd2f..817ed6334 100644 --- a/java/test/IndexTest.java +++ b/java/test/IndexTest.java @@ -267,13 +267,14 @@ public void testGetIntoBufferMethods() { @Test public void testConcurrentAdd() throws Exception { try (Index index = new Index.Config().metric("cos").dimensions(4).build()) { - index.reserve(1000); + final int threadsCount = 10; + index.reserve(1000, threadsCount); - ExecutorService executor = Executors.newFixedThreadPool(10); + ExecutorService executor = Executors.newFixedThreadPool(threadsCount); @SuppressWarnings("unchecked") - CompletableFuture[] futures = new CompletableFuture[10]; + CompletableFuture[] futures = new CompletableFuture[threadsCount]; - for (int t = 0; t < 10; t++) { + for (int t = 0; t < threadsCount; t++) { final int threadId = t; futures[t] = CompletableFuture.runAsync( @@ -290,25 +291,26 @@ public void testConcurrentAdd() throws Exception { CompletableFuture.allOf(futures).get(10, TimeUnit.SECONDS); executor.shutdown(); - assertEquals(500, index.size()); + assertEquals(50L * threadsCount, index.size()); } } @Test public void testConcurrentSearch() throws Exception { try (Index index = new Index.Config().metric("cos").dimensions(4).build()) { - index.reserve(100); + final int threadsCount = 5; + index.reserve(100, threadsCount); // Add some vectors first for (int i = 0; i < 100; i++) { index.add(i, randomVector(4)); } - ExecutorService executor = Executors.newFixedThreadPool(5); + ExecutorService executor = Executors.newFixedThreadPool(threadsCount); @SuppressWarnings("unchecked") - CompletableFuture[] futures = new CompletableFuture[5]; + CompletableFuture[] futures = new CompletableFuture[threadsCount]; - for (int t = 0; t < 5; t++) { + for (int t = 0; t < threadsCount; t++) { futures[t] = CompletableFuture.supplyAsync( () -> { @@ -328,56 +330,6 @@ public void testConcurrentSearch() throws Exception { } } - @Test - public void testMixedConcurrency() throws Exception { - try (Index index = new Index.Config().metric("cos").dimensions(3).build()) { - index.reserve(200); - - ExecutorService executor = Executors.newFixedThreadPool(8); - @SuppressWarnings("unchecked") - CompletableFuture[] addFutures = new CompletableFuture[4]; - @SuppressWarnings("unchecked") - CompletableFuture[] searchFutures = new CompletableFuture[4]; - - // Add operations - for (int t = 0; t < 4; t++) { - final int threadId = t; - addFutures[t] - = CompletableFuture.runAsync( - () -> { - for (int i = 0; i < 30; i++) { - long key = threadId * 30L + i; - index.add(key, randomVector(3)); - } - }, - executor); - } - - // Wait for some adds to complete, then start searches - Thread.sleep(100); - - // Search operations - for (int t = 0; t < 4; t++) { - searchFutures[t] - = CompletableFuture.runAsync( - () -> { - for (int i = 0; i < 10; i++) { - float[] queryVector = randomVector(3); - long[] results = index.search(queryVector, 5); - assertTrue(results.length >= 0); - } - }, - executor); - } - - CompletableFuture.allOf(addFutures).get(15, TimeUnit.SECONDS); - CompletableFuture.allOf(searchFutures).get(15, TimeUnit.SECONDS); - executor.shutdown(); - - assertEquals(120, index.size()); - } - } - @Test public void testBatchAdd() { try (Index index = new Index.Config().metric("cos").dimensions(2).build()) { @@ -751,30 +703,30 @@ public void testPlatformCapabilities() { String[] available = Index.hardwareAccelerationAvailable(); assertNotEquals("Available capabilities should not be null", null, available); assertTrue("Platform should have at least serial capability", available.length > 0); - + // Test compile-time capabilities String[] compiled = Index.hardwareAccelerationCompiled(); assertNotEquals("Compiled capabilities should not be null", null, compiled); assertTrue("Should have at least serial compiled", compiled.length > 0); - + // Should always include serial as baseline in both boolean hasAvailableSerial = false; boolean hasCompiledSerial = false; - + for (String cap : available) { if ("serial".equals(cap)) { hasAvailableSerial = true; break; } } - + for (String cap : compiled) { if ("serial".equals(cap)) { hasCompiledSerial = true; break; } } - + assertTrue("Platform should always support serial capability", hasAvailableSerial); assertTrue("Serial should always be compiled", hasCompiledSerial); diff --git a/javascript/usearch.ts b/javascript/usearch.ts index 63dba1ca9..109a07bd3 100644 --- a/javascript/usearch.ts +++ b/javascript/usearch.ts @@ -386,8 +386,8 @@ export class Index { * @throws Will throw an error if `k` is not a positive integer or if the size of the vectors is not a multiple of dimensions. * @throws Will throw an error if `vectors` is not a valid input type (TypedArray or an array of TypedArray) or if its flattened size is not a multiple of dimensions. */ - search(vectors: Vector, k: number, threads: number = 0): Matches; - search(vectors: Matrix, k: number, threads: number = 0): BatchMatches; + search(vectors: Vector, k: number, threads: number): Matches; + search(vectors: Matrix, k: number, threads: number): BatchMatches; search(vectors: VectorOrMatrix, k: number, threads: number = 0): Matches | BatchMatches { if ((!Number.isNaN(k) && typeof k !== "number") || k <= 0) { throw new Error(