Skip to content

Commit e02e25a

Browse files
authored
Merge branch 'main' into test-uv-docs
2 parents 8b2c816 + 6962073 commit e02e25a

File tree

3 files changed

+72
-32
lines changed

3 files changed

+72
-32
lines changed

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262

6363
autosummary_generate = True
6464
autodoc_member_order = "bysource"
65-
autodoc_mock_imports = ["cudf", "cuml", "cugraph", "cupy", "cupyx", "pylibraft"]
65+
autodoc_mock_imports = ["cudf", "cuml", "cugraph", "cupy", "cupyx", "pylibraft", "cuvs"]
6666
default_role = "literal"
6767
napoleon_google_docstring = False
6868
napoleon_numpy_docstring = True

docs/release-notes/0.10.12.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
```{rubric} Features
44
```
5+
* use `cuvs` over `raft` for `pp.neighbors` for `rapids>=24.12`{pr}`304` {smaller}`S Dicks`
56
```{rubric} Performance
67
```
78
```{rubric} Bug fixes

src/rapids_singlecell/preprocessing/_neighbors.py

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
import cupy as cp
88
import numpy as np
9+
import pylibraft
910
from cuml.manifold.simpl_set import fuzzy_simplicial_set
1011
from cupyx.scipy import sparse as cp_sparse
12+
from packaging.version import parse as parse_version
1113
from pylibraft.common import DeviceResources
1214
from scipy import sparse as sc_sparse
1315

@@ -59,6 +61,10 @@
5961
_Metrics = _MetricsDense | _MetricsSparse
6062

6163

64+
def _cuvs_switch():
65+
return parse_version(pylibraft.__version__) > parse_version("24.10")
66+
67+
6268
def _brute_knn(
6369
X: cp_sparse.spmatrix | cp.ndarray,
6470
Y: cp_sparse.spmatrix | cp.ndarray,
@@ -83,81 +89,114 @@ def _brute_knn(
8389
def _cagra_knn(
8490
X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping
8591
) -> tuple[cp.ndarray, cp.ndarray]:
86-
try:
87-
from pylibraft.neighbors import cagra
88-
except ImportError:
89-
raise ImportError(
90-
"The 'cagra' module is not available in your current RAFT installation. "
91-
"Please update RAFT to a version that supports 'cagra'."
92-
)
92+
if not _cuvs_switch():
93+
try:
94+
from pylibraft.neighbors import cagra
95+
except ImportError:
96+
raise ImportError(
97+
"The 'cagra' module is not available in your current RAFT installation. "
98+
"Please update RAFT to a version that supports 'cagra'."
99+
)
100+
resources = DeviceResources()
101+
build_kwargs = {"handle": resources}
102+
search_kwargs = {"handle": resources}
103+
else:
104+
from cuvs.neighbors import cagra
105+
106+
resources = None
107+
build_kwargs = {}
108+
search_kwargs = {}
93109

94-
handle = DeviceResources()
95110
build_params = cagra.IndexParams(metric="sqeuclidean", build_algo="nn_descent")
96-
index = cagra.build(build_params, X, handle=handle)
111+
index = cagra.build(build_params, X, **build_kwargs)
97112

98113
n_samples = Y.shape[0]
99114
all_neighbors = cp.zeros((n_samples, k), dtype=cp.int32)
100115
all_distances = cp.zeros((n_samples, k), dtype=cp.float32)
101116

102117
batchsize = 65000
103-
n_batches = math.ceil(Y.shape[0] / batchsize)
118+
n_batches = math.ceil(n_samples / batchsize)
104119
for batch in range(n_batches):
105120
start_idx = batch * batchsize
106-
stop_idx = min(batch * batchsize + batchsize, Y.shape[0])
121+
stop_idx = min((batch + 1) * batchsize, n_samples)
107122
batch_Y = Y[start_idx:stop_idx, :]
123+
108124
search_params = cagra.SearchParams()
109125
distances, neighbors = cagra.search(
110-
search_params, index, batch_Y, k, handle=handle
126+
search_params, index, batch_Y, k, **search_kwargs
111127
)
112128
all_neighbors[start_idx:stop_idx, :] = cp.asarray(neighbors)
113129
all_distances[start_idx:stop_idx, :] = cp.asarray(distances)
114-
handle.sync()
130+
131+
if resources is not None:
132+
resources.sync()
133+
115134
if metric == "euclidean":
116135
all_distances = cp.sqrt(all_distances)
136+
117137
return all_neighbors, all_distances
118138

119139

120140
def _ivf_flat_knn(
121141
X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping
122142
) -> tuple[cp.ndarray, cp.ndarray]:
123-
from pylibraft.neighbors import ivf_flat
143+
if not _cuvs_switch():
144+
from pylibraft.neighbors import ivf_flat
124145

125-
handle = DeviceResources()
126-
if X.shape[0] < 2048:
127-
n_lists = X.shape[0] // 2
146+
resources = DeviceResources()
147+
build_kwargs = {"handle": resources} # pylibraft uses 'handle'
148+
search_kwargs = {"handle": resources}
128149
else:
129-
n_lists = 1024
130-
index_params = ivf_flat.IndexParams(n_lists=n_lists, metric=metric)
131-
index = ivf_flat.build(index_params, X, handle=handle)
150+
from cuvs.neighbors import ivf_flat
132151

152+
resources = None
153+
build_kwargs = {} # cuvs does not need handle/resources
154+
search_kwargs = {}
155+
156+
n_lists = int(math.sqrt(X.shape[0]))
157+
index_params = ivf_flat.IndexParams(n_lists=n_lists, metric=metric)
158+
index = ivf_flat.build(index_params, X, **build_kwargs)
133159
distances, neighbors = ivf_flat.search(
134-
ivf_flat.SearchParams(), index, Y, k, handle=handle
160+
ivf_flat.SearchParams(), index, Y, k, **search_kwargs
135161
)
162+
163+
if resources is not None:
164+
resources.sync()
165+
136166
distances = cp.asarray(distances)
137167
neighbors = cp.asarray(neighbors)
138-
handle.sync()
168+
139169
return neighbors, distances
140170

141171

142172
def _ivf_pq_knn(
143173
X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping
144174
) -> tuple[cp.ndarray, cp.ndarray]:
145-
from pylibraft.neighbors import ivf_pq
175+
if not _cuvs_switch():
176+
from pylibraft.neighbors import ivf_pq
146177

147-
handle = DeviceResources()
148-
if X.shape[0] < 2048:
149-
n_lists = X.shape[0] // 2
178+
resources = DeviceResources()
179+
build_kwargs = {"handle": resources}
180+
search_kwargs = {"handle": resources}
150181
else:
151-
n_lists = 1024
152-
index_params = ivf_pq.IndexParams(n_lists=n_lists, metric=metric)
153-
index = ivf_pq.build(index_params, X, handle=handle)
182+
from cuvs.neighbors import ivf_pq
183+
184+
resources = None
185+
build_kwargs = {}
186+
search_kwargs = {}
154187

188+
n_lists = int(math.sqrt(X.shape[0]))
189+
index_params = ivf_pq.IndexParams(n_lists=n_lists, metric=metric)
190+
index = ivf_pq.build(index_params, X, **build_kwargs)
155191
distances, neighbors = ivf_pq.search(
156-
ivf_pq.SearchParams(), index, Y, k, handle=handle
192+
ivf_pq.SearchParams(), index, Y, k, **search_kwargs
157193
)
194+
if resources is not None:
195+
resources.sync()
196+
158197
distances = cp.asarray(distances)
159198
neighbors = cp.asarray(neighbors)
160-
handle.sync()
199+
161200
return neighbors, distances
162201

163202

0 commit comments

Comments
 (0)