Skip to content

Commit 94c76d8

Browse files
authored
improve L2 cache (#607)
* improve L2 cache * add release note * add release note
1 parent 7b014a6 commit 94c76d8

File tree

4 files changed

+40
-28
lines changed

4 files changed

+40
-28
lines changed

docs/release-notes/0.15.0.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
```{rubric} Features
44
```
55
* Allow multiple control groups in ``onesided_distances`` for computing energy distances against several references in a single kernel launch {pr}`601` {smaller}`S Dicks`
6+
* Add ``contrast_distances`` to ``EDistanceMetric`` for computing energy distances directly from a contrasts DataFrame {pr}`603` {smaller}`S Dicks`
7+
* Improve L2 cache efficiency in ``edistance`` and ``co_occurrence`` kernels by always tiling the smaller group into shared memory, yielding up to 5x speedup for datasets with unequal group sizes {pr}`607` {smaller}`S Dicks`
68

79
```{rubric} Bug fixes
810
```

src/rapids_singlecell/_cuda/cooc/kernels_cooc.cuh

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,18 @@ __global__ void occur_count_kernel_csr_catpairs_tiled(
298298
const int a = pair_left[pair_id];
299299
const int b = pair_right[pair_id];
300300

301-
const int start_a = cat_offsets[a];
302-
const int end_a = cat_offsets[a + 1];
303-
const int start_b = cat_offsets[b];
304-
const int end_b = cat_offsets[b + 1];
301+
const int start_oa = cat_offsets[a];
302+
const int end_oa = cat_offsets[a + 1];
303+
const int start_ob = cat_offsets[b];
304+
const int end_ob = cat_offsets[b + 1];
305+
306+
// Always iterate over the larger group (A) and tile the smaller group (B)
307+
// into shared memory. Small B stays hot in L2 across many A iterations.
308+
const bool do_swap = (end_oa - start_oa) < (end_ob - start_ob);
309+
const int start_a = do_swap ? start_ob : start_oa;
310+
const int end_a = do_swap ? end_ob : end_oa;
311+
const int start_b = do_swap ? start_oa : start_ob;
312+
const int end_b = do_swap ? end_oa : end_ob;
305313

306314
const int n_a = end_a - start_a;
307315
const int n_b = end_b - start_b;

src/rapids_singlecell/_cuda/edistance/kernels_edistance.cuh

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,21 @@ __global__ void edistance_kernel(const T* __restrict__ embedding,
2626

2727
T local_sum = T(0.0);
2828

29-
const int a = pair_left[pair_id];
30-
const int b = pair_right[pair_id];
31-
32-
const int start_a = cat_offsets[a];
33-
const int end_a = cat_offsets[a + 1];
34-
const int start_b = cat_offsets[b];
35-
const int end_b = cat_offsets[b + 1];
29+
const int pair_a = pair_left[pair_id];
30+
const int pair_b = pair_right[pair_id];
31+
32+
const int start_pa = cat_offsets[pair_a];
33+
const int end_pa = cat_offsets[pair_a + 1];
34+
const int start_pb = cat_offsets[pair_b];
35+
const int end_pb = cat_offsets[pair_b + 1];
36+
37+
// Always iterate over the larger group (A) and tile the smaller group (B)
38+
// into shared memory. Small B stays hot in L2 across many A iterations.
39+
const bool swap = (end_pa - start_pa) < (end_pb - start_pb);
40+
const int start_a = swap ? start_pb : start_pa;
41+
const int end_a = swap ? end_pb : end_pa;
42+
const int start_b = swap ? start_pa : start_pb;
43+
const int end_b = swap ? end_pa : end_pb;
3644

3745
const int n_a = end_a - start_a;
3846
const int n_b = end_b - start_b;
@@ -109,7 +117,7 @@ __global__ void edistance_kernel(const T* __restrict__ embedding,
109117
int j_local = jb_base + c;
110118

111119
// Skip lower triangle for diagonal blocks
112-
if (a == b && i_local >= j_local) continue;
120+
if (pair_a == pair_b && i_local >= j_local) continue;
113121

114122
local_sum += sqrt(dist_sq[c]);
115123
}

src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _subset_to_groups(
5959
self,
6060
adata: AnnData,
6161
groupby: str,
62-
needed_groups: Sequence[str],
62+
needed_groups: Sequence[str] | None,
6363
) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, list[str]]:
6464
"""Subset embedding and category mapping to only the needed groups.
6565
@@ -85,22 +85,16 @@ def _subset_to_groups(
8585
"""
8686
obs_col = adata.obs[groupby]
8787
embedding_raw = self._get_embedding(adata)
88-
if needed_groups is None:
89-
groups_list = list(obs_col.cat.categories.values)
88+
89+
if needed_groups is not None:
90+
mask = obs_col.isin(needed_groups).values
91+
obs_col = obs_col[mask].cat.remove_unused_categories()
92+
embedding = cp.asarray(embedding_raw[mask])
93+
else:
9094
embedding = cp.asarray(embedding_raw)
91-
group_map = {v: i for i, v in enumerate(groups_list)}
92-
group_labels = cp.array([group_map[c] for c in obs_col], dtype=cp.int32)
93-
k = len(groups_list)
94-
cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k)
95-
return embedding, cat_offsets, cell_indices, groups_list
96-
97-
# Subset before GPU transfer (CPU subset avoids full GPU allocation)
98-
needed_set = set(needed_groups)
99-
groups_list = [g for g in obs_col.cat.categories.values if g in needed_set]
100-
group_map = {v: i for i, v in enumerate(groups_list)}
101-
mask = obs_col.isin(groups_list).values
102-
embedding = cp.asarray(embedding_raw[mask])
103-
group_labels = cp.array([group_map[c] for c in obs_col[mask]], dtype=cp.int32)
95+
96+
groups_list = list(obs_col.cat.categories)
97+
group_labels = cp.array(obs_col.cat.codes.values, dtype=cp.int32)
10498
k = len(groups_list)
10599
cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k)
106100
return embedding, cat_offsets, cell_indices, groups_list

0 commit comments

Comments
 (0)