Skip to content

Commit d197e5c

Browse files
author
Frankie Robertson
committed
Add extra utils get_sparse_indices(...) and trunc_csr(...)
1 parent 6227e09 commit d197e5c

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

sklearn_ann/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,37 @@ def check_metric(metric, metrics):
77
raise ValueError(f"Unknown metric {metric!r}. Valid metrics are {metrics!r}")
88

99

10+
def get_sparse_indices(mat, idx):
11+
start_idx = mat.indptr[idx]
12+
end_idx = mat.indptr[idx + 1]
13+
return mat.indices[start_idx:end_idx]
14+
15+
1016
def get_sparse_row(mat, idx):
1117
start_idx = mat.indptr[idx]
1218
end_idx = mat.indptr[idx + 1]
1319
return zip(mat.indices[start_idx:end_idx], mat.data[start_idx:end_idx])
1420

1521

22+
def trunc_csr(csr, k):
23+
indptr = np.empty_like(csr.indptr)
24+
num_rows = len(csr.indptr) - 1
25+
indices = [None] * num_rows
26+
data = [None] * num_rows
27+
cur_indptr = 0
28+
for row_idx in range(num_rows):
29+
indptr[row_idx] = cur_indptr
30+
start_idx = csr.indptr[row_idx]
31+
old_end_idx = csr.indptr[row_idx + 1]
32+
end_idx = min(old_end_idx, start_idx + k)
33+
data[row_idx] = csr.data[start_idx:end_idx]
34+
indices[row_idx] = csr.indices[start_idx:end_idx]
35+
ptr_inc = min(k, old_end_idx - start_idx)
36+
cur_indptr = cur_indptr + ptr_inc
37+
indptr[-1] = cur_indptr
38+
return csr_matrix((np.concatenate(data), np.concatenate(indices), indptr))
39+
40+
1641
def or_else_csrs(csr1, csr2):
1742
# Possible TODO: Could use numba/Cython to speed this up?
1843
if csr1.shape != csr2.shape:

0 commit comments

Comments
 (0)