Skip to content

Replace uses of long with more portable types #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/xsf/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@
#define M_SQRT1_2 0.707106781186547524401
#endif

namespace xsf {

template<typename T>
constexpr bool is_supported_int_v =
std::is_same_v<T, int32_t> || std::is_same_v<T, int64_t>;

template<typename T>
using enable_if_supported_int_t = std::enable_if_t<is_supported_int_v<T>, int>;

}

#ifdef __CUDACC__
#define XSF_HOST_DEVICE __host__ __device__

Expand Down
8 changes: 5 additions & 3 deletions include/xsf/lambertw.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ namespace detail {
return z * cevalpoly(num, 2, z) / cevalpoly(denom, 2, z);
}

XSF_HOST_DEVICE inline std::complex<double> lambertw_asy(std::complex<double> z, long k) {
XSF_HOST_DEVICE inline std::complex<double> lambertw_asy(std::complex<double> z, double k) {
/* Compute the W function using the first two terms of the
* asymptotic series. See 4.20 in [1].
*/
Expand All @@ -64,7 +64,8 @@ namespace detail {

} // namespace detail

XSF_HOST_DEVICE inline std::complex<double> lambertw(std::complex<double> z, long k, double tol) {
template <typename K, enable_if_supported_int_t<K> = 0>
XSF_HOST_DEVICE inline std::complex<double> lambertw(std::complex<double> z, K k, double tol) {
double absz;
std::complex<double> w;
std::complex<double> ew, wew, wewz, wn;
Expand Down Expand Up @@ -142,7 +143,8 @@ XSF_HOST_DEVICE inline std::complex<double> lambertw(std::complex<double> z, lon
return {std::numeric_limits<double>::quiet_NaN(), std::numeric_limits<double>::quiet_NaN()};
}

XSF_HOST_DEVICE inline std::complex<float> lambertw(std::complex<float> z, long k, float tol) {
template <typename K, enable_if_supported_int_t<K> = 0>
XSF_HOST_DEVICE inline std::complex<float> lambertw(std::complex<float> z, K k, float tol) {
return static_cast<std::complex<float>>(lambertw(static_cast<std::complex<double>>(z), k, static_cast<double>(tol))
);
}
Expand Down
89 changes: 55 additions & 34 deletions include/xsf/sph_bessel.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ Translated to C++ by SciPy developers in 2024.

namespace xsf {

template <typename T>
T sph_bessel_j(long n, T x) {
template <typename T, typename N, enable_if_supported_int_t<N> = 0>
T sph_bessel_j(N n, T x) {
if (std::isnan(x)) {
return x;
}
Expand Down Expand Up @@ -71,7 +71,7 @@ T sph_bessel_j(long n, T x) {
}

T sn;
for (int i = 0; i < n - 1; ++i) {
for (N i = 0; i < n - 1; ++i) {
sn = (2 * i + 3) * s1 / x - s0;
s0 = s1;
s1 = sn;
Expand All @@ -84,8 +84,8 @@ T sph_bessel_j(long n, T x) {
return sn;
}

template <typename T>
std::complex<T> sph_bessel_j(long n, std::complex<T> z) {
template <typename T, typename N, enable_if_supported_int_t<N> = 0>
std::complex<T> sph_bessel_j(N n, std::complex<T> z) {
if (std::isnan(std::real(z)) || std::isnan(std::imag(z))) {
return z;
}
Expand Down Expand Up @@ -120,8 +120,13 @@ std::complex<T> sph_bessel_j(long n, std::complex<T> z) {
return out;
}

template <typename T>
T sph_bessel_j_jac(long n, T z) {
template <typename T, typename N, enable_if_supported_int_t<N> = 0>
T sph_bessel_j_jac(N n, T z) {
if (n < 0) {
set_error("spherical_j_jac", SF_ERROR_DOMAIN, nullptr);
return std::numeric_limits<T>::quiet_NaN();
}

if (n == 0) {
return -sph_bessel_j(1, z);
}
Expand All @@ -136,13 +141,13 @@ T sph_bessel_j_jac(long n, T z) {
}

// DLMF 10.51.2
return sph_bessel_j(n - 1, z) - static_cast<T>(n + 1) * sph_bessel_j(n, z) / z;
return sph_bessel_j(n - 1, z) - (static_cast<T>(n) + 1.0) * sph_bessel_j(n, z) / z;
}

template <typename T>
T sph_bessel_y(long n, T x) {
template <typename T, typename N, enable_if_supported_int_t<N> = 0>
T sph_bessel_y(N n, T x) {
T s0, s1, sn;
int idx;
N idx;

if (std::isnan(x)) {
return x;
Expand All @@ -154,7 +159,7 @@ T sph_bessel_y(long n, T x) {
}

if (x < 0) {
return std::pow(-1, n + 1) * sph_bessel_y(n, -x);
return n % 2 == 0 ? -sph_bessel_y(n, -x) : sph_bessel_y(n, -x);
}

if (x == std::numeric_limits<T>::infinity() || x == -std::numeric_limits<T>::infinity()) {
Expand Down Expand Up @@ -188,10 +193,11 @@ T sph_bessel_y(long n, T x) {
return sn;
}

inline float sph_bessel_y(long n, float x) { return sph_bessel_y(n, static_cast<double>(x)); }
template <typename N, enable_if_supported_int_t<N> = 0>
inline float sph_bessel_y(N n, float x) { return sph_bessel_y(n, static_cast<double>(x)); }

template <typename T>
std::complex<T> sph_bessel_y(long n, std::complex<T> z) {
template <typename T, typename N, enable_if_supported_int_t<N> = 0>
std::complex<T> sph_bessel_y(N n, std::complex<T> z) {
if (std::isnan(std::real(z)) || std::isnan(std::imag(z))) {
return z;
}
Expand All @@ -218,17 +224,22 @@ std::complex<T> sph_bessel_y(long n, std::complex<T> z) {
return std::sqrt(static_cast<T>(M_PI_2) / z) * cyl_bessel_y(n + 1 / static_cast<T>(2), z);
}

template <typename T>
T sph_bessel_y_jac(long n, T x) {
template <typename T, typename N, enable_if_supported_int_t<N> = 0>
T sph_bessel_y_jac(N n, T x) {
if (n < 0) {
set_error("spherical_y_jac", SF_ERROR_DOMAIN, nullptr);
return std::numeric_limits<T>::quiet_NaN();
}

if (n == 0) {
return -sph_bessel_y(1, x);
}

return sph_bessel_y(n - 1, x) - static_cast<T>(n + 1) * sph_bessel_y(n, x) / x;
return sph_bessel_y(n - 1, x) - (static_cast<T>(n) + 1.0) * sph_bessel_y(n, x) / x;
}

template <typename T>
T sph_bessel_i(long n, T x) {
template <typename T, typename N, enable_if_supported_int_t<N> = 0>
T sph_bessel_i(N n, T x) {
if (std::isnan(x)) {
return x;
}
Expand All @@ -249,7 +260,7 @@ T sph_bessel_i(long n, T x) {
if (std::isinf(x)) {
// https://dlmf.nist.gov/10.49.E8
if (x == -std::numeric_limits<T>::infinity()) {
return std::pow(-1, n) * std::numeric_limits<T>::infinity();
return n % 2 == 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
}

return std::numeric_limits<T>::infinity();
Expand All @@ -258,8 +269,8 @@ T sph_bessel_i(long n, T x) {
return sqrt(static_cast<T>(M_PI_2) / x) * cyl_bessel_i(n + 1 / static_cast<T>(2), x);
}

template <typename T>
std::complex<T> sph_bessel_i(long n, std::complex<T> z) {
template <typename T, typename N, enable_if_supported_int_t<N> = 0>
std::complex<T> sph_bessel_i(N n, std::complex<T> z) {
if (std::isnan(std::real(z)) || std::isnan(std::imag(z))) {
return z;
}
Expand All @@ -282,7 +293,7 @@ std::complex<T> sph_bessel_i(long n, std::complex<T> z) {
// https://dlmf.nist.gov/10.52.E5
if (std::imag(z) == 0) {
if (std::real(z) == -std::numeric_limits<T>::infinity()) {
return std::pow(-1, n) * std::numeric_limits<T>::infinity();
return n % 2 == 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
}

return std::numeric_limits<T>::infinity();
Expand All @@ -294,8 +305,13 @@ std::complex<T> sph_bessel_i(long n, std::complex<T> z) {
return std::sqrt(static_cast<T>(M_PI_2) / z) * cyl_bessel_i(n + 1 / static_cast<T>(2), z);
}

template <typename T>
T sph_bessel_i_jac(long n, T z) {
template <typename T, typename N, enable_if_supported_int_t<N> = 0>
T sph_bessel_i_jac(N n, T z) {
if (n < 0) {
set_error("spherical_i_jac", SF_ERROR_DOMAIN, nullptr);
return std::numeric_limits<T>::quiet_NaN();
}

if (n == 0) {
return sph_bessel_i(1, z);
}
Expand All @@ -308,11 +324,11 @@ T sph_bessel_i_jac(long n, T z) {
}
}

return sph_bessel_i(n - 1, z) - static_cast<T>(n + 1) * sph_bessel_i(n, z) / z;
return sph_bessel_i(n - 1, z) - (static_cast<T>(n) + 1.0) * sph_bessel_i(n, z) / z;
}

template <typename T>
T sph_bessel_k(long n, T z) {
template <typename T, typename N, enable_if_supported_int_t<N> = 0>
T sph_bessel_k(N n, T z) {
if (std::isnan(z)) {
return z;
}
Expand All @@ -338,8 +354,8 @@ T sph_bessel_k(long n, T z) {
return std::sqrt(M_PI_2 / z) * cyl_bessel_k(n + 1 / static_cast<T>(2), z);
}

template <typename T>
std::complex<T> sph_bessel_k(long n, std::complex<T> z) {
template <typename T, typename N, enable_if_supported_int_t<N> = 0>
std::complex<T> sph_bessel_k(N n, std::complex<T> z) {
if (std::isnan(std::real(z)) || std::isnan(std::imag(z))) {
return z;
}
Expand Down Expand Up @@ -369,13 +385,18 @@ std::complex<T> sph_bessel_k(long n, std::complex<T> z) {
return std::sqrt(static_cast<T>(M_PI_2) / z) * cyl_bessel_k(n + 1 / static_cast<T>(2), z);
}

template <typename T>
T sph_bessel_k_jac(long n, T x) {
template <typename T, typename N, enable_if_supported_int_t<N> = 0>
T sph_bessel_k_jac(N n, T x) {
if (n < 0) {
set_error("spherical_k_jac", SF_ERROR_DOMAIN, nullptr);
return std::numeric_limits<T>::quiet_NaN();
}

if (n == 0) {
return -sph_bessel_k(1, x);
}

return -sph_bessel_k(n - 1, x) - static_cast<T>(n + 1) * sph_bessel_k(n, x) / x;
return -sph_bessel_k(n - 1, x) - (static_cast<T>(n) + 1.0) * sph_bessel_k(n, x) / x;
}

} // namespace xsf
Loading