Skip to content

Commit 153174d

Browse files
committed
avoid register pressure on old hardware
1 parent ab3d5a8 commit 153174d

File tree

2 files changed

+38
-21
lines changed

2 files changed

+38
-21
lines changed

src/rapids_singlecell/pertpy_gpu/_metrics/_edistance_metric.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -485,13 +485,13 @@ def _pairwise_means(
485485
pairwise_sums = cp.zeros((k, k), dtype=embedding.dtype)
486486

487487
if num_pairs > 0:
488+
kernel, shared_mem, block_size = get_compute_group_distances_kernel(
489+
embedding.dtype, n_features
490+
)
488491
blocks_per_pair = self._calculate_blocks_per_pair(num_pairs)
489492
grid = (num_pairs, blocks_per_pair)
490-
block = (1024,)
493+
block = (block_size,)
491494

492-
kernel, shared_mem = get_compute_group_distances_kernel(
493-
embedding.dtype, n_features
494-
)
495495
kernel(
496496
grid,
497497
block,
@@ -574,13 +574,13 @@ def _onesided_means(
574574
onesided_sums = cp.zeros((k, k), dtype=embedding.dtype)
575575

576576
if num_pairs > 0:
577+
kernel, shared_mem, block_size = get_compute_group_distances_kernel(
578+
embedding.dtype, n_features
579+
)
577580
blocks_per_pair = self._calculate_blocks_per_pair(num_pairs)
578581
grid = (num_pairs, blocks_per_pair)
579-
block = (1024,)
582+
block = (block_size,)
580583

581-
kernel, shared_mem = get_compute_group_distances_kernel(
582-
embedding.dtype, n_features
583-
)
584584
kernel(
585585
grid,
586586
block,

src/rapids_singlecell/pertpy_gpu/_metrics/_kernels/_edistance_kernel.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,23 @@
1515
# Larger tiles = fewer iterations but more shared memory
1616
TILE_SIZES = [32, 50, 64]
1717

18-
# Cache for device shared memory limit (lazy initialization)
19-
_MAX_SHARED_MEM: int | None = None
18+
# Cache for device attributes (lazy initialization)
19+
_DEVICE_ATTRS: dict | None = None
2020

2121

22-
def _get_max_shared_mem() -> int:
23-
"""Get maximum shared memory per block for the current device (cached)."""
24-
global _MAX_SHARED_MEM
25-
if _MAX_SHARED_MEM is None:
22+
def _get_device_attrs() -> dict:
23+
"""Get device attributes for the current device (cached)."""
24+
global _DEVICE_ATTRS
25+
if _DEVICE_ATTRS is None:
2626
device = cp.cuda.Device()
27-
_MAX_SHARED_MEM = device.attributes["MaxSharedMemoryPerBlock"]
28-
return _MAX_SHARED_MEM
27+
# compute_capability is a string like "120" for CC 12.0, or "86" for CC 8.6
28+
cc_str = str(device.compute_capability)
29+
cc_major = int(cc_str[:-1]) if len(cc_str) > 1 else int(cc_str)
30+
_DEVICE_ATTRS = {
31+
"max_shared_mem": device.attributes["MaxSharedMemoryPerBlock"],
32+
"cc_major": cc_major,
33+
}
34+
return _DEVICE_ATTRS
2935

3036

3137
def _choose_feat_tile(
@@ -93,7 +99,7 @@ def _choose_feat_tile(
9399

94100
def get_compute_group_distances_kernel(
95101
dtype: np.dtype, n_features: int
96-
) -> tuple[object, int]:
102+
) -> tuple[object, int, int]:
97103
"""
98104
Compile GPU kernel for computing pairwise group distances.
99105
@@ -110,19 +116,30 @@ def get_compute_group_distances_kernel(
110116
Compiled CUDA kernel
111117
shared_mem_bytes
112118
Required shared memory in bytes
119+
block_size
120+
Recommended block size (threads per block)
113121
"""
114122
dtype = np.dtype(dtype)
115123
is_double = dtype == np.float64
116124
sqrt_fn = "sqrt" if is_double else "sqrtf"
117125
dtype_size = dtype.itemsize
118126

127+
device_attrs = _get_device_attrs()
128+
max_shared = device_attrs["max_shared_mem"]
129+
119130
# Cell tile based on dtype to avoid register pressure
120131
cell_tile = 16 if is_double else 32
121132

122133
# Auto-select feat_tile based on n_features and available shared memory
123-
feat_tile = _choose_feat_tile(
124-
n_features, _get_max_shared_mem(), cell_tile, dtype_size
125-
)
134+
feat_tile = _choose_feat_tile(n_features, max_shared, cell_tile, dtype_size)
135+
136+
# Default block size for Ampere+ (3090, A100, H100, etc.)
137+
block_size = 1024
138+
139+
# For pre-Ampere GPUs (CC < 8.0, e.g., T4) with float64, reduce block size
140+
# to avoid CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES due to register pressure
141+
if is_double and device_attrs["cc_major"] < 8:
142+
block_size = 256
126143

127144
kernel_code = """
128145
(const {0}* __restrict__ embedding,
@@ -270,4 +287,4 @@ def get_compute_group_distances_kernel(
270287

271288
shared_mem_bytes = cell_tile * feat_tile * dtype_size
272289

273-
return kernel, shared_mem_bytes
290+
return kernel, shared_mem_bytes, block_size

0 commit comments

Comments
 (0)