|
| 1 | +#include <cuda_runtime.h> |
| 2 | +#include <nanobind/nanobind.h> |
| 3 | +#include <cstdint> |
| 4 | + |
| 5 | +#include "kernels_outer.cuh" |
| 6 | + |
| 7 | +namespace nb = nanobind; |
| 8 | + |
| 9 | +template <typename T> |
| 10 | +static inline void launch_outer(std::uintptr_t E, std::uintptr_t Pr_b, std::uintptr_t R_sum, |
| 11 | + long long n_cats, long long n_pcs, long long switcher) { |
| 12 | + dim3 block(256); |
| 13 | + long long N = n_cats * n_pcs; |
| 14 | + dim3 grid((unsigned)((N + block.x - 1) / block.x)); |
| 15 | + outer_kernel<T><<<grid, block>>>(reinterpret_cast<T*>(E), reinterpret_cast<const T*>(Pr_b), |
| 16 | + reinterpret_cast<const T*>(R_sum), n_cats, n_pcs, switcher); |
| 17 | +} |
| 18 | + |
| 19 | +template <typename T> |
| 20 | +static inline void launch_harmony_corr(std::uintptr_t Z, std::uintptr_t W, std::uintptr_t cats, |
| 21 | + std::uintptr_t R, long long n_cells, long long n_pcs) { |
| 22 | + dim3 block(256); |
| 23 | + long long N = n_cells * n_pcs; |
| 24 | + dim3 grid((unsigned)((N + block.x - 1) / block.x)); |
| 25 | + harmony_correction_kernel<T><<<grid, block>>>( |
| 26 | + reinterpret_cast<T*>(Z), reinterpret_cast<const T*>(W), reinterpret_cast<const int*>(cats), |
| 27 | + reinterpret_cast<const T*>(R), n_cells, n_pcs); |
| 28 | +} |
| 29 | + |
| 30 | +NB_MODULE(_harmony_outer_cuda, m) { |
| 31 | + m.def("outer", [](std::uintptr_t E, std::uintptr_t Pr_b, std::uintptr_t R_sum, long long n_cats, |
| 32 | + long long n_pcs, long long switcher, int itemsize) { |
| 33 | + if (itemsize == 4) { |
| 34 | + launch_outer<float>(E, Pr_b, R_sum, n_cats, n_pcs, switcher); |
| 35 | + } else if (itemsize == 8) { |
| 36 | + launch_outer<double>(E, Pr_b, R_sum, n_cats, n_pcs, switcher); |
| 37 | + } else { |
| 38 | + throw nb::value_error("Unsupported itemsize (expected 4 or 8)"); |
| 39 | + } |
| 40 | + }); |
| 41 | + |
| 42 | + m.def("harmony_corr", [](std::uintptr_t Z, std::uintptr_t W, std::uintptr_t cats, |
| 43 | + std::uintptr_t R, long long n_cells, long long n_pcs, int itemsize) { |
| 44 | + if (itemsize == 4) { |
| 45 | + launch_harmony_corr<float>(Z, W, cats, R, n_cells, n_pcs); |
| 46 | + } else if (itemsize == 8) { |
| 47 | + launch_harmony_corr<double>(Z, W, cats, R, n_cells, n_pcs); |
| 48 | + } else { |
| 49 | + throw nb::value_error("Unsupported itemsize (expected 4 or 8)"); |
| 50 | + } |
| 51 | + }); |
| 52 | +} |
0 commit comments