Skip to content

Commit add1b63

Browse files
authored
fix: fix special._bessel not working (#60)
1 parent 25cdb8a commit add1b63

File tree

2 files changed

+44
-14
lines changed

2 files changed

+44
-14
lines changed

src/ultrasphere/special/_bessel.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,21 @@ def szv(
6161
zv = yv
6262
elif type == "h1":
6363
if derivative:
64-
zv = hankel1
65-
else:
6664
raise AssertionError()
65+
else:
66+
zv = hankel1
6767
elif type == "h2":
6868
if derivative:
69-
zv = hankel2
70-
else:
7169
raise AssertionError()
70+
else:
71+
zv = hankel2
72+
73+
dtype = xp.result_type(v, d, z)
74+
if type in ("h1", "h2"):
75+
dtype = xp.result_type(dtype, xp.complex64)
7276
return (
73-
xp.sqrt(xp.pi / 2)
74-
* xp.asarray(zv(v + d_half_minus_1, z), device=z.device, dtype=z.dtype)
77+
xp.sqrt(xp.asarray(xp.pi / 2, device=z.device, dtype=dtype))
78+
* xp.asarray(zv(v + d_half_minus_1, z), device=z.device, dtype=dtype)
7579
/ (z**d_half_minus_1)
7680
)
7781

tests/test_special.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,61 @@
11
from typing import Literal
22

3+
import array_api_extra as xpx
34
import numpy as np
45
import pytest
5-
from numpy.testing import assert_allclose
6-
from scipy.special import jv, jvp, spherical_jn, spherical_yn, yv, yvp
6+
from array_api._2024_12 import ArrayNamespaceFull
7+
from scipy.special import hankel1, hankel2, jv, jvp, spherical_jn, spherical_yn, yv, yvp
78

8-
from ultrasphere.special._bessel import sjv, syv
9+
from ultrasphere.special._bessel import shn1, shn2, sjv, syv
910

1011

1112
@pytest.mark.parametrize("derivative", [True, False])
1213
@pytest.mark.parametrize("d", [2, 3])
13-
@pytest.mark.parametrize("type", ["j", "y"])
14-
def test_sjyn(d: Literal[2, 3], type: Literal["j", "y"], derivative: bool) -> None:
14+
@pytest.mark.parametrize("type", ["j", "y", "h1", "h2"])
15+
def test_sjyn(
16+
d: Literal[2, 3],
17+
type: Literal["j", "y", "h1", "h2"],
18+
derivative: bool,
19+
xp: ArrayNamespaceFull,
20+
) -> None:
1521
n = np.random.randint(5, size=(10, 10))
1622
z = np.random.random((10, 10))
1723
if type == "j":
18-
actual = sjv(n, np.array(d), z, derivative=derivative)
1924
if d == 2:
2025
expected = np.sqrt(np.pi / 2) * (jvp(n, z) if derivative else jv(n, z))
2126
elif d == 3:
2227
expected = spherical_jn(n, z, derivative=derivative)
2328
else:
2429
raise ValueError("d must be 2 or 3")
30+
expected = xp.asarray(expected)
31+
actual = sjv(xp.asarray(n), xp.asarray(d), xp.asarray(z), derivative=derivative)
2532
elif type == "y":
26-
actual = syv(n, np.array(d), z, derivative=derivative)
2733
if d == 2:
2834
expected = np.sqrt(np.pi / 2) * (yvp(n, z) if derivative else yv(n, z))
2935
elif d == 3:
3036
expected = spherical_yn(n, z, derivative=derivative)
3137
else:
3238
raise ValueError("d must be 2 or 3")
39+
expected = xp.asarray(expected)
40+
actual = syv(xp.asarray(n), xp.asarray(d), xp.asarray(z), derivative=derivative)
41+
elif type == "h1":
42+
if derivative:
43+
pytest.skip("derivative of hankel1 not implemented")
44+
if d == 2:
45+
expected = np.sqrt(np.pi / 2) * hankel1(n, z)
46+
elif d == 3:
47+
expected = spherical_jn(n, z) + 1j * spherical_yn(n, z)
48+
expected = xp.asarray(expected)
49+
actual = shn1(xp.asarray(n), xp.asarray(d), xp.asarray(z))
50+
elif type == "h2":
51+
if derivative:
52+
pytest.skip("derivative of hankel2 not implemented")
53+
if d == 2:
54+
expected = np.sqrt(np.pi / 2) * hankel2(n, z)
55+
elif d == 3:
56+
expected = spherical_jn(n, z) - 1j * spherical_yn(n, z)
57+
expected = xp.asarray(expected)
58+
actual = shn2(xp.asarray(n), xp.asarray(d), xp.asarray(z))
3359
else:
3460
raise ValueError("type must be 'j' or 'y'")
35-
assert_allclose(actual, expected)
61+
assert xp.all(xpx.isclose(actual, expected, rtol=1e-6, atol=1e-6))

0 commit comments

Comments
 (0)