|
1 | 1 | from typing import Literal |
2 | 2 |
|
| 3 | +import array_api_extra as xpx |
3 | 4 | import numpy as np |
4 | 5 | import pytest |
5 | | -from numpy.testing import assert_allclose |
6 | | -from scipy.special import jv, jvp, spherical_jn, spherical_yn, yv, yvp |
| 6 | +from array_api._2024_12 import ArrayNamespaceFull |
| 7 | +from scipy.special import hankel1, hankel2, jv, jvp, spherical_jn, spherical_yn, yv, yvp |
7 | 8 |
|
8 | | -from ultrasphere.special._bessel import sjv, syv |
| 9 | +from ultrasphere.special._bessel import shn1, shn2, sjv, syv |
9 | 10 |
|
10 | 11 |
|
11 | 12 | @pytest.mark.parametrize("derivative", [True, False]) |
12 | 13 | @pytest.mark.parametrize("d", [2, 3]) |
13 | | -@pytest.mark.parametrize("type", ["j", "y"]) |
14 | | -def test_sjyn(d: Literal[2, 3], type: Literal["j", "y"], derivative: bool) -> None: |
| 14 | +@pytest.mark.parametrize("type", ["j", "y", "h1", "h2"]) |
| 15 | +def test_sjyn( |
| 16 | + d: Literal[2, 3], |
| 17 | + type: Literal["j", "y", "h1", "h2"], |
| 18 | + derivative: bool, |
| 19 | + xp: ArrayNamespaceFull, |
| 20 | +) -> None: |
15 | 21 | n = np.random.randint(5, size=(10, 10)) |
16 | 22 | z = np.random.random((10, 10)) |
17 | 23 | if type == "j": |
18 | | - actual = sjv(n, np.array(d), z, derivative=derivative) |
19 | 24 | if d == 2: |
20 | 25 | expected = np.sqrt(np.pi / 2) * (jvp(n, z) if derivative else jv(n, z)) |
21 | 26 | elif d == 3: |
22 | 27 | expected = spherical_jn(n, z, derivative=derivative) |
23 | 28 | else: |
24 | 29 | raise ValueError("d must be 2 or 3") |
| 30 | + expected = xp.asarray(expected) |
| 31 | + actual = sjv(xp.asarray(n), xp.asarray(d), xp.asarray(z), derivative=derivative) |
25 | 32 | elif type == "y": |
26 | | - actual = syv(n, np.array(d), z, derivative=derivative) |
27 | 33 | if d == 2: |
28 | 34 | expected = np.sqrt(np.pi / 2) * (yvp(n, z) if derivative else yv(n, z)) |
29 | 35 | elif d == 3: |
30 | 36 | expected = spherical_yn(n, z, derivative=derivative) |
31 | 37 | else: |
32 | 38 | raise ValueError("d must be 2 or 3") |
| 39 | + expected = xp.asarray(expected) |
| 40 | + actual = syv(xp.asarray(n), xp.asarray(d), xp.asarray(z), derivative=derivative) |
| 41 | + elif type == "h1": |
| 42 | + if derivative: |
| 43 | + pytest.skip("derivative of hankel1 not implemented") |
| 44 | + if d == 2: |
| 45 | + expected = np.sqrt(np.pi / 2) * hankel1(n, z) |
| 46 | + elif d == 3: |
| 47 | + expected = spherical_jn(n, z) + 1j * spherical_yn(n, z) |
| 48 | + expected = xp.asarray(expected) |
| 49 | + actual = shn1(xp.asarray(n), xp.asarray(d), xp.asarray(z)) |
| 50 | + elif type == "h2": |
| 51 | + if derivative: |
| 52 | + pytest.skip("derivative of hankel2 not implemented") |
| 53 | + if d == 2: |
| 54 | + expected = np.sqrt(np.pi / 2) * hankel2(n, z) |
| 55 | + elif d == 3: |
| 56 | + expected = spherical_jn(n, z) - 1j * spherical_yn(n, z) |
| 57 | + expected = xp.asarray(expected) |
| 58 | + actual = shn2(xp.asarray(n), xp.asarray(d), xp.asarray(z)) |
33 | 59 | else: |
34 | 60 | raise ValueError("type must be 'j' or 'y'") |
35 | | - assert_allclose(actual, expected) |
| 61 | + assert xp.all(xpx.isclose(actual, expected, rtol=1e-6, atol=1e-6)) |
0 commit comments