1515# Larger tiles = fewer iterations but more shared memory
1616TILE_SIZES = [32 , 50 , 64 ]
1717
18- # Cache for device shared memory limit (lazy initialization)
19- _MAX_SHARED_MEM : int | None = None
18+ # Cache for device attributes (lazy initialization)
19+ _DEVICE_ATTRS : dict | None = None
2020
2121
22- def _get_max_shared_mem () -> int :
23- """Get maximum shared memory per block for the current device (cached)."""
24- global _MAX_SHARED_MEM
25- if _MAX_SHARED_MEM is None :
22+ def _get_device_attrs () -> dict :
23+ """Get device attributes for the current device (cached)."""
24+ global _DEVICE_ATTRS
25+ if _DEVICE_ATTRS is None :
2626 device = cp .cuda .Device ()
27- _MAX_SHARED_MEM = device .attributes ["MaxSharedMemoryPerBlock" ]
28- return _MAX_SHARED_MEM
27+ # compute_capability is a string like "120" for CC 12.0, or "86" for CC 8.6
28+ cc_str = str (device .compute_capability )
29+ cc_major = int (cc_str [:- 1 ]) if len (cc_str ) > 1 else int (cc_str )
30+ _DEVICE_ATTRS = {
31+ "max_shared_mem" : device .attributes ["MaxSharedMemoryPerBlock" ],
32+ "cc_major" : cc_major ,
33+ }
34+ return _DEVICE_ATTRS
2935
3036
3137def _choose_feat_tile (
@@ -93,7 +99,7 @@ def _choose_feat_tile(
9399
94100def get_compute_group_distances_kernel (
95101 dtype : np .dtype , n_features : int
96- ) -> tuple [object , int ]:
102+ ) -> tuple [object , int , int ]:
97103 """
98104 Compile GPU kernel for computing pairwise group distances.
99105
@@ -110,19 +116,30 @@ def get_compute_group_distances_kernel(
110116 Compiled CUDA kernel
111117 shared_mem_bytes
112118 Required shared memory in bytes
119+ block_size
120+ Recommended block size (threads per block)
113121 """
114122 dtype = np .dtype (dtype )
115123 is_double = dtype == np .float64
116124 sqrt_fn = "sqrt" if is_double else "sqrtf"
117125 dtype_size = dtype .itemsize
118126
127+ device_attrs = _get_device_attrs ()
128+ max_shared = device_attrs ["max_shared_mem" ]
129+
119130 # Cell tile based on dtype to avoid register pressure
120131 cell_tile = 16 if is_double else 32
121132
122133 # Auto-select feat_tile based on n_features and available shared memory
123- feat_tile = _choose_feat_tile (
124- n_features , _get_max_shared_mem (), cell_tile , dtype_size
125- )
134+ feat_tile = _choose_feat_tile (n_features , max_shared , cell_tile , dtype_size )
135+
136+ # Default block size for Ampere+ (3090, A100, H100, etc.)
137+ block_size = 1024
138+
139+ # For pre-Ampere GPUs (CC < 8.0, e.g., T4) with float64, reduce block size
140+ # to avoid CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES due to register pressure
141+ if is_double and device_attrs ["cc_major" ] < 8 :
142+ block_size = 256
126143
127144 kernel_code = """
128145(const {0}* __restrict__ embedding,
@@ -270,4 +287,4 @@ def get_compute_group_distances_kernel(
270287
271288 shared_mem_bytes = cell_tile * feat_tile * dtype_size
272289
273- return kernel , shared_mem_bytes
290+ return kernel , shared_mem_bytes , block_size
0 commit comments