66
77import cupy as cp
88import numpy as np
9+ import pylibraft
910from cuml .manifold .simpl_set import fuzzy_simplicial_set
1011from cupyx .scipy import sparse as cp_sparse
12+ from packaging .version import parse as parse_version
1113from pylibraft .common import DeviceResources
1214from scipy import sparse as sc_sparse
1315
5961_Metrics = _MetricsDense | _MetricsSparse
6062
6163
64+ def _cuvs_switch ():
65+ return parse_version (pylibraft .__version__ ) > parse_version ("24.10" )
66+
67+
6268def _brute_knn (
6369 X : cp_sparse .spmatrix | cp .ndarray ,
6470 Y : cp_sparse .spmatrix | cp .ndarray ,
@@ -83,81 +89,114 @@ def _brute_knn(
8389def _cagra_knn (
8490 X : cp .ndarray , Y : cp .ndarray , k : int , metric : _Metrics , metric_kwds : Mapping
8591) -> tuple [cp .ndarray , cp .ndarray ]:
86- try :
87- from pylibraft .neighbors import cagra
88- except ImportError :
89- raise ImportError (
90- "The 'cagra' module is not available in your current RAFT installation. "
91- "Please update RAFT to a version that supports 'cagra'."
92- )
92+ if not _cuvs_switch ():
93+ try :
94+ from pylibraft .neighbors import cagra
95+ except ImportError :
96+ raise ImportError (
97+ "The 'cagra' module is not available in your current RAFT installation. "
98+ "Please update RAFT to a version that supports 'cagra'."
99+ )
100+ resources = DeviceResources ()
101+ build_kwargs = {"handle" : resources }
102+ search_kwargs = {"handle" : resources }
103+ else :
104+ from cuvs .neighbors import cagra
105+
106+ resources = None
107+ build_kwargs = {}
108+ search_kwargs = {}
93109
94- handle = DeviceResources ()
95110 build_params = cagra .IndexParams (metric = "sqeuclidean" , build_algo = "nn_descent" )
96- index = cagra .build (build_params , X , handle = handle )
111+ index = cagra .build (build_params , X , ** build_kwargs )
97112
98113 n_samples = Y .shape [0 ]
99114 all_neighbors = cp .zeros ((n_samples , k ), dtype = cp .int32 )
100115 all_distances = cp .zeros ((n_samples , k ), dtype = cp .float32 )
101116
102117 batchsize = 65000
103- n_batches = math .ceil (Y . shape [ 0 ] / batchsize )
118+ n_batches = math .ceil (n_samples / batchsize )
104119 for batch in range (n_batches ):
105120 start_idx = batch * batchsize
106- stop_idx = min (batch * batchsize + batchsize , Y . shape [ 0 ] )
121+ stop_idx = min (( batch + 1 ) * batchsize , n_samples )
107122 batch_Y = Y [start_idx :stop_idx , :]
123+
108124 search_params = cagra .SearchParams ()
109125 distances , neighbors = cagra .search (
110- search_params , index , batch_Y , k , handle = handle
126+ search_params , index , batch_Y , k , ** search_kwargs
111127 )
112128 all_neighbors [start_idx :stop_idx , :] = cp .asarray (neighbors )
113129 all_distances [start_idx :stop_idx , :] = cp .asarray (distances )
114- handle .sync ()
130+
131+ if resources is not None :
132+ resources .sync ()
133+
115134 if metric == "euclidean" :
116135 all_distances = cp .sqrt (all_distances )
136+
117137 return all_neighbors , all_distances
118138
119139
120140def _ivf_flat_knn (
121141 X : cp .ndarray , Y : cp .ndarray , k : int , metric : _Metrics , metric_kwds : Mapping
122142) -> tuple [cp .ndarray , cp .ndarray ]:
123- from pylibraft .neighbors import ivf_flat
143+ if not _cuvs_switch ():
144+ from pylibraft .neighbors import ivf_flat
124145
125- handle = DeviceResources ()
126- if X . shape [ 0 ] < 2048 :
127- n_lists = X . shape [ 0 ] // 2
146+ resources = DeviceResources ()
147+ build_kwargs = { "handle" : resources } # pylibraft uses 'handle'
148+ search_kwargs = { "handle" : resources }
128149 else :
129- n_lists = 1024
130- index_params = ivf_flat .IndexParams (n_lists = n_lists , metric = metric )
131- index = ivf_flat .build (index_params , X , handle = handle )
150+ from cuvs .neighbors import ivf_flat
132151
152+ resources = None
153+ build_kwargs = {} # cuvs does not need handle/resources
154+ search_kwargs = {}
155+
156+ n_lists = int (math .sqrt (X .shape [0 ]))
157+ index_params = ivf_flat .IndexParams (n_lists = n_lists , metric = metric )
158+ index = ivf_flat .build (index_params , X , ** build_kwargs )
133159 distances , neighbors = ivf_flat .search (
134- ivf_flat .SearchParams (), index , Y , k , handle = handle
160+ ivf_flat .SearchParams (), index , Y , k , ** search_kwargs
135161 )
162+
163+ if resources is not None :
164+ resources .sync ()
165+
136166 distances = cp .asarray (distances )
137167 neighbors = cp .asarray (neighbors )
138- handle . sync ()
168+
139169 return neighbors , distances
140170
141171
142172def _ivf_pq_knn (
143173 X : cp .ndarray , Y : cp .ndarray , k : int , metric : _Metrics , metric_kwds : Mapping
144174) -> tuple [cp .ndarray , cp .ndarray ]:
145- from pylibraft .neighbors import ivf_pq
175+ if not _cuvs_switch ():
176+ from pylibraft .neighbors import ivf_pq
146177
147- handle = DeviceResources ()
148- if X . shape [ 0 ] < 2048 :
149- n_lists = X . shape [ 0 ] // 2
178+ resources = DeviceResources ()
179+ build_kwargs = { "handle" : resources }
180+ search_kwargs = { "handle" : resources }
150181 else :
151- n_lists = 1024
152- index_params = ivf_pq .IndexParams (n_lists = n_lists , metric = metric )
153- index = ivf_pq .build (index_params , X , handle = handle )
182+ from cuvs .neighbors import ivf_pq
183+
184+ resources = None
185+ build_kwargs = {}
186+ search_kwargs = {}
154187
188+ n_lists = int (math .sqrt (X .shape [0 ]))
189+ index_params = ivf_pq .IndexParams (n_lists = n_lists , metric = metric )
190+ index = ivf_pq .build (index_params , X , ** build_kwargs )
155191 distances , neighbors = ivf_pq .search (
156- ivf_pq .SearchParams (), index , Y , k , handle = handle
192+ ivf_pq .SearchParams (), index , Y , k , ** search_kwargs
157193 )
194+ if resources is not None :
195+ resources .sync ()
196+
158197 distances = cp .asarray (distances )
159198 neighbors = cp .asarray (neighbors )
160- handle . sync ()
199+
161200 return neighbors , distances
162201
163202
0 commit comments