@@ -7,12 +7,37 @@ def check_metric(metric, metrics):
7
7
raise ValueError (f"Unknown metric { metric !r} . Valid metrics are { metrics !r} " )
8
8
9
9
10
+ def get_sparse_indices (mat , idx ):
11
+ start_idx = mat .indptr [idx ]
12
+ end_idx = mat .indptr [idx + 1 ]
13
+ return mat .indices [start_idx :end_idx ]
14
+
15
+
10
16
def get_sparse_row (mat , idx ):
11
17
start_idx = mat .indptr [idx ]
12
18
end_idx = mat .indptr [idx + 1 ]
13
19
return zip (mat .indices [start_idx :end_idx ], mat .data [start_idx :end_idx ])
14
20
15
21
22
+ def trunc_csr (csr , k ):
23
+ indptr = np .empty_like (csr .indptr )
24
+ num_rows = len (csr .indptr ) - 1
25
+ indices = [None ] * num_rows
26
+ data = [None ] * num_rows
27
+ cur_indptr = 0
28
+ for row_idx in range (num_rows ):
29
+ indptr [row_idx ] = cur_indptr
30
+ start_idx = csr .indptr [row_idx ]
31
+ old_end_idx = csr .indptr [row_idx + 1 ]
32
+ end_idx = min (old_end_idx , start_idx + k )
33
+ data [row_idx ] = csr .data [start_idx :end_idx ]
34
+ indices [row_idx ] = csr .indices [start_idx :end_idx ]
35
+ ptr_inc = min (k , old_end_idx - start_idx )
36
+ cur_indptr = cur_indptr + ptr_inc
37
+ indptr [- 1 ] = cur_indptr
38
+ return csr_matrix ((np .concatenate (data ), np .concatenate (indices ), indptr ))
39
+
40
+
16
41
def or_else_csrs (csr1 , csr2 ):
17
42
# Possible TODO: Could use numba/Cython to speed this up?
18
43
if csr1 .shape != csr2 .shape :
0 commit comments