diff --git a/include/xsf/config.h b/include/xsf/config.h index ec7f0bb..645bde5 100644 --- a/include/xsf/config.h +++ b/include/xsf/config.h @@ -53,6 +53,17 @@ #define M_SQRT1_2 0.707106781186547524401 #endif +namespace xsf { + +template +constexpr bool is_supported_int_v = + std::is_same_v || std::is_same_v; + +template +using enable_if_supported_int_t = std::enable_if_t, int>; + +} + #ifdef __CUDACC__ #define XSF_HOST_DEVICE __host__ __device__ diff --git a/include/xsf/lambertw.h b/include/xsf/lambertw.h index eeeca8a..39ca1ba 100644 --- a/include/xsf/lambertw.h +++ b/include/xsf/lambertw.h @@ -54,7 +54,7 @@ namespace detail { return z * cevalpoly(num, 2, z) / cevalpoly(denom, 2, z); } - XSF_HOST_DEVICE inline std::complex lambertw_asy(std::complex z, long k) { + XSF_HOST_DEVICE inline std::complex lambertw_asy(std::complex z, double k) { /* Compute the W function using the first two terms of the * asymptotic series. See 4.20 in [1]. */ @@ -64,7 +64,8 @@ namespace detail { } // namespace detail -XSF_HOST_DEVICE inline std::complex lambertw(std::complex z, long k, double tol) { +template = 0> +XSF_HOST_DEVICE inline std::complex lambertw(std::complex z, K k, double tol) { double absz; std::complex w; std::complex ew, wew, wewz, wn; @@ -142,7 +143,8 @@ XSF_HOST_DEVICE inline std::complex lambertw(std::complex z, lon return {std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN()}; } -XSF_HOST_DEVICE inline std::complex lambertw(std::complex z, long k, float tol) { +template = 0> +XSF_HOST_DEVICE inline std::complex lambertw(std::complex z, K k, float tol) { return static_cast>(lambertw(static_cast>(z), k, static_cast(tol)) ); } diff --git a/include/xsf/sph_bessel.h b/include/xsf/sph_bessel.h index ea1105c..40cc1fd 100644 --- a/include/xsf/sph_bessel.h +++ b/include/xsf/sph_bessel.h @@ -33,8 +33,8 @@ Translated to C++ by SciPy developers in 2024. namespace xsf { -template -T sph_bessel_j(long n, T x) { +template = 0> +T sph_bessel_j(N n, T x) { if (std::isnan(x)) { return x; } @@ -71,7 +71,7 @@ T sph_bessel_j(long n, T x) { } T sn; - for (int i = 0; i < n - 1; ++i) { + for (N i = 0; i < n - 1; ++i) { sn = (2 * i + 3) * s1 / x - s0; s0 = s1; s1 = sn; @@ -84,8 +84,8 @@ T sph_bessel_j(long n, T x) { return sn; } -template -std::complex sph_bessel_j(long n, std::complex z) { +template = 0> +std::complex sph_bessel_j(N n, std::complex z) { if (std::isnan(std::real(z)) || std::isnan(std::imag(z))) { return z; } @@ -120,8 +120,13 @@ std::complex sph_bessel_j(long n, std::complex z) { return out; } -template -T sph_bessel_j_jac(long n, T z) { +template = 0> +T sph_bessel_j_jac(N n, T z) { + if (n < 0) { + set_error("spherical_j_jac", SF_ERROR_DOMAIN, nullptr); + return std::numeric_limits::quiet_NaN(); + } + if (n == 0) { return -sph_bessel_j(1, z); } @@ -136,13 +141,13 @@ T sph_bessel_j_jac(long n, T z) { } // DLMF 10.51.2 - return sph_bessel_j(n - 1, z) - static_cast(n + 1) * sph_bessel_j(n, z) / z; + return sph_bessel_j(n - 1, z) - (static_cast(n) + 1.0) * sph_bessel_j(n, z) / z; } -template -T sph_bessel_y(long n, T x) { +template = 0> +T sph_bessel_y(N n, T x) { T s0, s1, sn; - int idx; + N idx; if (std::isnan(x)) { return x; @@ -154,7 +159,7 @@ T sph_bessel_y(long n, T x) { } if (x < 0) { - return std::pow(-1, n + 1) * sph_bessel_y(n, -x); + return n % 2 == 0 ? -sph_bessel_y(n, -x) : sph_bessel_y(n, -x); } if (x == std::numeric_limits::infinity() || x == -std::numeric_limits::infinity()) { @@ -188,10 +193,11 @@ T sph_bessel_y(long n, T x) { return sn; } -inline float sph_bessel_y(long n, float x) { return sph_bessel_y(n, static_cast(x)); } +template = 0> +inline float sph_bessel_y(N n, float x) { return sph_bessel_y(n, static_cast(x)); } -template -std::complex sph_bessel_y(long n, std::complex z) { +template = 0> +std::complex sph_bessel_y(N n, std::complex z) { if (std::isnan(std::real(z)) || std::isnan(std::imag(z))) { return z; } @@ -218,17 +224,22 @@ std::complex sph_bessel_y(long n, std::complex z) { return std::sqrt(static_cast(M_PI_2) / z) * cyl_bessel_y(n + 1 / static_cast(2), z); } -template -T sph_bessel_y_jac(long n, T x) { +template = 0> +T sph_bessel_y_jac(N n, T x) { + if (n < 0) { + set_error("spherical_y_jac", SF_ERROR_DOMAIN, nullptr); + return std::numeric_limits::quiet_NaN(); + } + if (n == 0) { return -sph_bessel_y(1, x); } - return sph_bessel_y(n - 1, x) - static_cast(n + 1) * sph_bessel_y(n, x) / x; + return sph_bessel_y(n - 1, x) - (static_cast(n) + 1.0) * sph_bessel_y(n, x) / x; } -template -T sph_bessel_i(long n, T x) { +template = 0> +T sph_bessel_i(N n, T x) { if (std::isnan(x)) { return x; } @@ -249,7 +260,7 @@ T sph_bessel_i(long n, T x) { if (std::isinf(x)) { // https://dlmf.nist.gov/10.49.E8 if (x == -std::numeric_limits::infinity()) { - return std::pow(-1, n) * std::numeric_limits::infinity(); + return n % 2 == 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); } return std::numeric_limits::infinity(); @@ -258,8 +269,8 @@ T sph_bessel_i(long n, T x) { return sqrt(static_cast(M_PI_2) / x) * cyl_bessel_i(n + 1 / static_cast(2), x); } -template -std::complex sph_bessel_i(long n, std::complex z) { +template = 0> +std::complex sph_bessel_i(N n, std::complex z) { if (std::isnan(std::real(z)) || std::isnan(std::imag(z))) { return z; } @@ -282,7 +293,7 @@ std::complex sph_bessel_i(long n, std::complex z) { // https://dlmf.nist.gov/10.52.E5 if (std::imag(z) == 0) { if (std::real(z) == -std::numeric_limits::infinity()) { - return std::pow(-1, n) * std::numeric_limits::infinity(); + return n % 2 == 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); } return std::numeric_limits::infinity(); @@ -294,8 +305,13 @@ std::complex sph_bessel_i(long n, std::complex z) { return std::sqrt(static_cast(M_PI_2) / z) * cyl_bessel_i(n + 1 / static_cast(2), z); } -template -T sph_bessel_i_jac(long n, T z) { +template = 0> +T sph_bessel_i_jac(N n, T z) { + if (n < 0) { + set_error("spherical_i_jac", SF_ERROR_DOMAIN, nullptr); + return std::numeric_limits::quiet_NaN(); + } + if (n == 0) { return sph_bessel_i(1, z); } @@ -308,11 +324,11 @@ T sph_bessel_i_jac(long n, T z) { } } - return sph_bessel_i(n - 1, z) - static_cast(n + 1) * sph_bessel_i(n, z) / z; + return sph_bessel_i(n - 1, z) - (static_cast(n) + 1.0) * sph_bessel_i(n, z) / z; } -template -T sph_bessel_k(long n, T z) { +template = 0> +T sph_bessel_k(N n, T z) { if (std::isnan(z)) { return z; } @@ -338,8 +354,8 @@ T sph_bessel_k(long n, T z) { return std::sqrt(M_PI_2 / z) * cyl_bessel_k(n + 1 / static_cast(2), z); } -template -std::complex sph_bessel_k(long n, std::complex z) { +template = 0> +std::complex sph_bessel_k(N n, std::complex z) { if (std::isnan(std::real(z)) || std::isnan(std::imag(z))) { return z; } @@ -369,13 +385,18 @@ std::complex sph_bessel_k(long n, std::complex z) { return std::sqrt(static_cast(M_PI_2) / z) * cyl_bessel_k(n + 1 / static_cast(2), z); } -template -T sph_bessel_k_jac(long n, T x) { +template = 0> +T sph_bessel_k_jac(N n, T x) { + if (n < 0) { + set_error("spherical_k_jac", SF_ERROR_DOMAIN, nullptr); + return std::numeric_limits::quiet_NaN(); + } + if (n == 0) { return -sph_bessel_k(1, x); } - return -sph_bessel_k(n - 1, x) - static_cast(n + 1) * sph_bessel_k(n, x) / x; + return -sph_bessel_k(n - 1, x) - (static_cast(n) + 1.0) * sph_bessel_k(n, x) / x; } } // namespace xsf