Skip to content

Commit 69acde4

Browse files
authored
Merge pull request #5137 from chrishavlin/fix_integrate_kernel_typing
update testing.integrate_kernel type hints
1 parent cc76fcf commit 69acde4

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

yt/testing.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
from functools import wraps
1111
from importlib.util import find_spec
1212
from shutil import which
13-
from typing import TYPE_CHECKING
13+
from typing import TYPE_CHECKING, TypeVar
1414
from unittest import SkipTest
1515

1616
import matplotlib
1717
import numpy as np
18+
import numpy.typing as npt
1819
from more_itertools import always_iterable
1920
from numpy.random import RandomState
2021
from unyt.exceptions import UnitOperationError
@@ -91,8 +92,8 @@ def assert_rel_equal(a1, a2, decimals, err_msg="", verbose=True):
9192

9293
# tested: volume integral is 1.
9394
def cubicspline_python(
94-
x: float | np.ndarray,
95-
) -> np.ndarray:
95+
x: float | npt.NDArray[np.floating],
96+
) -> npt.NDArray[np.floating]:
9697
"""
9798
cubic spline SPH kernel function for testing against more
9899
effiecient cython methods
@@ -118,8 +119,12 @@ def cubicspline_python(
118119

119120

120121
def integrate_kernel(
121-
kernelfunc: Callable[[float], float], b: float, hsml: float
122-
) -> float:
122+
kernelfunc: Callable[
123+
[float | npt.NDArray[np.floating]], float | npt.NDArray[np.floating]
124+
],
125+
b: float | npt.NDArray[np.floating],
126+
hsml: float | npt.NDArray[np.floating],
127+
) -> npt.NDArray[np.floating]:
123128
"""
124129
integrates a kernel function over a line passing entirely
125130
through it
@@ -147,18 +152,21 @@ def integrate_kernel(
147152
dx = np.diff(xe, axis=0)
148153
spv = kernelfunc(np.sqrt(xc**2 + x**2))
149154
integral = np.sum(spv * dx, axis=0)
150-
return pre * integral
155+
return np.atleast_1d(pre * integral)
151156

152157

153158
_zeroperiods = np.array([0.0, 0.0, 0.0])
154159

155160

161+
_FloatingT = TypeVar("_FloatingT", bound=np.floating)
162+
163+
156164
def distancematrix(
157-
pos3_i0: np.ndarray,
158-
pos3_i1: np.ndarray,
165+
pos3_i0: npt.NDArray[_FloatingT],
166+
pos3_i1: npt.NDArray[_FloatingT],
159167
periodic: tuple[bool, bool, bool] = (True,) * 3,
160-
periods: np.ndarray = _zeroperiods,
161-
) -> np.ndarray:
168+
periods: npt.NDArray[_FloatingT] = _zeroperiods,
169+
) -> npt.NDArray[_FloatingT]:
162170
"""
163171
Calculates the distances between two arrays of points.
164172

0 commit comments

Comments
 (0)