Skip to content

Commit d45d6bf

Browse files
committed
add ligrec
1 parent d46ab83 commit d45d6bf

File tree

4 files changed

+384
-305
lines changed

4 files changed

+384
-305
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ if (RSC_BUILD_EXTENSIONS)
5555
add_nb_cuda_module(_cooc_cuda src/rapids_singlecell/_cuda/cooc/cooc.cu)
5656
add_nb_cuda_module(_aggr_cuda src/rapids_singlecell/_cuda/aggr/aggr.cu)
5757
add_nb_cuda_module(_spca_cuda src/rapids_singlecell/_cuda/spca/spca.cu)
58+
add_nb_cuda_module(_ligrec_cuda src/rapids_singlecell/_cuda/ligrec/ligrec.cu)
5859
# Harmony CUDA modules
5960
add_nb_cuda_module(_harmony_scatter_cuda src/rapids_singlecell/_cuda/harmony/scatter/scatter.cu)
6061
add_nb_cuda_module(_harmony_outer_cuda src/rapids_singlecell/_cuda/harmony/outer/outer.cu)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#pragma once
2+
3+
#include <cuda_runtime.h>
4+
5+
template <typename T>
6+
__global__ void sum_and_count_dense_kernel(const T* __restrict__ data,
7+
const int* __restrict__ clusters,
8+
T* __restrict__ sum_gt0, int* __restrict__ count_gt0,
9+
int num_rows, int num_cols, int n_cls) {
10+
int i = blockIdx.x * blockDim.x + threadIdx.x;
11+
int j = blockIdx.y * blockDim.y + threadIdx.y;
12+
if (i >= num_rows || j >= num_cols) return;
13+
int cluster = clusters[i];
14+
T value = data[i * num_cols + j];
15+
if (value > (T)0) {
16+
atomicAdd(&sum_gt0[j * n_cls + cluster], value);
17+
atomicAdd(&count_gt0[j * n_cls + cluster], 1);
18+
}
19+
}
20+
21+
template <typename T>
22+
__global__ void sum_and_count_sparse_kernel(const int* __restrict__ indptr,
23+
const int* __restrict__ index,
24+
const T* __restrict__ data,
25+
const int* __restrict__ clusters,
26+
T* __restrict__ sum_gt0, int* __restrict__ count_gt0,
27+
int nrows, int n_cls) {
28+
int cell = blockDim.x * blockIdx.x + threadIdx.x;
29+
if (cell >= nrows) return;
30+
int start_idx = indptr[cell];
31+
int stop_idx = indptr[cell + 1];
32+
int cluster = clusters[cell];
33+
for (int gene = start_idx; gene < stop_idx; gene++) {
34+
T value = data[gene];
35+
int gene_number = index[gene];
36+
if (value > (T)0) {
37+
atomicAdd(&sum_gt0[gene_number * n_cls + cluster], value);
38+
atomicAdd(&count_gt0[gene_number * n_cls + cluster], 1);
39+
}
40+
}
41+
}
42+
43+
template <typename T>
44+
__global__ void mean_dense_kernel(const T* __restrict__ data, const int* __restrict__ clusters,
45+
T* __restrict__ g_cluster, int num_rows, int num_cols,
46+
int n_cls) {
47+
int i = blockIdx.x * blockDim.x + threadIdx.x;
48+
int j = blockIdx.y * blockDim.y + threadIdx.y;
49+
if (i >= num_rows || j >= num_cols) return;
50+
atomicAdd(&g_cluster[j * n_cls + clusters[i]], data[i * num_cols + j]);
51+
}
52+
53+
template <typename T>
54+
__global__ void mean_sparse_kernel(const int* __restrict__ indptr, const int* __restrict__ index,
55+
const T* __restrict__ data, const int* __restrict__ clusters,
56+
T* __restrict__ sum_gt0, int nrows, int n_cls) {
57+
int cell = blockDim.x * blockIdx.x + threadIdx.x;
58+
if (cell >= nrows) return;
59+
int start_idx = indptr[cell];
60+
int stop_idx = indptr[cell + 1];
61+
int cluster = clusters[cell];
62+
for (int gene = start_idx; gene < stop_idx; gene++) {
63+
T value = data[gene];
64+
int gene_number = index[gene];
65+
if (value > (T)0) {
66+
atomicAdd(&sum_gt0[gene_number * n_cls + cluster], value);
67+
}
68+
}
69+
}
70+
71+
template <typename T>
72+
__global__ void elementwise_diff_kernel(T* __restrict__ g_cluster,
73+
const T* __restrict__ total_counts, int num_genes,
74+
int num_clusters) {
75+
int i = blockIdx.x * blockDim.x + threadIdx.x;
76+
int j = blockIdx.y * blockDim.y + threadIdx.y;
77+
if (i >= num_genes || j >= num_clusters) return;
78+
g_cluster[i * num_clusters + j] = g_cluster[i * num_clusters + j] / total_counts[j];
79+
}
80+
81+
template <typename T>
82+
__global__ void interaction_kernel(const int* __restrict__ interactions,
83+
const int* __restrict__ interaction_clusters,
84+
const T* __restrict__ mean, T* __restrict__ res,
85+
const bool* __restrict__ mask, const T* __restrict__ g,
86+
int n_iter, int n_inter_clust, int n_cls) {
87+
int i = blockIdx.x * blockDim.x + threadIdx.x;
88+
int j = blockIdx.y * blockDim.y + threadIdx.y;
89+
if (i >= n_iter || j >= n_inter_clust) return;
90+
int rec = interactions[i * 2];
91+
int lig = interactions[i * 2 + 1];
92+
int c1 = interaction_clusters[j * 2];
93+
int c2 = interaction_clusters[j * 2 + 1];
94+
T m1 = mean[rec * n_cls + c1];
95+
T m2 = mean[lig * n_cls + c2];
96+
if (!isnan(res[i * n_inter_clust + j])) {
97+
if (m1 > (T)0 && m2 > (T)0) {
98+
if (mask[rec * n_cls + c1] && mask[lig * n_cls + c2]) {
99+
T g_sum = g[rec * n_cls + c1] + g[lig * n_cls + c2];
100+
res[i * n_inter_clust + j] += (g_sum > (m1 + m2));
101+
} else {
102+
res[i * n_inter_clust + j] = nan("");
103+
}
104+
} else {
105+
res[i * n_inter_clust + j] = nan("");
106+
}
107+
}
108+
}
109+
110+
template <typename T>
111+
__global__ void res_mean_kernel(const int* __restrict__ interactions,
112+
const int* __restrict__ interaction_clusters,
113+
const T* __restrict__ mean, T* __restrict__ res_mean, int n_inter,
114+
int n_inter_clust, int n_cls) {
115+
int i = blockIdx.x * blockDim.x + threadIdx.x;
116+
int j = blockIdx.y * blockDim.y + threadIdx.y;
117+
if (i >= n_inter || j >= n_inter_clust) return;
118+
int rec = interactions[i * 2];
119+
int lig = interactions[i * 2 + 1];
120+
int c1 = interaction_clusters[j * 2];
121+
int c2 = interaction_clusters[j * 2 + 1];
122+
T m1 = mean[rec * n_cls + c1];
123+
T m2 = mean[lig * n_cls + c2];
124+
if (m1 > (T)0 && m2 > (T)0) {
125+
res_mean[i * n_inter_clust + j] = (m1 + m2) / (T)2;
126+
}
127+
}
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
#include <cuda_runtime.h>
2+
#include <nanobind/nanobind.h>
3+
#include <cstdint>
4+
5+
#include "kernels_ligrec.cuh"
6+
7+
namespace nb = nanobind;
8+
9+
template <typename T>
10+
static inline void launch_sum_count_dense(std::uintptr_t data, std::uintptr_t clusters,
11+
std::uintptr_t sum, std::uintptr_t count, int rows,
12+
int cols, int ncls) {
13+
dim3 block(32, 32);
14+
dim3 grid((rows + block.x - 1) / block.x, (cols + block.y - 1) / block.y);
15+
sum_and_count_dense_kernel<T>
16+
<<<grid, block>>>(reinterpret_cast<const T*>(data), reinterpret_cast<const int*>(clusters),
17+
reinterpret_cast<T*>(sum), reinterpret_cast<int*>(count), rows, cols, ncls);
18+
}
19+
20+
template <typename T>
21+
static inline void launch_sum_count_sparse(std::uintptr_t indptr, std::uintptr_t index,
22+
std::uintptr_t data, std::uintptr_t clusters,
23+
std::uintptr_t sum, std::uintptr_t count, int rows,
24+
int ncls) {
25+
dim3 block(32);
26+
dim3 grid((rows + block.x - 1) / block.x);
27+
sum_and_count_sparse_kernel<T>
28+
<<<grid, block>>>(reinterpret_cast<const int*>(indptr), reinterpret_cast<const int*>(index),
29+
reinterpret_cast<const T*>(data), reinterpret_cast<const int*>(clusters),
30+
reinterpret_cast<T*>(sum), reinterpret_cast<int*>(count), rows, ncls);
31+
}
32+
33+
template <typename T>
34+
static inline void launch_mean_dense(std::uintptr_t data, std::uintptr_t clusters, std::uintptr_t g,
35+
int rows, int cols, int ncls) {
36+
dim3 block(32, 32);
37+
dim3 grid((rows + block.x - 1) / block.x, (cols + block.y - 1) / block.y);
38+
mean_dense_kernel<T><<<grid, block>>>(reinterpret_cast<const T*>(data),
39+
reinterpret_cast<const int*>(clusters),
40+
reinterpret_cast<T*>(g), rows, cols, ncls);
41+
}
42+
43+
template <typename T>
44+
static inline void launch_mean_sparse(std::uintptr_t indptr, std::uintptr_t index,
45+
std::uintptr_t data, std::uintptr_t clusters,
46+
std::uintptr_t g, int rows, int ncls) {
47+
dim3 block(32);
48+
dim3 grid((rows + block.x - 1) / block.x);
49+
mean_sparse_kernel<T>
50+
<<<grid, block>>>(reinterpret_cast<const int*>(indptr), reinterpret_cast<const int*>(index),
51+
reinterpret_cast<const T*>(data), reinterpret_cast<const int*>(clusters),
52+
reinterpret_cast<T*>(g), rows, ncls);
53+
}
54+
55+
template <typename T>
56+
static inline void launch_elementwise_diff(std::uintptr_t g, std::uintptr_t total_counts,
57+
int n_genes, int n_clusters) {
58+
dim3 block(32, 32);
59+
dim3 grid((n_genes + block.x - 1) / block.x, (n_clusters + block.y - 1) / block.y);
60+
elementwise_diff_kernel<T><<<grid, block>>>(
61+
reinterpret_cast<T*>(g), reinterpret_cast<const T*>(total_counts), n_genes, n_clusters);
62+
}
63+
64+
template <typename T>
65+
static inline void launch_interaction(std::uintptr_t interactions,
66+
std::uintptr_t interaction_clusters, std::uintptr_t mean,
67+
std::uintptr_t res, std::uintptr_t mask, std::uintptr_t g,
68+
int n_iter, int n_inter_clust, int ncls) {
69+
dim3 block(32, 32);
70+
dim3 grid((n_iter + block.x - 1) / block.x, (n_inter_clust + block.y - 1) / block.y);
71+
interaction_kernel<T><<<grid, block>>>(
72+
reinterpret_cast<const int*>(interactions),
73+
reinterpret_cast<const int*>(interaction_clusters), reinterpret_cast<const T*>(mean),
74+
reinterpret_cast<T*>(res), reinterpret_cast<const bool*>(mask), reinterpret_cast<const T*>(g),
75+
n_iter, n_inter_clust, ncls);
76+
}
77+
78+
template <typename T>
79+
static inline void launch_res_mean(std::uintptr_t interactions, std::uintptr_t interaction_clusters,
80+
std::uintptr_t mean, std::uintptr_t res_mean, int n_inter,
81+
int n_inter_clust, int ncls) {
82+
dim3 block(32, 32);
83+
dim3 grid((n_inter + block.x - 1) / block.x, (n_inter_clust + block.y - 1) / block.y);
84+
res_mean_kernel<T><<<grid, block>>>(reinterpret_cast<const int*>(interactions),
85+
reinterpret_cast<const int*>(interaction_clusters),
86+
reinterpret_cast<const T*>(mean),
87+
reinterpret_cast<T*>(res_mean), n_inter, n_inter_clust, ncls);
88+
}
89+
90+
NB_MODULE(_ligrec_cuda, m) {
91+
m.def("sum_count_dense", [](std::uintptr_t data, std::uintptr_t clusters, std::uintptr_t sum,
92+
std::uintptr_t count, int rows, int cols, int ncls, int itemsize) {
93+
if (itemsize == 4) {
94+
launch_sum_count_dense<float>(data, clusters, sum, count, rows, cols, ncls);
95+
} else if (itemsize == 8) {
96+
launch_sum_count_dense<double>(data, clusters, sum, count, rows, cols, ncls);
97+
} else {
98+
throw nb::value_error("Unsupported itemsize (expected 4 or 8)");
99+
}
100+
});
101+
102+
m.def("sum_count_sparse", [](std::uintptr_t indptr, std::uintptr_t index, std::uintptr_t data,
103+
std::uintptr_t clusters, std::uintptr_t sum, std::uintptr_t count,
104+
int rows, int ncls, int itemsize) {
105+
if (itemsize == 4) {
106+
launch_sum_count_sparse<float>(indptr, index, data, clusters, sum, count, rows, ncls);
107+
} else if (itemsize == 8) {
108+
launch_sum_count_sparse<double>(indptr, index, data, clusters, sum, count, rows, ncls);
109+
} else {
110+
throw nb::value_error("Unsupported itemsize (expected 4 or 8)");
111+
}
112+
});
113+
114+
m.def("mean_dense", [](std::uintptr_t data, std::uintptr_t clusters, std::uintptr_t g, int rows,
115+
int cols, int ncls, int itemsize) {
116+
if (itemsize == 4) {
117+
launch_mean_dense<float>(data, clusters, g, rows, cols, ncls);
118+
} else if (itemsize == 8) {
119+
launch_mean_dense<double>(data, clusters, g, rows, cols, ncls);
120+
} else {
121+
throw nb::value_error("Unsupported itemsize (expected 4 or 8)");
122+
}
123+
});
124+
125+
m.def("mean_sparse",
126+
[](std::uintptr_t indptr, std::uintptr_t index, std::uintptr_t data,
127+
std::uintptr_t clusters, std::uintptr_t g, int rows, int ncls, int itemsize) {
128+
if (itemsize == 4) {
129+
launch_mean_sparse<float>(indptr, index, data, clusters, g, rows, ncls);
130+
} else if (itemsize == 8) {
131+
launch_mean_sparse<double>(indptr, index, data, clusters, g, rows, ncls);
132+
} else {
133+
throw nb::value_error("Unsupported itemsize (expected 4 or 8)");
134+
}
135+
});
136+
137+
m.def("elementwise_diff", [](std::uintptr_t g, std::uintptr_t total_counts, int n_genes,
138+
int n_clusters, int itemsize) {
139+
if (itemsize == 4) {
140+
launch_elementwise_diff<float>(g, total_counts, n_genes, n_clusters);
141+
} else if (itemsize == 8) {
142+
launch_elementwise_diff<double>(g, total_counts, n_genes, n_clusters);
143+
} else {
144+
throw nb::value_error("Unsupported itemsize (expected 4 or 8)");
145+
}
146+
});
147+
148+
m.def("interaction", [](std::uintptr_t interactions, std::uintptr_t interaction_clusters,
149+
std::uintptr_t mean, std::uintptr_t res, std::uintptr_t mask,
150+
std::uintptr_t g, int n_iter, int n_inter_clust, int ncls, int itemsize) {
151+
if (itemsize == 4) {
152+
launch_interaction<float>(interactions, interaction_clusters, mean, res, mask, g, n_iter,
153+
n_inter_clust, ncls);
154+
} else if (itemsize == 8) {
155+
launch_interaction<double>(interactions, interaction_clusters, mean, res, mask, g, n_iter,
156+
n_inter_clust, ncls);
157+
} else {
158+
throw nb::value_error("Unsupported itemsize (expected 4 or 8)");
159+
}
160+
});
161+
162+
m.def("res_mean",
163+
[](std::uintptr_t interactions, std::uintptr_t interaction_clusters, std::uintptr_t mean,
164+
std::uintptr_t res_mean, int n_inter, int n_inter_clust, int ncls, int itemsize) {
165+
if (itemsize == 4) {
166+
launch_res_mean<float>(interactions, interaction_clusters, mean, res_mean, n_inter,
167+
n_inter_clust, ncls);
168+
} else if (itemsize == 8) {
169+
launch_res_mean<double>(interactions, interaction_clusters, mean, res_mean, n_inter,
170+
n_inter_clust, ncls);
171+
} else {
172+
throw nb::value_error("Unsupported itemsize (expected 4 or 8)");
173+
}
174+
});
175+
}

0 commit comments

Comments
 (0)